本文整理汇总了Python中tensorflow.python.training.distribution_strategy_context.get_tower_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_tower_context函数的具体用法?Python get_tower_context怎么用?Python get_tower_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_tower_context函数的19个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _assign_func
def _assign_func(self, *args, **kwargs):
f = kwargs.pop("f")
if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
if update_device is not None:
# We are calling an assign function in an update context.
return f(self._v, *args, **kwargs)
# We are calling an assign function in cross tower context, wrap it in an
# update call.
return distribution_strategy_context.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
assert distribution_strategy_context.get_tower_context()
# We are calling an assign function in tower context.
# We reduce the value we want to assign/add/sub. More details about how we
# handle the different use cases can be found in the _reduce method.
# We call the function with the reduced value.
if self._aggregation == vs.VariableAggregation.NONE:
raise ValueError("You must specify an aggregation method to update a "
"a variable in Tower Context.")
def merge_fn(strategy, value, *other_args, **other_kwargs):
return strategy.update(
self, f,
strategy.reduce(
aggregation=self._aggregation, value=value, destinations=self),
*other_args, **other_kwargs)
return distribution_strategy_context.get_tower_context().merge_call(
merge_fn, *args, **kwargs)
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:31,代码来源:values.py
示例2: model_fn
def model_fn():
with ops.name_scope(None, "foo"):
a = constant_op.constant(1.0, name="a")
distribution_strategy_context.get_tower_context().merge_call(
lambda _: _)
b = constant_op.constant(2.0, name="b")
return a, b
开发者ID:mrlittlepig,项目名称:tensorflow,代码行数:7,代码来源:mirrored_strategy_multigpu_test.py
示例3: set_non_tensor_output
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribution_strategy_context.get_cross_tower_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as aggregation doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
distribution_strategy_context.get_tower_context().merge_call(
merge_fn, output)
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:11,代码来源:values.py
示例4: model_fn
def model_fn(device_id):
assert isinstance(device_id, int)
def thread_creator_fn(next_creator, *args, **kwargs):
return next_creator(*args, **kwargs) + ":thread_" + str(device_id)
with variable_scope.variable_creator_scope(thread_creator_fn):
# Create a variable in this scope.
v = variable_scope.variable(1.0)
# This will pause the current thread, and execute the other thread.
distribution_strategy_context.get_tower_context().merge_call(
lambda _: _)
return v
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:13,代码来源:mirrored_strategy_test.py
示例5: _assert_in_default_state
def _assert_in_default_state(t):
t.assertIs(distribution_strategy_context._get_default_tower_context(),
distribution_strategy_context.get_tower_context())
t.assertIs(None, distribution_strategy_context.get_cross_tower_context())
t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
distribution_strategy_context.get_distribution_strategy())
t.assertFalse(distribution_strategy_context.has_distribution_strategy())
开发者ID:AnishShah,项目名称:tensorflow,代码行数:7,代码来源:distribute_test.py
示例6: skip_summary
def skip_summary():
# If using multiple towers in distributed strategy, skip summaries on all
# towers except the first one (tower_id=0).
# TODO(priyag): Add a new optional argument that will provide multiple
# alternatives to override default behavior. (e.g. run on last tower,
# compute sum or mean across towers).
tower_context = distribution_strategy_context.get_tower_context()
return tower_context and tower_context.tower_id > 0
开发者ID:AnishShah,项目名称:tensorflow,代码行数:8,代码来源:summary_op_util.py
示例7: increment_var
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
return vu.assign_add(amount, read_value=False)
def merge_fn(dist, vm):
return dist.update(vm, update)
tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
开发者ID:gunan,项目名称:tensorflow,代码行数:10,代码来源:distribute.py
示例8: set_last_step_output
def set_last_step_output(self, name, output,
aggregation=variables_lib.VariableAggregation.NONE):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
aggregation: Aggregation method to use to aggregate outputs from multiple
towers. Required if `set_last_step_output` is called in a tower context.
Optional in cross_tower_context.
When present, the outputs from all the towers are aggregated using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and aggregation is set, output
must be a `PerDevice` value.
The aggregation method is also recorded in a dictionary
`_last_step_outputs_aggregations` for later interpreting of the
outputs as already reduced or not.
"""
if distribution_strategy_context.get_cross_tower_context():
self._last_step_outputs_aggregations[name] = aggregation
if aggregation is variables_lib.VariableAggregation.NONE:
self._last_step_outputs[name] = output
else:
distribution = distribution_strategy_context.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce(
aggregation, output, destinations="/device:CPU:0")
else:
assert aggregation is not variables_lib.VariableAggregation.NONE
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(
aggregation, value, destinations="/device:CPU:0")
# Setting this inside the `merge_fn` because all towers share the same
# context object, so it's more robust to set it only once (even if all
# the towers are trying to set the same value).
self._last_step_outputs_aggregations[name] = aggregation
distribution_strategy_context.get_tower_context().merge_call(
merge_fn, output)
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:42,代码来源:values.py
示例9: merge_fn
def merge_fn(dist, s):
self.assertIs(
distribution_strategy_context._get_default_distribution_strategy(),
dist)
self.assertIs(None, distribution_strategy_context.get_tower_context())
self.assertIs(dist,
distribution_strategy_context.get_cross_tower_context())
self.assertIs(dist,
distribution_strategy_context.get_distribution_strategy())
self.assertFalse(
distribution_strategy_context.has_distribution_strategy())
return "foo_" + s
开发者ID:AnishShah,项目名称:tensorflow,代码行数:12,代码来源:distribute_test.py
示例10: increment_var
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
if isinstance(vu, resource_variable_ops.ResourceVariable):
return vu.assign_add(amount, read_value=False)
else:
return state_ops.assign_add(vu, amount)
def merge_fn(dist, vm):
return dist.group(dist.update(vm, update))
tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:13,代码来源:distribute.py
示例11: run_fn
def run_fn():
tower_context = distribution_strategy_context.get_tower_context()
self.assertTrue(tower_context is not None)
self.assertIs(None,
distribution_strategy_context.get_cross_tower_context())
self.assertTrue(distribution_strategy_context.has_distribution_strategy())
self.assertIs(dist,
distribution_strategy_context.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="bar"))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:14,代码来源:distribute_test.py
示例12: testScope
def testScope(self):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
self.assertIs(None, distribution_strategy_context.get_tower_context())
self.assertIs(dist,
distribution_strategy_context.get_cross_tower_context())
self.assertTrue(distribution_strategy_context.has_distribution_strategy())
self.assertIs(dist,
distribution_strategy_context.get_distribution_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="baz"))
_assert_in_default_state(self)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:16,代码来源:distribute_test.py
示例13: get
def get(self, device=None):
"""Returns the value for the current device or raises a ValueError."""
if device is None:
tower_context = distribution_strategy_context.get_tower_context()
if tower_context:
device = tower_context.device
else:
device = distribute_lib.get_update_device()
if device is None:
return self._get_cross_tower()
device = device_util.canonicalize(device)
try:
return self._index[device]
except KeyError as e:
six.raise_from(
ValueError("Device %s not found in %s (current device %s)" %
(device, self._index.keys(), device_util.current())), e)
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:17,代码来源:values.py
示例14: testMergeCall
def testMergeCall(self):
_assert_in_default_state(self)
def merge_fn(dist, s):
self.assertIs(
distribution_strategy_context._get_default_distribution_strategy(),
dist)
self.assertIs(None, distribution_strategy_context.get_tower_context())
self.assertIs(dist,
distribution_strategy_context.get_cross_tower_context())
self.assertIs(dist,
distribution_strategy_context.get_distribution_strategy())
self.assertFalse(
distribution_strategy_context.has_distribution_strategy())
return "foo_" + s
tower_ctx = distribution_strategy_context.get_tower_context()
self.assertIs(distribution_strategy_context._get_default_tower_context(),
tower_ctx)
self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
_assert_in_default_state(self)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:21,代码来源:distribute_test.py
示例15: decorated
def decorated(metric_obj, *args):
"""Decorated function with merge_call."""
tower_context = distribution_strategy_context.get_tower_context()
if tower_context is None: # if in cross tower context already
result_t = result_fn(*args)
else:
# TODO(psv): Test distribution of metrics using different distribution
# strategies.
# Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
# with distribution object as the first parameter. We create a wrapper
# here so that the result function need not have that parameter.
def merge_fn_wrapper(distribution, merge_fn, *args):
# We will get `PerDevice` merge function. Taking the first one as all
# are identical copies of the function that we had passed below.
return distribution.unwrap(merge_fn)[0](*args)
# Wrapping result in merge_call. merge_call is used when we want to leave
# tower mode and compute a value in cross tower mode.
result_t = tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
check_is_tensor_or_operation(result_t,
'Metric {0}\'s result'.format(metric_obj.name))
return result_t
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:23,代码来源:metrics.py
示例16: model_fn
def model_fn():
if num_gpus == 0:
last_part_device = 'device:CPU:0'
else:
last_part_device = (
'device:GPU:%d' %
distribution_strategy_context.get_tower_context().tower_id)
a = constant_op.constant(1.0)
b = constant_op.constant(2.0)
c = a + b
self.assertEqual(a.device, worker_device + '/' + last_part_device)
self.assertEqual(b.device, worker_device + '/' + last_part_device)
self.assertEqual(c.device, worker_device + '/' + last_part_device)
# The device scope is ignored for variables but not for normal ops.
with ops.device('/job:worker/task:0'):
x = variable_scope.get_variable(
'x', initializer=10.0,
aggregation=variable_scope.VariableAggregation.SUM)
x_add = x.assign_add(c)
e = a + c
# The variable x is on the task 1 since the device_function has been
# called once before the model_fn.
self.assertEqual(x.device, '/job:ps/task:1')
self.assertEqual(x_add.device, x.device)
self.assertEqual(e.device,
'/job:worker/replica:0/task:0/%s' % last_part_device)
# The colocate_vars_with can override the distribution's device.
with d.colocate_vars_with(x):
y = variable_scope.get_variable(
'y', initializer=20.0,
aggregation=variable_scope.VariableAggregation.SUM)
# We add an identity here to avoid complaints about summing
# non-distributed values.
y_add = y.assign_add(array_ops.identity(x_add))
self.assertEqual(y.device, '/job:ps/task:1')
self.assertEqual(y_add.device, y.device)
self.assertEqual(y.device, x.device)
z = variable_scope.get_variable(
'z', initializer=10.0,
aggregation=variable_scope.VariableAggregation.SUM)
self.assertEqual(z.device, '/job:ps/task:0')
self.assertNotEqual(z.device, x.device)
with ops.control_dependencies([y_add]):
# We add an identity here to avoid complaints about summing
# non-distributed values.
z_add = z.assign_add(array_ops.identity(y))
with ops.control_dependencies([z_add]):
f = z + c
self.assertEqual(f.device, worker_device + '/' + last_part_device)
# The device scope would merge with the default worker device.
with ops.device('/CPU:1'):
g = e + 1.0
self.assertEqual(g.device, worker_device + '/device:CPU:1')
# Ths ops.colocate_with will be ignored when defining a variale but not
# for a normal tensor.
with ops.colocate_with(x):
u = variable_scope.get_variable('u', initializer=30.0)
v = variable_scope.get_variable('v', initializer=30.0)
h = f + 1.0
self.assertIn('/job:ps/', u.device)
self.assertIn('/job:ps/', v.device)
# u and v are on different parameter servers.
self.assertTrue(u.device != x.device or v.device != x.device)
self.assertTrue(u.device == x.device or v.device == x.device)
# Here h is not on one worker. Note h.device is canonical while x.device
# is not but.
self.assertIn('/job:ps/', h.device)
return y_add, z_add, f
开发者ID:AnishShah,项目名称:tensorflow,代码行数:75,代码来源:parameter_server_strategy_test.py
示例17: _merge_call_merge_raises_fn
def _merge_call_merge_raises_fn():
distribution_strategy_context.get_tower_context().merge_call(
_call_merge_raises_fn)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:3,代码来源:strategy_test_lib.py
示例18: apply_gradients
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
applies gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
RuntimeError: If you should use `_distributed_apply()` instead.
"""
# This is a default implementation of apply_gradients() that can be shared
# by most optimizers. It relies on the subclass implementing the following
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
# Handle DistributionStrategy case.
if distribution_strategy_context.get_cross_tower_context():
raise RuntimeError("Use `_distributed_apply()` instead of "
"`apply_gradients()` in a cross-tower context.")
# TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
# always calling _distributed_apply(), using the default distribution
# as needed.
if distribution_strategy_context.has_distribution_strategy():
grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
return distribution_strategy_context.get_tower_context().merge_call(
self._distributed_apply, grads_and_vars, global_step, name)
# No DistributionStrategy case.
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
if not grads_and_vars:
raise ValueError("No variables provided.")
converted_grads_and_vars = []
for g, v in grads_and_vars:
if g is not None:
try:
# Convert the grad to Tensor or IndexedSlices if necessary.
g = ops.convert_to_tensor_or_indexed_slices(g)
except TypeError:
raise TypeError(
"Gradient must be convertible to a Tensor"
" or IndexedSlices, or None: %s" % g)
if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
raise TypeError(
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
p = _get_processor(v)
converted_grads_and_vars.append((g, v, p))
converted_grads_and_vars = tuple(converted_grads_and_vars)
var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, _, v in converted_grads_and_vars],))
with ops.init_scope():
self._create_slots(var_list)
update_ops = []
with ops.name_scope(name, self._name) as name:
self._prepare()
for grad, var, processor in converted_grads_and_vars:
if grad is None:
continue
# We colocate all ops created in _apply_dense or _apply_sparse
# on the same device as the variable.
# TODO(apassos): figure out how to get the variable name here.
if context.executing_eagerly() or isinstance(
var,
resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access
scope_name = ""
else:
scope_name = var.op.name
with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
update_ops.append(processor.update_op(self, grad))
if global_step is None:
apply_updates = self._finish(update_ops, name)
else:
with ops.control_dependencies([self._finish(update_ops, "update")]):
with ops.colocate_with(global_step):
if isinstance(global_step, resource_variable_ops.ResourceVariable):
# TODO(apassos): the implicit read in assign_add is slow; consider
# making it less so.
apply_updates = resource_variable_ops.assign_add_variable_op(
global_step.handle,
ops.convert_to_tensor(1, dtype=global_step.dtype),
name=name)
else:
apply_updates = state_ops.assign_add(global_step, 1, name=name)
if not context.executing_eagerly():
if isinstance(apply_updates, ops.Tensor):
#.........这里部分代码省略.........
开发者ID:HughKu,项目名称:tensorflow,代码行数:101,代码来源:optimizer.py
示例19: mark_devices_fn
def mark_devices_fn():
tower_id = distribution_strategy_context.get_tower_context().tower_id
self.assertLess(tower_id, len(d.worker_devices))
self.assertFalse(expected_devices[tower_id])
expected_devices[tower_id] = True
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:5,代码来源:strategy_test_lib.py
注:本文中的tensorflow.python.training.distribution_strategy_context.get_tower_context函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论