本文整理汇总了Python中tensorflow.python.framework.ops.get_gradient_function函数的典型用法代码示例。如果您正苦于以下问题:Python get_gradient_function函数的具体用法?Python get_gradient_function怎么用?Python get_gradient_function使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_gradient_function函数的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testGradientFunction
def testGradientFunction(self):
# Input to tf.py_func is necessary, otherwise get_gradient_function()
# returns None per default.
a = constant_op.constant(0)
x, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64])
y, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64], stateful=False)
self.assertEqual(None, ops.get_gradient_function(x.op))
self.assertEqual(None, ops.get_gradient_function(y.op))
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:8,代码来源:py_func_test.py
示例2: testOverrideGradients
def testOverrideGradients(self):
g = ops.Graph()
x = an_op(g)
with g.gradient_override_map({"copy": "copy_override"}):
y = copy_op(x)
fn = ops.get_gradient_function(y.op)
self.assertEqual(_CopyOverrideGrad, fn)
开发者ID:4chin,项目名称:tensorflow,代码行数:7,代码来源:ops_test.py
示例3: testNonExistentOverride
def testNonExistentOverride(self):
g = ops.Graph()
x = an_op(g)
with g.gradient_override_map({"copy": "unknown_override"}):
y = copy_op(x)
with self.assertRaisesRegexp(LookupError, "unknown_override"):
fn = ops.get_gradient_function(y.op)
开发者ID:4chin,项目名称:tensorflow,代码行数:7,代码来源:ops_test.py
示例4: find_non_differentiable
def find_non_differentiable(inputs, outputs):
"""
Searches through a TensorFlow graph to find non-differentiable elements
between ``inputs`` and ``outputs`` (elements that would prevent us from
computing ``d_outputs / d_inputs``.
Parameters
----------
inputs : list of ``tf.Tensor``
Input tensors
outputs : list of ``tf.Tensor``
Output tensors
"""
for o in outputs:
if o in inputs:
continue
else:
try:
grad = get_gradient_function(o.op)
if grad is None and len(o.op.inputs) > 0:
# note: technically we're not sure that this op is
# on the path to inputs. we could wait and propagate this
# until we find inputs, but that can take a long time for
# large graphs. it seems more useful to fail quickly, and
# risk some false positives
raise LookupError
find_non_differentiable(inputs, o.op.inputs)
except LookupError:
raise SimulationError(
"Graph contains non-differentiable "
"elements: %s" % o.op)
开发者ID:nengo,项目名称:nengo_deeplearning,代码行数:33,代码来源:utils.py
示例5: _Gradient
def _Gradient(tensors, devices):
reduce_tensors, _ = nccl_reduce(tensors, devices)
tensor_ops = [t.op for t in reduce_tensors]
d_tensors = _DeviceTensors(tensors, devices)
grad_tensors = [
ops.get_gradient_function(op)(op, loss)
for op, loss in zip(tensor_ops, d_tensors)
]
return grad_tensors, []
开发者ID:1000sprites,项目名称:tensorflow,代码行数:9,代码来源:nccl_ops_test.py
示例6: gradients
#.........这里部分代码省略.........
# Add the ops in 'to_ops' into the queue.
to_ops_set = set()
for op in to_ops:
# 'ready' handles the case where one output gradient relies on
# another output's gradient.
# pylint: disable=protected-access
ready = (pending_count[op._id] == 0)
if ready and op._id not in to_ops_set:
to_ops_set.add(op._id)
queue.append(op)
# pylint: enable=protected-access
if loop_state:
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
for y in loop_exits:
if _IsTrainable(y):
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
# The set of 'from_ops'.
stop_ops = _StopOps(from_ops, pending_count)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
with _maybe_colocate_with(op, colocate_gradients_with_ops):
if loop_state:
loop_state.EnterGradWhileContext(op, before=True)
out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
if loop_state:
loop_state.ExitGradWhileContext(op, before=True)
grad_fn = None
# pylint: disable=protected-access
is_func_call = ops.get_default_graph()._is_function(op.type)
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op._id not in stop_ops):
if is_func_call:
grad_fn = ops.get_default_graph()._get_function(
op.type).python_grad_func
# pylint: enable=protected-access
else:
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.
try:
grad_fn = ops.get_gradient_function(op)
except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
if loop_state:
loop_state.EnterGradWhileContext(op, before=False)
if (grad_fn or is_func_call) and has_out_grads:
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
if (not isinstance(out_grad, ops.Tensor) and
not out_grad) and _IsTrainable(op.outputs[i]):
# Only floating-point outputs get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
if loop_state:
out_grads[i] = loop_state.ZerosLike(op, i)
else:
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
with ops.name_scope(op.name + "_grad"):
# pylint: disable=protected-access
with ops.get_default_graph()._original_op(op):
# pylint: enable=protected-access
if grad_fn:
# If grad_fn was found, do not use SymbolicGradient even for
# functions.
in_grads = grad_fn(op, *out_grads)
else:
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
in_grads = _SymGrad(op, out_grads)
in_grads = _AsList(in_grads)
_VerifyGeneratedGradients(in_grads, op)
if gate_gradients and len(
[x for x in in_grads if x is not None]) > 1:
in_grads = control_flow_ops.tuple(in_grads)
_LogOpGradients(op, out_grads, in_grads)
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards.
in_grads = [None] * len(op.inputs)
for t_in, in_grad in zip(op.inputs, in_grads):
if in_grad is not None:
if isinstance(in_grad, ops.Tensor):
in_grad.set_shape(t_in.get_shape())
_SetGrad(grads, t_in, in_grad)
if loop_state:
loop_state.ExitGradWhileContext(op, before=False)
# Update pending count for the inputs of op and enqueue ready ops.
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state)
if loop_state:
loop_state.PostProcessing()
return [_GetGrad(grads, x) for x in xs]
开发者ID:kdavis-mozilla,项目名称:tensorflow,代码行数:101,代码来源:gradients_impl.py
示例7: testRegisterGradients
def testRegisterGradients(self):
g = ops.Graph()
x = an_op(g)
y = copy_op(x)
fn = ops.get_gradient_function(y.op)
self.assertEqual(_CopyGrad, fn)
开发者ID:4chin,项目名称:tensorflow,代码行数:6,代码来源:ops_test.py
示例8: gradients
#.........这里部分代码省略.........
to_ops = [t.op for t in ys]
from_ops = [t.op for t in xs]
pending_count, has_control_flow = _PendingCount(ops.get_default_graph(),
to_ops, from_ops)
# Iterate over the collected ops.
#
# grads: op => list of gradients received on each output endpoint of the
# op. The gradients for each endpoint are initially collected as a list.
# When it is time to call the op's gradient function, for each endpoint we
# aggregate the list of received gradients into a Add() Operation if there
# is more than one.
grads = {}
# Add the initial gradients for the ys.
for y, grad_y in zip(ys, grad_ys):
_SetGrad(grads, y, grad_y)
# Initialize queue with to_ops.
queue = collections.deque()
# Add the ops in 'to_ops' into the queue.
to_ops_set = set()
for op in to_ops:
# 'ready' handles the case where one output gradient relies on
# another output's gradient.
ready = (pending_count[op._id] == 0)
if ready and op._id not in to_ops_set: # pylint: disable=protected-access
to_ops_set.add(op._id)
queue.append(op)
# The set of 'from_ops'.
stop_ops = _StopOps(from_ops, pending_count)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
with ops.device(_GetGradsDevice(op, colocate_gradients_with_ops)):
if has_control_flow:
control_flow_ops.EnterGradWhileContext(op)
out_grads = _AggregatedGrads(grads, op, has_control_flow,
aggregation_method)
grad_fn = None
if any(out_grads) and op._id not in stop_ops:
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.
try:
grad_fn = ops.get_gradient_function(op)
except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
if grad_fn and any(out_grads):
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
if (not out_grad and
dtypes.as_dtype(op.outputs[i].dtype).base_dtype in
(dtypes.float32, dtypes.float64)):
# Only floating-point outputs get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
out_grads[i] = array_ops.zeros_like(op.outputs[i])
with ops.name_scope(op.name + "_grad"):
# pylint: disable=protected-access
with ops.get_default_graph()._original_op(op):
# pylint: enable=protected-access
op_wrapper = op
if has_control_flow:
op_wrapper = control_flow_ops.MakeWrapper(op)
in_grads = _AsList(grad_fn(op_wrapper, *out_grads))
_VerifyGeneratedGradients(in_grads, op)
if gate_gradients and len(in_grads) > 1:
in_grads = control_flow_ops.tuple(in_grads)
logging.vlog(1, "Gradient for '" + op.name + "'")
logging.vlog(1, " in --> %s",
", ".join([x.name for x in out_grads if x]))
logging.vlog(1, " out --> %s",
", ".join([x.name for x in in_grads if x]))
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagates a list of None backwards.
in_grads = [None] * len(op.inputs)
for t_in, in_grad in zip(op.inputs, in_grads):
if in_grad:
_SetGrad(grads, t_in, in_grad)
if has_control_flow:
control_flow_ops.ExitGradWhileContext(op)
# update pending count for the inputs of op.
for x in op.inputs:
pending_count[x.op._id] -= 1
ready = (pending_count[x.op._id] == 0)
if has_control_flow and not ready:
ready = (pending_count[x.op._id] > 0 and
control_flow_ops.IsLoopSwitch(x.op))
if ready:
queue.append(x.op)
for x in op.control_inputs:
pending_count[x._id] -= 1
if pending_count[x._id] is 0:
queue.append(x)
return [_GetGrad(grads, x) for x in xs]
开发者ID:rmt1,项目名称:tensorflow,代码行数:101,代码来源:gradients.py
示例9: _GradientsHelper
#.........这里部分代码省略.........
if ready and op not in to_ops_set and op in reachable_to_ops:
to_ops_set.add(op)
queue.append(op)
if loop_state:
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
for y in loop_exits:
if IsTrainable(y):
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
if loop_state:
loop_state.EnterGradWhileContext(op, before=True)
out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
aggregation_method)
if loop_state:
loop_state.ExitGradWhileContext(op, before=True)
grad_fn = None
func_call = None
is_partitioned_call = _IsPartitionedCall(op)
# pylint: disable=protected-access
is_func_call = (
src_graph._is_function(op.type) or is_partitioned_call)
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op not in stop_ops):
try:
grad_fn = ops.get_gradient_function(op)
except LookupError:
if is_func_call:
if is_partitioned_call:
func_call = src_graph._get_function( # pylint: disable=protected-access
compat.as_bytes(op.get_attr("f").name))
else:
func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
# Note that __defun is not set if the graph is
# imported. If it's set, we prefer to access the original
# defun.
func_call = getattr(op, "__defun", func_call)
grad_fn = func_call.python_grad_func
else:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
if loop_state:
loop_state.EnterGradWhileContext(op, before=False)
# NOTE(skyewm): We don't support computing gradients wrt a loop variable
# unless it's within the context of a single iteration (i.e. the
# gradient is wrt to the loop parameter in the body function, not wrt or
# through the initial value). This means if we're in a while loop
# context, we should never see a switch node from this context.
# pylint: disable=protected-access
if (control_flow_util.IsSwitch(op) and
op._control_flow_context is not None and
op._control_flow_context.IsWhileContext() and
op._control_flow_context ==
ops.get_default_graph()._get_control_flow_context()):
_RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
# pylint: enable=protected-access
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:67,代码来源:gradients_util.py
示例10: gradients
#.........这里部分代码省略.........
if loop_state:
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
for y in loop_exits:
if _IsTrainable(y):
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
with _maybe_colocate_with(op, colocate_gradients_with_ops):
if loop_state:
loop_state.EnterGradWhileContext(op, before=True)
out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
if loop_state:
loop_state.ExitGradWhileContext(op, before=True)
grad_fn = None
# pylint: disable=protected-access
func_call = None
is_func_call = ops.get_default_graph()._is_function(op.type)
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op._id not in stop_ops):
if is_func_call:
func_call = ops.get_default_graph()._get_function(op.type)
grad_fn = func_call.python_grad_func
# pylint: enable=protected-access
else:
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.
try:
grad_fn = ops.get_gradient_function(op)
except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
if loop_state:
loop_state.EnterGradWhileContext(op, before=False)
if (grad_fn or is_func_call) and has_out_grads:
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
(not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])):
# Only trainable outputs or outputs for a function call that
# will use SymbolicGradient get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
# TODO(apassos) gradients of resource handles might be an
# issue here because of zeros.
if loop_state:
out_grads[i] = loop_state.ZerosLike(op, i)
else:
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
with ops.name_scope(op.name + "_grad"):
# pylint: disable=protected-access
with ops.get_default_graph()._original_op(op):
# pylint: enable=protected-access
if grad_fn:
# If grad_fn was found, do not use SymbolicGradient even for
# functions.
in_grads = _MaybeCompile(grad_scope, op, func_call,
lambda: grad_fn(op, *out_grads))
else:
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:67,代码来源:gradients_impl.py
示例11: create_op
#.........这里部分代码省略.........
if self._return_as_is or op_type in _PASS_THROUGH_OPS:
return self._wrap(super(ImperativeGraph, self).create_op(*args, **kwargs))
if not output_dtypes:
return self._wrap(
super(ImperativeGraph, self).create_op(*args, **kwargs))
output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes]) # pylint: disable=protected-access
if output_has_ref:
if op_type not in _REF_OPS_WHITELIST:
raise errors.UnimplementedError(None, None,
op_type + ' op not supported in '
'imperative graph')
ret = super(ImperativeGraph, self).create_op(*args, **kwargs)
if self._in_variable_creation:
if op_type == 'Assign':
self.add_pending_init(ret)
return self._wrap(ret)
with self.return_as_is():
# Declares the variables to hold the output values of this op.
op_output_var = [state_ops.variable_op_v2(
tensor_shape.TensorShape(None), dtype, container=self._name)
for dtype in output_dtypes]
# Ops to free the resources used by the temporary cache variables.
# The following two ops are created for each cache variable,
# having no control dependencies on any other ops :
# var_handle_op ----> destroy_resource_op
for dtype, v in zip(output_dtypes, op_output_var):
with ops.control_dependencies(None):
self._variable_cleanup_ops += [
gen_resource_variable_ops.destroy_resource_op(
gen_resource_variable_ops.var_handle_op(
dtype, tensor_shape.TensorShape(None),
container=self._name, shared_name=v.op.name),
ignore_lookup_error=True)]
# Create the conditional to run the original op only when the variable
# corresponding to the first output is not initialized.
inited = state_ops.is_variable_initialized(op_output_var[0])
v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited)
# pylint: disable=protected-access
v_f_op = gen_array_ops._ref_identity(v_f)
v_t_op = gen_array_ops._ref_identity(v_t)
# pylint: enable=protected-access
with ops.control_dependencies([v_f_op.op]):
# Create the original op
orig_op = self._wrap(
super(ImperativeGraph, self).create_op(*args, **kwargs))
shapes = [val.get_shape() for val in orig_op.outputs]
controls = []
for var, val in zip(op_output_var, orig_op.outputs):
if (not val.get_shape().is_fully_defined() or
val.get_shape().num_elements() > 0):
assign_op = state_ops.assign(var, val, validate_shape=False)
assign_op.set_shape(val.get_shape())
controls.append(assign_op)
values = []
if len(controls) > 1:
if control_flow_ops.IsSwitch(orig_op):
# pylint: disable=protected-access
controls = gen_control_flow_ops._ref_merge(controls)
# pylint: enable=protected-access
else:
controls = control_flow_ops.tuple(controls)
for var, val in zip(op_output_var, orig_op.outputs):
with ops.control_dependencies(controls):
with self.colocate_with(v_f_op):
real_val = array_ops.identity(val)
with ops.control_dependencies([v_t_op.op]):
with self.colocate_with(v_t_op):
stored_val = array_ops.identity(var)
stored_val.set_shape(val.get_shape())
real_val, _ = control_flow_ops.merge([real_val, stored_val])
real_val.op.node_def.attr['_gradient_op_type'].CopyFrom(
attr_value_pb2.AttrValue(s=compat.as_bytes(self._merge_op_type)))
values.append(real_val)
for i, _ in enumerate(shapes):
values[i].set_shape(shapes[i])
self._outputs_map[orig_op.name] = values
try:
self._gradient_function_map[orig_op.name] = ops.get_gradient_function(
orig_op)
except (KeyError, LookupError):
pass
else:
orig_op.node_def.attr['_gradient_op_type'].CopyFrom(
attr_value_pb2.AttrValue(
s=compat.as_bytes(self._imperative_op_type)))
return MultiOutputOperation(values, orig_op)
开发者ID:chdinh,项目名称:tensorflow,代码行数:101,代码来源:imperative_graph.py
示例12: gradients
#.........这里部分代码省略.........
derivatives using a different initial gradient for each y (e.g., if
one wanted to weight the gradient differently for each value in
each y).
Args:
ys: A `Tensor` or list of tensors to be differentiated.
xs: A `Tensor` or list of tensors to be used for differentiation.
grad_ys: Optional. A `Tensor` or list of tensors the same size as
`ys` and holding the gradients computed for each y in `ys`.
name: Optional name to use for grouping all the gradient ops together.
defaults to 'gradients'.
colocate_gradients_with_ops: If True, try colocating gradients with
the corresponding op.
gate_gradients: If True, add a tuple around the gradients returned
for an operations. This avoids some race conditions.
aggregation_method: Specifies the method used to combine gradient terms.
Accepted values are constants defined in the class `AggregationMethod`.
Returns:
A list of `sum(dy/dx)` for each x in `xs`.
Raises:
LookupError: if one of the operations between `x` and `y` does not
have a registered gradient function.
ValueError: if the arguments are invalid.
"""
ys = _AsList(ys)
xs = _AsList(xs)
if grad_ys is None:
grad_ys = [None] * len(ys)
else:
grad_ys = _AsList(grad_ys)
with ops.op_scope(ys + xs + grad_ys, name, "gradients"):
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
xs = ops.convert_n_to_tensor_or_indexed_slices(xs, name="x")
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
# The approach we take here is as follows: Create a list of all ops in the
# subgraph between the ys and xs. Visit these ops in reverse order of ids
# to ensure that when we visit an op the gradients w.r.t its outputs have
# been collected. Then aggregate these gradients if needed, call the op's
# gradient function, and add the generated gradients to the gradients for
# its input.
# Initialize the pending count for ops in the connected subgraph from ys
# to the xs.
to_ops = [t.op for t in ys]
from_ops = [t.op for t in xs]
pending_count, has_control_flow = _PendingCount(
ops.get_default_graph(), to_ops, from_ops)
# Iterate over the collected ops.
#
# grads: op => list of gradients received on each output endpoint of the
# op. The gradients for each endpoint are initially collected as a list.
# When it is time to call the op's gradient function, for each endpoint we
# aggregate the list of received gradients into a Add() Operation if there
# is more than one.
grads = {}
# Add the initial gradients for the ys.
for y, grad_y in zip(ys, grad_ys):
_SetGrad(grads, y, grad_y)
# Initialize queue with to_ops.
queue = collections.deque()
# Add the ops in 'to_ops' into the queue.
to_ops_set = set()
for op in to_ops:
if op._id not in to_ops_set:
to_ops_set.add(op._id)
queue.append(op)
# The set of 'from_ops'.
stop_ops = _StopOps(from_ops, pending_count)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
with ops.device(_GetGradsDevice(op, colocate_gradients_with_ops)):
if has_control_flow:
control_flow_ops.EnterGradWhileContext(op)
out_grads = _AggregatedGrads(grads, op, has_control_flow,
aggregation_method)
grad_fn = None
if any(out_grads) and op._id not in stop_ops:
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.
try:
grad_fn = ops.get_gradient_function(op)
except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
if grad_fn and any(out_grads):
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
if (not out_grad
and types.as_dtype(op.outputs[i].dtype).base_dtype in
开发者ID:njustboy,项目名称:tensorflow,代码行数:101,代码来源:gradients.py
示例13: _GradientsHelper
#.........这里部分代码省略.........
if loop_state:
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
for y in loop_exits:
if _IsTrainable(y):
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
with _maybe_colocate_with(op, colocate_gradients_with_ops):
if loop_state:
loop_state.EnterGradWhileContext(op, before=True)
out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
if loop_state:
loop_state.ExitGradWhileContext(op, before=True)
grad_fn = None
# pylint: disable=protected-access
func_call = None
is_func_call = ops.get_default_graph()._is_function(op.type)
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op._id not in stop_ops):
if is_func_call:
func_call = ops.get_default_graph()._get_function(op.type)
grad_fn = func_call.python_grad_func
# pylint: enable=protected-access
else:
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.
try:
grad_fn = ops.get_gradient_function(op)
except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
if loop_state:
loop_state.EnterGradWhileContext(op, before=False)
if (grad_fn or is_func_call) and has_out_grads:
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
(not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])):
# Only trainable outputs or outputs for a function call that
# will use SymbolicGradient get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
# TODO(apassos) gradients of resource handles might be an
# issue here because of zeros.
if loop_state:
out_grads[i] = loop_state.ZerosLike(op, i)
else:
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
with ops.name_scope(op.name + "_grad"):
# pylint: disable=protected-access
with ops.get_default_graph()._original_op(op):
# pylint: enable=protected-access
if grad_fn:
# If grad_fn was found, do not use SymbolicGradient even for
# functions.
in_grads = _MaybeCompile(grad_scope, op, func_call,
lambda: grad_fn(op, *out_grads))
else:
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:67,代码来源:gradients_impl.py
注:本文中的tensorflow.python.framework.ops.get_gradient_function函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论