本文整理汇总了Python中tensorflow.set_random_seed函数的典型用法代码示例。如果您正苦于以下问题:Python set_random_seed函数的具体用法?Python set_random_seed怎么用?Python set_random_seed使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了set_random_seed函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: main
def main():
tf.set_random_seed(10)
with tf.Session() as sess:
rnn_cell = tf.nn.rnn_cell.LSTMCell(10)
# defining initial state
initial_state = rnn_cell.zero_state(4, dtype=tf.float32)
inputs = tf.Variable(tf.random_uniform(shape = (4, 30, 100)), name='input')
inputs = tf.identity(inputs, "input_node")
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, inputs, initial_state=initial_state, dtype=tf.float32)
y1 = tf.identity(outputs, 'outputs')
y2 = tf.identity(state, 'state')
t1 = tf.ones([4, 30, 10])
t2 = tf.ones([4, 10])
loss = tf.reduce_sum((y1 - t1) * (y1 - t1)) + tf.reduce_sum((y2 - t2) * (y2 - t2))
tf.identity(loss, name = "lstm_loss")
# tf.summary.FileWriter('/tmp/log', tf.get_default_graph())
net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split(','))
run_model(net_outputs, argv[1], None, argv[3] == 'True')
开发者ID:ru003ar,项目名称:BigDL,代码行数:26,代码来源:dynamic_lstm.py
示例2: gradient_memory_mbs
def gradient_memory_mbs():
"""Evaluates gradient, prints peak memory."""
start_time0 = time.perf_counter()
start_time = start_time0
tf.reset_default_graph()
tf.set_random_seed(1)
train_op, loss = create_train_op_and_loss()
print("Graph construction: %.2f ms" %(1000*(time.perf_counter()-start_time)))
g = tf.get_default_graph()
ops = g.get_operations()
for op in ge.filter_ops_from_regex(ops, "block_layer"):
tf.add_to_collection("checkpoints", op.outputs[0])
sess = create_session()
sessrun(tf.global_variables_initializer())
start_time = time.perf_counter()
sessrun(train_op)
start_time = time.perf_counter()
print("loss %f"%(sess.run(loss),))
print("Compute time: %.2f ms" %(1000*(time.perf_counter()-start_time)))
mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6
print("Memory used: %.2f MB "%(mem_use))
total_time = time.perf_counter()-start_time0
assert total_time < 100
return mem_use
开发者ID:BhaskarNallani,项目名称:gradient-checkpointing,代码行数:30,代码来源:imagenet_test.py
示例3: testTrainWithTrace
def testTrainWithTrace(self):
logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
'tmp_logs')
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = LogisticClassifier(tf_inputs)
slim.losses.log_loss(tf_predictions, tf_labels)
total_loss = slim.losses.get_total_loss()
tf.summary.scalar('total_loss', total_loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = slim.learning.create_train_op(total_loss, optimizer)
loss = slim.learning.train(
train_op,
logdir,
number_of_steps=300,
log_every_n_steps=10,
trace_every_n_steps=100)
self.assertIsNotNone(loss)
for trace_step in [1, 101, 201]:
trace_filename = 'tf_trace-%d.json' % trace_step
self.assertTrue(
os.path.isfile(os.path.join(logdir, trace_filename)))
开发者ID:moolighty,项目名称:tensorflow,代码行数:28,代码来源:learning_test.py
示例4: initialize_parameters
def initialize_parameters():
"""
Initializes parameters to build a neural network with tensorflow. The shapes are:
W1 : [25, 12288]
b1 : [25, 1]
W2 : [12, 25]
b2 : [12, 1]
W3 : [6, 12]
b3 : [6, 1]
Returns:
parameters -- a dictionary of tensors containing W1, b1, W2, b2, W3, b3
"""
tf.set_random_seed(1) # so that your "random" numbers match ours
### START CODE HERE ### (approx. 6 lines of code)
W1 = tf.get_variable("W1", [25,12288], initializer = tf.contrib.layers.xavier_initializer(seed = 1))
b1 = tf.get_variable("b1", [25,1], initializer = tf.zeros_initializer())
W2 = tf.get_variable("W2", [12,25], initializer = tf.contrib.layers.xavier_initializer(seed = 1))
b2 = tf.get_variable("b2", [12,1], initializer = tf.zeros_initializer())
W3 = tf.get_variable("W3", [6,12], initializer = tf.contrib.layers.xavier_initializer(seed = 1))
b3 = tf.get_variable("b3", [6,1], initializer = tf.zeros_initializer())
### END CODE HERE ###
parameters = {"W1": W1,
"b1": b1,
"W2": W2,
"b2": b2,
"W3": W3,
"b3": b3}
return parameters
开发者ID:shriavi,项目名称:datasciencecoursera,代码行数:33,代码来源:Tensorflow+Tutorial.py
示例5: testEmptyUpdateOps
def testEmptyUpdateOps(self):
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = BatchNormClassifier(tf_inputs)
slim.losses.log_loss(tf_predictions, tf_labels)
total_loss = slim.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = slim.learning.create_train_op(total_loss, optimizer,
update_ops=[])
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
moving_variance = tf.contrib.framework.get_variables_by_name(
'moving_variance')[0]
with tf.Session() as sess:
# Initialize all variables
sess.run(tf.global_variables_initializer())
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 4)
self.assertAllClose(variance, [1] * 4)
for _ in range(10):
sess.run([train_op])
mean = moving_mean.eval()
variance = moving_variance.eval()
# Since we skip update_ops the moving_vars are not updated.
self.assertAllClose(mean, [0] * 4)
self.assertAllClose(variance, [1] * 4)
开发者ID:moolighty,项目名称:tensorflow,代码行数:33,代码来源:learning_test.py
示例6: _do_sampling
def _do_sampling(self, logits, num_samples, sampler):
"""Samples using the supplied sampler and inputs.
Args:
logits: Numpy ndarray of shape [batch_size, num_classes].
num_samples: Int; number of samples to draw.
sampler: A sampler function that takes (1) a [batch_size, num_classes]
Tensor, (2) num_samples and returns a [batch_size, num_samples] Tensor.
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
with self.test_session() as sess:
tf.set_random_seed(1618)
op = sampler(tf.constant(logits), num_samples)
d = sess.run(op)
batch_size, num_classes = logits.shape
freqs_mat = []
for i in range(batch_size):
cnts = dict(collections.Counter(d[i, :]))
freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0)
for k in range(num_classes)]
freqs_mat.append(freqs)
return freqs_mat
开发者ID:0-T-0,项目名称:tensorflow,代码行数:26,代码来源:multinomial_op_test.py
示例7: wide_model
def wide_model(numeric_input, category_input, vocabs):
transpose_category_input = tf.transpose(category_input)
category_sum = None
# Append embadding category to numeric_sum
for i in range(0, len(vocabs)):
embedding = tf.get_variable("wideem" + str(i), [vocabs[i], 8],
initializer=tf.contrib.layers.xavier_initializer()
#partitioner=tf.fixed_size_partitioner(n_pss))
#partitioner=tf.min_max_variable_partitioner(n_pss, 0, 2 << 10)
)
# Pick one column from category input
col = tf.gather(transpose_category_input, [i])[0]
#col = tf.nn.embedding_lookup(transpose_category_input, [i])[0]
# Same as make [0001]*[w1,w2,w3,w4] = lookup w4
#embedded_col = embedding_lookup(tf.identity(embedding), col) # number * embedding output number
embedded_col = embedding_ops.embedding_lookup_unique(embedding, col)
if category_sum is None:
category_sum = embedded_col
else:
category_sum = tf.concat([category_sum, embedded_col], 1)
tf.set_random_seed(1)
w = tf.get_variable("W", [numeric_input.shape[1] + category_sum.shape[1], 1], initializer=tf.contrib.layers.xavier_initializer())
wmodel_logits_sum = tf.matmul(tf.concat([numeric_input, category_sum], 1), w)
return wmodel_logits_sum
开发者ID:ShifuML,项目名称:shifu,代码行数:28,代码来源:wnp_ssgd_not_embadding.py
示例8: testCreateOnecloneWithPS
def testCreateOnecloneWithPS(self):
g = tf.Graph()
with g.as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
model_fn = BatchNormClassifier
model_args = (tf_inputs, tf_labels)
deploy_config = model_deploy.DeploymentConfig(num_clones=1,
num_ps_tasks=1)
self.assertEqual(slim.get_variables(), [])
clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
self.assertEqual(len(slim.get_variables()), 5)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
self.assertEqual(len(update_ops), 2)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
optimizer)
self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
self.assertEqual(total_loss.op.name, 'total_loss')
for g, v in grads_and_vars:
self.assertDeviceEqual(g.device, '/job:worker/device:GPU:0')
self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
开发者ID:ALISCIFP,项目名称:models,代码行数:26,代码来源:model_deploy_test.py
示例9: __init__
def __init__(self, env, discount = 0.90, learning_rate = 0.008):
self.env = env
self.observation_space = env.observation_space
self.action_space = env.action_space
self.action_space_n = self.action_space.n
self.n_input = len(self.observation_space.high)
self.n_hidden_1 = 20
#Learning Parameters
self.learning_rate = learning_rate
self.discount = discount
self.num_epochs = 20
self.batch_size = 32
self.graph = tf.Graph()
#Neural network is a Multi-Layered perceptron with one hidden layer containing tanh units
with self.graph.as_default():
tf.set_random_seed(1234)
self.weights = {
'h1': tf.Variable(tf.random_normal([self.n_input, self.n_hidden_1])),
'out': tf.Variable(tf.random_normal([self.n_hidden_1, 1]))
}
self.biases = {
'b1': tf.Variable(tf.random_normal([self.n_hidden_1])),
'out': tf.Variable(tf.random_normal([1]))
}
self.state_input = self.x = tf.placeholder("float", [None, len(self.observation_space.high)])#State input
self.return_input = tf.placeholder("float") #Target return
self.value_pred = self.multilayer_perceptron(self.state_input, self.weights, self.biases)
self.loss = tf.reduce_mean(tf.pow(self.value_pred - self.return_input,2))
self.optim = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
init = tf.initialize_all_variables()
print("Value Graph Constructed")
self.sess = tf.Session(graph = self.graph)
self.sess.run(init)
开发者ID:mohakbhardwaj,项目名称:reinforcement-learning,代码行数:33,代码来源:cartpole-policy-gradient.py
示例10: testCreateMulticlone
def testCreateMulticlone(self):
g = tf.Graph()
with g.as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
model_fn = BatchNormClassifier
clone_args = (tf_inputs, tf_labels)
num_clones = 4
deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones)
self.assertEqual(slim.get_variables(), [])
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
self.assertEqual(len(slim.get_variables()), 5)
for v in slim.get_variables():
self.assertDeviceEqual(v.device, 'CPU:0')
self.assertDeviceEqual(v.value().device, 'CPU:0')
self.assertEqual(len(clones), num_clones)
for i, clone in enumerate(clones):
self.assertEqual(
clone.outputs.op.name,
'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
self.assertEqual(len(update_ops), 2)
self.assertEqual(clone.scope, 'clone_%d/' % i)
self.assertDeviceEqual(clone.device, 'GPU:%d' % i)
开发者ID:ALISCIFP,项目名称:models,代码行数:27,代码来源:model_deploy_test.py
示例11: testCreateMulticloneWithPS
def testCreateMulticloneWithPS(self):
g = tf.Graph()
with g.as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
model_fn = BatchNormClassifier
clone_args = (tf_inputs, tf_labels)
deploy_config = model_deploy.DeploymentConfig(num_clones=2,
num_ps_tasks=2)
self.assertEqual(slim.get_variables(), [])
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
self.assertEqual(len(slim.get_variables()), 5)
for i, v in enumerate(slim.get_variables()):
t = i % 2
self.assertDeviceEqual(v.device, '/job:ps/task:%d/device:CPU:0' % t)
self.assertDeviceEqual(v.device, v.value().device)
self.assertEqual(len(clones), 2)
for i, clone in enumerate(clones):
self.assertEqual(
clone.outputs.op.name,
'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
self.assertEqual(clone.scope, 'clone_%d/' % i)
self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:%d' % i)
开发者ID:ALISCIFP,项目名称:models,代码行数:26,代码来源:model_deploy_test.py
示例12: testCreateLogisticClassifier
def testCreateLogisticClassifier(self):
g = tf.Graph()
with g.as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
model_fn = LogisticClassifier
clone_args = (tf_inputs, tf_labels)
deploy_config = model_deploy.DeploymentConfig(num_clones=1)
self.assertEqual(slim.get_variables(), [])
clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
clone = clones[0]
self.assertEqual(len(slim.get_variables()), 2)
for v in slim.get_variables():
self.assertDeviceEqual(v.device, 'CPU:0')
self.assertDeviceEqual(v.value().device, 'CPU:0')
self.assertEqual(clone.outputs.op.name,
'LogisticClassifier/fully_connected/Sigmoid')
self.assertEqual(clone.scope, '')
self.assertDeviceEqual(clone.device, 'GPU:0')
self.assertEqual(len(slim.losses.get_losses()), 1)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
self.assertEqual(update_ops, [])
开发者ID:ALISCIFP,项目名称:models,代码行数:25,代码来源:model_deploy_test.py
示例13: main
def main(hps):
# Initialize Horovod.
hvd.init()
# Create tensorflow session
sess = tensorflow_session()
# Download and load dataset.
tf.set_random_seed(hvd.rank() + hvd.size() * hps.seed)
np.random.seed(hvd.rank() + hvd.size() * hps.seed)
# Get data and set train_its and valid_its
train_iterator, test_iterator, data_init = get_data(hps, sess)
hps.train_its, hps.test_its, hps.full_test_its = get_its(hps)
# Create log dir
logdir = os.path.abspath(hps.logdir) + "/"
if not os.path.exists(logdir):
os.mkdir(logdir)
# Create model
import model
model = model.model(sess, hps, train_iterator, test_iterator, data_init)
# Initialize visualization functions
visualise = init_visualizations(hps, model, logdir)
if not hps.inference:
# Perform training
train(sess, model, hps, logdir, visualise)
else:
infer(sess, model, hps, test_iterator)
开发者ID:chinatian,项目名称:glow,代码行数:33,代码来源:train.py
示例14: simple_test
def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
def seeded_env_fn():
env = env_fn()
env.seed(0)
return env
np.random.seed(0)
env = DummyVecEnv([seeded_env_fn])
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
tf.set_random_seed(0)
model = learn_fn(env)
sum_rew = 0
done = True
for i in range(n_trials):
if done:
obs = env.reset()
state = model.initial_state
if state is not None:
a, v, state, _ = model.step(obs, S=state, M=[False])
else:
a, v, _, _ = model.step(obs)
obs, rew, done, _ = env.step(a)
sum_rew += float(rew)
print("Reward in {} trials is {}".format(n_trials, sum_rew))
assert sum_rew > min_reward_fraction * n_trials, \
'sum of rewards {} is less than {} of the total number of trials {}'.format(sum_rew, min_reward_fraction, n_trials)
开发者ID:MrGoogol,项目名称:baselines,代码行数:26,代码来源:util.py
示例15: __init__
def __init__(self, input_dim=None, output_dim=1, init_path=None, opt_algo='gd', learning_rate=1e-2, l2_weight=0,
random_seed=None):
Model.__init__(self)
init_vars = [('w', [input_dim, output_dim], 'xavier', dtype),
('b', [output_dim], 'zero', dtype)]
self.graph = tf.Graph()
with self.graph.as_default():
if random_seed is not None:
tf.set_random_seed(random_seed)
self.X = tf.sparse_placeholder(dtype)
self.y = tf.placeholder(dtype)
self.vars = utils.init_var_map(init_vars, init_path) # 初始化变量w, b
w = self.vars['w']
b = self.vars['b']
xw = tf.sparse_tensor_dense_matmul(self.X, w)
logits = tf.reshape(xw + b, [-1])
self.y_prob = tf.sigmoid(logits)
self.loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(labels=self.y, logits=logits)) + \
l2_weight * tf.nn.l2_loss(xw)
self.optimizer = utils.get_optimizer(opt_algo, learning_rate, self.loss)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
tf.global_variables_initializer().run(session=self.sess)
开发者ID:zgcgreat,项目名称:WSDM,代码行数:28,代码来源:models.py
示例16: testProbabilitiesCanBeChanged
def testProbabilitiesCanBeChanged(self):
# Set up graph.
tf.set_random_seed(1234)
lbl1 = 0
lbl2 = 3
# This cond allows the necessary class queues to be populated.
label = tf.cond(tf.greater(0.5, tf.random_uniform([])), lambda: tf.constant(lbl1), lambda: tf.constant(lbl2))
val = [np.array([1, 4]) * label]
probs = tf.placeholder(tf.float32, shape=[5])
batch_size = 2
data_batch, labels = tf.contrib.training.stratified_sample_unknown_dist(val, label, probs, batch_size)
with self.test_session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(5):
[data], lbls = sess.run([data_batch, labels], feed_dict={probs: [1, 0, 0, 0, 0]})
for data_example in data:
self.assertListEqual([0, 0], list(data_example))
self.assertListEqual([0, 0], list(lbls))
# Now change distribution and expect different output.
for _ in range(5):
[data], lbls = sess.run([data_batch, labels], feed_dict={probs: [0, 0, 0, 1, 0]})
for data_example in data:
self.assertListEqual([3, 12], list(data_example))
self.assertListEqual([3, 3], list(lbls))
coord.request_stop()
coord.join(threads)
开发者ID:rhuangq,项目名称:tensorflow,代码行数:32,代码来源:sampling_ops_test.py
示例17: testGradientWithZeroWeight
def testGradientWithZeroWeight(self):
with tf.Graph().as_default():
tf.set_random_seed(0)
inputs = tf.ones((2, 3))
weights = tf.get_variable('weights',
shape=[3, 4],
initializer=tf.truncated_normal_initializer())
predictions = tf.matmul(inputs, weights)
optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9)
loss = tf.contrib.losses.mean_pairwise_squared_error(
predictions,
predictions,
0)
gradients_to_variables = optimizer.compute_gradients(loss)
init_op = tf.initialize_all_variables()
with self.test_session() as sess:
sess.run(init_op)
for grad, _ in gradients_to_variables:
np_grad = sess.run(grad)
self.assertFalse(np.isnan(np_grad).any())
开发者ID:apollos,项目名称:tensorflow,代码行数:25,代码来源:loss_ops_test.py
示例18: _train_model
def _train_model(self, checkpoint_dir, num_steps):
"""Trains a simple classification model.
Note that the data has been configured such that after around 300 steps,
the model has memorized the dataset (e.g. we can expect %100 accuracy).
Args:
checkpoint_dir: The directory where the checkpoint is written to.
num_steps: The number of steps to train for.
"""
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = logistic_classifier(tf_inputs)
loss = tf.contrib.losses.log_loss(tf_predictions, tf_labels)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = tf.contrib.training.create_train_op(loss, optimizer)
loss = tf.contrib.training.train(
train_op, checkpoint_dir, hooks=[
tf.train.StopAtStepHook(num_steps)
])
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:25,代码来源:evaluation_test.py
示例19: construct_graph
def construct_graph(self, training, seed):
"""Returns a TensorflowGraph object."""
graph = tf.Graph()
# Lazily created by _get_shared_session().
shared_session = None
# Cache of TensorFlow scopes, to prevent '_1' appended scope names
# when subclass-overridden methods use the same scopes.
name_scopes = {}
# Setup graph
with graph.as_default():
if seed is not None:
tf.set_random_seed(seed)
(output, labels, weights) = self.build(graph, name_scopes, training)
if training:
loss = self.add_training_cost(graph, name_scopes, output, labels, weights)
else:
loss = None
output = self.add_output_ops(graph, output) # add softmax heads
return TensorflowGraph(
graph=graph,
session=shared_session,
name_scopes=name_scopes,
output=output,
labels=labels,
weights=weights,
loss=loss)
开发者ID:joegomes,项目名称:deepchem,代码行数:30,代码来源:__init__.py
示例20: main
def main(_):
# Fixed seed for repeatability
seed = 8964
tf.set_random_seed(seed)
np.random.seed(seed)
random.seed(seed)
if FLAGS.legacy_mode and FLAGS.seq_length < 3:
raise ValueError('Legacy mode supports sequence length > 2 only.')
if not gfile.Exists(FLAGS.checkpoint_dir):
gfile.MakeDirs(FLAGS.checkpoint_dir)
train_model = model.Model(data_dir=FLAGS.data_dir,
is_training=True,
learning_rate=FLAGS.learning_rate,
beta1=FLAGS.beta1,
reconstr_weight=FLAGS.reconstr_weight,
smooth_weight=FLAGS.smooth_weight,
ssim_weight=FLAGS.ssim_weight,
icp_weight=FLAGS.icp_weight,
batch_size=FLAGS.batch_size,
img_height=FLAGS.img_height,
img_width=FLAGS.img_width,
seq_length=FLAGS.seq_length,
legacy_mode=FLAGS.legacy_mode)
train(train_model, FLAGS.pretrained_ckpt, FLAGS.checkpoint_dir,
FLAGS.train_steps, FLAGS.summary_freq)
开发者ID:ALISCIFP,项目名称:models,代码行数:29,代码来源:train.py
注:本文中的tensorflow.set_random_seed函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论