本文整理汇总了Python中tensorflow.get_default_session函数的典型用法代码示例。如果您正苦于以下问题:Python get_default_session函数的具体用法?Python get_default_session怎么用?Python get_default_session使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_default_session函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: get_session
def get_session():
"""Returns the TF session to be used by the backend.
If a default TensorFlow session is available, we will return it.
Else, we will return the global Keras session.
If no global Keras session exists at this point:
we will create a new global session.
Note that you can manually set the global session
via `K.set_session(sess)`.
"""
global _SESSION
if tf.get_default_session() is not None:
return tf.get_default_session()
if _SESSION is None:
if not os.environ.get("OMP_NUM_THREADS"):
_SESSION = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
else:
nb_thread = int(os.environ.get("OMP_NUM_THREADS"))
_SESSION = tf.Session(
config=tf.ConfigProto(intra_op_parallelism_threads=nb_thread, allow_soft_placement=True)
)
return _SESSION
开发者ID:faroit,项目名称:keras,代码行数:25,代码来源:tensorflow_backend.py
示例2: train
def train(self, obs, actions, gaes, rewards, v_preds_next):
tf.get_default_session().run(self.train_op, feed_dict={self.Policy.obs: obs,
self.Old_Policy.obs: obs,
self.actions: actions,
self.rewards: rewards,
self.v_preds_next: v_preds_next,
self.gaes: gaes})
开发者ID:6-Billionaires,项目名称:gail_ppo_optimizer,代码行数:7,代码来源:ppo.py
示例3: train_step
def train_step(self, cases, weights, caching):
if len(cases) != len(weights):
raise ValueError('cases and weights must have the same length.')
if len(cases) == 0:
#logging.warn('Training on zero cases.')
print >> sys.stderr, " WARNING: Zero cases \033[F"
# still increment the step
sess = tf.get_default_session()
sess.run(self._increment_step)
elif not self._max_batch_size or len(cases) <= self._max_batch_size:
print >> sys.stderr, " Updating ({} cases) \033[F".format(len(cases))
self.compute(self._take_step, cases, weights, caching)
else:
print >> sys.stderr, " Updating ({} cases) \033[F".format(len(cases))
assert not caching
grads = None
slices = range(0, len(cases), self._max_batch_size)
for i in verboserate(slices, desc='Computing gradients ({} cases)'.format(len(cases))):
cases_slice = cases[i:i + self._max_batch_size]
weights_slice = weights[i:i + self._max_batch_size]
grads_slice = self.compute(self._grad_tensors,
cases_slice, weights_slice, False)
if grads is None:
grads = grads_slice
else:
for i in xrange(len(self._grad_tensors)):
grads[i] += grads_slice[i]
sess = tf.get_default_session()
feed_dict = dict(zip(self._combined_grad_placeholders, grads))
sess.run(self._apply_gradients, feed_dict)
sess.run(self._increment_step)
开发者ID:siddk,项目名称:lang2program,代码行数:32,代码来源:parse_model.py
示例4: main
def main(args):
with tf.Graph().as_default():
with tf.Session() as sess:
# Load the model metagraph and checkpoint
print('Model directory: %s' % args.model_dir)
meta_file, ckpt_file = facenet.get_model_filenames(os.path.expanduser(args.model_dir))
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
model_dir_exp = os.path.expanduser(args.model_dir)
saver = tf.train.import_meta_graph(os.path.join(model_dir_exp, meta_file), clear_devices=True)
tf.get_default_session().run(tf.global_variables_initializer())
tf.get_default_session().run(tf.local_variables_initializer())
saver.restore(tf.get_default_session(), os.path.join(model_dir_exp, ckpt_file))
# Retrieve the protobuf graph definition and fix the batch norm nodes
input_graph_def = sess.graph.as_graph_def()
# Freeze the graph def
output_graph_def = freeze_graph_def(sess, input_graph_def, 'embeddings')
# Serialize and dump the output graph to the filesystem
with tf.gfile.GFile(args.output_file, 'wb') as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph: %s" % (len(output_graph_def.node), args.output_file))
开发者ID:citysir,项目名称:facenet,代码行数:26,代码来源:freeze_graph.py
示例5: get_session
def get_session():
"""Get the globally defined TensorFlow session.
If the session is not already defined, then the function will create
a global session.
Returns:
_ED_SESSION: tf.InteractiveSession.
"""
global _ED_SESSION
if tf.get_default_session() is None:
_ED_SESSION = tf.InteractiveSession()
else:
_ED_SESSION = tf.get_default_session()
save_stderr = sys.stderr
try:
import os
sys.stderr = open(os.devnull, 'w') # suppress keras import
from keras import backend as K
sys.stderr = save_stderr
have_keras = True
except ImportError:
sys.stderr = save_stderr
have_keras = False
if have_keras:
K.set_session(_ED_SESSION)
return _ED_SESSION
开发者ID:JoyceYa,项目名称:edward,代码行数:29,代码来源:graphs.py
示例6: fit
def fit(self, xs, ys):
if self.normalize_inputs:
# recompute normalizing constants for inputs
new_mean = np.mean(xs, axis=0, keepdims=True)
new_std = np.std(xs, axis=0, keepdims=True) + 1e-8
tf.get_default_session().run(tf.group(
tf.assign(self.x_mean_var, new_mean),
tf.assign(self.x_std_var, new_std),
))
if self.use_trust_region and self.first_optimized:
old_prob = self.f_prob(xs)
inputs = [xs, ys, old_prob]
optimizer = self.tr_optimizer
else:
inputs = [xs, ys]
optimizer = self.optimizer
loss_before = optimizer.loss(inputs)
if self.name:
prefix = self.name + "_"
else:
prefix = ""
logger.record_tabular(prefix + 'LossBefore', loss_before)
optimizer.optimize(inputs)
loss_after = optimizer.loss(inputs)
logger.record_tabular(prefix + 'LossAfter', loss_after)
logger.record_tabular(prefix + 'dLoss', loss_before - loss_after)
self.first_optimized = True
开发者ID:flyers,项目名称:rllab,代码行数:27,代码来源:categorical_mlp_regressor.py
示例7: test_lookup_activations
def test_lookup_activations(self):
x = tf.constant(-1.0, shape=[2, 2])
with self.test_session():
activations = ['relu','prelu','selu','crelu']
for activation in activations:
activation = ops.lookup(activation)(x)
tf.get_default_session().run(tf.global_variables_initializer())
self.assertNotEqual(x.eval()[0][0], activation.eval()[0][0])
开发者ID:255BITS,项目名称:hyperchamber-gan,代码行数:10,代码来源:ops_test.py
示例8: restore_trainer
def restore_trainer(self, filename):
'''
Load the training progress (including the model)
Args:
filename: path where the model will be saved
'''
self.modelsaver.restore(tf.get_default_session(), filename)
self.saver.restore(tf.get_default_session(), filename + '_trainvars')
开发者ID:vrenkens,项目名称:tfkaldi,代码行数:10,代码来源:trainer.py
示例9: test_logging_trainable
def test_logging_trainable(self):
with tf.Graph().as_default() as g, self.test_session(g):
var = tf.Variable(tf.constant(42.0), name='foo')
var.initializer.run()
cof = tf.constant(1.0)
loss = tf.sub(tf.mul(var, cof), tf.constant(1.0))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
tf.get_default_session().run(train_step)
self._run_monitor(learn.monitors.LoggingTrainable('foo'))
self.assertRegexpMatches(str(self.logged_message), var.name)
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:10,代码来源:monitors_test.py
示例10: fit
def fit(self, paths, policy=None, batch_size=32, max_itrs=100, logger=None, lr=1e-3,**kwargs):
#self._compute_path_probs(paths, insert=True)
self.eval_expert_probs(paths, policy, insert=True)
self.eval_expert_probs(self.expert_trajs, policy, insert=True)
obs, acts, path_probs = self.extract_paths(paths, keys=('observations', 'actions', 'a_logprobs'))
expert_obs, expert_acts, expert_probs = self.extract_paths(self.expert_trajs, keys=('observations', 'actions', 'a_logprobs'))
# Train discriminator
for it in TrainingIterator(max_itrs, heartbeat=5):
obs_batch, act_batch, lprobs_batch = \
self.sample_batch(obs, acts, path_probs, batch_size=batch_size)
expert_obs_batch, expert_act_batch, expert_lprobs_batch = \
self.sample_batch(expert_obs, expert_acts, expert_probs, batch_size=batch_size)
labels = np.zeros((batch_size*2, 1))
labels[batch_size:] = 1.0
obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)
loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict={
self.act_t: act_batch,
self.obs_t: obs_batch,
self.labels: labels,
self.lprobs: lprobs_batch,
self.lr: lr
})
it.record('loss', loss)
if it.heartbeat:
print(it.itr_message())
mean_loss = it.pop_mean('loss')
print('\tLoss:%f' % mean_loss)
if logger:
energy, logZ, dtau = tf.get_default_session().run([self.energy, self.value_fn, self.d_tau],
feed_dict={self.act_t: acts, self.obs_t: obs,
self.lprobs: np.expand_dims(path_probs, axis=1)})
logger.record_tabular('IRLLogZ', np.mean(logZ))
logger.record_tabular('IRLAverageEnergy', np.mean(energy))
logger.record_tabular('IRLAverageLogPtau', np.mean(-energy-logZ))
logger.record_tabular('IRLAverageLogQtau', np.mean(path_probs))
logger.record_tabular('IRLMedianLogQtau', np.median(path_probs))
logger.record_tabular('IRLAverageDtau', np.mean(dtau))
energy, logZ, dtau = tf.get_default_session().run([self.energy, self.value_fn, self.d_tau],
feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs,
self.lprobs: np.expand_dims(expert_probs, axis=1)})
logger.record_tabular('IRLAverageExpertEnergy', np.mean(energy))
logger.record_tabular('IRLAverageExpertLogPtau', np.mean(-energy-logZ))
logger.record_tabular('IRLAverageExpertLogQtau', np.mean(expert_probs))
logger.record_tabular('IRLMedianExpertLogQtau', np.median(expert_probs))
logger.record_tabular('IRLAverageExpertDtau', np.mean(dtau))
return mean_loss
开发者ID:saadmahboob,项目名称:inverse_rl,代码行数:54,代码来源:imitation_learning.py
示例11: get_session
def get_session():
global _session
# Build/retrieve the session if it doesn't exist
if _session is None:
if tf.get_default_session() is not None:
_session = tf.get_default_session()
else:
_session = tf.Session()
return _session
开发者ID:CloudBreadPaPa,项目名称:tensorrec,代码行数:11,代码来源:session_management.py
示例12: test_preserves_existing_session
def test_preserves_existing_session(self):
with tf.Session() as sess:
op = tf.reduce_sum([2, 2])
self.assertIs(sess, tf.get_default_session())
result = self._square(123)
self.assertEqual(123 * 123, result)
self.assertIs(sess, tf.get_default_session())
number_of_lights = sess.run(op)
self.assertEqual(number_of_lights, 4)
开发者ID:jlewi,项目名称:tensorboard,代码行数:11,代码来源:util_test.py
示例13: zero_model_gradient_accumulators
def zero_model_gradient_accumulators(cls) -> None:
zero_operations = [
tf.get_default_graph().get_operation_by_name(
'{}/zero_model_gradient_accumulators'.format(
variable_scope_name))
for variable_scope_name in [
'empty_statistic',
'move_rate',
'game_state_as_update',
'updated_statistic',
'updated_update',
'cost_function']]
tf.get_default_session().run(zero_operations)
开发者ID:thomasste,项目名称:ugtsa,代码行数:14,代码来源:model_builder.py
示例14: predict_with_three_models_on_hashtags
def predict_with_three_models_on_hashtags(hashtag_dir, hashtag_emb_dir, trial_hashtag_names, labels_exist=True):
# eval_hashtag_names = get_hashtag_file_names(SEMEVAL_HUMOR_EVAL_DIR)
emb_char_predictions = []
emb_predictions = []
char_predictions = []
per_hashtag_first_tweet_ids = []
per_hashtag_second_tweet_ids = []
K.clear_session()
K.set_session(tf.get_default_session())
hp1 = humor_predictor.HumorPredictor(EMB_CHAR_HUMOR_MODEL_DIR, use_emb_model=True, use_char_model=True)
for trial_hashtag_name in trial_hashtag_names:
np_predictions, np_output_prob, np_labels, first_tweet_ids, second_tweet_ids = hp1(hashtag_dir,
trial_hashtag_name)
emb_char_predictions.append(np_output_prob)
per_hashtag_first_tweet_ids.append(first_tweet_ids)
per_hashtag_second_tweet_ids.append(second_tweet_ids)
K.clear_session()
K.set_session(tf.get_default_session())
hp2 = humor_predictor.HumorPredictor(EMB_HUMOR_MODEL_DIR, use_emb_model=True, use_char_model=False)
for trial_hashtag_name in trial_hashtag_names:
np_predictions, np_output_prob, np_labels, first_tweet_ids, second_tweet_ids = hp2(hashtag_dir,
trial_hashtag_name)
emb_predictions.append(np_output_prob)
K.clear_session()
K.set_session(tf.get_default_session())
hp3 = humor_predictor.HumorPredictor(CHAR_HUMOR_MODEL_DIR, use_emb_model=False, use_char_model=True)
for trial_hashtag_name in trial_hashtag_names:
np_predictions, np_output_prob, np_labels, first_tweet_ids, second_tweet_ids = hp3(hashtag_dir,
trial_hashtag_name)
char_predictions.append(np_output_prob)
all_predictions = []
for i in range(len(trial_hashtag_names)):
hashtag_all_predictions = np.concatenate(
[np.reshape(emb_char_predictions[i], [-1, 1]), np.reshape(emb_predictions[i], [-1, 1]), np.reshape(char_predictions[i], [-1, 1])], axis=1)
all_predictions.append(hashtag_all_predictions)
hashtag_labels = None
if labels_exist:
hashtag_labels = []
for hashtag_name in trial_hashtag_names:
print 'Loading label for hashtag %s' % hashtag_name
np_first_tweets, np_second_tweets, np_labels, first_tweet_ids, second_tweet_ids, np_hashtag = \
load_hashtag_data(hashtag_emb_dir, hashtag_name)
hashtag_labels.append(np_labels)
return all_predictions, hashtag_labels, per_hashtag_first_tweet_ids, per_hashtag_second_tweet_ids
开发者ID:text-machine-lab,项目名称:ht_wars,代码行数:50,代码来源:humor_ensemble_processing2.py
示例15: main
def main(args):
with tf.Graph().as_default():
with tf.Session() as sess:
# Load the model metagraph and checkpoint
print('Model directory: %s' % args.model_dir)
meta_file, ckpt_file = facenet.get_model_filenames(os.path.expanduser(args.model_dir))
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
model_dir_exp = os.path.expanduser(args.model_dir)
saver = tf.train.import_meta_graph(os.path.join(model_dir_exp, meta_file), clear_devices=True)
tf.get_default_session().run(tf.global_variables_initializer())
tf.get_default_session().run(tf.local_variables_initializer())
saver.restore(tf.get_default_session(), os.path.join(model_dir_exp, ckpt_file))
# Retrieve the protobuf graph definition and fix the batch norm nodes
gd = sess.graph.as_graph_def()
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Get the list of important nodes
output_node_names = 'embeddings'
whitelist_names = []
for node in gd.node:
if node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or node.name.startswith('phase_train'):
print(node.name)
whitelist_names.append(node.name)
# Replace all the variables in the graph with constants of the same values
output_graph_def = graph_util.convert_variables_to_constants(
sess, gd, output_node_names.split(","),
variable_names_whitelist=whitelist_names)
# Serialize and dump the output graph to the filesystem
with tf.gfile.GFile(args.output_file, 'wb') as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
开发者ID:billtiger,项目名称:CATANA,代码行数:48,代码来源:freeze_graph.py
示例16: get_grad
def get_grad(self, obs, actions, gaes, rewards, v_preds_next):
return tf.get_default_session().run(self.gradients, feed_dict={self.Policy.obs: obs,
self.Old_Policy.obs: obs,
self.actions: actions,
self.rewards: rewards,
self.v_preds_next: v_preds_next,
self.gaes: gaes})
开发者ID:6-Billionaires,项目名称:gail_ppo_optimizer,代码行数:7,代码来源:ppo.py
示例17: _trigger_epoch
def _trigger_epoch(self):
try:
if not self.meta_graph_written:
self.saver.export_meta_graph(
os.path.join(logger.LOG_DIR,
'graph-{}.meta'.format(logger.get_time_str())),
collection_list=self.graph.get_all_collection_keys())
self.meta_graph_written = True
self.saver.save(
tf.get_default_session(),
self.path,
global_step=self.global_step,
write_meta_graph=False)
# create a symbolic link for the latest model
latest = self.saver.last_checkpoints[-1]
basename = os.path.basename(latest)
linkname = os.path.join(os.path.dirname(latest), 'latest')
try:
os.unlink(linkname)
except OSError:
pass
os.symlink(basename, linkname)
except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!")
开发者ID:Paseam,项目名称:tensorpack,代码行数:25,代码来源:common.py
示例18: get_global_step_value
def get_global_step_value():
"""
Returns:
int: global_step value in current graph and session"""
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
开发者ID:tobyma,项目名称:tensorpack,代码行数:7,代码来源:common.py
示例19: get_param_shapes
def get_param_shapes(self, **tags):
tag_tuple = tuple(sorted(list(tags.items()), key=lambda x: x[0]))
if tag_tuple not in self._cached_param_shapes:
params = self.get_params(**tags)
param_values = tf.get_default_session().run(params)
self._cached_param_shapes[tag_tuple] = [val.shape for val in param_values]
return self._cached_param_shapes[tag_tuple]
开发者ID:QuantCollective,项目名称:maml_rl,代码行数:7,代码来源:parameterized.py
示例20: test_outputs
def test_outputs(self, model, inputs, output_tensors, outputs):
"""Test for correct output."""
sess = tf.get_default_session()
guarantee_initialized_variables(sess)
args, kwargs = inputs
test_outputs = model.compute(output_tensors, *args, **kwargs)
assert_array_collections_equal(outputs, test_outputs, decimal=4)
开发者ID:siddk,项目名称:lang2program,代码行数:7,代码来源:test_framework.py
注:本文中的tensorflow.get_default_session函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论