本文整理汇总了Python中tensorflow.python.estimator.util.fn_args函数的典型用法代码示例。如果您正苦于以下问题:Python fn_args函数的具体用法?Python fn_args怎么用?Python fn_args使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了fn_args函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: add
def add(self, layer_func):
if isinstance(layer_func, base.Layer):
args = estimator_util.fn_args(layer_func.call)
self.track_layer(layer_func)
elif callable(layer_func):
args = estimator_util.fn_args(layer_func)
else:
raise TypeError(
"Sequential.add() takes only tf.layers.Layer objects or callables; "
"not '%s' of type '%s'." % (layer_func, type(layer_func)))
self._layers_funcs.append((("training" in args), layer_func))
开发者ID:bikong2,项目名称:tensorflow,代码行数:11,代码来源:network.py
示例2: _call_model_fn
def _call_model_fn(self, features, labels, add_batch_size_in_params=False):
"""Calls the model_fn with required parameters."""
model_fn_args = util.fn_args(self._model_fn)
kwargs = {}
config = copy.deepcopy(self._config)
params = copy.deepcopy(self._params)
if 'labels' in model_fn_args:
kwargs['labels'] = labels
else:
if labels is not None:
raise ValueError(
'model_fn does not take labels, but input_fn returns labels.')
if 'mode' in model_fn_args:
kwargs['mode'] = self._mode
if 'config' in model_fn_args:
kwargs['config'] = config
if 'params' in model_fn_args:
kwargs['params'] = params
if add_batch_size_in_params:
if 'params' not in model_fn_args:
raise ValueError(
'model_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params[\'batch_size\']'.format(self._model_fn))
if self._mode == model_fn_lib.ModeKeys.TRAIN:
# For TPU training. `params` is never `None`.
params[_BATCH_SIZE_KEY] = _per_shard_batch_size(self._train_batch_size,
config)
return self._model_fn(features=features, **kwargs)
开发者ID:awisbith,项目名称:tensorflow,代码行数:33,代码来源:tpu_estimator.py
示例3: _call_loss_fn
def _call_loss_fn(loss_fn, labels, logits, features):
"""Calls loss_fn and checks the returned shape.
Args:
loss_fn: The loss function.
labels: Processed labels Tensor.
logits: Logits Tensor of shape [batch_size, logits_dimension].
features: Features dict.
Returns:
Loss Tensor with shape [batch_size, 1].
"""
loss_fn_args = util.fn_args(loss_fn)
kwargs = {}
if 'features' in loss_fn_args:
kwargs['features'] = features
unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)
batch_size = array_ops.shape(logits)[0]
loss_shape = array_ops.shape(unweighted_loss)
check_shape_op = control_flow_ops.Assert(
math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])),
data=[
'loss_fn must return Tensor of shape [batch_size, 1]. Given: ',
loss_shape])
with ops.control_dependencies([check_shape_op]):
return array_ops.identity(unweighted_loss)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:head.py
示例4: export
def export(self,
estimator,
export_path,
checkpoint_path=None,
eval_result=None):
"""Exports the given Estimator to a specific format.
Args:
estimator: the Estimator to export.
export_path: A string containing a directory where to write the export.
checkpoint_path: The checkpoint path to export. If None (the default),
the strategy may locate a checkpoint (e.g. the most recent) by itself.
eval_result: The output of Estimator.evaluate on this checkpoint. This
should be set only if checkpoint_path is provided (otherwise it is
unclear which checkpoint this eval refers to).
Returns:
The string path to the exported directory.
Raises:
ValueError: if the export_fn does not have the required signature.
"""
export_fn_args = util.fn_args(self.export_fn)
kwargs = {}
if 'checkpoint_path' in export_fn_args:
kwargs['checkpoint_path'] = checkpoint_path
if 'eval_result' in export_fn_args:
if 'checkpoint_path' not in export_fn_args:
raise ValueError('An export_fn accepting eval_result must also accept '
'checkpoint_path.')
kwargs['eval_result'] = eval_result
return self.export_fn(estimator, export_path, **kwargs)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:33,代码来源:export_strategy.py
示例5: _call_model_fn
def _call_model_fn(self, features, labels, mode, config):
"""Calls model function.
Args:
features: features dict.
labels: labels dict.
mode: ModeKeys
config: RunConfig
Returns:
An `EstimatorSpec` object.
Raises:
ValueError: if model_fn returns invalid objects.
"""
model_fn_args = util.fn_args(self._model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
else:
if labels is not None:
raise ValueError(
'model_fn does not take labels, but input_fn returns labels.')
if 'mode' in model_fn_args:
kwargs['mode'] = mode
if 'params' in model_fn_args:
kwargs['params'] = self.params
if 'config' in model_fn_args:
kwargs['config'] = config
model_fn_results = self._model_fn(features=features, **kwargs)
if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
raise ValueError('model_fn should return an EstimatorSpec.')
return model_fn_results
开发者ID:ilya-edrenkin,项目名称:tensorflow,代码行数:35,代码来源:estimator.py
示例6: call_logit_fn
def call_logit_fn(logit_fn, features, mode, params, config):
"""Calls logit_fn.
A utility function that calls the provided logit_fn with the relevant subset
of provided arguments. Similar to tf.estimator._call_model_fn().
Args:
logit_fn: A logit_fn as defined above.
features: The features dict.
mode: TRAIN / EVAL / PREDICT ModeKeys.
params: The hyperparameter dict.
config: The configuration object.
Returns:
A logit Tensor, the output of logit_fn.
Raises:
ValueError: if logit_fn does not return a Tensor.
"""
logit_fn_args = util.fn_args(logit_fn)
kwargs = {}
if 'mode' in logit_fn_args:
kwargs['mode'] = mode
if 'params' in logit_fn_args:
kwargs['params'] = params
if 'config' in logit_fn_args:
kwargs['config'] = config
logit_fn_results = logit_fn(features=features, **kwargs)
if not isinstance(logit_fn_results, ops.Tensor):
raise ValueError('model_fn should return a Tensor.')
return logit_fn_results
开发者ID:1000sprites,项目名称:tensorflow,代码行数:33,代码来源:logit_fns.py
示例7: _call_input_fn
def _call_input_fn(self, input_fn, mode):
"""Calls the input function.
Args:
input_fn: The input function.
mode: ModeKeys
Returns:
Either features or (features, labels) where features and labels are:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
Raises:
ValueError: if input_fn takes invalid arguments.
"""
input_fn_args = util.fn_args(input_fn)
kwargs = {}
if 'mode' in input_fn_args:
kwargs['mode'] = mode
if 'params' in input_fn_args:
kwargs['params'] = self.params
if 'config' in input_fn_args:
kwargs['config'] = self.config
with ops.device('/cpu:0'):
return input_fn(**kwargs)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:25,代码来源:estimator.py
示例8: run_step_fn
def run_step_fn(self, step_fn):
"""Run ops using a step function.
Args:
step_fn: A function or a method with a single argument of type
`StepContext`. The function may use methods of the argument to
perform computations with access to a raw session.
The returned value of the `step_fn` will be returned from `run_step_fn`,
unless a stop is requested. In that case, the next `should_stop` call
will return True.
Example usage:
```python
with tf.Graph().as_default():
c = tf.placeholder(dtypes.float32)
v = tf.add(c, 4.0)
w = tf.add(c, 0.5)
def step_fn(step_context):
a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
if a <= 4.5:
step_context.request_stop()
return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})
with tf.MonitoredSession() as session:
while not session.should_stop():
a = session.run_step_fn(step_fn)
```
Hooks interact with the `run_with_hooks()` call inside the `step_fn`
as they do with a `MonitoredSession.run` call.
Returns:
Returns the returned value of `step_fn`.
Raises:
StopIteration: if `step_fn` has called `request_stop()`. It may be
caught by `with tf.MonitoredSession()` to close the session.
ValueError: if `step_fn` doesn't have a single argument called
`step_context`. It may also optionally have `self` for cases when it
belongs to an object.
"""
step_fn_arguments = util.fn_args(step_fn)
if step_fn_arguments != ('step_context',) and step_fn_arguments != (
'self',
'step_context',
):
raise ValueError(
'`step_fn` may either have one `step_context` argument, or'
' `self` and `step_context` arguments if it\'s an instance'
' method. Got {} instead.'.format(step_fn_arguments))
# `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
# Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
# `_CoordinatedSession.run` downstream in either case. This allows
# `_PREEMPTION_ERRORS` to propage from within `step_fn` to
# `_RecoverableSession.run_step_fn`.
return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
开发者ID:SylChan,项目名称:tensorflow,代码行数:60,代码来源:monitored_session.py
示例9: test_callable
def test_callable(self):
class Foo(object):
def __call__(self, a, b):
return a + b
self.assertEqual(('self', 'a', 'b'), util.fn_args(Foo()))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:8,代码来源:util_test.py
示例10: test_bounded_method
def test_bounded_method(self):
class Foo(object):
def bar(self, a, b):
return a + b
self.assertEqual(('a', 'b'), util.fn_args(Foo().bar))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:util_test.py
示例11: _call_input_fn
def _call_input_fn(self, input_fn, mode):
"""Calls the input function.
Args:
input_fn: The input function.
mode: ModeKeys
Returns:
Either features or (features, labels) where features and labels are:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
Raises:
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
input_fn_args = util.fn_args(input_fn)
config = self.config # a deep copy.
kwargs = {}
if 'params' in input_fn_args:
kwargs['params'] = self.params # a deep copy.
else:
raise ValueError('input_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params["batch_size"]'.format(input_fn))
if 'config' in input_fn_args:
kwargs['config'] = config
# Now for TPU training.
if mode == model_fn_lib.ModeKeys.TRAIN:
kwargs['params'][_BATCH_SIZE_KEY] = (
_per_shard_batch_size(self._train_batch_size, config, self._use_tpu)
if not config.tpu_config.per_host_input_for_training else
self._train_batch_size)
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
with ops.device('/cpu:0'):
return input_fn(**kwargs)
job = _tpu_job(config)
def placement_function(index):
if job is None:
return '/replica:0/task:0/device:CPU:0'
else:
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
if not config.tpu_config.per_host_input_for_training:
num_shards = config.tpu_config.num_shards
inputs = _InputsHolder(num_shards=num_shards)
for i in range(config.tpu_config.num_shards):
with ops.device(placement_function(i)):
inputs.append_tuple(input_fn(**kwargs))
return inputs.as_features_and_labels_tuple()
else:
# TODO(xiejw): Extend this to multi-host support.
with ops.device(placement_function(0)):
return input_fn(**kwargs)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:57,代码来源:tpu_estimator.py
示例12: _call_input_fn
def _call_input_fn(self, input_fn, mode):
"""Calls the input function.
Args:
input_fn: The input function.
mode: ModeKeys
Returns:
Either features or (features, labels) where features and labels are:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
Raises:
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
input_fn_args = util.fn_args(input_fn)
config = self.config # a deep copy.
kwargs = {}
if 'params' in input_fn_args:
kwargs['params'] = self.params # a deep copy.
else:
raise ValueError('input_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params["batch_size"]'.format(input_fn))
if 'config' in input_fn_args:
kwargs['config'] = config
# Now for TPU training.
per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config)
kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size
job = _tpu_job(config)
def placement_function(index):
if job is None:
return '/replica:0/task:0/device:CPU:0'
else:
return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
features = []
labels = []
for i in range(config.tpu_config.num_shards):
with ops.device(placement_function(i)):
result = input_fn(**kwargs)
# input_fn may return either features or (features, labels)
if isinstance(result, tuple):
features.append(result[0])
labels.append(result[1])
else:
features.append(result)
if not labels or all(l is None for l in labels):
return _PerShardOutput(features), None
return _PerShardOutput(features), _PerShardOutput(labels)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:57,代码来源:tpu_estimator.py
示例13: _verify_metric_fn_args
def _verify_metric_fn_args(metric_fn):
args = set(estimator_util.fn_args(metric_fn))
if tf_inspect.ismethod(metric_fn):
if 'self' in args:
args.remove('self')
invalid_args = list(args - _VALID_METRIC_FN_ARGS)
if invalid_args:
raise ValueError('metric_fn (%s) has following not expected args: %s' %
(metric_fn, invalid_args))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:9,代码来源:extenders.py
示例14: _get_standardized_predicate_fn
def _get_standardized_predicate_fn(predicate_fn):
pred_fn_args = estimator_util.fn_args(predicate_fn)
if "checkpoint_path" not in pred_fn_args:
# pylint: disable=unused-argument
def _pred_fn_wrapper(eval_results, checkpoint_path):
return predicate_fn(eval_results)
return _pred_fn_wrapper
else:
return predicate_fn
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:10,代码来源:experiment.py
示例15: test_partial_function
def test_partial_function(self):
expected_test_arg = 123
def fn(a, test_arg):
if test_arg != expected_test_arg:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg=123)
self.assertEqual(('a',), util.fn_args(wrapped_fn))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:11,代码来源:util_test.py
示例16: _call_metric_fn
def _call_metric_fn(metric_fn, features, labels, predictions, config):
"""Calls metric fn with proper arguments."""
metric_fn_args = estimator_util.fn_args(metric_fn)
kwargs = {}
if 'features' in metric_fn_args:
kwargs['features'] = features
if 'labels' in metric_fn_args:
kwargs['labels'] = labels
if 'predictions' in metric_fn_args:
kwargs['predictions'] = predictions
if 'config' in metric_fn_args:
kwargs['config'] = config
return metric_fn(**kwargs)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:13,代码来源:extenders.py
示例17: test_double_partial
def test_double_partial(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(a, test_arg1, test_arg2):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:13,代码来源:util_test.py
示例18: _verify_compare_fn_args
def _verify_compare_fn_args(compare_fn):
"""Verifies compare_fn arguments."""
args = set(util.fn_args(compare_fn))
if 'best_eval_result' not in args:
raise ValueError(
'compare_fn (%s) must include best_eval_result argument.' % compare_fn)
if 'current_eval_result' not in args:
raise ValueError(
'compare_fn (%s) must include current_eval_result argument.' %
compare_fn)
non_valid_args = list(args - set(['best_eval_result', 'current_eval_result']))
if non_valid_args:
raise ValueError('compare_fn (%s) has following not expected args: %s' %
(compare_fn, non_valid_args))
开发者ID:jinxin0924,项目名称:tensorflow,代码行数:14,代码来源:exporter.py
示例19: test_partial_function_with_positional_args
def test_partial_function_with_positional_args(self):
expected_test_arg = 123
def fn(test_arg, a):
if test_arg != expected_test_arg:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, 123)
self.assertEqual(('a',), util.fn_args(wrapped_fn))
self.assertEqual(3, wrapped_fn(3))
self.assertEqual(3, wrapped_fn(a=3))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:14,代码来源:util_test.py
示例20: test_double_partial_with_positional_args_in_outer_layer
def test_double_partial_with_positional_args_in_outer_layer(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(test_arg1, a, test_arg2):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, 123)
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:16,代码来源:util_test.py
注:本文中的tensorflow.python.estimator.util.fn_args函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论