本文整理汇总了Python中tensorflow.python.framework.ops.control_dependencies函数的典型用法代码示例。如果您正苦于以下问题:Python control_dependencies函数的具体用法?Python control_dependencies怎么用?Python control_dependencies使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了control_dependencies函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _Update_global_variables
def _Update_global_variables():
global_norm = []
# a = a / t
for g in grad_vars:
global_norm.append(state_ops.assign(g, g / self._period))
# apply
with ops.control_dependencies(global_norm):
apply_global_op = self._opt.apply_gradients(
zip(grad_vars, global_center_vars))
# pull
with ops.control_dependencies([apply_global_op]):
update_ops = []
if global_step:
with ops.colocate_with(global_step):
update_ops.append(state_ops.assign_add(global_step, 1))
for lvar in local_vars:
g_val = self._global_map[lvar].read_value()
update_ops.append(state_ops.assign(lvar, g_val))
for grad_var in grad_vars:
update_ops.append(
state_ops.assign(grad_var, array_ops.zeros_like(grad_var)))
variable_update = control_flow_ops.group(*(update_ops))
return variable_update
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:25,代码来源:agn_optimizer.py
示例2: testCaching
def testCaching(self):
"""Confirm caching of control output is recalculated between calls."""
a = constant_op.constant(1)
b = constant_op.constant(2)
with ops.control_dependencies([a]):
c = constant_op.constant(42)
shared = {}
def sub(t):
shared[t] = shared.get(t, 0) + 1
return t
a = subscribe.subscribe(a,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
with ops.control_dependencies([b]):
d = constant_op.constant(11)
# If it was using outdated cached control_outputs then
# evaling would not trigger the new subscription.
b = subscribe.subscribe(b,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
with self.cached_session() as sess:
c_out = self.evaluate([c])
d_out = self.evaluate([d])
self.assertEqual(c_out, [42])
self.assertEqual(d_out, [11])
self.assertEqual(shared, {2: 1, 1: 1})
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:31,代码来源:subscribe_test.py
示例3: testIgnoredArguments
def testIgnoredArguments(self):
"""Tests that JIT computations can ignore formal parameters."""
with self.session(config=NoRewriteSessionConfig()) as sess:
x = array_ops.placeholder(dtypes.int32)
y = array_ops.placeholder(dtypes.int32)
with jit_scope():
z = math_ops.add(x, x)
w = math_ops.add(y, y)
# Pulls 'w' into the same compilation via control dependencies.
with ops.control_dependencies([w]):
n = control_flow_ops.no_op()
with ops.control_dependencies([n]):
t = math_ops.add(z, z)
run_metadata = config_pb2.RunMetadata()
out = test_utils.RunWithWarmup(
sess,
t, {
x: np.int32(7),
y: np.int32(404)
},
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
self.assert_(MetadataHasXlaRunOp(run_metadata))
self.assertAllClose(28, out)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:27,代码来源:jit_test.py
示例4: _get_train_ops
def _get_train_ops(self, features, targets):
"""See base class."""
global_step = contrib_variables.get_global_step()
assert global_step
logits = self._logits(features, is_training=True)
if self._enable_centered_bias:
centered_bias_step = [self._centered_bias_step(targets, features)]
else:
centered_bias_step = []
with ops.control_dependencies(centered_bias_step):
loss = self._loss(logits, targets, features)
logging_ops.scalar_summary("loss", loss)
linear_vars = self._get_linear_vars()
dnn_vars = self._get_dnn_vars()
grads = gradients.gradients(loss, dnn_vars + linear_vars)
if self._gradient_clip_norm:
grads, _ = clip_ops.clip_by_global_norm(grads, self._gradient_clip_norm)
dnn_grads = grads[0 : len(dnn_vars)]
linear_grads = grads[len(dnn_vars) :]
train_ops = self._get_linear_training_ops(linear_grads, linear_vars) + self._get_dnn_training_ops(
dnn_grads, dnn_vars
)
train_step = control_flow_ops.group(*train_ops, name="combined_training_op")
with ops.control_dependencies([train_step]):
with ops.get_default_graph().colocate_with(global_step):
return state_ops.assign_add(global_step, 1).op, loss
开发者ID:285219011,项目名称:liuwenfeng,代码行数:30,代码来源:dnn_linear_combined.py
示例5: _fn
def _fn():
x = constant_op.constant(points)
if batch_size == num_points:
return input_lib.limit_epochs(x, num_epochs=num_epochs), None
if randomize:
indices = random_ops.random_uniform(
constant_op.constant([batch_size]),
minval=0,
maxval=num_points - 1,
dtype=dtypes.int32,
seed=10)
else:
# We need to cycle through the indices sequentially. We create a queue
# to maintain the list of indices.
q = data_flow_ops.FIFOQueue(num_points, dtypes.int32, ())
# Conditionally initialize the Queue.
def _init_q():
with ops.control_dependencies(
[q.enqueue_many(math_ops.range(num_points))]):
return control_flow_ops.no_op()
init_q = control_flow_ops.cond(q.size() <= 0, _init_q,
control_flow_ops.no_op)
with ops.control_dependencies([init_q]):
offsets = q.dequeue_many(batch_size)
with ops.control_dependencies([q.enqueue_many(offsets)]):
indices = array_ops.identity(offsets)
batch = array_ops.gather(x, indices)
return (input_lib.limit_epochs(batch, num_epochs=num_epochs), None)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:30,代码来源:kmeans_test.py
示例6: testTensorArrayReadTwice
def testTensorArrayReadTwice(self):
with self.test_session(use_gpu=True):
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
ta_readonce = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=2)
w_readonce = ta_readonce.unstack(value)
r0_readonce = w_readonce.read(0)
with ops.control_dependencies([r0_readonce]):
r1_readonce = w_readonce.read(0)
with self.assertRaisesOpError(
r"Could not read index 0 twice because it was cleared after a "
r"previous read \(perhaps try setting clear_after_read = false\?\)"):
r1_readonce.eval()
ta_readtwice = tensor_array_ops.TensorArray(
dtype=dtypes.float32,
tensor_array_name="foo",
size=2,
clear_after_read=False)
w_readtwice = ta_readtwice.unstack(value)
r0_readtwice = w_readtwice.read(0)
with ops.control_dependencies([r0_readtwice]):
r1_readtwice = w_readtwice.read(0)
self.assertAllEqual([1.0, -1.0], r1_readtwice.eval())
开发者ID:jzuern,项目名称:tensorflow,代码行数:28,代码来源:tensor_array_ops_test.py
示例7: testMultidimensionalAcculumator
def testMultidimensionalAcculumator(self):
with self.test_session() as sess:
accumulator = stats_accumulator_ops.StatsAccumulator(
stamp_token=0,
gradient_shape=tensor_shape.scalar(),
hessian_shape=tensor_shape.scalar())
with ops.control_dependencies([accumulator._create_op]):
op1 = accumulator.add(
stamp_token=0,
partition_ids=[1, 2, 1],
feature_ids=[[2, 2], [3, 0], [2, 2]],
gradients=[0.1, 0.3, 0.8],
hessians=[0.2, 0.4, -9])
op2 = accumulator.add(0, [2, 1], [[3, 1], [2, 2]], [0.1, 1], [0.2, -1])
with ops.control_dependencies([op1, op2]):
num_updates, partition, bucket_ids, grads, hessians = accumulator.flush(
stamp_token=0, next_stamp_token=1)
num_updates, partition, bucket_ids, grads, hessians = sess.run(
[num_updates, partition, bucket_ids, grads, hessians])
result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians)
self.assertEqual(num_updates, 2)
self.assertEqual(len(result), 3)
# Key is partion, bucket, dimension.
self.assertAllClose(result[(1, 2, 2)], [1.9, -9.8])
self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4])
self.assertAllClose(result[(2, 3, 1)], [0.1, 0.2])
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:28,代码来源:stats_accumulator_ops_test.py
示例8: test_train_max_steps_is_not_incremental
def test_train_max_steps_is_not_incremental(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=15)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(15, step)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:graph_actions_test.py
示例9: _resource_apply_sparse
def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
beta_1_t = self._get_hyper('beta_1', var_dtype)
beta_2_t = self._get_hyper('beta_2', var_dtype)
local_step = math_ops.cast(self.iterations + 1, var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
epsilon_t = self._get_hyper('epsilon', var_dtype)
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_slice = array_ops.gather(m, indices)
m_t_slice = m_slice * beta_1_t + grad * (1 - beta_1_t)
with ops.control_dependencies([m_t_slice]):
m_t = self._resource_scatter_update(m, indices, m_t_slice)
# u_t = max(beta2 * u, abs(g_t))
v = self.get_slot(var, 'v')
v_slice = array_ops.gather(v, indices)
v_t_slice = math_ops.maximum(v_slice * beta_2_t, math_ops.abs(grad))
with ops.control_dependencies([v_t_slice]):
v_t = self._resource_scatter_update(v, indices, v_t_slice)
# theta_t = theta - lr / (1 - beta1^t) * m_t / u_t
var_slice = -lr_t / (1 - beta_1_power) * (
m_t_slice / (v_t_slice + epsilon_t))
with ops.control_dependencies([var_slice]):
var_update = self._resource_scatter_add(var, indices, var_slice)
return control_flow_ops.group(*[var_update, m_t, v_t])
开发者ID:aeverall,项目名称:tensorflow,代码行数:29,代码来源:adamax.py
示例10: _resource_apply_sparse
def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
beta_1_t = self._get_hyper('beta_1', var_dtype)
beta_2_t = self._get_hyper('beta_2', var_dtype)
local_step = math_ops.cast(self.iterations + 1, var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
beta_2_power = math_ops.pow(beta_2_t, local_step)
epsilon_t = self._get_hyper('epsilon', var_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
m_scaled_g_values = grad * (1 - beta_1_t)
m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
# m_bar = (1 - beta1) * g_t + beta1 * m_t
m_bar = m_scaled_g_values + beta_1_t * array_ops.gather(m_t, indices)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, 'v')
v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
v_t_slice = array_ops.gather(v_t, indices)
v_sqrt = math_ops.sqrt(v_t_slice)
var_update = self._resource_scatter_add(var, indices,
-lr * m_bar / (v_sqrt + epsilon_t))
return control_flow_ops.group(*[var_update, m_bar, v_t])
开发者ID:aeverall,项目名称:tensorflow,代码行数:32,代码来源:nadam.py
示例11: test_train_skip_train_if_max_step_already_saved
def test_train_skip_train_if_max_step_already_saved(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:graph_actions_test.py
示例12: get_best
def get_best(self, n):
"""Return the indices and values of the n highest scores in the TopN."""
def refresh_shortlist():
"""Update the shortlist with the highest scores in id_to_score."""
new_scores, new_ids = nn_ops.top_k(self.id_to_score, self.shortlist_size)
smallest_new_score = math_ops.reduce_min(new_scores)
new_length = math_ops.reduce_sum(
math_ops.to_int32(math_ops.greater(new_scores, dtypes.float32.min)))
u1 = self.sl_ids.assign(
math_ops.to_int64(array_ops.concat([[new_length], new_ids], 0)))
u2 = self.sl_scores.assign(
array_ops.concat([[smallest_new_score], new_scores], 0))
self.last_ops = [u1, u2]
return control_flow_ops.group(u1, u2)
# We only need to refresh the shortlist if n is greater than the
# current shortlist size (which is stored in sl_ids[0]).
with ops.control_dependencies(self.last_ops):
cond_op = control_flow_ops.cond(n > self.sl_ids[0], refresh_shortlist,
control_flow_ops.no_op)
with ops.control_dependencies([cond_op]):
topk_values, topk_indices = nn_ops.top_k(
self.sl_scores,
math_ops.minimum(n, math_ops.to_int32(self.sl_ids[0])))
# topk_indices are the indices into the shortlist, we want to return
# the indices into id_to_score
gathered_indices = array_ops.gather(self.sl_ids, topk_indices)
return gathered_indices, topk_values
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:29,代码来源:topn.py
示例13: _check_labels
def _check_labels(labels, expected_labels_dimension):
"""Check labels type and shape."""
with ops.name_scope(None, 'labels', (labels,)) as scope:
labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
if isinstance(labels, sparse_tensor.SparseTensor):
raise ValueError('SparseTensor labels are not supported.')
labels_shape = array_ops.shape(labels)
err_msg = 'labels shape must be [batch_size, {}]'.format(
expected_labels_dimension)
assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
with ops.control_dependencies([assert_rank]):
static_shape = labels.shape
if static_shape is not None:
dim1 = static_shape[1]
if (dim1 is not None) and (dim1 != expected_labels_dimension):
raise ValueError(
'Mismatched label shape. '
'Classifier configured with n_classes=%s. Received %s. '
'Suggested Fix: check your n_classes argument to the estimator '
'and/or the shape of your label.' %
(expected_labels_dimension, dim1))
assert_dimension = check_ops.assert_equal(
expected_labels_dimension, labels_shape[1], message=err_msg)
with ops.control_dependencies([assert_dimension]):
return array_ops.identity(labels, name=scope)
开发者ID:cneeruko,项目名称:tensorflow,代码行数:25,代码来源:head.py
示例14: _get_sparse_tensors
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
id_tensor = sparse_tensors.id_tensor
weight_tensor = sparse_tensors.weight_tensor
# Expands final dimension, so that embeddings are not combined during
# embedding lookup.
check_id_rank = check_ops.assert_equal(
array_ops.rank(id_tensor), 2,
data=[
'Column {} expected ID tensor of rank 2. '.format(self.name),
'id_tensor shape: ', array_ops.shape(id_tensor)])
with ops.control_dependencies([check_id_rank]):
id_tensor = sparse_ops.sparse_reshape(
id_tensor,
shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
if weight_tensor is not None:
check_weight_rank = check_ops.assert_equal(
array_ops.rank(weight_tensor), 2,
data=[
'Column {} expected weight tensor of rank 2.'.format(self.name),
'weight_tensor shape:', array_ops.shape(weight_tensor)])
with ops.control_dependencies([check_weight_rank]):
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor,
shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:27,代码来源:sequential_feature_column.py
示例15: update_weights
def update_weights(self, train_op):
"""Updates the model weights.
This function must be called on at least one worker after `minimize`.
In distributed training this call can be omitted on non-chief workers to
speed up training.
Args:
train_op: The operation returned by the `minimize` call.
Returns:
An Operation that updates the model weights.
"""
with ops.control_dependencies([train_op]):
update_ops = []
# Copy over unshrinked weights to user provided variables.
for name in ['sparse_features_weights', 'dense_features_weights']:
for var, slot_var in zip(self._variables[name],
self._slots['unshrinked_' + name]):
update_ops.append(var.assign(slot_var))
# Apply proximal step.
with ops.control_dependencies(update_ops):
update_ops = []
for name in ['sparse_features_weights', 'dense_features_weights']:
for var in self._variables[name]:
with ops.device(var.device):
update_ops.append(
sdca_shrink_l1(
self._convert_n_to_tensor(
[var], as_ref=True),
l1=self._symmetric_l1_regularization(),
l2=self._symmetric_l2_regularization()))
return control_flow_ops.group(*update_ops)
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:34,代码来源:sdca_ops.py
示例16: testAssertIntegerForm
def testAssertIntegerForm(self):
# This should only be detected as an integer.
x = [1., 5, 10, 15, 20]
y = [1.1, 5, 10, 15, 20]
# First component isn't less than float32.eps = 1e-7
z = [1.0001, 5, 10, 15, 20]
# This shouldn"t be detected as an integer.
w = [1e-8, 5, 10, 15, 20]
with self.test_session():
with ops.control_dependencies([distribution_util.assert_integer_form(x)]):
array_ops.identity(x).eval()
with self.assertRaisesOpError("x has non-integer components"):
with ops.control_dependencies(
[distribution_util.assert_integer_form(y)]):
array_ops.identity(y).eval()
with self.assertRaisesOpError("x has non-integer components"):
with ops.control_dependencies(
[distribution_util.assert_integer_form(z)]):
array_ops.identity(z).eval()
with self.assertRaisesOpError("x has non-integer components"):
with ops.control_dependencies(
[distribution_util.assert_integer_form(w)]):
array_ops.identity(w).eval()
开发者ID:Immexxx,项目名称:tensorflow,代码行数:26,代码来源:distribution_util_test.py
示例17: testAssertIntegerForm
def testAssertIntegerForm(self):
# This should only be detected as an integer.
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
# First component isn't less than float32.eps = 1e-7
z = array_ops.placeholder(dtypes.float32)
# This shouldn"t be detected as an integer.
w = array_ops.placeholder(dtypes.float32)
feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20],
z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]}
with self.test_session():
with ops.control_dependencies([distribution_util.assert_integer_form(x)]):
array_ops.identity(x).eval(feed_dict=feed_dict)
with self.assertRaisesOpError("x has non-integer components"):
with ops.control_dependencies(
[distribution_util.assert_integer_form(y)]):
array_ops.identity(y).eval(feed_dict=feed_dict)
with self.assertRaisesOpError("x has non-integer components"):
with ops.control_dependencies(
[distribution_util.assert_integer_form(z)]):
array_ops.identity(z).eval(feed_dict=feed_dict)
with self.assertRaisesOpError("x has non-integer components"):
with ops.control_dependencies(
[distribution_util.assert_integer_form(w)]):
array_ops.identity(w).eval(feed_dict=feed_dict)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:28,代码来源:distribution_util_test.py
示例18: _apply_sparse_shared
def _apply_sparse_shared(self, grad, var, indices,
scatter_add, scatter_update):
beta1_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_slice = array_ops.gather(m, indices)
m_t_slice = m_slice * beta1_t + grad * (1 - beta1_t)
with ops.control_dependencies([m_t_slice]):
m_t = scatter_update(m, indices, m_t_slice)
# u_t = max(beta2 * u, abs(g_t))
v = self.get_slot(var, "v")
v_slice = array_ops.gather(v, indices)
v_t_slice = math_ops.maximum(v_slice * beta2_t, math_ops.abs(grad))
with ops.control_dependencies([v_t_slice]):
v_t = scatter_update(v, indices, v_t_slice)
# theta_t = theta - lr / (1 - beta1^t) * m_t / u_t
var_slice = -lr_t / (1 - beta1_power) * (m_t_slice /
(v_t_slice + epsilon_t))
with ops.control_dependencies([var_slice]):
var_update = scatter_add(var, indices, var_slice)
return control_flow_ops.group(*[var_update, m_t, v_t])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:26,代码来源:adamax.py
示例19: _capture_tensor_as_extra_input
def _capture_tensor_as_extra_input(self, tensor, name=None):
# Substitute with a placeholder.
self.extra_inputs.append(tensor)
# Hoist the new input placeholder out of any control flow context
# we're currently in.
with ops.control_dependencies(None):
ph = array_ops.placeholder(
tensor.dtype, shape=tensor.get_shape(), name=name)
# pylint: disable=protected-access
if ops._USE_C_SHAPES:
handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph,
tensor._as_tf_output())
if handle_data:
c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
ph._as_tf_output(),
compat.as_bytes(handle_data))
else:
ph._handle_data = tensor._handle_data
# pylint: enable=protected-access
self.inputs.append(ph)
self._captured[tensor] = ph
self.extra_args.append(ph)
if _is_guaranteed_const(tensor):
with ops.control_dependencies(None):
return array_ops.guarantee_const(ph)
else:
return ph
开发者ID:didukhle,项目名称:tensorflow,代码行数:27,代码来源:function.py
示例20: testReadWrite
def testReadWrite(self):
"""Tests initialization, reading, and writing a resource variable."""
with self.test_session() as session:
with self.test_scope():
with variable_scope.variable_scope("ascope", use_resource=True):
x = variable_scope.get_variable(
"x",
shape=[],
dtype=dtypes.float32,
initializer=init_ops.constant_initializer(2))
a = x.read_value()
with ops.control_dependencies([a]):
b = state_ops.assign(x, 47)
with ops.control_dependencies([b]):
c = x.read_value()
with ops.control_dependencies([c]):
d = state_ops.assign_add(x, 3)
with ops.control_dependencies([d]):
e = x.read_value()
session.run(variables.global_variables_initializer())
v1, v2, v3 = session.run([a, c, e])
self.assertAllClose(2.0, v1)
self.assertAllClose(47.0, v2)
self.assertAllClose(50.0, v3)
开发者ID:Immexxx,项目名称:tensorflow,代码行数:25,代码来源:variable_ops_test.py
注:本文中的tensorflow.python.framework.ops.control_dependencies函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论