本文整理汇总了Python中tensorflow.initialize_local_variables函数的典型用法代码示例。如果您正苦于以下问题:Python initialize_local_variables函数的具体用法?Python initialize_local_variables怎么用?Python initialize_local_variables使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了initialize_local_variables函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testOneThreadDynamicPad
def testOneThreadDynamicPad(self):
with self.test_session() as sess:
batch_size = 10
num_batches = 3
zero64 = tf.constant(0, dtype=tf.int64)
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
string = tf.tile(["string"], tf.to_int32(tf.pack([counter])))
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
batched = tf.train.batch(
[counter, string], batch_size=batch_size, dynamic_pad=True)
threads = tf.train.start_queue_runners()
for i in range(num_batches):
results = sess.run(batched)
expected_results = np.arange(i * batch_size, (i + 1) * batch_size)
max_len = expected_results[-1]
self.assertAllEqual(results[0], expected_results)
expected_strings = [
[b"string"] * rep + [b""] * (max_len - rep)
for rep in expected_results]
self.assertAllEqual(results[1], expected_strings)
# Reached the limit.
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(batched)
for thread in threads:
thread.join()
开发者ID:0ruben,项目名称:tensorflow,代码行数:29,代码来源:input_test.py
示例2: testNoLimit
def testNoLimit(self):
with self.test_session():
seven = tf.constant(7)
seven_forever = tf.train.limit_epochs(seven)
tf.initialize_local_variables().run()
for _ in range(100):
self.assertEqual(7, seven_forever.eval())
开发者ID:0ruben,项目名称:tensorflow,代码行数:7,代码来源:input_test.py
示例3: __init__
def __init__(self, model_def_file, class_labels_file):
logging.info('Loading net and associated files...')
with tf.Graph().as_default(), tf.device('cpu:0'):
self.sess = tf.Session()
self.image_buffer = tf.placeholder(tf.string)
image = tf.image.decode_jpeg(self.image_buffer, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = self.eval_image(image, 299, 299)
image = tf.sub(image, 0.5)
image = tf.mul(image, 2.0)
images = tf.expand_dims(image, 0)
# Run inference.
logits, predictions = inception_model.inference(
images, NUM_CLASSES + 1)
# Transform output to topK result.
self.values, self.indices = tf.nn.top_k(
predictions, NUM_TOP_CLASSES)
variable_averages = tf.train.ExponentialMovingAverage(
inception_model.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
tf.initialize_all_variables().run(session=self.sess)
tf.initialize_local_variables().run(session=self.sess)
saver = tf.train.Saver(variables_to_restore)
saver.restore(self.sess, model_def_file)
# Required to get the filename matching to run.
self.label_names = ['none']
with open(class_labels_file) as f:
for line in f.read().decode("utf-8").splitlines():
self.label_names.append(line)
开发者ID:hetaoaoao,项目名称:tensorflow_web_deploy,代码行数:34,代码来源:PyClassification.py
示例4: _testRemoveSqueezableDimensions
def _testRemoveSqueezableDimensions(
self, predictions_have_static_shape, predictions_have_extra_dim, labels_have_static_shape, labels_have_extra_dim
):
assert not (predictions_have_extra_dim and labels_have_extra_dim)
predictions_value = (0, 1, 1, 0, 0, 1, 0)
labels_value = (0, 0, 1, 1, 0, 0, 0)
input_predictions_value = [[p] for p in predictions_value] if predictions_have_extra_dim else predictions_value
input_labels_value = [[l] for l in labels_value] if labels_have_extra_dim else labels_value
with tf.Graph().as_default() as g:
feed_dict = {}
if predictions_have_static_shape:
predictions = tf.constant(input_predictions_value, dtype=tf.int32)
else:
predictions = tf.placeholder(dtype=tf.int32, name="predictions")
feed_dict[predictions] = input_predictions_value
if labels_have_static_shape:
labels = tf.constant(input_labels_value, dtype=tf.int32)
else:
labels = tf.placeholder(dtype=tf.int32, name="labels")
feed_dict[labels] = input_labels_value
squeezed_predictions, squeezed_labels = tf.contrib.framework.remove_squeezable_dimensions(
predictions, labels
)
with self.test_session(g):
tf.initialize_local_variables().run()
self.assertAllClose(predictions_value, squeezed_predictions.eval(feed_dict=feed_dict))
self.assertAllClose(labels_value, squeezed_labels.eval(feed_dict=feed_dict))
开发者ID:pronobis,项目名称:tensorflow,代码行数:30,代码来源:tensor_util_test.py
示例5: read_data_int64
def read_data_int64(input_fname):
import pdb
with tictoc():
input_fname_queue = tf.train.string_input_producer([input_fname], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(input_fname_queue)
features = {'bit_features' : tf.VarLenFeature(tf.int64)}
parsed_example = tf.parse_single_example(serialized_example, features)
bit_features = parsed_example['bit_features']
bit_features = tf.sparse_tensor_to_dense(bit_features)
bit_features = tf.reshape(bit_features, [-1, 62])
with tf.Session() as sess:
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
i = 0
while not coord.should_stop():
x = bit_features.eval()
if i % 10000 == 0: print("substance {}".format(i))
i += 1
except tf.errors.OutOfRangeError:
pass
finally:
coord.request_stop()
coord.join(threads)
开发者ID:momeara,项目名称:DeepSEA,代码行数:30,代码来源:benchmark_data_reading.py
示例6: cnn_train
def cnn_train(config, data_len, embed, pf_r1, pf_r2):
config.data_len = data_len
tf.reset_default_graph()
with tf.Session() as session:
# build model
with tf.variable_scope("cnn_ch", reuse=None):
m_train = ch_model(config)
with tf.variable_scope("cnn_ch", reuse=True):
m_valid = ch_model(config)
doc_datas, pf_r1s, pf_r2s, labels = read_batch(config.csv_file, config, True)
doc_datas_v, pf_r1s_V, pf_r2s_v, labels_v = read_batch(config.csv_file, config, False)
for item in tf.all_variables():
print "var: ", item
for item in tf.local_variables():
print "local:", item
loss, _ = m_train.inference(doc_datas, pf_r1s, pf_r2s, labels)
loss_v, acc_v = m_valid.inference(doc_datas_v, pf_r1s_V, pf_r2s_v, labels_v)
train_op = m_train.train(loss)
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
m_train.assign_word_embed(session, embed)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=session)
epoch = 0
step = 0
min_cost = sys.maxint
try:
while not coord.should_stop():
_, f_l = session.run([train_op, loss])
step += 1
if step == config.data_len // config.batch_size:
cost = 0.0
acc = 0.0
for i in range(step):
v_l, acc_l = session.run([loss_v, acc_v])
cost += v_l
acc += acc_l
cost /= step
acc /= step
if cost < min_cost:
min_cost = cost
print "save model as cost:", cost
m_train.saver.save(session, config.model_path)
print "epoch: ", epoch, "loss: ", cost, "acc: ", acc, "step:", step
step = 0
epoch += 1
except tf.errors.OutOfRangeError:
print("Done training")
finally:
coord.request_stop()
coord.join(threads)
开发者ID:sww9370,项目名称:Relation_Extraction,代码行数:58,代码来源:extraction_tensorflow.py
示例7: main
def main(_):
if FLAGS.train_data:
num_labels, num_features, train_data, train_labels = extract_data(FLAGS.train_data, feature_limit=FEATURE_LIMIT)
else:
num_labels, num_features = 2, FEATURE_LIMIT
train_data, train_labels = [], []
print "labels", num_labels, "features", num_features
if FLAGS.test_data:
_, _, test_data, test_labels = extract_data(FLAGS.test_data, feature_limit=FEATURE_LIMIT)
else:
test_data, test_labels = [], []
train_size = len(train_data)
model = LinearModel(num_features, num_labels, FLAGS.learning_rate)
# Create local session to train and test
with tf.Session(graph=model.graph) as s:
ckpt = tf.train.get_checkpoint_state(FLAGS.models)
if ckpt and ckpt.model_checkpoint_path:
model.saver.restore(s, ckpt.model_checkpoint_path)
print "Model loaded from", ckpt.model_checkpoint_path
else:
model.init.run()
print "Initialized"
if test_data:
print 'testing'
correct = 0
total = 0
tf.initialize_local_variables()
for i in range(len(test_data) // BATCH_SIZE):
offset = i * BATCH_SIZE
batch_data = transform(test_data[offset:(offset + BATCH_SIZE)], num_features)
batch_labels = test_labels[offset:(offset + BATCH_SIZE)]
c = s.run(
[model.correct_sum],
feed_dict={model.x: batch_data, model.y_: batch_labels})
correct += c[0]
total += BATCH_SIZE
print correct, total, "accuracy:", float(correct) / total
return
# Iterate and train.
average_loss = 0
for step in xrange(FLAGS.train_steps * len(train_data) // BATCH_SIZE):
offset = (step * BATCH_SIZE) % train_size
batch_data = transform(train_data[offset: (offset + BATCH_SIZE)], num_features)
batch_labels = train_labels[offset: (offset + BATCH_SIZE)]
_, loss_val = s.run([model.optimizer, model.cross_entropy],
feed_dict={model.x: batch_data, model.y_: batch_labels})
average_loss += loss_val
if step > 0 and step % K == 0:
print "Average loss at step: ", model.global_step.eval(), " loss: ", average_loss / K
average_loss = 0
checkpoint_path = os.path.join(FLAGS.models, "pe.ckpt")
model.saver.save(s, checkpoint_path, global_step=model.global_step)
开发者ID:yaoyaowd,项目名称:tensorflow_demo,代码行数:57,代码来源:linear_model.py
示例8: test_empty_labels_and_scores_gives_nan_auc
def test_empty_labels_and_scores_gives_nan_auc(self):
with self.test_session():
labels = tf.constant([], shape=[0], dtype=tf.bool)
scores = tf.constant([], shape=[0], dtype=tf.float32)
score_range = [0, 1.]
auc, update_op = tf.contrib.metrics.auc_using_histogram(labels, scores,
score_range)
tf.initialize_local_variables().run()
update_op.run()
self.assertTrue(np.isnan(auc.eval()))
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:10,代码来源:histogram_ops_test.py
示例9: testLimit
def testLimit(self):
with self.test_session():
love_me = tf.constant("Love Me")
love_me_two_times = tf.train.limit_epochs(love_me, num_epochs=2)
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
self.assertEqual(b"Love Me", love_me_two_times.eval())
self.assertEqual(b"Love Me", love_me_two_times.eval())
with self.assertRaises(tf.errors.OutOfRangeError):
love_me_two_times.eval()
开发者ID:0ruben,项目名称:tensorflow,代码行数:10,代码来源:input_test.py
示例10: _check_auc
def _check_auc(self,
nbins=100,
desired_auc=0.75,
score_range=None,
num_records=50,
frac_true=0.5,
atol=0.05,
num_updates=10):
"""Check auc accuracy against synthetic data.
Args:
nbins: nbins arg from contrib.metrics.auc_using_histogram.
desired_auc: Number in [0, 1]. The desired auc for synthetic data.
score_range: 2-tuple, (low, high), giving the range of the resultant
scores. Defaults to [0, 1.].
num_records: Positive integer. The number of records to return.
frac_true: Number in (0, 1). Expected fraction of resultant labels that
will be True. This is just in expectation...more or less may actually
be True.
atol: Absolute tolerance for final AUC estimate.
num_updates: Update internal histograms this many times, each with a new
batch of synthetic data, before computing final AUC.
Raises:
AssertionError: If resultant AUC is not within atol of theoretical AUC
from synthetic data.
"""
score_range = [0, 1.] or score_range
with self.test_session():
labels = tf.placeholder(tf.bool, shape=[num_records])
scores = tf.placeholder(tf.float32, shape=[num_records])
auc, update_op = tf.contrib.metrics.auc_using_histogram(labels,
scores,
score_range,
nbins=nbins)
tf.initialize_local_variables().run()
# Updates, then extract auc.
for _ in range(num_updates):
labels_a, scores_a = synthetic_data(desired_auc, score_range,
num_records, self.rng, frac_true)
update_op.run(feed_dict={labels: labels_a, scores: scores_a})
labels_a, scores_a = synthetic_data(desired_auc, score_range, num_records,
self.rng, frac_true)
# Fetch current auc, and verify that fetching again doesn't change it.
auc_eval = auc.eval()
self.assertAlmostEqual(auc_eval, auc.eval(), places=5)
msg = ('nbins: %s, desired_auc: %s, score_range: %s, '
'num_records: %s, frac_true: %s, num_updates: %s') % (nbins,
desired_auc,
score_range,
num_records,
frac_true,
num_updates)
np.testing.assert_allclose(desired_auc, auc_eval, atol=atol, err_msg=msg)
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:55,代码来源:histogram_ops_test.py
示例11: run_eval
def run_eval(dataset, hps, logdir, mode, num_eval_steps):
with tf.variable_scope("model"):
hps.num_sampled = 0
hps.keep_prob = 1.0
model = LM(hps, "eval", "/cpu:0")
if hps.average_params:
print("Averaging parameters for evaluation.")
saver = tf.train.Saver(model.avg_dict)
else:
saver = tf.train.Saver()
# Use only 4 threads for the evaluation
config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=20,
inter_op_parallelism_threads=1)
sess = tf.Session(config=config)
sw = tf.train.SummaryWriter(logdir + '/' + mode, sess.graph)
ckpt_loader = CheckpointLoader(saver, model.global_step, logdir + "/train")
with sess.as_default():
while ckpt_loader.load_checkpoint():
global_step = ckpt_loader.last_global_step
data_iterator = dataset.iterate_once(hps.batch_size * hps.num_gpus,
hps.num_steps)
tf.initialize_local_variables().run()
loss_nom = 0.0
loss_den = 0.0
for i, (x, y, w) in enumerate(data_iterator):
if i >= num_eval_steps:
break
loss = sess.run(model.loss, {model.x: x, model.y: y, model.w: w})
loss_nom += loss
loss_den += w.mean()
loss = loss_nom / loss_den
sys.stdout.write("%d: %.3f (%.3f) ... " % (i, loss, np.exp(loss)))
sys.stdout.flush()
sys.stdout.write("\n")
log_perplexity = loss_nom / loss_den
print("Results at %d: log_preplexity = %.3f perplexity = %.3f" % (
global_step, log_perplexity, np.exp(log_perplexity)))
summary = tf.Summary()
summary.value.add(tag='eval/log_perplexity', simple_value=log_perplexity)
summary.value.add(tag='eval/perplexity', simple_value=np.exp(log_perplexity))
sw.add_summary(summary, global_step)
sw.flush()
开发者ID:IgorWang,项目名称:RNNLM,代码行数:52,代码来源:run_utils.py
示例12: test_input_fname_producer
def test_input_fname_producer(input_fname):
import pdb
pdb.set_trace()
with tf.Session() as sess:
queue = tf.train.string_input_producer(
[input_fname], num_epochs=None, shuffle=False)
dequeue = queue.dequeue()
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
threads = tf.train.start_queue_runners()
output = dequeue.eval()
for thread in threads:
thread.join()
开发者ID:momeara,项目名称:DeepSEA,代码行数:14,代码来源:benchmark_data_reading.py
示例13: _testTwoThreadsHelper
def _testTwoThreadsHelper(self, use_dict):
with self.test_session() as sess:
batch_size = 10
num_batches = 3
zero64 = tf.constant(0, dtype=tf.int64)
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
sparse_counter = tf.SparseTensor(
indices=tf.reshape(zero64, [1, 1]),
values=tf.pack([tf.cast(counter, tf.float32)]),
shape=[1])
if use_dict:
batched = tf.train.shuffle_batch(
{"c": counter, "s": sparse_counter, "S": "string"},
batch_size=batch_size, capacity=32,
min_after_dequeue=16, seed=141421)
batched_fetch = [batched["c"], batched["s"], batched["S"]]
else:
batched = tf.train.shuffle_batch(
[counter, sparse_counter, "string"],
batch_size=batch_size, capacity=32,
min_after_dequeue=16, seed=141421)
batched_fetch = batched
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
threads = tf.train.start_queue_runners()
all_counts = []
for i in range(num_batches):
results = sess.run(batched_fetch)
self.assertEqual(len(results[0]), batch_size)
all_counts.extend(results[0])
self.assertAllEqual(
results[1].indices,
np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
self.assertAllEqual(results[0], results[1].values)
self.assertAllEqual(results[1].shape, [batch_size, 1])
self.assertAllEqual(results[2], [b"string"] * batch_size)
# Results scrambled, but include all the expected numbers.
deltas = [all_counts[i + 1] - all_counts[i]
for i in range(len(all_counts) - 1)]
self.assertFalse(all(d == deltas[0] for d in deltas))
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
# Reached the limit.
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(batched_fetch)
for thread in threads:
thread.join()
开发者ID:0ruben,项目名称:tensorflow,代码行数:49,代码来源:input_test.py
示例14: test_batch_text_lines
def test_batch_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("A\nB\nC\nD\nE\n")
batch_size = 3
queue_capacity = 10
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
[filename], batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
read_batch_size=10, name=name)
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
开发者ID:todda,项目名称:tensorflow,代码行数:25,代码来源:graph_io_test.py
示例15: testMultipleUpdatesWithWeightedValues
def testMultipleUpdatesWithWeightedValues(self):
with self.test_session() as sess:
# Create the queue that populates the predictions.
preds_queue = tf.FIFOQueue(4, dtypes=tf.float32, shapes=(1, 1))
_enqueue_vector(sess, preds_queue, [0])
_enqueue_vector(sess, preds_queue, [1])
_enqueue_vector(sess, preds_queue, [2])
_enqueue_vector(sess, preds_queue, [1])
predictions = preds_queue.dequeue()
# Create the queue that populates the labels.
labels_queue = tf.FIFOQueue(4, dtypes=tf.float32, shapes=(1, 1))
_enqueue_vector(sess, labels_queue, [0])
_enqueue_vector(sess, labels_queue, [1])
_enqueue_vector(sess, labels_queue, [1])
_enqueue_vector(sess, labels_queue, [2])
labels = labels_queue.dequeue()
# Create the queue that populates the missing labels.
weights_queue = tf.FIFOQueue(4, dtypes=tf.int64, shapes=(1, 1))
_enqueue_vector(sess, weights_queue, [1])
_enqueue_vector(sess, weights_queue, [1])
_enqueue_vector(sess, weights_queue, [0])
_enqueue_vector(sess, weights_queue, [0])
weights = weights_queue.dequeue()
accuracy, update_op = tf.contrib.metrics.streaming_accuracy(
predictions, labels, weights)
sess.run(tf.initialize_local_variables())
for _ in range(4):
sess.run(update_op)
self.assertEqual(1.0, accuracy.eval())
开发者ID:01bui,项目名称:tensorflow,代码行数:33,代码来源:metric_ops_test.py
示例16: testFinalOpsIsEvaluated
def testFinalOpsIsEvaluated(self):
_, update_op = slim.metrics.streaming_accuracy(self._predictions, self._labels)
init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())
with self.test_session() as sess:
accuracy_value = slim.evaluation.evaluation(sess, init_op=init_op, final_op=update_op)
self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
开发者ID:yxiong,项目名称:tensorflow,代码行数:7,代码来源:evaluation_test.py
示例17: testMultipleMetricsOnMultipleBatchesOfSizeOne
def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
with self.test_session() as sess:
# Create the queue that populates the predictions.
preds_queue = tf.FIFOQueue(2, dtypes=tf.float32, shapes=(1, 3))
_enqueue_vector(sess, preds_queue, [10, 8, 6])
_enqueue_vector(sess, preds_queue, [-4, 3, -1])
predictions = preds_queue.dequeue()
# Create the queue that populates the labels.
labels_queue = tf.FIFOQueue(2, dtypes=tf.float32, shapes=(1, 3))
_enqueue_vector(sess, labels_queue, [1, 3, 2])
_enqueue_vector(sess, labels_queue, [2, 4, 6])
labels = labels_queue.dequeue()
mae, ma_update_op = tf.contrib.metrics.streaming_mean_absolute_error(
predictions, labels)
mse, ms_update_op = tf.contrib.metrics.streaming_mean_squared_error(
predictions, labels)
sess.run(tf.initialize_local_variables())
sess.run([ma_update_op, ms_update_op])
sess.run([ma_update_op, ms_update_op])
self.assertAlmostEqual(32 / 6.0, mae.eval(), 5)
self.assertAlmostEqual(208 / 6.0, mse.eval(), 5)
开发者ID:01bui,项目名称:tensorflow,代码行数:25,代码来源:metric_ops_test.py
示例18: testNullString
def testNullString(self):
# Runtime check for empty string list. This is slightly oblique:
# The queue runner should die with an assertion error on the null
# input tensor, causing the dequeue to fail with an OutOfRangeError.
with self.test_session():
coord = tf.train.Coordinator()
queue = tf.train.string_input_producer(tf.constant([], dtype=tf.string))
dequeue = queue.dequeue()
tf.initialize_all_variables().run()
tf.initialize_local_variables().run()
threads = tf.train.start_queue_runners(coord=coord)
with self.assertRaises(tf.errors.OutOfRangeError):
dequeue.eval()
coord.request_stop()
for thread in threads:
thread.join()
开发者ID:0ruben,项目名称:tensorflow,代码行数:16,代码来源:input_test.py
示例19: test_read_csv
def test_read_csv(self):
gfile.Glob = self._orig_glob
tempdir = tempfile.mkdtemp()
filename = os.path.join(tempdir, "file.csv")
gfile.Open(filename, "w").write("ABC\nDEF\nGHK\n")
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
filename, batch_size,
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity, name=name)
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
开发者ID:343829084,项目名称:tensorflow,代码行数:27,代码来源:graph_io_test.py
示例20: compute_accuracy
def compute_accuracy(x, l, mask):
"""Compute model accuracy."""
preds = ch_model.get_probs(x)
preds = tf.squeeze(preds)
preds = tf.argmax(preds, -1, output_type=l.dtype)
_, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)
if FLAGS.surrogate_attack:
preds = sur_ch_model.get_probs(x)
preds = tf.squeeze(preds)
preds = tf.argmax(preds, -1, output_type=l.dtype)
acc_update_op = tf.tuple((acc_update_op,
tf.metrics.accuracy(l, preds, weights=mask)[1]))
sess.run(tf.initialize_local_variables())
for i in range(FLAGS.eval_steps):
tf.logging.info(
"\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps))
acc = sess.run(acc_update_op)
if FLAGS.surrogate_attack:
tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
else:
tf.logging.info("\tFinal acc: %.4f" % acc)
return acc
开发者ID:qixiuai,项目名称:tensor2tensor,代码行数:25,代码来源:t2t_attack.py
注:本文中的tensorflow.initialize_local_variables函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论