本文整理汇总了Python中tensorflow.python.framework.ops.add_to_collections函数的典型用法代码示例。如果您正苦于以下问题:Python add_to_collections函数的具体用法?Python add_to_collections怎么用?Python add_to_collections使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了add_to_collections函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: var_creator
def var_creator(*args, **kwargs):
"""Create an AggregatingVariable and fix up collections."""
# Record what collections this variable should be added to.
collections = kwargs.pop("collections", None)
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
# Create and wrap the variable.
v = next_creator(*args, **kwargs)
wrapped = values.AggregatingVariable(v, aggregation)
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
# ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the contained
# variable to the TRAINABLE_VARIABLES collection, so we manually
# remove it and replace with the wrapper. We can't set "trainable"
# to False for next_creator() since that causes functions like
# implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
l.remove(v)
g.add_to_collections(collections, wrapped)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
return wrapped
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:31,代码来源:parameter_server_strategy.py
示例2: variable_creator_scope
def variable_creator_scope(self, next_creator, **kwargs):
"""Creates variables & adds them to collections to match legacy code."""
collections = kwargs.pop("collections", None)
v = None
# Get expected variable name.
name = kwargs.get("name", None)
with ops.name_scope(name, "Variable") as name_scope:
name = name_scope
if self._share_variables:
v = self._variables_by_name.get(name, None)
if v is None:
v = next_creator(**kwargs)
self._variables.append(v)
if self._share_variables:
self._variables_by_name[name] = v
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
ops.add_to_collections(collections, v)
return v
开发者ID:perfmjs,项目名称:tensorflow,代码行数:27,代码来源:wrap_function.py
示例3: __init__
def __init__(self, initial_value, trainable=True, collections=None,
validate_shape=True, name=None):
"""Creates a new variable with value `initial_value`.
The new variable is added to the graph collections listed in `collections`,
which defaults to `[GraphKeys.VARIABLES]`.
If `trainable` is `True` the variable is also added to the graph collection
`GraphKeys.TRAINABLE_VARIABLES`.
This constructor creates both a `variable` Op and an `assign` Op to set the
variable to its initial value.
Args:
initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
The initial value for the Variable. Must have a shape specified unless
`validate_shape` is set to False.
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.VARIABLES]`.
validate_shape: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
name: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically.
Returns:
A Variable.
Raises:
ValueError: If the initial value does not have a shape and
`validate_shape` is `True`.
"""
if collections is None:
collections = [ops.GraphKeys.VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.control_dependencies(None):
with ops.op_scope([initial_value], name, "Variable") as name:
self._initial_value = ops.convert_to_tensor(initial_value,
name="initial_value")
initial_value_shape = self._initial_value.get_shape()
if validate_shape and not initial_value_shape.is_fully_defined():
raise ValueError("initial_value must have a shape specified: %s"
% self._initial_value)
shape_to_set = initial_value_shape if validate_shape else []
self._variable = state_ops.variable_op(
shape_to_set, self._initial_value.dtype.base_dtype,
set_shape=validate_shape, name=name)
with ops.device(self._variable.device):
self._initializer_op = state_ops.assign(
self._variable, self._initial_value,
validate_shape=validate_shape).op
self._snapshot = array_ops.identity(self._variable, name="read")
ops.add_to_collections(collections, self)
self._save_slice_info = None
开发者ID:Mandar-Shinde,项目名称:tensorflow,代码行数:59,代码来源:variables.py
示例4: _init_from_args
def _init_from_args(self, initial_value=None, trainable=True,
collections=None, validate_shape=True,
caching_device=None, name=None):
"""Creates a new variable from arguments.
Args:
initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
The initial value for the Variable. Must have a shape specified unless
`validate_shape` is set to False.
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.VARIABLES]`.
validate_shape: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
caching_device: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device. If not `None`, caches on another device. Typical use is to
cache on the device where the Ops using the Variable reside, to
deduplicate copying through `Switch` and other conditional statements.
name: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically.
Raises:
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
"""
if initial_value is None:
raise ValueError("initial_value must be specified.")
if collections is None:
collections = [ops.GraphKeys.VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.control_dependencies(None):
with ops.op_scope([initial_value], name, "Variable") as name:
self._initial_value = ops.convert_to_tensor(initial_value,
name="initial_value")
initial_value_shape = self._initial_value.get_shape()
if validate_shape and not initial_value_shape.is_fully_defined():
raise ValueError("initial_value must have a shape specified: %s"
% self._initial_value)
shape_to_set = initial_value_shape if validate_shape else []
self._variable = state_ops.variable_op(
shape_to_set, self._initial_value.dtype.base_dtype,
set_shape=validate_shape, name=name)
with ops.device(self._variable.device):
self._initializer_op = state_ops.assign(
self._variable, self._initial_value,
validate_shape=validate_shape).op
with ops.device(caching_device if caching_device is not None
else self._variable.device):
self._snapshot = array_ops.identity(self._variable, name="read")
ops.add_to_collections(collections, self)
self._caching_device = caching_device
self._save_slice_info = None
开发者ID:chintanpanchamia,项目名称:tensorflow,代码行数:58,代码来源:variables.py
示例5: _register_dense_variable_read
def _register_dense_variable_read(read, collections, trainable):
"""Helper function to put a read from a dense variable in the collections."""
if collections is None:
collections = []
if (trainable and
ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES not in collections):
collections = (list(collections) +
[ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES])
ops.add_to_collections(collections, read)
开发者ID:brchiu,项目名称:tensorflow,代码行数:9,代码来源:resource_variable_ops.py
示例6: _init_from_args
def _init_from_args(self, name):
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
with ops.control_dependencies(None):
# pylint: disable=protected-access
handle_name = ops._name_from_scope_name(name)
container = ops.get_default_graph()._container
# pylint: enable=protected-access
if container is None:
container = ""
self._handle = gen_resource_variable_ops.critical_section_op(
shared_name=handle_name, name=name)
if context.in_graph_mode():
ops.add_to_collections(CRITICAL_SECTIONS, self)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:14,代码来源:critical_section_ops.py
示例7: variable_creator_scope
def variable_creator_scope(self, next_creator, **kwargs):
"""Creates variables & adds them to collections to match legacy code."""
v = next_creator(**kwargs)
self._variables.append(v)
collections = kwargs.get("collections")
trainable = v.trainable
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
ops.add_to_collections(collections, v)
return v
开发者ID:kylin9872,项目名称:tensorflow,代码行数:16,代码来源:wrap_function.py
示例8: _init_from_args
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
with ops.init_scope():
# pylint: disable=protected-access
container = ops.get_default_graph()._container
# pylint: enable=protected-access
if shared_name is None:
shared_name = name
if container is None:
container = ""
self._handle = gen_resource_variable_ops.mutex_v2(
shared_name=shared_name, container=container, name=name)
if not context.executing_eagerly():
ops.add_to_collections(CRITICAL_SECTIONS, self)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:16,代码来源:critical_section_ops.py
示例9: _init_from_args
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the Notification from constructor arguments."""
with ops.name_scope(name, "Notification", []) as name:
with ops.init_scope():
# pylint: disable=protected-access
container = ops.get_default_graph()._container
# pylint: enable=protected-access
if shared_name is None:
shared_name = name
if container is None:
container = ""
# Build the notification resource outside of any control dependencies.
with ops.control_dependencies(None):
self._handle = gen_resource_variable_ops.notification(
shared_name=shared_name, container=container, name=name)
if not context.executing_eagerly():
ops.add_to_collections(NOTIFICATIONS, self)
开发者ID:ebrevdo,项目名称:tensorflow,代码行数:18,代码来源:notification_ops.py
示例10: collect_named_outputs
def collect_named_outputs(collections, alias, outputs):
"""Add `Tensor` outputs tagged with alias to collections.
It is useful to collect end-points or tags for summaries. Example of usage:
logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
assert 'inception_v3/logits' in logits.aliases
Args:
collections: A collection or list of collections. If None skip collection.
alias: String to append to the list of aliases of outputs, for example,
'inception_v3/conv1'.
outputs: Tensor, an output tensor to collect
Returns:
The outputs Tensor to allow inline call.
"""
append_tensor_alias(outputs, alias)
if collections:
ops.add_to_collections(collections, outputs)
return outputs
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:21,代码来源:utils.py
示例11: collect_named_outputs
def collect_named_outputs(collections, name, outputs):
"""Add tuple (name, outputs) to collections.
It is useful to collect end-points or tags for summaries. Example of usage:
logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
Args:
collections: A collection or list of collections. If None skip collection.
name: String, name to represent the outputs, ex. 'inception_v3/conv1'
outputs: Tensor, an output tensor to collect
Returns:
The outputs Tensor to allow inline call.
"""
if collections:
# Remove ending '/' if present.
if name[-1] == '/':
name = name[:-1]
ops.add_to_collections(collections, (name, outputs))
return outputs
开发者ID:AI-MR-Related,项目名称:tensorflow,代码行数:21,代码来源:utils.py
示例12: collect_named_outputs
def collect_named_outputs(collections, alias, outputs):
"""Add `Tensor` outputs tagged with alias to collections.
It is useful to collect end-points or tags for summaries. Example of usage:
logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
assert logits.alias == 'inception_v3/logits'
Args:
collections: A collection or list of collections. If None skip collection.
alias: String, alias to name the outputs, ex. 'inception_v3/conv1'
outputs: Tensor, an output tensor to collect
Returns:
The outputs Tensor to allow inline call.
"""
# Remove ending '/' if present.
if alias[-1] == '/':
alias = alias[:-1]
outputs.alias = alias
if collections:
ops.add_to_collections(collections, outputs)
return outputs
开发者ID:DavidNemeskey,项目名称:tensorflow,代码行数:23,代码来源:utils.py
示例13: _init_from_args
#.........这里部分代码省略.........
shape and `validate_shape` is `True`.
"""
if initial_value is None:
raise ValueError("initial_value must be specified.")
init_from_fn = callable(initial_value)
if init_from_fn and dtype is None:
raise ValueError("dtype must also be specified when initial_value is callable.")
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
if not isinstance(collections, (list, tuple, set)):
raise ValueError(
"collections argument to Variable constructor must be a list, tuple, "
"or set. Got %s of type %s" % (collections, type(collections))
)
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
expected_shape = tensor_shape.as_shape(expected_shape)
with ops.control_dependencies(None):
with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name:
# Get the initial value from a callable function. The real shape of the
# variable will be set later, since under the init_from_fn case, the
# shape won't be known until after the function is invoked.
#
# NOTE: The current Variable OpKernel does not support
# partially defined shapes, so we only set the shape if it is
# fully defined. For historical reasons, we use the scalar
# shape (`[]`) to represent an unknown or partially known
# shape. A future version of the Variable ops will remove this
# limitation.
def full_shape_to_list(shape):
"""Returns shape as a list if shape is fully defined."""
if shape and shape.is_fully_defined():
return shape.as_list()
else:
return []
def assert_expected_shape():
"""Asserts that the initial value has the expected shape."""
if expected_shape:
expected_shape.assert_is_compatible_with(self._initial_value.get_shape())
if init_from_fn:
expected_shape_list = full_shape_to_list(expected_shape)
set_shape = validate_shape and expected_shape.is_fully_defined()
self._variable = state_ops.variable_op(
expected_shape_list, dtype.base_dtype, set_shape=set_shape, name=name
)
with ops.colocate_with(self._variable.op):
with ops.name_scope("Initializer"):
# Colocate the tensors created by the initial_value() function
# with the variable itself.
self._initial_value = ops.convert_to_tensor(
initial_value(), name="initial_value", dtype=dtype
)
assert_expected_shape()
# Or get the initial value from a Tensor or Python object.
else:
self._initial_value = ops.convert_to_tensor(initial_value, name="initial_value", dtype=dtype)
assert_expected_shape()
set_shape = validate_shape and self._initial_value.get_shape().is_fully_defined()
# In this case, the variable op can't be created until after the
# initial_value has been converted to a Tensor with a known type.
self._variable = state_ops.variable_op(
full_shape_to_list(self._initial_value.get_shape()),
self._initial_value.dtype.base_dtype,
set_shape=set_shape,
name=name,
)
# Manually overrides the variable's shape with the initial value's.
if validate_shape:
initial_value_shape = self._initial_value.get_shape()
if not initial_value_shape.is_fully_defined():
raise ValueError("initial_value must have a shape specified: %s" % self._initial_value)
self._variable.set_shape(initial_value_shape)
# TODO(b/28152992): Remove the below hack modifying the node_def shape
# directly once set_shape() handles it.
self._variable.op.node_def.attr["shape"].shape.CopyFrom(initial_value_shape.as_proto())
# Assigns initial value.
self._initializer_op = state_ops.assign(
self._variable, self._initial_value, validate_shape=validate_shape
).op
# TODO(vrv): Change this class to not take caching_device, but
# to take the op to colocate the snapshot with, so we can use
# colocation rather than devices.
if caching_device is not None:
with ops.device(caching_device):
self._snapshot = array_ops.identity(self._variable, name="read")
else:
with ops.colocate_with(self._variable.op):
self._snapshot = array_ops.identity(self._variable, name="read")
ops.add_to_collections(collections, self)
self._caching_device = caching_device
self._save_slice_info = None
开发者ID:shakamunyi,项目名称:tensorflow,代码行数:101,代码来源:variables.py
示例14: _init_from_args
#.........这里部分代码省略.........
with ops.name_scope("Initializer"):
initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
self._handle = _eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=handle_name,
name=name,
graph_mode=False)
self._handle_device = (
self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
self._shape = initial_value.get_shape()
# pylint: enable=protected-access
# Or get the initial value from a Tensor or Python object.
else:
with ops.name_scope("Initializer"):
initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
# pylint: disable=protected-access
if (self._in_graph_mode and initial_value is not None and
initial_value.op._get_control_flow_context() is not None):
raise ValueError(
"Initializer for variable %s is from inside a control-flow "
"construct, such as a loop or conditional. When creating a "
"variable inside a loop or conditional, use a lambda as the "
"initializer." % name)
# pylint: enable=protected-access
self._handle = _eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=handle_name,
name=name,
graph_mode=self._in_graph_mode)
self._handle_device = (self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
self._shape = initial_value.get_shape()
self._initial_value = initial_value if self._in_graph_mode else None
self._handle_name = handle_name + ":0"
self._dtype = initial_value.dtype.base_dtype
self._constraint = constraint
if self._in_graph_mode:
with ops.name_scope("IsInitialized"):
self._is_initialized_op = (
gen_resource_variable_ops.var_is_initialized_op(self._handle))
if initial_value is not None:
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
self._initializer_op = (
gen_resource_variable_ops.assign_variable_op(
self._handle,
self._try_guard_against_uninitialized_dependencies(
initial_value),
name=n))
with ops.name_scope("Read"), ops.colocate_with(self._handle):
# Manually assign reads to the handle's device to avoid log
# messages.
with ops.device(self._handle_device):
value = self._read_variable_op()
self._graph_element = value
if caching_device is not None:
# Variables may be created in a tf.device() or ops.colocate_with()
# context. At the same time, users would expect caching device to
# be independent of this context, and/or would not expect the
# current device context to be merged with the caching device
# spec. Therefore we reset the colocation stack before creating
# the cached value. Note that resetting the colocation stack will
# also reset the device stack.
with ops.colocate_with(None, ignore_existing=True):
with ops.device(caching_device):
self._cached_value = array_ops.identity(value)
else:
self._cached_value = None
else:
gen_resource_variable_ops.assign_variable_op(self._handle,
initial_value)
self._is_initialized_op = None
self._initializer_op = None
self._graph_element = None
if caching_device:
with ops.device(caching_device):
self._cached_value = self._read_variable_op()
else:
self._cached_value = None
if context.in_graph_mode():
ops.add_to_collections(collections, self)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
if not self._in_graph_mode:
# After the handle has been created, set up a way to clean it up when
# executing eagerly. We'll hold the only reference to the deleter, so that
# when this object is garbage collected the deleter will be too. This
# means ResourceVariables can be part of reference cycles without those
# cycles being uncollectable, and means that no __del__ will be defined at
# all in graph mode.
self._handle_deleter = EagerResourceDeleter(
handle=self._handle, handle_device=self._handle_device)
开发者ID:keithc61,项目名称:tensorflow,代码行数:101,代码来源:resource_variable_ops.py
示例15: _init_from_args
#.........这里部分代码省略.........
which is the initial value for the Variable. The initial value must have
a shape specified unless `validate_shape` is set to False. Can also be a
callable with no argument that returns the initial value when called.
(Note that initializer functions from init_ops.py must first be bound
to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
validate_shape: Ignored. Provided for compatibility with tf.Variable.
caching_device: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device. If not `None`, caches on another device. Typical use is to
cache on the device where the Ops using the Variable reside, to
deduplicate copying through `Switch` and other conditional statements.
name: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically.
dtype: If set, initial_value will be converted to the given type.
If None, either the datatype will be kept (if initial_value is
a Tensor) or float32 will be used (if it is a Python object convertible
to a Tensor).
Raises:
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
"""
if initial_value is None:
raise ValueError("initial_value must be specified.")
init_from_fn = callable(initial_value)
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
if not isinstance(collections, (list, tuple, set)):
raise ValueError(
"collections argument to Variable constructor must be a list, tuple, "
"or set. Got %s of type %s" % (collections, type(collections)))
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
self._save_slice_info = None
with ops.control_dependencies(None):
with ops.name_scope(name, "Variable", [] if init_from_fn else
[initial_value]) as name:
# pylint: disable=protected-access
true_name = ops._name_from_scope_name(name)
if init_from_fn:
# Use attr_scope and device(None) to simulate the behavior of
# colocate_with when the variable we want to colocate with doesn't
# yet exist.
attr = attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(
s=[compat.as_bytes("loc:@%s" % true_name)]))
with ops.get_default_graph()._attr_scope({"_class": attr}):
with ops.name_scope("Initializer"), ops.device(None):
self._initial_value = ops.convert_to_tensor(
initial_value(), name="initial_value", dtype=dtype)
self._handle = gen_resource_variable_ops.var_handle_op(
shape=self._initial_value.get_shape(),
dtype=self._initial_value.dtype.base_dtype,
shared_name=true_name, name=name)
# pylint: enable=protected-access
# Or get the initial value from a Tensor or Python object.
else:
self._initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
self._handle = gen_resource_variable_ops.var_handle_op(
shape=self._initial_value.get_shape(),
dtype=self._initial_value.dtype.base_dtype,
shared_name=true_name, name=name)
self._dtype = self._initial_value.dtype.base_dtype
with ops.name_scope("IsInitialized"):
self._is_initialized_op = (
gen_resource_variable_ops.var_is_initialized_op(self._handle))
if initial_value is not None:
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
self._handle, self._initial_value, name=n)
with ops.name_scope("Read"), ops.colocate_with(self._handle):
# Manually assign reads to the handle's device to avoid log messages.
with ops.device(self._handle.device):
value = gen_resource_variable_ops.read_variable_op(
self._handle, dtype=self._dtype)
self._graph_element = value
if caching_device is not None:
# Variables may be created in a tf.device() or ops.colocate_with()
# context. At the same time, users would expect caching device to be
# independent of this context, and/or would not expect the current
# device context to be merged with the caching device spec.
# Therefore we reset the colocation stack before creating the cached
# value. Note that resetting the colocation stack will also reset
# the device stack.
with ops.colocate_with(None, ignore_existing=True):
with ops.device(caching_device):
self._cached_value = array_ops.identity(value)
else:
self._cached_value = None
ops.add_to_collections(collections, self)
开发者ID:chenjun0210,项目名称:tensorflow,代码行数:101,代码来源:resource_variable_ops.py
示例16: execute
def execute(self, fn, *args, **kwargs):
"""Execute function `fn(*args, **kwargs)` inside the CriticalSection.
Args:
fn: The function to execute. Must return at least one tensor.
*args: Additional positional arguments to `fn`.
**kwargs: Additional keyword arguments to `fn`.
Several keywords are reserved for `execute`. These are:
- name; The name to use when creating the execute operation.
- exclusive_resource_access; Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
Returns:
The tensors returned from `fn(*args, **kwargs)`.
Raises:
ValueError: If `fn` attempts to use this `CriticalSection` in any nested
way.
ValueError: If `exclusive_resource_access` is not provided (is `True`) and
another `CriticalSection` has an execution requesting the same
resources as in `*args`, `**kwargs`, and any additionaly captured
inputs in `fn`. Note, even if `exclusive_resource_access` is `True`,
if another execution in another `CriticalSection` was created without
`exclusive_resource_access=True`, a `ValueError` will be raised.
"""
name = kwargs.pop("name", None)
exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
args = nest.map_structure(ops.convert_to_tensor, args)
with ops.name_scope(name, "critical_section_execute", []):
fn_op = function.make_defun_op(fn, *args, **kwargs)
flat_dtypes = nest.flatten(fn_op.output_dtypes)
flat_shapes = nest.flatten(fn_op.output_shapes)
all_inputs = nest.flatten(args) + fn_op.captured_inputs
if self._handle in all_inputs:
raise ValueError("The function fn attempts to access the "
"CriticalSection in which it would be running. This "
"is illegal and would cause deadlocks. "
"CriticalSection: %s." % self._handle)
if context.in_graph_mode():
# Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
all_input_resources = [
x for x in all_inputs if x.dtype == dtypes.resource]
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
if sg.op.inputs[0].name == self._handle.name:
# Other executions in the same critical section are allowed.
continue
if not (exclusive_resource_access or sg.exclusive_resource_access):
# Neither execution requested exclusive access.
continue
sg_input_names = [y.name for y in sg.op.inputs[1:]]
for res in all_input_resources:
if res.name in sg_input_names:
raise ValueError(
"This execution would access resource %s; but either this "
"execution (CriticalSection: %s) or Execution '%s' "
"(CriticalSection: %s) requested exclusive resource access "
"of this resource for their critical section. Did you mean "
"to call execute with keyword argument "
"exclusive_resource_access=False?"
% (res.name,
self.name,
sg.op.name,
sg.op.inputs[0].op.name))
flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
critical_section=self._handle,
arguments=all_inputs,
f=fn_op,
output_types=flat_dtypes,
output_shapes=flat_shapes)
if context.in_graph_mode():
if isinstance(flat_outputs, ops.Operation):
flat_outputs = [flat_outputs]
op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor)
else flat_outputs[0])
signature = _ExecutionSignature(
op=op,
exclusive_resource_access=exclusive_resource_access)
ops.add_to_collections(
CRITICAL_SECTION_EXECUTIONS, signature)
return (flat_outputs[0]
if (len(flat_outputs) == 1
and isinstance(flat_outputs[0], ops.Operation))
else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:93,代码来源:critical_section_ops.py
示例17: execute
#.........这里部分代码省略.........
calling `fn` in the critical section, create a lambda:
```python
critical_section.execute(lambda: fn(*my_args, **my_kwargs))
```
Args:
fn: The function to execute. Must return at least one tensor.
exclusive_resource_access: Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
name: The name to use when creating the execute operation.
Returns:
The tensors returned from `fn()`.
Raises:
ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
or lazy way that may cause a deadlock.
ValueError: If `exclusive_resource_access == True` and
another `CriticalSection` has an execution requesting the same
resources as `fn``. Note, even if `exclusive_resource_access` is
`True`, if another execution in another `CriticalSection` was created
without `exclusive_resource_access=True`, a `ValueError` will be raised.
"""
with ops.name_scope(name, "critical_section_execute", []):
# Ensure that mutex locking only happens *after* all args and
# kwargs have been executed. This avoids certain types of deadlocks.
lock = gen_resource_variable_ops.mutex_lock(self._handle)
if not context.executing_eagerly():
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
# Operations created by other threads.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
existing_ops = ops.get_default_graph().get_operations()
with ops.control_dependencies([lock]):
r = fn()
# TODO(ebrevdo): If creating critical sections in a python loop, this
# makes graph creation time quadratic. Revisit if this
# becomes a problem.
created_ops = (set(ops.get_default_graph().get_operations())
.difference(existing_ops))
else:
with ops.control_dependencies([lock]):
r = fn()
if not context.executing_eagerly():
self._add_control_dependencies_to_lock(created_ops, lock.op)
# captured_resources is a list of resources that are directly
# accessed only by ops created during fn(), not by any
# ancestors of those ops in the graph.
captured_resources = set([
input_ for op in created_ops
for input_ in op.inputs
if input_.dtype == dtypes.resource
])
# NOTE(ebrevdo): The only time self._is_self_handle() is True
# in this call is if one of the recently created ops, within
# the execute(), themselves attempt to access the
# CriticalSection. This will cause a deadlock.
if any(self._is_self_handle(x) for x in captured_resources):
raise ValueError("The function fn attempts to directly access the "
"CriticalSection in which it would be running. "
"This is illegal and would cause deadlocks.")
self._check_multiple_access_to_resources(
captured_resources, exclusive_resource_access)
r_flat = [_identity(x) for x in nest.flatten(r)]
with ops.control_dependencies(r_flat):
# The identity must run on the same machine as self._handle
with ops.colocate_with(self._handle):
# Do not use array_ops.identity as there are special
# optimizations within TensorFlow which seem to elide it
# even when optimizations are disabled(!).
ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
lock)
# Make sure that if any element of r is accessed, all of
# them are executed together.
r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
with ops.control_dependencies([ensure_lock_exists]):
outputs = nest.map_structure(_identity, r)
if not context.executing_eagerly():
signature = _ExecutionSignature(
op=lock.op,
handle=self._handle,
resources=list(captured_resources),
exclusive_resource_access=exclusive_resource_access)
ops.add_to_collections(
CRITICAL_SECTION_EXEC
|
请发表评论