• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python def_function.function函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.eager.def_function.function函数的典型用法代码示例。如果您正苦于以下问题:Python function函数的具体用法?Python function怎么用?Python function使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了function函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: test_nested_functions

  def test_nested_functions(self):
    f = def_function.function(
        lambda x: x*2.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
    g = def_function.function(
        lambda x: f(x) + 1.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

    root = tracking.AutoCheckpointable()
    root.g = g
    imported = self.cycle(root)
    imported.g(constant_op.constant([1.0]))
开发者ID:gautam1858,项目名称:tensorflow,代码行数:12,代码来源:load_test.py


示例2: test_nested_functions

  def test_nested_functions(self, cycles):
    f = def_function.function(
        lambda x: x*2.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
    g = def_function.function(
        lambda x: f(x) + 1.0,
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])

    root = tracking.AutoCheckpointable()
    root.g = g
    # TODO(vbardiovsky): Enable this test. For this to work, we must ensure that
    # concrete_function._inference_function._graph._functions contains all
    # functions that were on the graph before saving.
    imported = self.cycle(root, 1)
    imported.g(constant_op.constant([1.0]))
开发者ID:tthhee,项目名称:tensorflow,代码行数:15,代码来源:load_test.py


示例3: test_callable

  def test_callable(self):
    class M1(tracking.AutoCheckpointable):

      @def_function.function(
          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
      def __call__(self, x):
        return x

    root = tracking.AutoCheckpointable()
    root.m1 = M1()
    root.m2 = tracking.AutoCheckpointable()
    root.m2.__call__ = def_function.function(
        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
            lambda x: x*3.0)
    imported = self.cycle(root)
    x = constant_op.constant(1.0)

    self.assertTrue(callable(imported.m1))
    self.assertAllEqual(root.m1(x), imported.m1(x))

    # Note: `root.m2` was not callable since `__call__` attribute was set
    # into the instance and not on the class. But after a serialization cycle
    # that starts to work.
    self.assertTrue(callable(imported.m2))
    self.assertAllEqual(root.m2.__call__(x), imported.m2(x))

    # Verify that user objects without `__call__` attribute are not callable.
    self.assertFalse(callable(imported))
开发者ID:gautam1858,项目名称:tensorflow,代码行数:28,代码来源:load_test.py


示例4: test_functools_partial_keywords

  def test_functools_partial_keywords(self):
    def f(x, y):
      return x + y

    func = def_function.function(
        functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
    self.assertAllEqual(func(), [0.0])
开发者ID:perfmjs,项目名称:tensorflow,代码行数:7,代码来源:def_function_test.py


示例5: test_functools_partial_new_default

  def test_functools_partial_new_default(self):
    def f(x=3, y=7):
      return x + y

    func = def_function.function(functools.partial(f, y=6))
    self.assertEqual(func().numpy(), 9)
    self.assertEqual(func(y=8).numpy(), 11)
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:def_function_test.py


示例6: testConstSavedModel

  def testConstSavedModel(self):
    """Test a basic model with functions to make sure functions are inlined."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.f = def_function.function(lambda x: 2. * x)
    to_save = root.f.get_concrete_function(input_data)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(variable_graph_def))
    self.assertTrue(variable_graph_def.library.function)

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(constant_graph_def.library.function)

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:convert_to_constants_test.py


示例7: testConstructConcreteFunction

  def testConstructConcreteFunction(self):
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    func = root.f.get_concrete_function(input_data)

    input_func = convert_to_constants._construct_concrete_function(
        func, func.graph.as_graph_def())

    # Test if model has enough metadata to be frozen afterwards.
    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(2, self._getNumVariables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:26,代码来源:convert_to_constants_test.py


示例8: testVariableSavedModel

  def testVariableSavedModel(self):
    """Test a basic model with Variables with saving/loading the SavedModel."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    to_save = root.f.get_concrete_function(input_data)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertTrue(self._hasStatefulPartitionedCallOp(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:28,代码来源:convert_to_constants_test.py


示例9: add_metric_step

def add_metric_step(defun):
  optimizer = keras.optimizer_v2.rmsprop.RMSprop()
  model = testing_utils.get_model_from_layers([
      LayerWithMetrics(),
      keras.layers.Dense(1, kernel_initializer='zeros', activation='softmax')
  ],
                                              input_shape=(10,))

  def train_step(x, y):
    with backprop.GradientTape() as tape:
      y_pred_1 = model(x)
      y_pred_2 = model(2 * x)
      y_pred = y_pred_1 + y_pred_2
      loss = keras.losses.mean_squared_error(y, y_pred)
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    assert len(model.metrics) == 2
    return [m.result() for m in model.metrics]

  if defun:
    train_step = def_function.function(train_step)

  x, y = array_ops.ones((10, 10)), array_ops.zeros((10, 1))
  metrics = train_step(x, y)
  assert np.allclose(metrics[0], 1.5)
  assert np.allclose(metrics[1], 1.5)
  return metrics
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:custom_training_loop_test.py


示例10: testDecorate

  def testDecorate(self):
    func = def_function.function(lambda: 1)
    def decorator(f):
      return lambda: 1 + f()

    func._decorate(decorator)
    self.assertEqual(func().numpy(), 2)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:7,代码来源:def_function_test.py


示例11: test_functools_partial_single_positional

  def test_functools_partial_single_positional(self):
    def f(x, y):
      return x + y

    func = def_function.function(
        functools.partial(f, constant_op.constant(1)))
    self.assertAllEqual(func(5), 6)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:7,代码来源:def_function_test.py


示例12: test_structured_output

  def test_structured_output(self):

    # Use fields with non-alphabetical order
    named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])

    def func(input1, input2):
      named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
      return [named_tuple, input2, {"x": 0.5}]

    root = tracking.AutoCheckpointable()
    root.f = def_function.function(func)

    result = root.f(constant_op.constant(2), constant_op.constant(3))

    self.assertEqual(5, result[0].a.numpy())
    self.assertEqual(6, result[0].b.numpy())
    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
    self.assertEqual(3, result[1].numpy())
    self.assertEqual(0.5, result[2]["x"].numpy())

    imported = self.cycle(root)

    result = imported.f(constant_op.constant(2), constant_op.constant(5))
    self.assertEqual(7, result[0].a.numpy())
    self.assertEqual(10, result[0].b.numpy())
    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
    self.assertEqual(5, result[1].numpy())
    self.assertEqual(0.5, result[2]["x"].numpy())
开发者ID:gautam1858,项目名称:tensorflow,代码行数:28,代码来源:load_test.py


示例13: test_table

 def test_table(self):
   initializer = lookup_ops.TextFileInitializer(
       self._vocab_path,
       key_dtype=dtypes.string,
       key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
       value_dtype=dtypes.int64,
       value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
   root = util.Checkpoint(table=lookup_ops.HashTable(
       initializer, default_value=-1))
   root.table_user = def_function.function(
       root.table.lookup,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
   self.assertEqual(
       2,
       self.evaluate(root.table_user(constant_op.constant("gamma"))))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(root, save_dir)
   file_io.delete_file(self._vocab_path)
   self.assertAllClose(
       {"output_0": [2, 0]},
       _import_and_infer(save_dir, {"keys": ["gamma", "alpha"]}))
   second_dir = os.path.join(self.get_temp_dir(), "second_dir")
   # Asset paths should track the location the SavedModel is loaded from.
   file_io.rename(save_dir, second_dir)
   self.assertAllClose(
       {"output_0": [2, 1]},
       _import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:27,代码来源:save_test.py


示例14: _apply_all_reduce

def _apply_all_reduce(reduction, tensors):
  """Helper function for all_* functions."""
  if not tensors:
    raise ValueError('Must pass >0 tensors to all reduce operations')

  shared_name = _get_shared_name()

  def _all_reduce():
    """Call nccl allreduce."""
    res = []
    for t in tensors:
      _check_device(t)
      with ops.device(t.device):
        res.append(
            gen_nccl_ops.nccl_all_reduce(
                input=t,
                reduction=reduction,
                num_devices=len(tensors),
                shared_name=shared_name))
    return res

  if context.executing_eagerly():
    # Nccl ops will block unless they are executed concurrently such as in a
    # graph or a defun.
    return def_function.function(_all_reduce)()
  else:
    return _all_reduce()
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:27,代码来源:nccl_ops.py


示例15: test_structured_inputs

  def test_structured_inputs(self):

    def func(x, training=True):
      # x is a nested structure, we care about one particular tensor.
      _, (a, b) = x
      if training:
        return 2 * a["a"] + b
      else:
        return 7

    root = tracking.AutoCheckpointable()
    root.f = def_function.function(func)

    x = constant_op.constant(10)
    y = constant_op.constant(11)

    input1 = [6, ({"a": x}, y)]
    input2 = [7, ({"a": x}, y)]  # Not compatible with input1 signature.
    input3 = [6, ({"a": y}, x)]  # Compatible with input1 signature.

    # Note: by only calling f(input1) before serialization, only inputs with
    # matching signature will be valid on the loaded model.
    self.assertEqual(31, root.f(input1).numpy())

    imported = self.cycle(root)

    with self.assertRaisesRegexp(AssertionError,
                                 "Could not find matching function to call.*"):
      imported.f(input2)

    self.assertEqual(31, imported.f(input1).numpy())
    self.assertEqual(32, imported.f(input3).numpy())
开发者ID:gautam1858,项目名称:tensorflow,代码行数:32,代码来源:load_test.py


示例16: testRequestNotToCompile

  def testRequestNotToCompile(self):
    with self.test_scope():
      def f(x):
        with ops.device('device:CPU:0'):
          y = 2.0 * x
        return x, y

      wholly_compiled_f = def_function.function(f)
      op_by_op_f = function.defun_with_attributes(
          f, attributes={'_XlaCompile': False})

      x = constant_op.constant([0.0, 2.0], name='data')

      # When function is wholly compiled, all outputs will be on the
      # device on which it is run.
      r_x, r_y = wholly_compiled_f(x)
      self.assertAllEqual([0.0, 2.0], r_x)
      self.assertAllEqual([0.0, 4.0], r_y)
      if context.executing_eagerly():
        # backing_device is only available for eager tensors.
        self.assertRegexpMatches(r_x.backing_device, self.device)
        self.assertRegexpMatches(r_y.backing_device, self.device)

      # When function is executed op-by-op, requested devices will be
      # respected.
      r_x, r_y = op_by_op_f(x)
      self.assertAllEqual([0.0, 2.0], r_x)
      self.assertAllEqual([0.0, 4.0], r_y)
      if context.executing_eagerly():
        # backing_device is only available for eager tensors.
        self.assertRegexpMatches(r_x.backing_device, self.device)
        self.assertRegexpMatches(r_y.backing_device, 'device:CPU:0')
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:32,代码来源:eager_test.py


示例17: test_single_function_default_signature

 def test_single_function_default_signature(self):
   model = tracking.AutoCheckpointable()
   model.f = def_function.function(lambda: 3., input_signature=())
   model.f()
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(model, save_dir)
   self.assertAllClose({"output_0": 3.},
                       _import_and_infer(save_dir, {}))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:save_test.py


示例18: test_non_concrete_error

 def test_non_concrete_error(self):
   root = tracking.AutoCheckpointable()
   root.f = def_function.function(lambda x: 2. * x)
   root.f(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "must be converted to concrete functions"):
     save.save(root, save_dir, root.f)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:save_test.py


示例19: test_non_concrete_error

 def test_non_concrete_error(self):
   root = tracking.AutoTrackable()
   root.f = def_function.function(lambda x: 2. * x)
   root.f(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "Expected a TensorFlow function"):
     save.save(root, save_dir, root.f)
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:save_test.py


示例20: testTypeInvalid

  def testTypeInvalid(self):
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)

    with self.assertRaises(ValueError) as error:
      _ = lite.TFLiteConverterV2.from_concrete_function(root.f)
    self.assertIn('call from_concrete_function', str(error.exception))
开发者ID:kylin9872,项目名称:tensorflow,代码行数:9,代码来源:lite_v2_test.py



注:本文中的tensorflow.python.eager.def_function.function函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python execute.execute函数代码示例发布时间:2022-05-27
下一篇:
Python core._status_to_exception函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap