本文整理汇总了Python中tensorflow.local_variables_initializer函数的典型用法代码示例。如果您正苦于以下问题:Python local_variables_initializer函数的具体用法?Python local_variables_initializer怎么用?Python local_variables_initializer使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了local_variables_initializer函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test
def test(model, config, prompts):
sr = 24000 if 'blizzard' in config.data_path else 16000
meta = data_input.load_meta(config.data_path)
config.r = audio.r
ivocab = meta['vocab']
config.vocab_size = len(ivocab)
with tf.device('/cpu:0'):
batch_inputs = data_input.load_prompts(prompts, ivocab)
config.num_prompts = len(prompts)
with tf.Session() as sess:
stft_mean = tf.get_variable('stft_mean', shape=(1025*audio.r,), dtype=tf.float16)
stft_std = tf.get_variable('stft_std', shape=(1025*audio.r,), dtype=tf.float32)
# initialize model
model = model(config, batch_inputs, train=False)
train_writer = tf.summary.FileWriter('log/' + config.save_path + '/test', sess.graph)
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver()
print('restoring weights')
latest_ckpt = tf.train.latest_checkpoint(
'weights/' + config.save_path[:config.save_path.rfind('/')]
)
saver.restore(sess, latest_ckpt)
stft_mean, stft_std = sess.run([stft_mean, stft_std])
try:
while(True):
out = sess.run([
model.output,
model.alignments,
batch_inputs
])
outputs, alignments, inputs = out
print('saving samples')
for out, words, align in zip(outputs, inputs['text'], alignments):
# store a sample to listen to
text = ''.join([ivocab[w] for w in words])
attention_plot = data_input.generate_attention_plot(align)
sample = audio.invert_spectrogram(out*stft_std + stft_mean)
merged = sess.run(tf.summary.merge(
[tf.summary.audio(text, sample[None, :], sr),
tf.summary.image(text, attention_plot)]
))
train_writer.add_summary(merged, 0)
except tf.errors.OutOfRangeError:
coord.request_stop()
coord.join(threads)
开发者ID:yhgon,项目名称:Tacotron-tf-barronalex,代码行数:60,代码来源:test.py
示例2: train
def train(model, data, gen, params):
anim_frames = []
with tf.Session() as session:
tf.local_variables_initializer().run()
tf.global_variables_initializer().run()
for step in range(params.num_steps + 1):
# update discriminator
x = data.sample(params.batch_size)
z = gen.sample(params.batch_size)
loss_d, _, = session.run([model.loss_d, model.opt_d], {
model.x: np.reshape(x, (params.batch_size, 1)),
model.z: np.reshape(z, (params.batch_size, 1))
})
# update generator
z = gen.sample(params.batch_size)
loss_g, _ = session.run([model.loss_g, model.opt_g], {
model.z: np.reshape(z, (params.batch_size, 1))
})
if step % params.log_every == 0:
print('{}: {:.4f}\t{:.4f}'.format(step, loss_d, loss_g))
if params.anim_path and (step % params.anim_every == 0):
anim_frames.append(
samples(model, session, data, gen.range, params.batch_size)
)
if params.anim_path:
save_animation(anim_frames, params.anim_path, gen.range)
else:
samps = samples(model, session, data, gen.range, params.batch_size)
plot_distributions(samps, gen.range)
开发者ID:yashvardhan90,项目名称:gan-intro,代码行数:35,代码来源:gan.py
示例3: test_empty_labels_and_scores_gives_nan_auc
def test_empty_labels_and_scores_gives_nan_auc(self):
with self.test_session():
labels = tf.constant([], shape=[0], dtype=tf.bool)
scores = tf.constant([], shape=[0], dtype=tf.float32)
score_range = [0, 1.]
auc, update_op = tf.contrib.metrics.auc_using_histogram(labels, scores,
score_range)
tf.local_variables_initializer().run()
update_op.run()
self.assertTrue(np.isnan(auc.eval()))
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:10,代码来源:histogram_ops_test.py
示例4: _check_auc
def _check_auc(self,
nbins=100,
desired_auc=0.75,
score_range=None,
num_records=50,
frac_true=0.5,
atol=0.05,
num_updates=10):
"""Check auc accuracy against synthetic data.
Args:
nbins: nbins arg from contrib.metrics.auc_using_histogram.
desired_auc: Number in [0, 1]. The desired auc for synthetic data.
score_range: 2-tuple, (low, high), giving the range of the resultant
scores. Defaults to [0, 1.].
num_records: Positive integer. The number of records to return.
frac_true: Number in (0, 1). Expected fraction of resultant labels that
will be True. This is just in expectation...more or less may actually
be True.
atol: Absolute tolerance for final AUC estimate.
num_updates: Update internal histograms this many times, each with a new
batch of synthetic data, before computing final AUC.
Raises:
AssertionError: If resultant AUC is not within atol of theoretical AUC
from synthetic data.
"""
score_range = [0, 1.] or score_range
with self.test_session():
labels = tf.placeholder(tf.bool, shape=[num_records])
scores = tf.placeholder(tf.float32, shape=[num_records])
auc, update_op = tf.contrib.metrics.auc_using_histogram(labels,
scores,
score_range,
nbins=nbins)
tf.local_variables_initializer().run()
# Updates, then extract auc.
for _ in range(num_updates):
labels_a, scores_a = synthetic_data(desired_auc, score_range,
num_records, self.rng, frac_true)
update_op.run(feed_dict={labels: labels_a, scores: scores_a})
labels_a, scores_a = synthetic_data(desired_auc, score_range, num_records,
self.rng, frac_true)
# Fetch current auc, and verify that fetching again doesn't change it.
auc_eval = auc.eval()
self.assertAlmostEqual(auc_eval, auc.eval(), places=5)
msg = ('nbins: %s, desired_auc: %s, score_range: %s, '
'num_records: %s, frac_true: %s, num_updates: %s') % (nbins,
desired_auc,
score_range,
num_records,
frac_true,
num_updates)
np.testing.assert_allclose(desired_auc, auc_eval, atol=atol, err_msg=msg)
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:55,代码来源:histogram_ops_test.py
示例5: train
def train(self, DGTrain, DGTest, saver=True):
epoch = DGTrain.length
self.LearningRateSchedule(self.LEARNING_RATE, self.K, epoch)
trainable_var = tf.trainable_variables()
self.regularize_model()
self.optimization(trainable_var)
self.ExponentialMovingAverage(trainable_var, self.DECAY_EMA)
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
self.summary_test_writer = tf.summary.FileWriter(self.LOG + '/test',
graph=self.sess.graph)
self.summary_writer = tf.summary.FileWriter(self.LOG + '/train', graph=self.sess.graph)
merged_summary = tf.summary.merge_all()
steps = self.STEPS
# for i in range(Xval.shape[0]):
# imsave("/tmp/image_{}.png".format(i), Xval[i])
# imsave("/tmp/label_{}.png".format(i), Yval[i,:,:,0])
for step in range(steps):
batch_data, batch_labels = DGTrain.Batch(0, self.BATCH_SIZE)
feed_dict = {self.input_node: batch_data,
self.train_labels_node: batch_labels}
# self.optimizer is replaced by self.training_op for the exponential moving decay
_, l, lr, predictions, s = self.sess.run(
[self.training_op, self.loss, self.learning_rate,
self.train_prediction, merged_summary],
feed_dict=feed_dict)
if step % self.N_PRINT == 0:
i = datetime.now()
print i.strftime('%Y/%m/%d %H:%M:%S: \n ')
self.summary_writer.add_summary(s, step)
error, acc, acc1, recall, prec, f1 = self.error_rate(predictions, batch_labels, step)
print(' Step %d of %d' % (step, steps))
print(' Learning rate: %.5f \n') % lr
print(' Mini-batch loss: %.5f \n Accuracy: %.1f%% \n acc1: %.1f%% \n recall: %1.f%% \n prec: %1.f%% \n f1 : %1.f%% \n' %
(l, acc, acc1, recall, prec, f1))
self.Validation(DGTest, step)
开发者ID:PeterJackNaylor,项目名称:PhD_Fabien,代码行数:50,代码来源:ObjectOriented.py
示例6: main
def main(model_config, train_config, track_config):
# Create training directory
train_dir = train_config['train_dir']
if not tf.gfile.IsDirectory(train_dir):
tf.logging.info('Creating training directory: %s', train_dir)
tf.gfile.MakeDirs(train_dir)
# Build the Tensorflow graph
g = tf.Graph()
with g.as_default():
# Set fixed seed
np.random.seed(train_config['seed'])
tf.set_random_seed(train_config['seed'])
# Build the model
model = siamese_model.SiameseModel(model_config, train_config, mode='inference')
model.build()
# Save configurations for future reference
save_cfgs(train_dir, model_config, train_config, track_config)
saver = tf.train.Saver(tf.global_variables(),
max_to_keep=train_config['max_checkpoints_to_keep'])
# Dynamically allocate GPU memory
gpu_options = tf.GPUOptions(allow_growth=True)
sess_config = tf.ConfigProto(gpu_options=gpu_options)
sess = tf.Session(config=sess_config)
model_path = tf.train.latest_checkpoint(train_config['train_dir'])
if not model_path:
# Initialize all variables
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
start_step = 0
# Load pretrained embedding model if needed
if model_config['embed_config']['embedding_checkpoint_file']:
model.init_fn(sess)
else:
logging.info('Restore from last checkpoint: {}'.format(model_path))
sess.run(tf.local_variables_initializer())
saver.restore(sess, model_path)
start_step = tf.train.global_step(sess, model.global_step.name) + 1
checkpoint_path = osp.join(train_config['train_dir'], 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=start_step)
开发者ID:fossabot,项目名称:SiamFC-TensorFlow,代码行数:49,代码来源:convert_pretrained_model.py
示例7: test_batch_text_lines
def test_batch_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("A\nB\nC\nD\nE\n")
batch_size = 3
queue_capacity = 10
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
[filename], batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
read_batch_size=10, name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
开发者ID:moolighty,项目名称:tensorflow,代码行数:26,代码来源:graph_io_test.py
示例8: testRoundtrip
def testRoundtrip(self, rate=0.25, count=5, n=500):
"""Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`."""
foo = self.get_values(count)
bar = self.get_values(count)
weights = self.get_weights(count)
resampled_in, rates = tf.contrib.training.weighted_resample([foo, bar], tf.constant(weights), rate, seed=123)
resampled_back_out = tf.contrib.training.resample_at_rate(resampled_in, 1.0 / rates, seed=456)
init = tf.local_variables_initializer()
with self.test_session() as s:
s.run(init) # initialize
# outputs
counts_resampled = collections.Counter()
counts_reresampled = collections.Counter()
for _ in range(n):
resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out])
self.assertAllEqual(resampled_vs[0], resampled_vs[1])
self.assertAllEqual(reresampled_vs[0], reresampled_vs[1])
for v in resampled_vs[0]:
counts_resampled[v] += 1
for v in reresampled_vs[0]:
counts_reresampled[v] += 1
# assert that resampling worked as expected
self.assert_expected(weights, rate, counts_resampled, n)
# and that re-resampling gives the approx identity.
self.assert_expected([1.0 for _ in weights], 1.0, counts_reresampled, n, abs_delta=0.1 * n * count)
开发者ID:brchiu,项目名称:tensorflow,代码行数:34,代码来源:resample_test.py
示例9: blend_images
def blend_images(data_folder1, data_folder2, out_folder, alpha=.5):
filename_queue = tf.placeholder(dtype=tf.string)
label = tf.placeholder(dtype=tf.int32)
tensor_image = tf.read_file(filename_queue)
image = tf.image.decode_jpeg(tensor_image, channels=3)
multiplier = tf.div(tf.constant(224, tf.float32),
tf.cast(tf.maximum(tf.shape(image)[0], tf.shape(image)[1]), tf.float32))
x = tf.cast(tf.round(tf.mul(tf.cast(tf.shape(image)[0], tf.float32), multiplier)), tf.int32)
y = tf.cast(tf.round(tf.mul(tf.cast(tf.shape(image)[1], tf.float32), multiplier)), tf.int32)
image = tf.image.resize_images(image, [x, y])
image = tf.image.rot90(image, k=label)
image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)
sess = tf.Session()
sess.run(tf.local_variables_initializer())
for root, folders, files in os.walk(data_folder1):
for each in files:
if each.find('.jpg') >= 0:
img1 = Image.open(os.path.join(root, each))
img2_path = os.path.join(root.replace(data_folder1, data_folder2), each.split("-")[-1])
rotation = int(each.split("-")[1])
img2 = sess.run(image, feed_dict={filename_queue: img2_path, label: rotation})
imsave(os.path.join(os.getcwd(), "temp", "temp.jpg"), img2)
img2 = Image.open(os.path.join(os.getcwd(), "temp", "temp.jpg"))
out_image = Image.blend(img1, img2, alpha)
outfile = os.path.join(root.replace(data_folder1, out_folder), each)
if not os.path.exists(os.path.split(outfile)[0]):
os.makedirs(os.path.split(outfile)[0])
out_image.save(outfile)
else:
print(each)
sess.close()
开发者ID:Sabrewarrior,项目名称:PhotoOrientation,代码行数:35,代码来源:misc.py
示例10: test
def test(self, p1, p2, steps):
loss, roc = 0., 0.
acc, F1, recall = 0., 0., 0.
precision, jac, AJI = 0., 0., 0.
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
self.sess.run(init_op)
self.Saver()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for step in range(steps):
feed_dict = {self.is_training: False}
l, prob, batch_labels = self.sess.run([self.loss, self.train_prediction,
self.train_labels_node], feed_dict=feed_dict)
loss += l
out = ComputeMetrics(prob[0,:,:,1], batch_labels[0,:,:,0], p1, p2)
acc += out[0]
roc += out[1]
jac += out[2]
recall += out[3]
precision += out[4]
F1 += out[5]
AJI += out[6]
coord.request_stop()
coord.join(threads)
loss, acc, F1 = np.array([loss, acc, F1]) / steps
recall, precision, roc = np.array([recall, precision, roc]) / steps
jac, AJI = np.array([jac, AJI]) / steps
return loss, acc, F1, recall, precision, roc, jac, AJI
开发者ID:PeterJackNaylor,项目名称:PhD_Fabien,代码行数:30,代码来源:UNet.py
示例11: initialize_variables
def initialize_variables(sess, saver, logdir, checkpoint=None, resume=None):
"""Initialize or restore variables from a checkpoint if available.
Args:
sess: Session to initialize variables in.
saver: Saver to restore variables.
logdir: Directory to search for checkpoints.
checkpoint: Specify what checkpoint name to use; defaults to most recent.
resume: Whether to expect recovering a checkpoint or starting a new run.
Raises:
ValueError: If resume expected but no log directory specified.
RuntimeError: If no resume expected but a checkpoint was found.
"""
sess.run(tf.group(
tf.local_variables_initializer(),
tf.global_variables_initializer()))
if resume and not (logdir or checkpoint):
raise ValueError('Need to specify logdir to resume a checkpoint.')
if logdir:
state = tf.train.get_checkpoint_state(logdir)
if checkpoint:
checkpoint = os.path.join(logdir, checkpoint)
if not checkpoint and state and state.model_checkpoint_path:
checkpoint = state.model_checkpoint_path
if checkpoint and resume is False:
message = 'Found unexpected checkpoint when starting a new run.'
raise RuntimeError(message)
if checkpoint:
saver.restore(sess, checkpoint)
开发者ID:shamanez,项目名称:agents,代码行数:30,代码来源:utility.py
示例12: testSummariesAreFlushedToDiskWithoutGlobalStep
def testSummariesAreFlushedToDiskWithoutGlobalStep(self):
output_dir = os.path.join(self.get_temp_dir(), 'flush_test_no_global_step')
if tf.gfile.Exists(output_dir): # For running on jenkins.
tf.gfile.DeleteRecursively(output_dir)
names_to_metrics, names_to_updates = self._create_names_to_metrics(
self._predictions, self._labels)
for k in names_to_metrics:
v = names_to_metrics[k]
tf.summary.scalar(k, v)
summary_writer = tf.train.SummaryWriter(output_dir)
initial_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
eval_op = tf.group(*names_to_updates.values())
with self.test_session() as sess:
slim.evaluation.evaluation(
sess,
initial_op=initial_op,
eval_op=eval_op,
summary_op=tf.summary.merge_all(),
summary_writer=summary_writer)
names_to_values = {name: names_to_metrics[name].eval()
for name in names_to_metrics}
self._verify_summaries(output_dir, names_to_values)
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:29,代码来源:evaluation_test.py
示例13: run
def run():
with tf.Session() as sess:
print("start")
feature = {'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)}
# Create a list of filenames and pass it to a queue
print(data_path)
filename_queue = tf.train.string_input_producer(data_path, num_epochs=1)
# Define a reader and read the next record
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# Decode the record read by the reader
features = tf.parse_single_example(serialized_example, features=feature)
# Convert the image data from string back to the numbers
image = tf.decode_raw(features['image'], tf.uint8)
# image = tf.cast(image, tf.int32)
# Cast label data into int32
label = tf.cast(features['label'], tf.int32)
# Reshape image data into the original shape
init_op = [tf.global_variables_initializer(), tf.local_variables_initializer()]
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
train_list = []
for i in range(1000):
example, l = sess.run([image, label])
train_list.append((example,l))
# print (example, l)
coord.request_stop()
coord.join(threads)
return train_list
# run()
开发者ID:ykakde,项目名称:trash-classifier,代码行数:34,代码来源:tf_file_reader.py
示例14: main
def main(argv):
del argv # Unused.
# Sanity check on the GCS bucket URL.
if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
sys.exit(1)
# Verify that writing to the records file in GCS works.
print("\n=== Testing writing and reading of GCS record file... ===")
example_data = create_examples(FLAGS.num_examples, 5)
with tf.python_io.TFRecordWriter(FLAGS.gcs_bucket_url) as hf:
for e in example_data:
hf.write(e.SerializeToString())
print("Data written to: %s" % FLAGS.gcs_bucket_url)
# Verify that reading from the tfrecord file works and that
# tf_record_iterator works.
record_iter = tf.python_io.tf_record_iterator(FLAGS.gcs_bucket_url)
read_count = 0
for _ in record_iter:
read_count += 1
print("Read %d records using tf_record_iterator" % read_count)
if read_count != FLAGS.num_examples:
print("FAIL: The number of records read from tf_record_iterator (%d) "
"differs from the expected number (%d)" % (read_count,
FLAGS.num_examples))
sys.exit(1)
# Verify that running the read op in a session works.
print("\n=== Testing TFRecordReader.read op in a session... ===")
with tf.Graph().as_default() as _:
filename_queue = tf.train.string_input_producer([FLAGS.gcs_bucket_url],
num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
tf.train.start_queue_runners()
index = 0
for _ in range(FLAGS.num_examples):
print("Read record: %d" % index)
sess.run(serialized_example)
index += 1
# Reading one more record should trigger an exception.
try:
sess.run(serialized_example)
print("FAIL: Failed to catch the expected OutOfRangeError while "
"reading one more record than is available")
sys.exit(1)
except tf.errors.OutOfRangeError:
print("Successfully caught the expected OutOfRangeError while "
"reading one more record than is available")
create_dir_test()
create_object_test()
开发者ID:DILASSS,项目名称:tensorflow,代码行数:60,代码来源:gcs_smoke.py
示例15: predict
def predict(self):
import cv2
import glob
import numpy as np
# TODO 不应该这样写,应该直接读图片预测,而不是从tfrecord读取,因为顺序变了,无法对应
predict_file_path = glob.glob(os.path.join(ORIGIN_PREDICT_DIRECTORY, '*.tif'))
print(len(predict_file_path))
ckpt_path = CHECK_POINT_PATH
all_parameters_saver = tf.train.Saver()
with tf.Session() as sess: # 开始一个会话
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
for index, image_path in enumerate(predict_file_path):
# image = cv2.imread(image_path, flags=0)
image = np.reshape(a=cv2.imread(image_path, flags=0), newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
predict_image = sess.run(
tf.argmax(input=self.prediction, axis=3),
feed_dict={
self.input_image: image,
self.keep_prob: 1.0, self.lamb: 0.004
}
)
cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % index), predict_image[0] * 255)
print('Done prediction')
开发者ID:USTCzxm,项目名称:U-net,代码行数:27,代码来源:unet-TF.py
示例16: test_input_pipeline
def test_input_pipeline(self):
Xs, Ys = dsu.tiny_imagenet_load()
n_batches = 0
batch_size = 10
with tf.Graph().as_default(), tf.Session() as sess:
batch_generator = dsu.create_input_pipeline(
Xs[:100],
batch_size=batch_size,
n_epochs=1,
shape=(64, 64, 3),
crop_shape=(64, 64, 3))
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
tf.get_default_graph().finalize()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
batch = sess.run(batch_generator)
assert (batch.shape == (batch_size, 64, 64, 3))
n_batches += 1
except tf.errors.OutOfRangeError:
pass
finally:
coord.request_stop()
coord.join(threads)
assert (n_batches == 10)
开发者ID:pradeeps,项目名称:pycadl,代码行数:28,代码来源:test_dataset_utils.py
示例17: get_hit_rate_and_ndcg
def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user,
top_k=rconst.TOP_K, match_mlperf=False):
rconst.TOP_K = top_k
rconst.NUM_EVAL_NEGATIVES = predicted_scores_by_user.shape[1] - 1
batch_size = items_by_user.shape[0]
users = np.repeat(np.arange(batch_size)[:, np.newaxis],
rconst.NUM_EVAL_NEGATIVES + 1, axis=1)
users, items, duplicate_mask = \
data_pipeline.BaseDataConstructor._assemble_eval_batch(
users, items_by_user[:, -1:], items_by_user[:, :-1], batch_size)
g = tf.Graph()
with g.as_default():
logits = tf.convert_to_tensor(
predicted_scores_by_user.reshape((-1, 1)), tf.float32)
softmax_logits = tf.concat([tf.zeros(logits.shape, dtype=logits.dtype),
logits], axis=1)
duplicate_mask = tf.convert_to_tensor(duplicate_mask, tf.float32)
metric_ops = neumf_model.compute_eval_loss_and_metrics(
logits=logits, softmax_logits=softmax_logits,
duplicate_mask=duplicate_mask, num_training_neg=NUM_TRAIN_NEG,
match_mlperf=match_mlperf).eval_metric_ops
hr = metric_ops[rconst.HR_KEY]
ndcg = metric_ops[rconst.NDCG_KEY]
init = [tf.global_variables_initializer(),
tf.local_variables_initializer()]
with self.test_session(graph=g) as sess:
sess.run(init)
return sess.run([hr[1], ndcg[1]])
开发者ID:pooyadavoodi,项目名称:models,代码行数:34,代码来源:ncf_test.py
示例18: test_smoke
def test_smoke(self):
"""Smoke test for a full pipeline."""
_, tname = tempfile.mkstemp()
num = 100
num_epochs = 2
self._write_examples(tname, [self._random_io_data() for _ in range(num)])
tensors = data.read_from_files([tname], shuffle=True, num_epochs=num_epochs)
batches = lin.shuffle_batch(tensors=tensors, batch_size=5)
count = 0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while True:
actual = sess.run(batches)
count += len(actual[0])
except tf.errors.OutOfRangeError as ex:
coord.request_stop(ex=ex)
finally:
coord.request_stop()
coord.join(threads)
self.assertEqual(num * num_epochs, count)
os.remove(tname)
开发者ID:usman776,项目名称:dket,代码行数:26,代码来源:test_data.py
示例19: test_keyed_read_text_lines
def test_keyed_read_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("ABC\nDEF\nGHK\n")
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
filename, batch_size,
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity, name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":1"], [b"ABC"]])
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":2"], [b"DEF"]])
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":3"], [b"GHK"]])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
开发者ID:moolighty,项目名称:tensorflow,代码行数:31,代码来源:graph_io_test.py
示例20: predict
def predict(self):
print 'Running inference...'
self.sess.run(tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()))
self.load_weights('/Users/shashank/TensorFlow/SPN/weights/')
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=self.sess,coord=coord)
result = []
truth = []
count =0
try:
while not coord.should_stop():
print count
batch_imgs, batch_labels, batch_landmarks, batch_visibility, batch_pose, batch_gender = self.sess.run([self.images,self.labels,self.land, self.vis, self.po, self.gen])
batch_imgs = (batch_imgs - 127.5) / 128.0
net_preds = self.sess.run(self.net_output, feed_dict={self.X: batch_imgs})
result.append(np.concatenate(net_preds, axis=1))
truth.append(np.concatenate([batch_labels[:, np.newaxis], batch_landmarks, batch_visibility, batch_pose, batch_gender], axis=1))
count += 1
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
np.save('test_results', np.concatenate(result, axis = 0))
np.save('truth', np.concatenate(truth, axis = 0))
开发者ID:dmehr,项目名称:HyperFace-TensorFlow-implementation,代码行数:29,代码来源:model_prediction.py
注:本文中的tensorflow.local_variables_initializer函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论