本文整理汇总了Python中tensorflow.python.ops.state_ops.scatter_update函数的典型用法代码示例。如果您正苦于以下问题:Python scatter_update函数的具体用法?Python scatter_update怎么用?Python scatter_update使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了scatter_update函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _apply_sparse
def _apply_sparse(self, grad, var):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
# m := beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_t = state_ops.scatter_update(m, grad.indices,
beta1_t * array_ops.gather(m, grad.indices) +
(1 - beta1_t) * grad.values,
use_locking=self._use_locking)
# v := beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_t = state_ops.scatter_update(v, grad.indices,
beta2_t * array_ops.gather(v, grad.indices) +
(1 - beta2_t) * math_ops.square(grad.values),
use_locking=self._use_locking)
# variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))
m_t_slice = array_ops.gather(m_t, grad.indices)
v_t_slice = array_ops.gather(v_t, grad.indices)
denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
var_update = state_ops.scatter_sub(var, grad.indices,
lr * m_t_slice / denominator_slice,
use_locking=self._use_locking)
return control_flow_ops.group(var_update, m_t, v_t)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:32,代码来源:lazy_adam_optimizer.py
示例2: testScatterUpdateInvalidArgs
def testScatterUpdateInvalidArgs(self):
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update")
# The exact error and message differ between graph construction (where the
# error is realized during shape inference at graph construction time) and
# eager execution (where the error is realized during kernel execution).
with self.assertRaisesRegexp(Exception, r"shape.*2.*3"):
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
开发者ID:aeverall,项目名称:tensorflow,代码行数:7,代码来源:resource_variable_ops_test.py
示例3: _apply_sparse
def _apply_sparse(self, grad, var):
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
# the following equations given in [1]
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_t = state_ops.scatter_update(m, grad.indices,
beta1_t * array_ops.gather(m, grad.indices) +
(1. - beta1_t) * grad.values,
use_locking=self._use_locking)
m_t_slice = tf.gather(m_t, grad.indices)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_t = state_ops.scatter_update(v, grad.indices,
beta2_t * array_ops.gather(v, grad.indices) +
(1. - beta2_t) * tf.square(grad.values),
use_locking=self._use_locking)
v_prime = self.get_slot(var, "v_prime")
v_t_slice = tf.gather(v_t, grad.indices)
v_prime_slice = tf.gather(v_prime, grad.indices)
v_t_prime = state_ops.scatter_update(v_prime, grad.indices, tf.maximum(v_prime_slice, v_t_slice))
v_t_prime_slice = array_ops.gather(v_t_prime, grad.indices)
var_update = state_ops.scatter_sub(var, grad.indices,
lr_t * m_t_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t, v_t_prime])
开发者ID:jkhlot,项目名称:tensorflow-XNN,代码行数:32,代码来源:optimizer.py
示例4: _SparseUpdate
def _SparseUpdate(variable, gradients, accum, linear, base_lr,
lr_power, l1, l2):
"""Sparse Update "variable", "accum", "linear" based on sparse "gradients".
See the description in _Update.
Args:
variable: A Variable.
gradients: A Sparse Tensor
accum: A Variable containing the sum of the squares of gradients.
linear: A Variable containing approximation info.
base_lr: A constant represents base learning rate.
lr_power: A constant is used to adjust learning rate.
l1: A constant represents l1 regularization strength.
l2: A constant represents l2 regularization strength.
Returns:
A group op including three ScatterUpdate ops:
1. ScatterUpdate for "accum"
2. ScatterUpdate for "linear"
3. ScatterUpdate for "variable"
"""
assert isinstance(gradients, ops.IndexedSlices)
with ops.name_scope("sparse_update_" + variable.op.name) as scope:
dtype = variable.dtype.base_dtype
base_lr = ops.convert_to_tensor(base_lr, dtype=dtype)
lr_power = ops.convert_to_tensor(lr_power, dtype=dtype)
l1 = ops.convert_to_tensor(l1, dtype=dtype)
l2 = ops.convert_to_tensor(l2, dtype=dtype)
# Compute the new value for the accumulator
previous_accum = array_ops.gather(accum, gradients.indices)
sqr_grad = gradients.values * gradients.values
accum_updated = sqr_grad + previous_accum
# Compute the new linear
neg_lr_power = math_ops.neg(lr_power)
sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow(
previous_accum, neg_lr_power)
sigma /= base_lr
variable_slice = array_ops.gather(variable, gradients.indices)
proximal_adjust = sigma * variable_slice
linear_slice = array_ops.gather(linear, gradients.indices)
linear_updated = linear_slice + gradients.values - proximal_adjust
# Compute the new "variable"
variable_updated = _Compute(accum_updated, linear_updated, base_lr,
lr_power, l1, l2)
with ops.control_dependencies([sigma]):
accum_update_op = state_ops.scatter_update(accum, gradients.indices,
accum_updated)
linear_update_op = state_ops.scatter_update(linear, gradients.indices,
linear_updated)
variable_update_op = state_ops.scatter_update(variable, gradients.indices,
variable_updated)
group_op = control_flow_ops.group(linear_update_op, accum_update_op,
variable_update_op, name=scope)
return group_op
开发者ID:onexuan,项目名称:TensorflowAndroid,代码行数:59,代码来源:ftrl.py
示例5: testScatterBool
def testScatterBool(self):
with context.eager_mode():
ref = resource_variable_ops.ResourceVariable(
[False, True, False], trainable=False)
indices = math_ops.range(3)
updates = constant_op.constant([True, True, True])
state_ops.scatter_update(ref, indices, updates)
self.assertAllEqual(ref.read_value(), [True, True, True])
开发者ID:aeverall,项目名称:tensorflow,代码行数:8,代码来源:resource_variable_ops_test.py
示例6: shortlist_insert
def shortlist_insert():
larger_ids = array_ops.boolean_mask(
math_ops.to_int64(ids), larger_scores)
larger_score_values = array_ops.boolean_mask(scores, larger_scores)
shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids)
u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores)
return control_flow_ops.group(u1, u2)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:9,代码来源:topn.py
示例7: testBooleanScatterUpdate
def testBooleanScatterUpdate(self):
if not test.is_gpu_available():
with self.test_session(use_gpu=False) as session:
var = variables.Variable([True, False])
update0 = state_ops.scatter_update(var, 1, True)
update1 = state_ops.scatter_update(
var, constant_op.constant(
0, dtype=dtypes.int64), False)
var.initializer.run()
session.run([update0, update1])
self.assertAllEqual([False, True], var.eval())
开发者ID:1000sprites,项目名称:tensorflow,代码行数:13,代码来源:scatter_ops_test.py
示例8: _apply_sparse
def _apply_sparse(self, grad, var):
if len(grad.indices.get_shape()) == 1:
grad_indices = grad.indices
grad_values = grad.values
else:
grad_indices = array_ops.reshape(grad.indices, [-1])
grad_values = array_ops.reshape(grad.values, [-1, grad.values.get_shape()[-1].value])
gidxs, metagidxs = array_ops.unique(grad_indices)
sizegidxs = array_ops.size(gidxs)
gvals = math_ops.unsorted_segment_sum(grad_values, metagidxs, sizegidxs)
# m_t = mu * m + (1 - mu) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = gvals * (1 - self._mu_t)
m_t = state_ops.scatter_update(m, gidxs,
array_ops.gather(m, gidxs) * self._mu_t,
use_locking=self._use_locking)
m_t = state_ops.scatter_add(m_t, gidxs, m_scaled_g_values,
use_locking=self._use_locking)
m_t_ = array_ops.gather(m_t, gidxs) / (1 - self._mu2_t * self._mu_power)
# m_bar = mu * m_t + (1 - mu) * g_t
m_bar = self._mu2_t * m_t_ + m_scaled_g_values / (1 - self._mu_power)
var_update = state_ops.scatter_sub(var, gidxs,
self._lr_t * m_bar,
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t])
开发者ID:MarvinBertin,项目名称:TensorFlow-Algorithms,代码行数:25,代码来源:nesterov.py
示例9: _sparse_moving_average
def _sparse_moving_average(self, x_tm1, idxs, b_t_, name, beta=.9):
"""
Creates a moving average for a sparse variable.
Inputs:
x_tm1: the associated parameter (e.g. a weight matrix)
idxs: the tensor representing the indices used
b_t_: the value to accumulate (e.g. slices of the gradient)
name: a string to use to retrieve it later (e.g. 'm')
beta: the decay factor (defaults to .9)
Outputs:
a_t: the average after moving (same shape as x_tm1, not b_t_)
t: the internal timestep (used to correct initialization bias)
"""
a_tm1 = self._zeros_slot(x_tm1, '%s' % name, self._name)
a_tm1_ = array_ops.gather(a_tm1, idxs)
tm1 = self._zeros_idx_slot(x_tm1, '%s/tm1' % name, self._name)
tm1_ = array_ops.gather(tm1, idxs)
t = state_ops.scatter_add(tm1, idxs, tm1_*0+1, use_locking=self._use_locking)
t_ = array_ops.gather(t, idxs)
if beta < 1:
beta_t = ops.convert_to_tensor(beta, name='%s/decay' % name)
beta_t_ = beta_t * (1-beta_t**tm1_) / (1-beta_t**t_)
else:
beta_t_ = tm1_/t_
a_t = state_ops.scatter_update(a_tm1, idxs, beta_t_*a_tm1_, use_locking=self._use_locking)
a_t = state_ops.scatter_add(a_t, idxs, (1-beta_t)*b_t_, use_locking=self._use_locking)
return a_t, t
开发者ID:tdozat,项目名称:Optimization,代码行数:28,代码来源:optimizers.py
示例10: _apply_sparse
def _apply_sparse(self, grad, var):
return self._apply_sparse_shared(
grad.values, var, grad.indices,
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking),
lambda x, i, v: state_ops.scatter_update( # pylint: disable=g-long-lambda
x, i, v, use_locking=self._use_locking))
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:7,代码来源:adamax.py
示例11: create_axis_ops
def create_axis_ops(sp_input, num_items, update_fn, axis_name):
"""Creates book-keeping and training ops for a given axis.
Args:
sp_input: A SparseTensor corresponding to the row or column batch.
num_items: An integer, the total number of items of this axis.
update_fn: A function that takes one argument (`sp_input`), and that
returns a tuple of
* new_factors: A flot Tensor of the factor values after update.
* update_op: a TensorFlow op which updates the factors.
* loss: A float Tensor, the unregularized loss.
* reg_loss: A float Tensor, the regularization loss.
* sum_weights: A float Tensor, the sum of factor weights.
axis_name: A string that specifies the name of the axis.
Returns:
A tuple consisting of:
* reset_processed_items_op: A TensorFlow op, to be run before the
beginning of any sweep. It marks all items as not-processed.
* axis_train_op: A Tensorflow op, to be run during this axis' sweeps.
"""
processed_items_init = array_ops.fill(dims=[num_items], value=False)
with ops.colocate_with(processed_items_init):
processed_items = variable_scope.variable(
processed_items_init,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="processed_" + axis_name)
reset_processed_items_op = state_ops.assign(
processed_items, processed_items_init,
name="reset_processed_" + axis_name)
_, update_op, loss, reg, sum_weights = update_fn(sp_input)
input_indices = sp_input.indices[:, 0]
with ops.control_dependencies([
update_op,
state_ops.assign(loss_var, loss + reg),
state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]):
with ops.colocate_with(processed_items):
update_processed_items = state_ops.scatter_update(
processed_items,
input_indices,
array_ops.ones_like(input_indices, dtype=dtypes.bool),
name="update_processed_{}_indices".format(axis_name))
with ops.control_dependencies([update_processed_items]):
is_sweep_done = math_ops.reduce_all(processed_items)
axis_train_op = control_flow_ops.group(
global_step_incr_op,
state_ops.assign(is_sweep_done_var, is_sweep_done),
state_ops.assign_add(
completed_sweeps_var,
math_ops.cast(is_sweep_done, dtypes.int32)),
name="{}_sweep_train_op".format(axis_name))
return reset_processed_items_op, axis_train_op
开发者ID:TianyouLi,项目名称:tensorflow,代码行数:53,代码来源:wals.py
示例12: scatter_update
def scatter_update(cls, factor, indices, values, sharding_func):
"""Helper function for doing sharded scatter update."""
assert isinstance(factor, list)
if len(factor) == 1:
with ops.colocate_with(factor[0]):
# TODO(agarwal): assign instead of scatter update for full batch update.
return state_ops.scatter_update(factor[0], indices, values).op
else:
num_shards = len(factor)
assignments, new_ids = sharding_func(indices)
assert assignments is not None
assignments = math_ops.cast(assignments, dtypes.int32)
sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments,
num_shards)
sharded_values = data_flow_ops.dynamic_partition(values, assignments,
num_shards)
updates = []
for i in xrange(num_shards):
updates.append(
state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[
i]))
return control_flow_ops.group(*updates)
开发者ID:kadeng,项目名称:tensorflow,代码行数:22,代码来源:factorization_ops.py
示例13: _resource_apply_dense
def _resource_apply_dense(self, grad, var, state):
self._variables.append(var)
dim = self.shape_dict[var.name]
start_index = self.index_dict[var.name]
end_index = start_index + dim
# Update flat_gradient at the index associated with the variable.
flat_grad = self._get_flat_grad(state)
new_flat_grad = array_ops.reshape(grad, [-1])
flat_grad_updated = state_ops.scatter_update(
flat_grad, math_ops.range(start_index, end_index), new_flat_grad)
return flat_grad_updated
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:13,代码来源:ggt.py
示例14: remove
def remove(self, ids):
"""Remove the ids (and their associated scores) from the TopN."""
with ops.control_dependencies(self.last_ops):
scatter_op = state_ops.scatter_update(
self.id_to_score,
ids,
array_ops.ones_like(
ids, dtype=dtypes.float32) * dtypes.float32.min)
# We assume that removed ids are almost always in the shortlist,
# so it makes no sense to hide the Op behind a tf.cond
shortlist_ids_to_remove, new_length = tensor_forest_ops.top_n_remove(
self.sl_ids, ids)
u1 = state_ops.scatter_update(
self.sl_ids,
array_ops.concat([[0], shortlist_ids_to_remove], 0),
array_ops.concat(
[new_length, array_ops.ones_like(shortlist_ids_to_remove) * -1],
0))
u2 = state_ops.scatter_update(
self.sl_scores,
shortlist_ids_to_remove,
dtypes.float32.min * array_ops.ones_like(
shortlist_ids_to_remove, dtype=dtypes.float32))
self.last_ops = [scatter_op, u1, u2]
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:24,代码来源:topn.py
示例15: testResourceVariableScatterGather
def testResourceVariableScatterGather(self):
c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
l = list_ops.tensor_list_from_tensor(c, element_shape=[])
v = vs.get_variable("var", initializer=[l] * 10, use_resource=True)
v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.float32)
self.evaluate(v.initializer)
self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_0_stacked))
v_r_sparse_stacked = list_ops.tensor_list_stack(
v.sparse_read(0), dtypes.float32)
self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_sparse_stacked))
l_new_0 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
l_new_1 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1])
updated_v_elems = array_ops.unstack(updated_v)
updated_v_stacked = [
list_ops.tensor_list_stack(el, dtypes.float32) for el in updated_v_elems
]
expected = ([[1.0, 2.0]] * 3 + [[3.0, 4.0], [1.0, 2.0], [5.0, 6.0]] +
[[1.0, 2.0]] * 4)
self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
开发者ID:aeverall,项目名称:tensorflow,代码行数:20,代码来源:list_ops_test.py
示例16: ScatterUpdateGrads
def ScatterUpdateGrads(op, grad):
var, indices, updates = op.inputs
updates_grad = array_ops.gather(grad, indices)
# dynamic stitch approach (this seems to be a bit slower)
# grad_range = math_ops.range(grad.get_shape()[0].value)
# var_grad = data_flow_ops.dynamic_stitch(
# [grad_range, indices],
# [grad, array_ops.zeros(updates.get_shape())])
if isinstance(grad, ops.IndexedSlices):
# note: we could use this approach for everything, but the
# temporary variable approach seems to be slightly faster (but we
# can't use that on indexedslices)
var_grad = grad - array_ops.scatter_nd(
array_ops.expand_dims(indices, 1), updates_grad,
var.get_shape())
else:
shape = tuple(grad.get_shape().as_list())
dtype = grad.dtype.base_dtype
with variable_scope.variable_scope(
"gradient_vars", reuse=variable_scope.AUTO_REUSE):
var_grad = variable_scope.get_variable(
"tmp" + "_%s" * (len(grad.get_shape()) + 1) % (
shape + (dtype.name,)),
shape=shape, dtype=dtype, trainable=False,
collections=["gradient_vars"])
var_grad = state_ops.assign(var_grad, grad)
var_grad = state_ops.scatter_update(
var_grad, indices, array_ops.zeros_like(updates))
# we need to force a copy so that any future assignments to the
# variable will not affect the value we return here
# TODO: check if this is still necessary in TensorFlow 2.0
var_grad = var_grad + 0
return var_grad, None, updates_grad
开发者ID:nengo,项目名称:nengo_deeplearning,代码行数:40,代码来源:tensorflow_patch.py
示例17: insert
def insert(self, ids, scores):
"""Insert the ids and scores into the TopN."""
with ops.control_dependencies(self.last_ops):
scatter_op = state_ops.scatter_update(self.id_to_score, ids, scores)
larger_scores = math_ops.greater(scores, self.sl_scores[0])
def shortlist_insert():
larger_ids = array_ops.boolean_mask(
math_ops.to_int64(ids), larger_scores)
larger_score_values = array_ops.boolean_mask(scores, larger_scores)
shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids)
u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores)
return control_flow_ops.group(u1, u2)
# We only need to insert into the shortlist if there are any
# scores larger than the threshold.
cond_op = control_flow_ops.cond(
math_ops.reduce_any(larger_scores), shortlist_insert,
control_flow_ops.no_op)
with ops.control_dependencies([cond_op]):
self.last_ops = [scatter_op, cond_op]
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:23,代码来源:topn.py
示例18: _scatter_update
def _scatter_update(self, x, i, v):
return state_ops.scatter_update(
x, i, v, use_locking=self._use_locking)
开发者ID:clsung,项目名称:tensorflow,代码行数:3,代码来源:lazy_adam_optimizer.py
示例19: testScatterUpdateCast
def testScatterUpdateCast(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
state_ops.scatter_update(v, [1], [3])
self.assertAllEqual([1.0, 3.0], v.numpy())
开发者ID:aeverall,项目名称:tensorflow,代码行数:5,代码来源:resource_variable_ops_test.py
示例20: _create_hook_ops
def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops):
"""Creates ops to update is_row_sweep_var, global_step and completed_sweeps.
Creates two boolean tensors `processed_rows` and `processed_cols`, which
keep track of which rows/cols have been processed during the current sweep.
Returns ops that should be run after each row / col update.
- When `self._is_row_sweep_var` is True, it sets
processed_rows[input_row_indices] to True.
- When `self._is_row_sweep_var` is False, it sets
processed_cols[input_col_indices] to True.
Args:
input_row_indices: A Tensor. The indices of the input rows that are
processed during the current sweep.
input_col_indices: A Tensor. The indices of the input columns that
are processed during the current sweep.
train_ops: A list of ops. The ops created by this function have control
dependencies on `train_ops`.
Returns:
A tuple consisting of:
update_op: An op to be run jointly with training. It updates the state
and increments counters (global step and completed sweeps).
is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is
done, i.e. all rows (during a row sweep) or all columns (during a
column sweep) have been processed.
switch_op: An op to be run in `self.before_run` when the sweep is done.
"""
processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False)
with ops.colocate_with(processed_rows_init):
processed_rows = variable_scope.variable(
processed_rows_init,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="sweep_hook_processed_rows")
processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False)
with ops.colocate_with(processed_cols_init):
processed_cols = variable_scope.variable(
processed_cols_init,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="sweep_hook_processed_cols")
switch_ops = control_flow_ops.group(
state_ops.assign(
self._is_row_sweep_var,
math_ops.logical_not(self._is_row_sweep_var)),
state_ops.assign(processed_rows, processed_rows_init),
state_ops.assign(processed_cols, processed_cols_init))
is_sweep_done_var = variable_scope.variable(
False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="is_sweep_done")
# After running the `train_ops`, updates `processed_rows` or
# `processed_cols` tensors, depending on whether this is a row or col sweep.
with ops.control_dependencies(train_ops):
with ops.colocate_with(processed_rows):
update_processed_rows = state_ops.scatter_update(
processed_rows,
input_row_indices,
math_ops.logical_and(
self._is_row_sweep_var,
array_ops.ones_like(input_row_indices, dtype=dtypes.bool)))
with ops.colocate_with(processed_cols):
update_processed_cols = state_ops.scatter_update(
processed_cols,
input_col_indices,
math_ops.logical_and(
math_ops.logical_not(self._is_row_sweep_var),
array_ops.ones_like(input_col_indices, dtype=dtypes.bool)))
update_processed_op = control_flow_ops.group(
update_processed_rows, update_processed_cols)
with ops.control_dependencies([update_processed_op]):
is_sweep_done = math_ops.logical_or(
math_ops.reduce_all(processed_rows),
math_ops.reduce_all(processed_cols))
# Increments global step.
global_step = framework_variables.get_global_step()
if global_step is not None:
global_step_incr_op = state_ops.assign_add(
global_step, 1, name="global_step_incr").op
else:
global_step_incr_op = control_flow_ops.no_op()
# Increments completed sweeps.
completed_sweeps_incr_op = state_ops.assign_add(
self._completed_sweeps_var,
math_ops.cast(is_sweep_done, dtypes.int32),
use_locking=True).op
update_ops = control_flow_ops.group(
global_step_incr_op,
completed_sweeps_incr_op,
state_ops.assign(is_sweep_done_var, is_sweep_done))
return update_ops, is_sweep_done_var, switch_ops
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:96,代码来源:wals.py
注:本文中的tensorflow.python.ops.state_ops.scatter_update函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论