本文整理汇总了Python中tensorflow.global_variables函数的典型用法代码示例。如果您正苦于以下问题:Python global_variables函数的具体用法?Python global_variables怎么用?Python global_variables使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了global_variables函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testBatchNorm
def testBatchNorm(self, module):
model = module(output_channels=self.output_channels,
kernel_shapes=self.kernel_shapes,
strides=self.strides,
paddings=self.paddings,
use_batch_norm=True)
self.assertTrue(model.use_batch_norm)
input_to_net = tf.placeholder(tf.float32, shape=(1, 100, 100, 3))
# Check Tensorflow flags work
is_training = tf.placeholder(tf.bool)
test_local_stats = tf.placeholder(tf.bool)
model(input_to_net,
is_training=is_training,
test_local_stats=test_local_stats)
# Check Python is_training flag works
model(input_to_net, is_training=False, test_local_stats=False)
model_variables = model.get_variables()
self.assertEqual(
len(model_variables),
len(self.output_channels) * 3 - 1)
# Check that the appropriate moving statistics variables have been created.
self.assertTrue(
any("moving_variance" in var.name
for var in tf.global_variables()))
self.assertTrue(
any("moving_mean" in var.name
for var in tf.global_variables()))
开发者ID:TianjiPang,项目名称:sonnet,代码行数:33,代码来源:convnet_test.py
示例2: freeze_session
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a prunned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
prunned so subgraphs that are not neccesary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
开发者ID:DXZ,项目名称:git_test,代码行数:28,代码来源:keras_to_tfsevring_final.py
示例3: get_model_params
def get_model_params(variable_prefix, split_lstm_matrices=True):
if variable_prefix:
exclude = [ variable_prefix+"/Variable", variable_prefix+"/Variable_1" ]
tmp = { v.op.name: v.eval() for v in tf.global_variables() if (v.op.name.startswith(variable_prefix) and v.op.name not in exclude) }
else:
exclude = [ "Variable", "Variable_1" ]
tmp = { v.op.name: v.eval() for v in tf.global_variables() if v.op.name not in exclude }
# Rename keys
params = {name.replace("/", "-"): param for name, param in tmp.items()}
if split_lstm_matrices:
for name in params.keys():
if "LSTMCell" in name:
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
if "Matrix" in name:
i, j, f, o = array_ops.split(1, 4, params[name])
elif "Bias" in name:
i, j, f, o = array_ops.split(0, 4, params[name])
else:
logging.error("Unknown tensor type..")
exit(1)
name_i = name.replace("LSTMCell", "LSTMCell-i")
name_j = name.replace("LSTMCell", "LSTMCell-j")
name_f = name.replace("LSTMCell", "LSTMCell-f")
name_o = name.replace("LSTMCell", "LSTMCell-o")
params[name_i] = i.eval()
params[name_j] = j.eval()
params[name_f] = f.eval()
params[name_o] = o.eval()
del params[name]
elif "AttnV" in name:
params[name] = array_ops.reshape(params[name], [ params[name].shape[0], 1 ]).eval()
elif "AttnW" in name:
# remove dims of size 1
params[name] = tf.squeeze(params[name]).eval()
return params
开发者ID:ehasler,项目名称:tensorflow,代码行数:35,代码来源:model_utils.py
示例4: add_saver
def add_saver(self):
"""Adds a Saver for all variables in the graph."""
logging.info('Generating op to save variables:\n\t%s',
'\n\t'.join([x.name for x in tf.global_variables()]))
self.saver = tf.train.Saver(
var_list=[x for x in tf.global_variables()],
write_version=saver_pb2.SaverDef.V1)
开发者ID:ALISCIFP,项目名称:models,代码行数:7,代码来源:graph_builder.py
示例5: _add_saving_op
def _add_saving_op():
"""
Define a tensorflow operation to save or restore the network
:return: a tensorflow tf.train.Saver operation
"""
# Define an op to save or restore the network
# Only save needed tensors :
# - weight and biais from the input layer, the output layer
# - weight and biais from the LSTM (which are named kernel and bias respectively)
# - currents global_step and learning_rate
for var in tf.global_variables():
logging.debug("TF variable : %s - %s", var.name, var)
save_list = [var for var in tf.global_variables()
if (var.name.find('/input_w:0') != -1) or (var.name.find('/input_b:0') != -1) or
(var.name.find('/output_w:0') != -1) or (var.name.find('/output_b:0') != -1) or
(var.name.find('global_step:0') != -1) or (var.name.find('learning_rate:0') != -1) or
(var.name.find('/kernel:0') != -1) or (var.name.find('/bias:0') != -1)]
if len(save_list) == 0:
raise ValueError("Trying to define the saving operation before the RNN is built")
saver_op = tf.train.Saver(save_list)
return saver_op
开发者ID:inikdom,项目名称:rnn-speech,代码行数:26,代码来源:AcousticModel.py
示例6: load_vggish_slim_checkpoint
def load_vggish_slim_checkpoint(session, checkpoint_path):
"""Loads a pre-trained VGGish-compatible checkpoint.
This function can be used as an initialization function (referred to as
init_fn in TensorFlow documentation) which is called in a Session after
initializating all variables. When used as an init_fn, this will load
a pre-trained checkpoint that is compatible with the VGGish model
definition. Only variables defined by VGGish will be loaded.
Args:
session: an active TensorFlow session.
checkpoint_path: path to a file containing a checkpoint that is
compatible with the VGGish model definition.
"""
# Get the list of names of all VGGish variables that exist in
# the checkpoint (i.e., all inference-mode VGGish variables).
with tf.Graph().as_default():
define_vggish_slim(training=False)
vggish_var_names = [v.name for v in tf.global_variables()]
# Get the list of all currently existing variables that match
# the list of variable names we just computed.
vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names]
# Use a Saver to restore just the variables selected above.
saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained')
saver.restore(session, checkpoint_path)
开发者ID:ameerellaboudy,项目名称:models,代码行数:27,代码来源:vggish_slim.py
示例7: add_saver
def add_saver(self):
"""Adds a Saver for all variables in the graph."""
logging.info('Saving non-quantized variables:\n\t%s', '\n\t'.join(
[x.name for x in tf.global_variables() if 'quantized' not in x.name]))
self.saver = tf.train.Saver(
var_list=[
x for x in tf.global_variables() if 'quantized' not in x.name
],
write_version=saver_pb2.SaverDef.V1)
开发者ID:JiweiHe,项目名称:models,代码行数:9,代码来源:graph_builder.py
示例8: train
def train(hparams, event_dir=None, model_dir=None,
restore_agent=True, epoch=0):
"""Train."""
with tf.name_scope("rl_train"):
train_summary_op, _, initialization = define_train(hparams, event_dir)
if event_dir:
summary_writer = tf.summary.FileWriter(
event_dir, graph=tf.get_default_graph(), flush_secs=60)
if model_dir:
model_saver = tf.train.Saver(
tf.global_variables(".*network_parameters.*"))
else:
summary_writer = None
model_saver = None
# TODO(piotrmilos): This should be refactored, possibly with
# handlers for each type of env
if hparams.environment_spec.simulated_env:
env_model_loader = tf.train.Saver(
tf.global_variables("next_frame*"))
else:
env_model_loader = None
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
initialization(sess)
if env_model_loader:
trainer_lib.restore_checkpoint(
hparams.world_model_dir, env_model_loader, sess, must_restore=True)
start_step = 0
if model_saver and restore_agent:
start_step = trainer_lib.restore_checkpoint(
model_dir, model_saver, sess)
# Fail-friendly, don't train if already trained for this epoch
if start_step >= ((hparams.epochs_num * (epoch + 1))):
tf.logging.info("Skipping PPO training for epoch %d as train steps "
"(%d) already reached", epoch, start_step)
return
for epoch_index in range(hparams.epochs_num):
summary = sess.run(train_summary_op)
if summary_writer:
summary_writer.add_summary(summary, epoch_index)
if (hparams.eval_every_epochs and
epoch_index % hparams.eval_every_epochs == 0):
if summary_writer and summary:
summary_writer.add_summary(summary, epoch_index)
else:
tf.logging.info("Eval summary not saved")
if (model_saver and hparams.save_models_every_epochs and
(epoch_index % hparams.save_models_every_epochs == 0 or
(epoch_index + 1) == hparams.epochs_num)):
ckpt_path = os.path.join(
model_dir, "model.ckpt-{}".format(epoch_index + 1 + start_step))
model_saver.save(sess, ckpt_path)
开发者ID:kltony,项目名称:tensor2tensor,代码行数:56,代码来源:rl_trainer_lib.py
示例9: _create_initializers
def _create_initializers(self):
if self._var_count != len(tf.global_variables()):
save_dir = os.path.dirname(self._save_path) if self._save_path else None
if save_dir and not tf.gfile.IsDirectory(save_dir):
tf.gfile.MakeDirs(save_dir)
self._saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
self._init = tf.global_variables_initializer()
self._local_init = tf.local_variables_initializer()
self._check_inited = tf.assert_variables_initialized()
self._var_count = len(tf.global_variables())
if self._summary_writer:
self._summaries = tf.summary.merge_all()
self._summary_writer.add_graph(tf.get_default_graph())
开发者ID:google,项目名称:prettytensor,代码行数:13,代码来源:local_trainer.py
示例10: testBatchNormScale
def testBatchNormScale(self):
height, width = 299, 299
num_classes = 1000
inputs = tf.placeholder(tf.float32, (1, height, width, 3))
with tf.contrib.slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
inception.inception_resnet_v2(inputs, num_classes, is_training=False)
gamma_names = set(
v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
self.assertGreater(len(gamma_names), 0)
for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma', gamma_names)
开发者ID:zhangjiulong,项目名称:models,代码行数:13,代码来源:inception_resnet_v2_test.py
示例11: get_train_op
def get_train_op(self,
loss,
learning_rate,
optimizer=None,
clip_norm=None,
learnable_scopes=None,
optimizer_scope_name=None):
""" Get train operation for given loss
Args:
loss: loss, tf tensor or scalar
learning_rate: scalar or placeholder
clip_norm: clip gradients norm by clip_norm
learnable_scopes: which scopes are trainable (None for all)
optimizer: instance of tf.train.Optimizer, default Adam
Returns:
train_op
"""
if optimizer_scope_name is None:
opt_scope = tf.variable_scope('Optimizer')
else:
opt_scope = tf.variable_scope(optimizer_scope_name)
with opt_scope:
if learnable_scopes is None:
variables_to_train = tf.global_variables()
else:
variables_to_train = []
for scope_name in learnable_scopes:
for var in tf.global_variables():
if scope_name in var.name:
variables_to_train.append(var)
if optimizer is None:
optimizer = tf.train.AdamOptimizer
# For batch norm it is necessary to update running averages
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
def clip_if_not_none(grad):
if grad is not None:
return tf.clip_by_norm(grad, clip_norm)
opt = optimizer(learning_rate)
grads_and_vars = opt.compute_gradients(loss, var_list=variables_to_train)
if clip_norm is not None:
grads_and_vars = [(clip_if_not_none(grad), var)
for grad, var in grads_and_vars]
train_op = opt.apply_gradients(grads_and_vars)
return train_op
开发者ID:wangzhenya,项目名称:DeepPavlov,代码行数:51,代码来源:tf_model.py
示例12: assign_weight
def assign_weight(self):
'''
Encapsulate unit-class pruning and multi-class pruning print("PRUNE FOR CLASS", self.target_class_id)
'''
print("assign weights......")
maskDict = []
if (self.multiPruning == True and len(self.target_class_id) > 1):
maskDict = self.mask_class_multi_by_value()
else:
maskDict = self.mask_unit_by_value(self.target_class_id[0])
for tmpLayer in maskDict:
if (tmpLayer["name"][0] == "C"): # if the layer is convolutional layer
with self.graph.as_default():
layerNum = tmpLayer["name"].strip("Conv")
name = "Conv" + layerNum + "/composite_function/kernel:0"
for var in tf.global_variables():
if var.name == name:
tmpWeights = self.sess.run(var)
tmpMask = np.array(tmpLayer["shape"])
tmpWeights[:,:,:, tmpMask == 0] = 0
assign = tf.assign(var, tmpWeights)
self.sess.run(assign)
# print(self.sess.run(self.graph.get_tensor_by_name(name))==0)
if (tmpLayer["name"][0] == "F"): # if the layer is fully connected
with self.graph.as_default():
layerNum = tmpLayer["name"].strip("FC")
name_W = "FC" + layerNum + "/W:0"
name_bias = "FC" + layerNum + "/bias:0"
for var in tf.global_variables():
if var.name == name_W:
tmpWeights = self.sess.run(var)
tmpMask = np.array(tmpLayer["shape"])
tmpWeights[:, tmpMask == 0] = 0
assign = tf.assign(var, tmpWeights)
self.sess.run(assign)
# print(self.sess.run(self.graph.get_tensor_by_name(name_W))==0)
if var.name == name_bias:
tmpBias = self.sess.run(var)
tmpMask = np.array(tmpLayer["shape"])
tmpBias[tmpMask == 0] = 0
assign = tf.assign(var, tmpBias)
self.sess.run(assign)
# print(self.sess.run(self.graph.get_tensor_by_name(name_bias))==0)
print("assign finished!")
'''
开发者ID:sjtu-cs222,项目名称:Group_30,代码行数:51,代码来源:vggTrimmedModel.py
示例13: optimize
def optimize(loss, learning_rate, hparams, use_tpu=False):
"""Minimize loss."""
loss = weight_decay_and_noise(loss, hparams, learning_rate)
loss = tf.identity(loss, name="total_loss")
# Print trainable variables.
log_variable_sizes(verbose=hparams.summarize_vars)
# Print non-trainable variables.
non_trainable_variables = list(
set(tf.global_variables()) - set(tf.trainable_variables()))
log_variable_sizes(non_trainable_variables, tag="Non-trainable variables",
verbose=hparams.summarize_vars)
if hparams.summarize_vars:
summarize_variables()
# Summarize non-trainable variables as well
summarize_variables(non_trainable_variables, tag="Non-trainable variables")
diet_vars = [
v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
]
log_variable_sizes(
diet_vars, "Diet Variables", verbose=hparams.summarize_vars)
opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu)
if use_tpu:
opt = tf.contrib.tpu.CrossShardOptimizer(opt)
opt_summaries = []
if common_layers.should_generate_summaries():
tf.summary.scalar("learning_rate", learning_rate)
opt_summaries.append("loss")
if hparams.summarize_grads:
tf.logging.info("Summarizing gradients")
opt_summaries.extend(
["gradients", "gradient_norm", "global_gradient_norm"])
if hparams.clip_grad_norm:
tf.logging.info("Clipping gradients, norm: %0.5f", hparams.clip_grad_norm)
if hparams.grad_noise_scale:
tf.logging.info("Adding noise to gradients, noise scale: %0.5f",
hparams.grad_noise_scale)
train_op = tf.contrib.layers.optimize_loss(
name="training",
loss=loss,
global_step=tf.train.get_or_create_global_step(),
learning_rate=learning_rate,
clip_gradients=hparams.clip_grad_norm or None,
gradient_noise_scale=hparams.grad_noise_scale or None,
optimizer=opt,
summaries=opt_summaries,
colocate_gradients_with_ops=True)
return train_op
开发者ID:qixiuai,项目名称:tensor2tensor,代码行数:50,代码来源:optimize.py
示例14: getLoadVars
def getLoadVars(self):
v = tf.global_variables()
if(self.resLoad):
v = [var for var in v if (("class_weight" in var.name) or ("class_bias" in var.name) or ("conv1" in var.name)) and ("Adam" not in var.name)]
else:
v = [var for var in v if ("Adam" not in var.name)]
return v
开发者ID:slundqui,项目名称:DeepGAP,代码行数:7,代码来源:MLPVid2.py
示例15: testNotInLocalVariables
def testNotInLocalVariables(self):
with self.test_session():
with tf.variable_scope('A'):
a = tf.contrib.framework.model_variable('a', [5])
self.assertTrue(a in tf.global_variables())
self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
self.assertFalse(a in tf.local_variables())
开发者ID:jeffzheng1,项目名称:tensorflow,代码行数:7,代码来源:variables_test.py
示例16: get_global_variable_by_name
def get_global_variable_by_name(name):
"""Returns the global variable of given name.
name : the name of the global variable
"""
# return [v for v in tf.variables() if v.name == name][0]
return [v for v in tf.global_variables() if v.name == name][0]
开发者ID:zhangleiqss,项目名称:Distributed-TensorFlow-Guide,代码行数:7,代码来源:DOWNPOUR.py
示例17: train
def train(self, data=0, steps=-1, dropout=None, display_step=10, test_step=200, batch_size=10,
do_resume=False): # epochs=-1,
if data: self.data = data
steps = 9999999 if steps == -1 else steps
session = self.session
# with tf.device(_cpu):
# import tensorflow.contrib.layers as layers
# t = tf.verify_tensor_all_finite(t, msg)
tf.add_check_numerics_ops()
try:
self.summaries = tf.summary.merge_all()
except:
self.summaries = tf.merge_all_summaries()
try:
self.summary_writer = tf.summary.FileWriter(current_logdir(), session.graph) #
except:
self.summary_writer = tf.train.SummaryWriter(current_logdir(), session.graph) #
if not dropout: dropout = 1. # keep all
x = self.x
y = self.y
keep_prob = self.keep_prob
try:
saver = tf.train.Saver(tf.global_variables())
except:
saver = tf.train.Saver(tf.all_variables())
snapshot = self.name + str(get_last_tensorboard_run_nr())
checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if do_resume and checkpoint:
print("LOADING " + checkpoint + " !!!")
saver.restore(session, checkpoint)
try:
session.run([tf.global_variables_initializer()])
except:
session.run([tf.initialize_all_variables()])
step = 0 # show first
while step < steps:
batch_xs, batch_ys = self.next_batch(batch_size, session)
# print("step %d \r" % step)# end=' ')
# tf.train.shuffle_batch_join(example_list, batch_size, capacity=min_queue_size + batch_size * 16, min_queue_size)
# Fit training using batch data
feed_dict = {x: batch_xs, y: batch_ys, keep_prob: dropout, self.train_phase: True}
loss, _ = session.run([self.cost, self.optimizer], feed_dict=feed_dict)
if step % display_step == 0:
seconds = int(time.time()) - start
# Calculate batch accuracy, loss
feed = {x: batch_xs, y: batch_ys, keep_prob: 1., self.train_phase: False}
acc, summary = session.run([self.accuracy, self.summaries], feed_dict=feed)
# self.summary_writer.add_summary(summary, step) # only test summaries for smoother curve
print("\rStep {:d} Loss= {:.6f} Accuracy= {:.3f} Time= {:d}s".format(step, loss, acc, seconds), end=' ')
if str(loss) == "nan": return print("\nLoss gradiant explosion, exiting!!!") # restore!
if step % test_step == 0: self.test(step)
if step % save_step == 0 and step > 0:
print("SAVING snapshot %s" % snapshot)
saver.save(session, checkpoint_dir + snapshot + ".ckpt", self.global_step)
step += 1
print("\nOptimization Finished!")
self.test(step, number=10000) # final test
开发者ID:duydb2,项目名称:tensorflow-speech-recognition,代码行数:60,代码来源:net.py
示例18: train_speech_to_text_network
def train_speech_to_text_network():
logit = speech_to_text_network()
# CTC loss
indices = tf.where(tf.not_equal(tf.cast(Y, tf.float32), 0.))
target = tf.SparseTensor(indices=indices, values=tf.gather_nd(Y, indices) - 1, shape=tf.cast(tf.shape(Y), tf.int64))
loss = tf.nn.ctc_loss(logit, target, sequence_len, time_major=False)
# optimizer
lr = tf.Variable(0.001, dtype=tf.float32, trainable=False)
optimizer = MaxPropOptimizer(learning_rate=lr, beta2=0.99)
var_list = [t for t in tf.trainable_variables()]
gradient = optimizer.compute_gradients(loss, var_list=var_list)
optimizer_op = optimizer.apply_gradients(gradient)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
for epoch in range(16):
sess.run(tf.assign(lr, 0.001 * (0.97 ** epoch)))
global pointer
pointer = 0
for batch in range(n_batch):
batches_wavs, batches_labels = get_next_batches(batch_size)
train_loss, _ = sess.run([loss, optimizer_op], feed_dict={X: batches_wavs, Y: batches_labels})
print(epoch, batch, train_loss)
if epoch % 5 == 0:
saver.save(sess, 'speech.module', global_step=epoch)
开发者ID:luohuayong,项目名称:tensorflow,代码行数:30,代码来源:t15.py
示例19: initialize
def initialize(self, sess):
# Initial file lists are empty
np_paths = []
ss_paths = []
# Fresh train directly from ImageNet weights
print('Loading initial model weights from {:s}'.format(self.pretrained_model))
variables = tf.global_variables()
# Initialize all variables first
sess.run(tf.variables_initializer(variables, name='init'))
var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
# Get the variables to restore, ignoring the variables to fix
variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, self.pretrained_model)
print('Loaded.')
# Need to fix the variables before loading, so that the RGB weights are changed to BGR
# For VGG16 it also changes the convolutional weights fc6 and fc7 to
# fully connected weights
self.net.fix_variables(sess, self.pretrained_model)
print('Fixed.')
last_snapshot_iter = 0
rate = cfg.TRAIN.LEARNING_RATE
stepsizes = list(cfg.TRAIN.STEPSIZE)
return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths
开发者ID:tobechao,项目名称:FasterRcnnTF_ICPR2018,代码行数:26,代码来源:train_val.py
示例20: evaluate
def evaluate():
""" Build evaluation graph and run. """
with tf.Graph().as_default():
with tf.variable_scope('cnn'):
m = model.Model(FLAGS, is_train=False)
saver = tf.train.Saver(tf.global_variables())
# read test files
if FLAGS.train_data:
loader = text_input.DataLoader(os.path.join(FLAGS.data_dir, 'train.cPickle'), batch_size=FLAGS.batch_size)
else:
loader = text_input.DataLoader(os.path.join(FLAGS.data_dir, 'test.cPickle'), batch_size=FLAGS.batch_size)
print 'Start evaluation, %d batches needed, with %d examples per batch.' % (loader.num_batch, FLAGS.batch_size)
true_count = 0
avg_loss = 0
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise IOError("Loading checkpoint file failed!")
for _ in range(loader.num_batch):
x, y = loader.next_batch()
true_count_value, loss_value = sess.run([m.true_count_op, m.total_loss],
feed_dict={m.inputs:x, m.labels:y})
true_count += true_count_value
avg_loss += loss_value
accuracy = float(true_count) / (loader.num_batch * FLAGS.batch_size)
avg_loss = float(avg_loss) / loader.num_batch
print '%s: test_loss = %.6f, test_accuracy = %.3f' % (datetime.now(), avg_loss, accuracy)
开发者ID:yuhaozhang,项目名称:sentence-convnet,代码行数:34,代码来源:eval.py
注:本文中的tensorflow.global_variables函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论