本文整理汇总了Python中tensorflow.python.util.function_utils.fn_args函数的典型用法代码示例。如果您正苦于以下问题:Python fn_args函数的具体用法?Python fn_args怎么用?Python fn_args使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了fn_args函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: add
def add(self, layer_func):
if isinstance(layer_func, base.Layer):
args = function_utils.fn_args(layer_func.call)
self.track_layer(layer_func)
elif callable(layer_func):
args = function_utils.fn_args(layer_func)
else:
raise TypeError(
"Sequential.add() takes only tf.layers.Layer objects or callables; "
"not '%s' of type '%s'." % (layer_func, type(layer_func)))
self._layers_funcs.append((("training" in args), layer_func))
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:11,代码来源:network.py
示例2: eval_step
def eval_step():
"""A single step of evaluation."""
estimator_spec = self._call_model_fn(features, labels,
model_fn_lib.ModeKeys.EVAL, params)
try:
captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
except AttributeError:
captured_scaffold_fn.capture(None)
eval_metric_fn = None
eval_metric_fn_tensors = []
try:
if estimator_spec.eval_metrics:
(eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
except AttributeError:
pass
# If a dictionary is provided, we need to convert it into a list sorted
# according to order of eval_metric_fn positional arguments.
if isinstance(eval_metric_fn_tensors, dict):
eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
eval_metric_fn_tensors = [
eval_metric_fn_tensors[i] for i in eval_metric_fn_args
]
captured_eval_metric_fn.capture(eval_metric_fn)
return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:29,代码来源:xla.py
示例3: call
def call(*args):
kwargs = dict(
zip(function_utils.fn_args(getattr(self._type, name))[1:], args))
specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs)
if specs is None:
raise ValueError(
'No tensor specifications were provided for: %s' % name)
flat_dtypes = nest.flatten(nest.map_structure(lambda s: s.dtype, specs))
flat_shapes = nest.flatten(nest.map_structure(lambda s: s.shape, specs))
def py_call(*args):
try:
self._out.send(args)
result = self._out.recv()
if isinstance(result, Exception):
raise result
if result is not None:
return result
except Exception as e:
if isinstance(e, IOError):
raise StopIteration() # Clean exit.
else:
raise
result = tf.py_func(py_call, (name,) + tuple(args), flat_dtypes,
name=name)
if isinstance(result, tf.Operation):
return result
for t, shape in zip(result, flat_shapes):
t.set_shape(shape)
return nest.pack_sequence_as(specs, result)
开发者ID:reinforcementdriving,项目名称:scalable_agent,代码行数:35,代码来源:py_process.py
示例4: run_step_fn
def run_step_fn(self, step_fn):
"""Run ops using a step function.
Args:
step_fn: A function or a method with a single argument of type
`StepContext`. The function may use methods of the argument to
perform computations with access to a raw session.
The returned value of the `step_fn` will be returned from `run_step_fn`,
unless a stop is requested. In that case, the next `should_stop` call
will return True.
Example usage:
```python
with tf.Graph().as_default():
c = tf.placeholder(dtypes.float32)
v = tf.add(c, 4.0)
w = tf.add(c, 0.5)
def step_fn(step_context):
a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
if a <= 4.5:
step_context.request_stop()
return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})
with tf.MonitoredSession() as session:
while not session.should_stop():
a = session.run_step_fn(step_fn)
```
Hooks interact with the `run_with_hooks()` call inside the `step_fn`
as they do with a `MonitoredSession.run` call.
Returns:
Returns the returned value of `step_fn`.
Raises:
StopIteration: if `step_fn` has called `request_stop()`. It may be
caught by `with tf.MonitoredSession()` to close the session.
ValueError: if `step_fn` doesn't have a single argument called
`step_context`. It may also optionally have `self` for cases when it
belongs to an object.
"""
step_fn_arguments = function_utils.fn_args(step_fn)
if step_fn_arguments != ('step_context',) and step_fn_arguments != (
'self',
'step_context',
):
raise ValueError(
'`step_fn` may either have one `step_context` argument, or'
' `self` and `step_context` arguments if it\'s an instance'
' method. Got {} instead.'.format(step_fn_arguments))
# `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
# Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
# `_CoordinatedSession.run` downstream in either case. This allows
# `_PREEMPTION_ERRORS` to propage from within `step_fn` to
# `_RecoverableSession.run_step_fn`.
return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:60,代码来源:monitored_session.py
示例5: test_bounded_method
def test_bounded_method(self):
class Foo(object):
def bar(self, a, b):
return a + b
self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:function_utils_test.py
示例6: test_callable
def test_callable(self):
class Foo(object):
def __call__(self, a, b):
return a + b
self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:function_utils_test.py
示例7: __init__
def __init__(self, type_, *constructor_args, **constructor_kwargs):
self._type = type_
self._constructor_kwargs = dict(
zip(function_utils.fn_args(type_.__init__)[1:], constructor_args))
self._constructor_kwargs.update(constructor_kwargs)
tf.add_to_collection(PyProcess.COLLECTION, self)
self._proxy = _TFProxy(type_, self._constructor_kwargs)
开发者ID:reinforcementdriving,项目名称:scalable_agent,代码行数:9,代码来源:py_process.py
示例8: _get_standardized_predicate_fn
def _get_standardized_predicate_fn(predicate_fn):
pred_fn_args = function_utils.fn_args(predicate_fn)
if "checkpoint_path" not in pred_fn_args:
# pylint: disable=unused-argument
def _pred_fn_wrapper(eval_results, checkpoint_path):
return predicate_fn(eval_results)
return _pred_fn_wrapper
else:
return predicate_fn
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:10,代码来源:experiment.py
示例9: _verify_estimator_spec
def _verify_estimator_spec(self, estimator_spec):
"""Verifies estimator spec contains correct data."""
# TODO(ycao): Implement estimator spec verification for other modes.
try:
if estimator_spec.scaffold:
logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
'. Please use TPUEstimatorSpec.scaffold_fn instead.')
except AttributeError:
pass
try:
if estimator_spec.eval_metric_ops:
raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
'XLA compilation. Please use '
'TPUEstimatorSpec.eval_metrics instead.')
except AttributeError:
pass
if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
# If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
# check that eval_metrics contains eval_metric_fn and
# eval_metric_fn_tensors with matching arguments.
try:
eval_metrics = estimator_spec.eval_metrics
except AttributeError:
eval_metrics = None
if eval_metrics:
(eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
if isinstance(eval_metric_fn_tensors, dict):
missing_tensors = [
i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
]
additional_tensors = [
i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
]
if missing_tensors:
raise ValueError('Arguments %s are needed by metric_fn (first '
'element of TPUEstimatorSpec.eval_metrics) but '
'they are not provided by evaluation tensors '
'(second element of TPUEstimatorSpec.eval_metrics)'
'.' % missing_tensors)
if additional_tensors:
raise ValueError('Arguments %s are provided by evaluation tensors '
'(second element of TPUEstimatorSpec.eval_metrics)'
' but they are not needed by metric_fn (first '
'element of TPUEstimatorSpec.eval_metrics).' %
additional_tensors)
return estimator_spec
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:55,代码来源:xla.py
示例10: test_partial_function
def test_partial_function(self):
expected_test_arg = 123
def fn(a, test_arg):
if test_arg != expected_test_arg:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg=123)
self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:11,代码来源:function_utils_test.py
示例11: _call_metric_fn
def _call_metric_fn(metric_fn, features, labels, predictions, config):
"""Calls metric fn with proper arguments."""
metric_fn_args = function_utils.fn_args(metric_fn)
kwargs = {}
if 'features' in metric_fn_args:
kwargs['features'] = features
if 'labels' in metric_fn_args:
kwargs['labels'] = labels
if 'predictions' in metric_fn_args:
kwargs['predictions'] = predictions
if 'config' in metric_fn_args:
kwargs['config'] = config
return metric_fn(**kwargs)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:13,代码来源:extenders.py
示例12: test_double_partial
def test_double_partial(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(a, test_arg1, test_arg2):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:13,代码来源:function_utils_test.py
示例13: test_double_partial_with_positional_args_in_both_layers
def test_double_partial_with_positional_args_in_both_layers(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(test_arg1, test_arg2, a):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, 123) # binds to test_arg1
double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2
self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:16,代码来源:function_utils_test.py
示例14: _call_model_fn
def _call_model_fn(self, features, labels, mode, params):
"""Calls the model_fn with required parameters."""
model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
elif labels is not None:
raise ValueError(
'model_fn does not take labels, but input_fn returns labels.')
if 'mode' in model_fn_args:
kwargs['mode'] = mode
if 'params' in model_fn_args:
kwargs['params'] = params
return self._verify_estimator_spec(
self._model_fn(features=features, **kwargs))
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:18,代码来源:xla.py
示例15: call_logit_fn
def call_logit_fn(logit_fn, features, mode, params, config):
"""Calls logit_fn.
A utility function that calls the provided logit_fn with the relevant subset
of provided arguments. Similar to tf.estimator._call_model_fn().
Args:
logit_fn: A logit_fn as defined above.
features: The features dict.
mode: TRAIN / EVAL / PREDICT ModeKeys.
params: The hyperparameter dict.
config: The configuration object.
Returns:
A logit Tensor, the output of logit_fn.
Raises:
ValueError: if logit_fn does not return a Tensor or a dictionary mapping
strings to Tensors.
"""
logit_fn_args = function_utils.fn_args(logit_fn)
kwargs = {}
if 'mode' in logit_fn_args:
kwargs['mode'] = mode
if 'params' in logit_fn_args:
kwargs['params'] = params
if 'config' in logit_fn_args:
kwargs['config'] = config
logit_fn_results = logit_fn(features=features, **kwargs)
result_is_valid_dictionary = (
isinstance(logit_fn_results, dict) and
all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor))
for k, v in six.iteritems(logit_fn_results)]))
result_is_tensor = isinstance(logit_fn_results, ops.Tensor)
if not (result_is_valid_dictionary or result_is_tensor):
raise ValueError('logit_fn should return a Tensor or a dictionary mapping '
'strings to Tensors. logit_fn returned: %s' %
logit_fn_results)
return logit_fn_results
开发者ID:AnishShah,项目名称:tensorflow,代码行数:42,代码来源:logit_fns.py
示例16: _validate_properties
def _validate_properties(run_config):
"""Validates the properties."""
def _validate(property_name, cond, message):
property_value = getattr(run_config, property_name)
if property_value is not None and not cond(property_value):
raise ValueError(message)
_validate('model_dir', lambda dir: dir,
message='model_dir should be non-empty')
_validate('save_summary_steps', lambda steps: steps >= 0,
message='save_summary_steps should be >= 0')
_validate('save_checkpoints_steps', lambda steps: steps >= 0,
message='save_checkpoints_steps should be >= 0')
_validate('save_checkpoints_secs', lambda secs: secs >= 0,
message='save_checkpoints_secs should be >= 0')
_validate('session_config',
lambda sc: isinstance(sc, config_pb2.ConfigProto),
message='session_config must be instance of ConfigProto')
_validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0,
message='keep_checkpoint_max should be >= 0')
_validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0,
message='keep_checkpoint_every_n_hours should be > 0')
_validate('log_step_count_steps', lambda num_steps: num_steps > 0,
message='log_step_count_steps should be > 0')
_validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types),
message='tf_random_seed must be integer.')
_validate('device_fn', lambda device_fn: six.callable(device_fn) and
set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
message='device_fn must be callable with exactly'
' one argument "op".')
_validate('protocol',
lambda protocol: protocol in (None, "grpc", "grpc+verbs"),
message='protocol should be grpc or grpc+verbs')
开发者ID:AnishShah,项目名称:tensorflow,代码行数:40,代码来源:run_config.py
示例17: test_simple_function
def test_simple_function(self):
def fn(a, b):
return a + b
self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:4,代码来源:function_utils_test.py
示例18: _verify_metric_fn_args
def _verify_metric_fn_args(metric_fn):
args = set(function_utils.fn_args(metric_fn))
invalid_args = list(args - _VALID_METRIC_FN_ARGS)
if invalid_args:
raise ValueError('metric_fn (%s) has following not expected args: %s' %
(metric_fn, invalid_args))
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:6,代码来源:extenders.py
示例19: __call__
def __call__(self, inputs, *args, **kwargs):
"""Wraps `call`, applying pre- and post-processing steps.
Arguments:
inputs: input tensor(s).
*args: additional positional arguments to be passed to `self.call`.
**kwargs: additional keyword arguments to be passed to `self.call`.
**Note**: kwarg `scope` is reserved for use by the layer.
Returns:
Output tensor(s).
Note:
- If the layer's `call` method takes a `scope` keyword argument,
this argument will be automatically set to the current variable scope.
- If the layer's `call` method takes a `mask` argument (as some Keras
layers do), its default value will be set to the mask generated
for `inputs` by the previous layer (if `input` did come from
a layer that generated a corresponding mask, i.e. if it came from
a Keras layer with masking support.
Raises:
ValueError: if the layer's `call` method returns None (an invalid value).
"""
scope = kwargs.pop('scope', None)
if self._keras_style:
if scope is not None:
raise ValueError(
'scope argument not allowed when keras style layers are enabled, '
'but saw: {}'.format(scope))
return super(Layer, self).__call__(inputs, *args, **kwargs)
self._set_scope(scope)
if not context.executing_eagerly():
try:
# Set layer's "graph" at build time
self._graph = ops._get_graph_from_inputs(nest.flatten(inputs), # pylint: disable=protected-access
graph=self._graph)
except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
if self.built:
try:
# Some classes which inherit from Layer do not use its constructor, so
# rather than initializing to None we check for an AttributeError.
scope_context_manager = self._always_reuse_variable_scope
except AttributeError:
# From this point we will always set reuse=True, so create a "final"
# variable scope with this setting. We avoid re-creating variable scopes
# after this point as an optimization.
self._always_reuse_variable_scope = vs.variable_scope(
self._scope, reuse=True, auxiliary_name_scope=False)
scope_context_manager = self._always_reuse_variable_scope
else:
scope_context_manager = vs.variable_scope(
self._scope, reuse=self._reuse, auxiliary_name_scope=False)
with scope_context_manager as scope:
self._current_scope = scope
try:
call_has_scope_arg = self._call_has_scope_arg
except AttributeError:
self._call_fn_args = function_utils.fn_args(self.call)
self._call_has_scope_arg = 'scope' in self._call_fn_args
call_has_scope_arg = self._call_has_scope_arg
if call_has_scope_arg:
kwargs['scope'] = scope
# Actually call layer
outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
if not context.executing_eagerly():
# Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
return outputs
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:78,代码来源:base.py
示例20: _get_loss_towers
def _get_loss_towers(model_fn,
mode,
features,
labels,
params,
config,
devices,
local_ps_devices,
loss_reduction,
name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
"""Replicate the loss computation across devices."""
tower_specs = []
model_fn_args = function_utils.fn_args(model_fn)
optional_params = {}
if 'params' in model_fn_args:
optional_params['params'] = copy.deepcopy(params)
if 'config' in model_fn_args:
optional_params['config'] = copy.deepcopy(config)
# pylint: disable=protected-access
round_robin_strategy = device_setter_lib._RoundRobinStrategy(
num_tasks=len(local_ps_devices))
TowerOptimizer._graph_state().set_reduction_across_towers(
loss_reduction, len(devices))
for i, device in enumerate(devices):
is_the_first_tower = (i == 0)
device_setter = _local_device_setter(
worker_device=device,
ps_devices=local_ps_devices,
ps_strategy=round_robin_strategy)
# We would like to preserve the names of the variables and ops that the user
# might be relying on. Names without a prefix are going to resolve to
# variables and ops of the first tower.
name_scope = name_scope_pattern
if is_the_first_tower:
name_scope = ''
with variable_scope.variable_scope(
'', reuse=not is_the_first_tower) as var_scope:
with ops_lib.name_scope(name_scope.format(i)) as name_scope:
with TowerOptimizer._graph_state().tower(
tower_id=i, var_scope=var_scope, name_scope=name_scope):
with ops_lib.device(device_setter):
labels_shard = None
if labels:
labels_shard = labels[i]
tower_spec = model_fn(
mode=mode,
features=features[i],
labels=labels_shard,
**optional_params)
if (tower_spec.train_op is not None and len(devices) > 1 and
not TowerOptimizer.has_been_used()):
raise ValueError('Please wrap optimizers with TowerOptimizer'
' in order to use replicate_model_fn with'
' multiple `devices`.')
# Scaling the loss here doesn't actually affect gradients. Another
# instance of scaling happens inside the TowerOptimizer.
tower_spec = _scale_tower_loss(
tower_spec, loss_reduction, number_of_towers=len(devices))
tower_specs.append(tower_spec)
if not TowerOptimizer._did_towers_have_same_optimizer_calls():
raise ValueError('Each invocation of model_fn was supposed to make the same'
' optimizer calls.')
TowerOptimizer._clear_graph_state()
# pylint: enable=protected-access
return tower_specs
开发者ID:AnishShah,项目名称:tensorflow,代码行数:75,代码来源:replicate_model_fn.py
注:本文中的tensorflow.python.util.function_utils.fn_args函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论