• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python tf_decorator.unwrap函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.util.tf_decorator.unwrap函数的典型用法代码示例。如果您正苦于以下问题:Python unwrap函数的具体用法?Python unwrap怎么用?Python unwrap使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了unwrap函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: __call__

  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names
        and kwarg `allow_multiple_exports` not set.
    """
    api_names_attr = API_ATTRS[self._api_name].names

    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      delattr(undecorated_f, api_names_attr)

    _, undecorated_func = tf_decorator.unwrap(func)

    # Check for an existing api. We check if attribute name is in
    # __dict__ instead of using hasattr to verify that subclasses have
    # their own _tf_api_names as opposed to just inheriting it.
    if api_names_attr in undecorated_func.__dict__:
      raise SymbolAlreadyExposedError(
          'Symbol %s is already exposed as %s.' %
          (undecorated_func.__name__, getattr(
              undecorated_func, api_names_attr)))  # pylint: disable=protected-access
    setattr(undecorated_func, api_names_attr, self._names)
    return func
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:32,代码来源:tf_export.py


示例2: __call__

  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names.
    """
    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      del undecorated_f._tf_api_names  # pylint: disable=protected-access

    _, undecorated_func = tf_decorator.unwrap(func)

    # Check for an existing api. We check if attribute name is in
    # __dict__ instead of using hasattr to verify that subclasses have
    # their own _tf_api_names as opposed to just inheriting it.
    if '_tf_api_names' in undecorated_func.__dict__:
      # pylint: disable=protected-access
      raise SymbolAlreadyExposedError(
          'Symbol %s is already exposed as %s.' %
          (undecorated_func.__name__, undecorated_func._tf_api_names))
      # pylint: enable=protected-access

    # Complete the export by creating/overriding attribute
    # pylint: disable=protected-access
    undecorated_func._tf_api_names = self._names
    # pylint: enable=protected-access
    return func
开发者ID:keveman,项目名称:tensorflow,代码行数:34,代码来源:tf_export.py


示例3: __call__

  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names
        and kwarg `allow_multiple_exports` not set.
    """
    api_names_attr = API_ATTRS[self._api_name].names
    api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      delattr(undecorated_f, api_names_attr)
      delattr(undecorated_f, api_names_attr_v1)

    _, undecorated_func = tf_decorator.unwrap(func)
    self.set_attr(undecorated_func, api_names_attr, self._names)
    self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
    return func
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:25,代码来源:tf_export.py


示例4: _op_is_in_tf_version

def _op_is_in_tf_version(op, version):
  if version == 1:
    return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
            op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS)
  elif version == 2:
    return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
  else:
    raise ValueError('Expected version 1 or 2.')
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:ragged_dispatch.py


示例5: fn_args

def fn_args(fn):
  """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
  _, fn = tf_decorator.unwrap(fn)

  # Handle callables.
  if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
    return tuple(tf_inspect.getargspec(fn.__call__).args)

  # Handle functools.partial and similar objects.
  if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
    # Handle nested partial.
    original_args = fn_args(fn.func)
    if not original_args:
      return tuple()

    return tuple([
        arg for arg in original_args[len(fn.args):]
        if arg not in set((fn.keywords or {}).keys())
    ])

  # Handle function.
  return tuple(tf_inspect.getargspec(fn).args)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:32,代码来源:util.py


示例6: testUnwrapBoundMethods

 def testUnwrapBoundMethods(self):
   test_decorated_class = TestDecoratedClass()
   self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3))
   decorators, target = tf_decorator.unwrap(test_decorated_class.return_params)
   self.assertEqual('test_decorator_increment_first_int_arg',
                    decorators[0].decorator_name)
   self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:7,代码来源:tf_decorator_test.py


示例7: testUnwrapReturnsDecoratorListFromOutermostToInnermost

 def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self):
   decorators, _ = tf_decorator.unwrap(test_decorated_function)
   self.assertEqual('decorator 1', decorators[0].decorator_name)
   self.assertEqual('test_decorator_increment_first_int_arg',
                    decorators[1].decorator_name)
   self.assertEqual('decorator 3', decorators[2].decorator_name)
   self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:7,代码来源:tf_decorator_test.py


示例8: visit

 def visit(unused_path, unused_parent, children):
   """Visitor that collects TF 2.0 names."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v2 = tf_export.get_v2_names(attr)
     for name in api_names_v2:
       v2_names.add(name)
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:generate_v2_renames_map.py


示例9: testReorderFileNeedsUpdate

  def testReorderFileNeedsUpdate(self):
    reordered_function_names = (
        tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names)
    function_reorders = (
        tf_upgrade_v2.TFAPIChangeSpec().function_reorders)

    added_names_message = """Some function names in
self.reordered_function_names are not in reorders_v2.py.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
    removed_names_message = """%s in self.reorders_v2 does not match
any name in self.reordered_function_names.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
    self.assertTrue(
        reordered_function_names.issubset(function_reorders),
        added_names_message)
    # function_reorders should contain reordered_function_names
    # and their TensorFlow V1 aliases.
    for name in function_reorders:
      # get other names for this function
      attr = get_symbol_for_name(tf.compat.v1, name)
      _, attr = tf_decorator.unwrap(attr)
      v1_names = tf_export.get_v1_names(attr)
      self.assertTrue(v1_names)
      v1_names = ["tf.%s" % n for n in v1_names]
      # check if any other name is in
      self.assertTrue(
          any(n in reordered_function_names for n in v1_names),
          removed_names_message % name)
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:34,代码来源:tf_upgrade_v2_test.py


示例10: getfullargspec

def getfullargspec(obj):  # pylint: disable=redefined-builtin
  """TFDecorator-aware replacement for `inspect.getfullargspec`/`getargspec`.

  This wrapper uses `inspect.getfullargspec` if available and falls back to
  `inspect.getargspec` in Python 2.

  Args:
    obj: A callable, possibly decorated.

  Returns:
    The `FullArgSpec` that describes the signature of
    the outermost decorator that changes the callable's signature. If the
    callable is not decorated, `inspect.getfullargspec()` will be called
    directly on the callable.
  """
  if six.PY2:
    def spec_fn(target):
      argspecs = _inspect.getargspec(target)
      fullargspecs = FullArgSpec(
          args=argspecs.args,
          varargs=argspecs.varargs,
          varkw=argspecs.keywords,
          defaults=argspecs.defaults,
          kwonlyargs=[],
          kwonlydefaults=None,
          annotations={})
      return fullargspecs
  else:
    spec_fn = _inspect.getfullargspec

  decorators, target = tf_decorator.unwrap(obj)
  return next((d.decorator_argspec for d in decorators
               if d.decorator_argspec is not None), spec_fn(target))
开发者ID:moses-sun,项目名称:tensorflow,代码行数:33,代码来源:tf_inspect.py


示例11: get_canonical_name_for_symbol

def get_canonical_name_for_symbol(symbol, api_name=TENSORFLOW_API_NAME):
  """Get canonical name for the API symbol.

  Canonical name is the first non-deprecated endpoint name.

  Args:
    symbol: API function or class.
    api_name: API name (tensorflow or estimator).

  Returns:
    Canonical name for the API symbol (for e.g. initializers.zeros) if
    canonical name could be determined. Otherwise, returns None.
  """
  if not hasattr(symbol, '__dict__'):
    return None
  api_names_attr = API_ATTRS[api_name].names
  _, undecorated_symbol = tf_decorator.unwrap(symbol)
  if api_names_attr not in undecorated_symbol.__dict__:
    return None
  api_names = getattr(undecorated_symbol, api_names_attr)
  # TODO(annarev): may be add a separate deprecated attribute
  # for estimator names.
  deprecated_api_names = undecorated_symbol.__dict__.get(
      '_tf_deprecated_api_names', [])
  return get_canonical_name(api_names, deprecated_api_names)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:25,代码来源:tf_export.py


示例12: get_api_init_text

def get_api_init_text(packages,
                      output_package,
                      api_name,
                      api_version,
                      compat_api_versions=None):
  """Get a map from destination module to __init__.py code for that module.

  Args:
    packages: Base python packages containing python with target tf_export
      decorators.
    output_package: Base output python package where generated API will be
      added.
    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
    api_version: API version you want to generate (1 or 2).
    compat_api_versions: Additional API versions to generate under compat/
      directory.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
  if compat_api_versions is None:
    compat_api_versions = []
  module_code_builder = _ModuleInitCodeBuilder(output_package)
  # Traverse over everything imported above. Specifically,
  # we want to traverse over TensorFlow Python modules.

  def in_packages(m):
    return any(package in m for package in packages)

  for module in list(sys.modules.values()):
    # Only look at tensorflow modules.
    if (not module or not hasattr(module, '__name__') or
        module.__name__ is None or not in_packages(module.__name__)):
      continue
    # Do not generate __init__.py files for contrib modules for now.
    if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib'))
        and '.lite' not in module.__name__):
      continue

    for module_contents_name in dir(module):
      if (module.__name__ + '.' + module_contents_name
          in _SYMBOLS_TO_SKIP_EXPLICITLY):
        continue
      attr = getattr(module, module_contents_name)
      _, attr = tf_decorator.unwrap(attr)

      add_imports_for_symbol(
          module_code_builder, attr, module.__name__, module_contents_name,
          api_name, api_version)
      for compat_api_version in compat_api_versions:
        add_imports_for_symbol(
            module_code_builder, attr, module.__name__, module_contents_name,
            api_name, compat_api_version,
            _COMPAT_MODULE_TEMPLATE % compat_api_version)

  return module_code_builder.build()
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:59,代码来源:create_python_api.py


示例13: visit

 def visit(unused_path, unused_parent, children):
   """Visitor that collects TF 2.0 names."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     if not hasattr(attr, '__dict__'):
       continue
     api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
     for name in api_names_v2:
       v2_names.add(name)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:9,代码来源:generate_v2_renames_map.py


示例14: testUnwrapReturnsListOfUniqueTFDecorators

 def testUnwrapReturnsListOfUniqueTFDecorators(self):
   decorators, _ = tf_decorator.unwrap(test_decorated_function)
   self.assertEqual(3, len(decorators))
   self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
   self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator))
   self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator))
   self.assertIsNot(decorators[0], decorators[1])
   self.assertIsNot(decorators[1], decorators[2])
   self.assertIsNot(decorators[2], decorators[0])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:9,代码来源:tf_decorator_test.py


示例15: testRewrapMutatesAffectedFunction

  def testRewrapMutatesAffectedFunction(self):

    def new_target(x):
      return x * 3

    self.assertEqual((1 * 2 + 1) ** 2, test_rewrappable_decorated(1))
    prev_target, _ = tf_decorator.unwrap(test_rewrappable_decorated)
    tf_decorator.rewrap(test_rewrappable_decorated, prev_target, new_target)
    self.assertEqual((1 * 3 + 1) ** 2, test_rewrappable_decorated(1))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:9,代码来源:tf_decorator_test.py


示例16: serialize_keras_object

def serialize_keras_object(instance):
  _, instance = tf_decorator.unwrap(instance)
  if instance is None:
    return None
  if hasattr(instance, 'get_config'):
    return serialize_keras_class_and_config(instance.__class__.__name__,
                                            instance.get_config())
  if hasattr(instance, '__name__'):
    return instance.__name__
  else:
    raise ValueError('Cannot serialize', instance)
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:11,代码来源:generic_utils.py


示例17: conversion_visitor

    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        if not tf_inspect.isfunction(attr):
          continue
        names_v1 = tf_export.get_v1_names(attr)
        arg_names_v1 = get_args(attr)

        for name in names_v1:
          tf_name = "tf.%s" % name
          if tf_name in function_warnings or tf_name in function_transformers:
            continue  # These require manual change
          if tf_name in v1_name_exceptions:
            continue
          # Assert that arg names after converting to v2 are present in
          # v2 function.
          # 1. First, create an input of the form:
          #    tf.foo(arg1=val1, arg2=val2, ...)
          args = ",".join(
              ["%s=%d" % (from_name, from_index)
               for from_index, from_name in enumerate(arg_names_v1)])
          text_input = "%s(%s)" % (tf_name, args)
          # 2. Convert the input to V2.
          _, _, _, text = self._upgrade(text_input)
          new_function_name, new_args = get_func_and_args_from_str(text)
          if new_function_name == "tf.compat.v1.%s" % name:
            if tf_name in keyword_renames:
              # If we rename arguments, new function must be available in 2.0.
              # We should not be using compat.v1 in this case.
              self.assertFalse(
                  "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" %
                  (new_function_name, text_input, text))
            continue
          # 3. Verify V2 function and arguments.
          args_v2 = get_args(self.v2_symbols[new_function_name])
          args_v2.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v2,
                "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v2)))
          # 4. Verify that the argument exists in v1 as well.
          if new_function_name in set(["tf.nn.ctc_loss",
                                       "tf.saved_model.save"]):
            continue
          args_v1 = get_args(self.v1_symbols[new_function_name])
          args_v1.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v1,
                "Invalid argument '%s' in 1.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v1)))
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:54,代码来源:tf_upgrade_v2_test.py


示例18: _get_func_name

def _get_func_name(func):
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func):
      return func.__name__
    elif tf_inspect.ismethod(func):
      return "%s.%s" % (func.__self__.__name__, func.__name__)
    else:  # Probably a class instance with __call__
      return type(func)
  else:
    raise ValueError("Argument must be callable")
开发者ID:cameronphchen,项目名称:tensorflow,代码行数:11,代码来源:function.py


示例19: visit

 def visit(unused_path, unused_parent, children):
   """Visitor that collects rename strings to add to rename_line_set."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     if not hasattr(attr, '__dict__'):
       continue
     api_names = attr.__dict__.get(tensorflow_api_attr, [])
     deprecated_api_names = attr.__dict__.get('_tf_deprecated_api_names', [])
     canonical_name = tf_export.get_canonical_name(
         api_names, deprecated_api_names)
     for name in deprecated_api_names:
       rename_line_set.add('    \'tf.%s\': \'tf.%s\'' % (name, canonical_name))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:12,代码来源:generate_v2_renames_map.py


示例20: conversion_visitor

 def conversion_visitor(unused_path, unused_parent, children):
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names = tf_export.get_v1_names(attr)
     for name in api_names:
       _, _, _, text = self._upgrade("tf." + name)
       if (text and
           not text.startswith("tf.compat.v1") and
           text not in self.v2_symbols):
         self.assertFalse(
             True, "Symbol %s generated from %s not in v2 API" % (
                 text, name))
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:12,代码来源:tf_upgrade_v2_test.py



注:本文中的tensorflow.python.util.tf_decorator.unwrap函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python tf_export.tf_export函数代码示例发布时间:2022-05-27
下一篇:
Python tf_decorator.make_decorator函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap