本文整理汇总了Python中tensorflow.contrib.graph_editor.reroute_ts函数的典型用法代码示例。如果您正苦于以下问题:Python reroute_ts函数的具体用法?Python reroute_ts怎么用?Python reroute_ts使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了reroute_ts函数的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_reroute
def test_reroute(self):
ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
self.assertTrue(match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
self.assertTrue(match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:8,代码来源:reroute_test.py
示例2: _FoldUnfusedBatchNorms
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
"""Finds unfused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
Args:
graph: Graph to walk and modify.
is_training: Bool, True if training.
freeze_batch_norm_delay: How many steps to wait before freezing moving mean
and variance and using them for batch normalization.
Raises:
ValueError: When batch norm folding fails.
"""
input_to_ops_map = input_to_ops.InputToOps(graph)
for bn in common.BatchNormGroups(graph):
has_scaling = _HasScaling(graph, input_to_ops_map, bn)
if not _IsValidUnfusedBatchNorm(graph, bn):
continue
# The mangling code intimately depends on BatchNorm node's internals.
original_op, folded_op = _CreateFoldedOp(
graph,
bn,
has_scaling=has_scaling,
freeze_batch_norm_delay=freeze_batch_norm_delay,
is_training=is_training)
activation = common.GetEndpointActivationOp(graph, bn)
if activation:
nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
[original_op.outputs[0]],
can_modify=[activation])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % activation.name)
continue
# Treat consumer ops in bypass modules differently since they have Add
# operations instead of Relu* above.
add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
[original_op.outputs[0]],
can_modify=[add_bypass])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:49,代码来源:fold_batch_norms.py
示例3: _FoldFusedBatchNorms
def _FoldFusedBatchNorms(graph):
"""Finds fused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
Args:
graph: Graph to walk and modify.
Raises:
ValueError: When batch norm folding fails.
"""
for match in _FindFusedBatchNorms(graph):
scope, sep, _ = match.layer_op.name.rpartition('/')
# Make sure new ops are added to `graph` and put on the same device as
# `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
# named `scope`. Otherwise, TF creates a unique scope whose name starts with
# `scope`.
with graph.as_default(), graph.name_scope(scope + sep), ops.device(
match.bn_op.device):
with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
# new weights = old weights * gamma / sqrt(variance + epsilon)
# new biases = -mean * gamma / sqrt(variance + epsilon) + beta
multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
match.variance_tensor + match.bn_op.get_attr('epsilon'))
bias_tensor = math_ops.subtract(
match.beta_tensor,
match.mean_tensor * multiplier_tensor,
name='bias')
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
if match.layer_op.type == 'DepthwiseConv2dNative':
new_shape = [
match.weight_tensor.get_shape().as_list()[2],
match.weight_tensor.get_shape().as_list()[3]
]
multiplier_tensor = array_ops.reshape(
multiplier_tensor, new_shape, name='scale_reshape')
# TODO(suharshs): This naming of the following ops needs to carefully
# follow the naming expected by quantize.py. Generalize the quantize code
# to not require these delicate naming conventions.
scaled_weight_tensor = math_ops.multiply(
match.weight_tensor, multiplier_tensor, name='mul_fold')
new_layer_tensor = _CloneWithNewOperands(
match.layer_op, match.input_tensor, scaled_weight_tensor)
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
match.output_tensor)
if nodes_modified_count != 1:
raise ValueError(
'Unexpected inputs to op: %s' % match.output_tensor.name)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:58,代码来源:fold_batch_norms.py
示例4: FoldBatchNorms
def FoldBatchNorms(graph):
"""Finds batch norm layers in the graph, folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
Args:
graph: Graph to walk and modify.
Raises:
ValueError: When batch norm folding fails.
"""
# Fail immediately when the graph contains unsupported fused batch norm ops.
if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'):
raise ValueError('Fused batch norm is not supported')
input_to_ops_map = input_to_ops.InputToOps(graph)
for bn in common.BatchNormGroups(graph):
has_scaling = _HasScaling(graph, input_to_ops_map, bn)
# The mangling code intimately depends on BatchNorm node's internals.
original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling)
activation = common.GetEndpointActivationOp(graph, bn)
if activation:
nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
[original_op.outputs[0]],
can_modify=[activation])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % activation.name)
continue
# Treat consumer ops in bypass modules differently since they have Add
# operations instead of Relu* above.
add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
[original_op.outputs[0]],
can_modify=[add_bypass])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
开发者ID:alexsax,项目名称:tensorflow,代码行数:42,代码来源:fold_batch_norms.py
示例5: _InsertQuantOp
#.........这里部分代码省略.........
quant_delay: (Optional, default None) Int, count of global steps for which
to delay quantization. This helps weights stabilize at the start of
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
narrow_range: Whether to use the narrow quantization range
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
producer_scope: The restriction of producer scope. If not None, the new op
will be inserted only when the producer is in this scope.
consumer_scope: The restriction of producer scope. If not None, the new op
will be inserted only when all the consumers are in this scope.
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
if producer_scope and not producer.name.startswith(producer_scope):
logging.info(
'_InsertQuantOp ignores context="%s" name="%s" '
'because producer "%s" is not in scope "%s"',
context, name, producer.name, producer_scope)
return
if consumer_scope:
consumers_in_scope = []
for consumer in consumers:
if consumer.name.startswith(consumer_scope):
consumers_in_scope.append(consumer)
else:
logging.info(
'_InsertQuantOp context="%s" name="%s" ignores '
'consumer "%s" because it is not in scope "%s"',
context, name, consumer.name, consumer_scope)
return
consumers = consumers_in_scope
name_prefix = _AddContextToName(context, name)
# This is needed on TPU where name_scope == 'TPUReplicate/loop', and
# name_prefix starts with 'TPUReplicate/loop/'; without dropping it
# variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
# breaks things later.
name_scope = ops.get_name_scope()
if name_scope:
name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')
inputs = producer.outputs[0]
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
fake_quant_ops = set([
'FakeQuantWithMinMaxVars',
'FakeQuantWithMinMaxArgs'
])
if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
return
if moving_avg:
quant = (
quant_ops.MovingAvgQuantize(
inputs,
init_min=init_min,
init_max=init_max,
ema_decay=ema_decay,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
else:
quant = (
quant_ops.LastValueQuantize(
inputs,
init_min=init_min,
init_max=init_max,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
if quant_delay and quant_delay > 0:
activate_quant = math_ops.greater_equal(
common.CreateOrGetQuantizationStep(),
quant_delay,
name=name_prefix + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=name_prefix + '/delayed_quant')
if consumers:
tensors_modified_count = graph_editor.reroute_ts(
[quant], [inputs], can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
# of consumers.
if tensors_modified_count < len(consumers):
raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
[consumer.name for consumer in consumers]))
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:101,代码来源:quantize.py
示例6: _FoldFusedBatchNorms
def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
"""Finds fused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
Args:
graph: Graph to walk and modify.
is_training: Bool, true if training.
freeze_batch_norm_delay: How many steps to wait before freezing moving mean
and variance and using them for batch normalization.
Raises:
ValueError: When batch norm folding fails.
"""
for match in _FindFusedBatchNorms(graph):
scope, sep, _ = match.layer_op.name.rpartition('/')
# Make sure new ops are added to `graph` and put on the same device as
# `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
# named `scope`. Otherwise, TF creates a unique scope whose name starts with
# `scope`.
with graph.as_default(), graph.name_scope(scope + sep):
with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
# new weights = old weights * gamma / sqrt(variance + epsilon)
# new biases = -mean * gamma / sqrt(variance + epsilon) + beta
multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
match.variance_tensor + match.bn_op.get_attr('epsilon'))
bias_tensor = math_ops.subtract(
match.beta_tensor,
match.mean_tensor * multiplier_tensor,
name='bias')
correction_scale, correction_recip, correction_offset = None, None, None
if is_training:
correction_scale, correction_recip, correction_offset = (
_ComputeBatchNormCorrections(
context='',
match=match,
freeze_batch_norm_delay=freeze_batch_norm_delay,
fused_batch_norm=True))
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
weights = match.weight_tensor
if match.layer_op.type == 'DepthwiseConv2dNative':
new_shape = [
match.weight_tensor.get_shape().as_list()[2],
match.weight_tensor.get_shape().as_list()[3]
]
multiplier_tensor = array_ops.reshape(
multiplier_tensor, new_shape, name='scale_reshape')
if correction_scale is not None:
correction_scale = array_ops.reshape(
correction_scale, new_shape, name='correction_reshape')
if correction_scale is not None:
weights = math_ops.multiply(
correction_scale, weights, name='correction_mult')
scaled_weight_tensor = math_ops.multiply(
weights, multiplier_tensor, name='mul_fold')
new_layer_tensor = _CloneWithNewOperands(
match.layer_op, match.input_tensor, scaled_weight_tensor)
if correction_recip is not None:
new_layer_tensor = math_ops.multiply(
correction_recip, new_layer_tensor, name='post_conv_mul')
new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset),
'correction_add')
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
match.output_tensor)
if nodes_modified_count == 0:
raise ValueError('Folding batch norms failed, %s had no outputs.' %
match.output_tensor.name)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:79,代码来源:fold_batch_norms.py
示例7: _ComputeBatchNormCorrections
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
fused_batch_norm):
"""Computes batch norm correction params.
Before batch normalization is frozen:
We use batch statistics for batch norm.
correction_scale = sigma_b/sigma_mv
correction_recip = 1/correction_scale
correction_offset = 0
After batch normalization is frozen:
correction_scale = sigma_b/sigma_mv
correction_recip = 1
correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv).
Batch norm is frozen if global_step > bn_freeze_delay.
The corrections ensure that:
a) The weights are quantized after scaling by gamma/sigma_mv. This enables
smoother training as the scaling on the weights changes slowly, rather than
jump across mini-batches
b) Changing the values of the corrections allows for one to switch between
using batch statistics to using moving mean and average, without requiring
changes to batch_norm
Args:
context: The scope under which we look for batch norm params
match: Object containing required batch norm tensors for correction
computation.
freeze_batch_norm_delay: Delay in steps at which computation switches
from regular batch norm to frozen mean and variance.
fused_batch_norm: Bool, true if fused batch norm is used.
Returns:
A tuple of correction_scale, correction_recip, correction_offset
"""
g = ops.get_default_graph()
prefix = '' if not context else context + '/'
with g.name_scope(prefix + 'batch_norm_correction'):
recip_sigma_mv = math_ops.rsqrt(
match.moving_variance_tensor + match.batch_epsilon)
recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon)
correction_scale = math_ops.divide(
recip_sigma_mv, recip_sigma, name='scale_compute')
correction_scale = array_ops.identity(
correction_scale, name='correction_scale')
correction_recip = math_ops.reciprocal(
correction_scale, name='reciprocal_compute')
correction_offset = math_ops.multiply(
match.gamma_tensor,
match.mean_tensor * recip_sigma -
match.moving_mean_tensor * recip_sigma_mv,
name='offset_compute')
if freeze_batch_norm_delay is not None:
use_mv_avg = math_ops.greater_equal(
common.CreateOrGetQuantizationStep(),
freeze_batch_norm_delay,
name='use_moving_average')
else:
use_mv_avg = False
bn_decay_zero = 0.0
bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())
bn_decay_mean_out = utils.smart_cond(
use_mv_avg,
lambda: bn_decay_zero,
lambda: match.bn_decay_mean_tensor,
name='freeze_moving_mean')
graph_editor.reroute_ts(
[bn_decay_mean_out], [match.bn_decay_mean_tensor],
can_modify=bn_decay_mean_consumers)
if fused_batch_norm is False:
bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
bn_decay_var_out = utils.smart_cond(
use_mv_avg,
lambda: bn_decay_zero,
lambda: match.bn_decay_var_tensor,
name='freeze_moving_var')
graph_editor.reroute_ts(
[bn_decay_var_out], [match.bn_decay_var_tensor],
can_modify=bn_decay_var_consumers)
correction_recip = utils.smart_cond(
use_mv_avg,
lambda: array_ops.ones(correction_scale.shape),
lambda: correction_recip,
name='correction_recip')
correction_offset = utils.smart_cond(
use_mv_avg,
lambda: correction_offset,
lambda: array_ops.zeros(correction_offset.shape),
name='correction_offset')
return correction_scale, correction_recip, correction_offset
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:99,代码来源:fold_batch_norms.py
示例8: test_compatibility
def test_compatibility(self):
with self.assertRaises(ValueError):
ge.reroute_ts([self.a0, self.b0], [self.a2, self.b2])
开发者ID:1000sprites,项目名称:tensorflow,代码行数:3,代码来源:reroute_test.py
示例9: _InsertQuantOp
def _InsertQuantOp(
self,
context,
producer,
consumers,
name,
moving_avg=True,
init_min=-6.0,
init_max=6.0,
delay_requested=True,
bits=8,
narrow_range=False,):
"""Inserts a quant op between a producer op and (multiple) consumer ops.
Args:
context: Context where producer and consumer operations are nested.
producer: Producer operation of the pairs where quantization will be
inserted.
consumers: Consumer operations of the pairs.
name: Name for the new quantization op within the context.
moving_avg: Specifies whether to use exponential moving average or just
the last value seen.
init_min: Starting minimum value for the new quantization op.
init_max: Starting maximum value for the new quantization op.
delay_requested: If true, implement quantization delay where needed.
False value explicitly disables delay quantization everywhere.
bits: Number of bits to use for quantization, must be between 2 and 8.
narrow_range: Whether to use the narrow quantization range
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
scope = context + '/' + name
inputs = producer.outputs[0]
if moving_avg:
quant = (quant_ops.MovingAvgQuantize(
inputs,
init_min=init_min,
init_max=init_max,
ema_decay=self.ema_decay,
is_training=self.is_training,
num_bits=bits,
narrow_range=narrow_range,
updates_collection=_UPDATE_QUANT_OPS,
vars_collection=self.vars_collection,
scope=scope))
else:
quant = (quant_ops.LastValueQuantize(
inputs,
init_min=init_min,
init_max=init_max,
is_training=self.is_training,
num_bits=bits,
narrow_range=narrow_range,
updates_collection=_UPDATE_QUANT_OPS,
vars_collection=self.vars_collection,
scope=scope))
if delay_requested and self.quant_delay and self.quant_delay > 0:
activate_quant = math_ops.greater_equal(
training_util.get_or_create_global_step(),
self.quant_delay,
name=scope + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=scope + '/delayed_quant')
nodes_modified_count = graph_editor.reroute_ts(
[quant], [inputs], can_modify=consumers)
if nodes_modified_count != len(consumers):
raise ValueError('Some inputs not quantized for ops: [%s]' %
', '.join([consumer.name for consumer in consumers]))
开发者ID:SylChan,项目名称:tensorflow,代码行数:75,代码来源:quantize.py
示例10: _InsertQuantOp
def _InsertQuantOp(context,
name,
producer,
consumers,
is_training,
moving_avg=True,
init_min=-6.0,
init_max=6.0,
bits=8,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
narrow_range=False):
"""Inserts a quant op between a producer op and (multiple) consumer ops.
Args:
context: Context w,here producer and consumer operations are nested.
name: Name for the new quantization op within the context.
producer: Producer operation of the pairs where quantization will be
inserted.
consumers: Consumer operations of the pairs.
is_training: Whether quantizing training graph or eval graph.
moving_avg: Specifies whether to use exponential moving average or just
the last value seen.
init_min: Starting minimum value for the new quantization op.
init_max: Starting maximum value for the new quantization op.
bits: Number of bits to use for quantization, must be between 2 and 8.
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
quantization intervals for quantizing activations (see here about EMA:
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
quant_delay: (Optional, default None) Int, count of global steps for which
to delay quantization. This helps weights stabilize at the start of
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
narrow_range: Whether to use the narrow quantization range
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
name_prefix = _AddContextToName(context, name)
inputs = producer.outputs[0]
if moving_avg:
quant = (
quant_ops.MovingAvgQuantize(
inputs,
init_min=init_min,
init_max=init_max,
ema_decay=ema_decay,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
else:
quant = (
quant_ops.LastValueQuantize(
inputs,
init_min=init_min,
init_max=init_max,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
if quant_delay and quant_delay > 0:
activate_quant = math_ops.greater_equal(
common.CreateOrGetQuantizationStep(),
quant_delay,
name=name_prefix + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=name_prefix + '/delayed_quant')
nodes_modified_count = graph_editor.reroute_ts(
[quant], [inputs], can_modify=consumers)
if nodes_modified_count != len(consumers):
raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join(
[consumer.name for consumer in consumers]))
开发者ID:dananjayamahesh,项目名称:tensorflow,代码行数:83,代码来源:quantize.py
示例11: gradients
#.........这里部分代码省略.........
# new edge cases, exclude them
if ys_intersect_checkpoints:
debug_print("Warning, some output nodes are also checkpoints nodes: {}".format(
format_ops(ys_intersect_checkpoints)))
# remove initial and terminal nodes from checkpoints list if present
checkpoints = list(set(checkpoints) - set(ys) - set(xs))
# check that we have some nodes to checkpoint
if not checkpoints:
raise Exception('no checkpoints nodes found or given as input! ')
# disconnect dependencies between checkpointed tensors
checkpoints_disconnected = {}
for x in checkpoints:
if x.op and x.op.name is not None:
grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
else:
grad_node = tf.stop_gradient(x)
checkpoints_disconnected[x] = grad_node
# partial derivatives to the checkpointed tensors and xs
ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
stop_at_ts=checkpoints, within_ops=fwd_ops)
debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints))
debug_print("ops_to_copy = {}".format(ops_to_copy))
debug_print("Processing list {}".format(ys))
_, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
for origin_op, op in info._transformed_ops.items():
op._set_device(origin_op.node_def.device)
copied_ops = info._transformed_ops.values()
debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
ge.reroute_ts(checkpoints_disconnected.values(),
checkpoints_disconnected.keys(),
can_modify=copied_ops)
debug_print("Rewired {} in place of {} restricted to {}".format(
checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops))
# get gradients with respect to current boundary + original x's
copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
boundary = list(checkpoints_disconnected.values())
dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
debug_print("Got gradients {}".format(dv))
debug_print("for %s", copied_ys)
debug_print("with respect to {}".format(boundary+xs))
inputs_to_do_before = [y.op for y in ys]
if grad_ys is not None:
inputs_to_do_before += grad_ys
wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
# partial derivatives to the checkpointed nodes
# dictionary of "node: backprop" for nodes in the boundary
d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
dv[:len(checkpoints_disconnected)])}
# partial derivatives to xs (usually the params of the neural net)
d_xs = dv[len(checkpoints_disconnected):]
# incorporate derivatives flowing through the checkpointed nodes
checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
for ts in checkpoints_sorted_lists[::-1]:
debug_print("Processing list {}".format(ts))
checkpoints_other = [r for r in checkpoints if r not in ts]
checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other]
开发者ID:stonezuohui,项目名称:faceswap,代码行数:67,代码来源:memory_saving_gradients.py
注:本文中的tensorflow.contrib.graph_editor.reroute_ts函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论