• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python variable_scope.variable_creator_scope函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.variable_scope.variable_creator_scope函数的典型用法代码示例。如果您正苦于以下问题:Python variable_creator_scope函数的具体用法?Python variable_creator_scope怎么用?Python variable_creator_scope使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了variable_creator_scope函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testCreatorStacksAreThreadLocal

  def testCreatorStacksAreThreadLocal(self):
    devices = ["/device:CPU:0", "/device:GPU:0"]
    dist = mirrored_strategy.MirroredStrategy(devices)

    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v

    def main_thread_creator(next_creator, *args, **kwargs):
      # We are not using the underlying next_creator for test purposes.
      del next_creator, args, kwargs
      return "main_thread"

    with context.graph_mode(), \
        dist.scope(), \
        variable_scope.variable_creator_scope(main_thread_creator):
      result = dist.call_for_each_tower(model_fn, dist.worker_device_index)
      result = dist.unwrap(result)
      expected = ["main_thread:thread_0", "main_thread:thread_1"]
      self.assertEquals(expected, result)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:29,代码来源:mirrored_strategy_test.py


示例2: scope

  def scope(self):
    """Returns a context manager selecting this DistributionStrategy as current.

    Inside a `with distribution_strategy.scope():` code block, this thread
    will use a variable creator set by `distribution_strategy`, and will
    enter its "cross-tower context".

    Returns:
      A context manager.
    """
    if has_distribution_strategy():
      _require_cross_tower_context(self)
      return _SameScopeAgainContext(self)

    def creator_with_resource_vars(*args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["use_resource"] = True
      return self._create_variable(*args, **kwargs)

    def disable_partitioned_variables(getter, *args, **kwargs):
      if kwargs.pop("partitioner", None) is not None:
        tf_logging.log_first_n(
            tf_logging.WARN, "Partitioned variables are disabled when using "
            "DistributionStrategy.", 1)
      return getter(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator_with_resource_vars),
        variable_scope.variable_scope(
            variable_scope.get_variable_scope(),
            custom_getter=disable_partitioned_variables),
        self._default_device)
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:32,代码来源:distribute.py


示例3: testOptimizerInsideModelFn

  def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
    created_variables = []
    trainable_variables = []

    def appending_creator(next_creator, *args, **kwargs):
      v = next_creator(*args, **kwargs)
      created_variables.append(v.name)
      if "trainable" in kwargs and kwargs["trainable"]:
        trainable_variables.append(v.name)
      return v

    # Creator scope needs to be set before it's used inside
    # `distribution.scope`.
    with variable_scope.variable_creator_scope(
        appending_creator), distribution.scope():
      model_fn, dataset_fn, _ = minimize_loss_example(
          optimizer_fn,
          use_bias=True,
          use_callable_loss=True,
          create_optimizer_inside_model_fn=True)

      def step_fn(ctx, inputs):
        del ctx  # Unused
        return distribution.group(
            distribution.extended.call_for_each_replica(
                model_fn, args=(inputs,)))

      iterator = self._get_iterator(distribution, dataset_fn)

      def run_step():
        return distribution.extended.experimental_run_steps_on_iterator(
            step_fn, iterator, iterations=1).run_op

      if not context.executing_eagerly():
        with self.cached_session() as sess:
          run_step = sess.make_callable(run_step())
      self.evaluate(variables_lib.global_variables_initializer())
      run_step()

      def get_expected_variables(optimizer_fn, num_parameter_devices):
        optimizer = optimizer_fn()
        name = optimizer._name

        if isinstance(optimizer, optimizer_v2.OptimizerV2):
          variables = VAR_MAP_V2[name]
        else:
          variables = VAR_MAP_V1[name]

        extended_variables = [
            v + "/replica_{}".format(replica)
            for v in variables
            for replica in range(1, num_parameter_devices)
        ]
        variables = list(variables) + extended_variables
        return set([v + ":0" for v in variables])

      self.assertEqual(
          get_expected_variables(optimizer_fn,
                                 len(distribution.extended.parameter_devices)),
          set(created_variables))
开发者ID:aritratony,项目名称:tensorflow,代码行数:60,代码来源:minimize_loss_test.py


示例4: notify_about_variables

def notify_about_variables(callback):
  """Calls `callback(var)` for all `tf.{Variable,get_variable}` results.

  Callback should not modify the variable passed in. Use cases that require
  variables to be modified should use `variable_creator_scope` directly and sit
  within the variable creator stack.

  >>> variables = []
  >>> with notify_about_variables(variables.append):
  ...   v = tf.Variable(1.0, name='v')
  ...   w = tf.get_variable('w', [])
  >>> assert variables == [v, w]

  Args:
    callback: a callable taking a single argument which is a tf.Variable.

  Yields:
    `None` - used for contextmanager API.
  """
  def _tracking_creator(getter, **kwargs):
    v = getter(**kwargs)
    callback(v)
    return v

  with variable_scope_ops.variable_creator_scope(_tracking_creator):
    yield
开发者ID:ccchang0111,项目名称:sonnet,代码行数:26,代码来源:util.py


示例5: call

 def call(self, inputs, mask=None, training=None):
   arguments = self.arguments
   if self._fn_expects_mask_arg:
     arguments['mask'] = mask
   if self._fn_expects_training_arg:
     arguments['training'] = training
   with variable_scope.variable_creator_scope(self._variable_creator):
     return self.function(inputs, **arguments)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:8,代码来源:core.py


示例6: tower_local_var_scope

  def tower_local_var_scope(self, reduce_method):
    """Does not set to resource variables."""
    def create_tower_local_variable(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["trainable"] = False
      return next_creator(*args, **kwargs)

    _require_distribution_strategy_scope(self)
    return variable_scope.variable_creator_scope(create_tower_local_variable)
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:9,代码来源:distribute.py


示例7: testOptimizerInsideModelFn

  def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
    created_variables = []
    trainable_variables = []

    def appending_creator(next_creator, *args, **kwargs):
      v = next_creator(*args, **kwargs)
      created_variables.append(v.name)
      if "trainable" in kwargs and kwargs["trainable"]:
        trainable_variables.append(v.name)
      return v

    # Creator scope needs to be set before it's used inside
    # `distribution.scope`.
    with variable_scope.variable_creator_scope(
        appending_creator), distribution.scope():
      model_fn, dataset, layer = minimize_loss_example(
          optimizer_fn,
          use_bias=True,
          use_callable_loss=True,
          create_optimizer_inside_model_fn=True)

      iterator = distribution.distribute_dataset(dataset)

      def run_step():
        return distribution.group(
            distribution.call_for_each_tower(
                model_fn, iterator.get_next(), run_concurrently=layer.built))

      if not context.executing_eagerly():
        with self.test_session() as sess:
          run_step = sess.make_callable(run_step())
        self.evaluate(variables_lib.global_variables_initializer())

      run_step()

      def get_expected_variables(optimizer_fn, num_parameter_devices):
        variables_map = {
            "GradientDescent": ["dense/kernel", "dense/bias"],
            "Adam": [
                "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
                "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
                "dense/bias/Adam_1"
            ]
        }
        variables = variables_map[optimizer_fn().get_name()]
        variables.extend([
            v + "/replica_{}".format(replica)
            for v in variables
            for replica in range(1, num_parameter_devices)
        ])
        return set([v + ":0" for v in variables])

      self.assertEqual(
          get_expected_variables(optimizer_fn,
                                 len(distribution.parameter_devices)),
          set(created_variables))
开发者ID:bikong2,项目名称:tensorflow,代码行数:56,代码来源:minimize_loss_test.py


示例8: save

  def save(self, session=None, checkpoint_number=None):
    """Creates a new checkpoint and manages it.

    Args:
      session: The session to evaluate variables in. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.
      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
        `Tensor`, used to number the checkpoint. If `None` (default),
        checkpoints are numbered using `checkpoint.save_counter`. Even if
        `checkpoint_number` is provided, `save_counter` is still incremented. A
        user-provided `checkpoint_number` is not incremented even if it is a
        `Variable`.

    Returns:
      The path to the new checkpoint. It is also recorded in the `checkpoints`
      and `latest_checkpoint` properies.
    """
    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
    # slightly with a custom numbering option.
    if context.executing_eagerly():
      save_counter = self._checkpoint.save_counter
      save_counter.assign_add(1)
    else:
      if session is None:
        session = ops.get_default_session()

      def _initializing_creator(next_creator, **kwargs):
        """Initialize the save counter if it has been newly created."""
        v = next_creator(**kwargs)
        session.run(v.initializer)
        return v

      with variable_scope.variable_creator_scope(_initializing_creator):
        save_counter = self._checkpoint.save_counter
      if self._save_counter_assign is None:
        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
      session.run(self._save_counter_assign)
    if checkpoint_number is None:
      checkpoint_number = save_counter
    if not isinstance(checkpoint_number, compat.integral_types):
      checkpoint_number = training_util.global_step(
          sess=session, global_step_tensor=checkpoint_number)
    prefix = "%s-%d" % (self._prefix, checkpoint_number)
    save_path = self._checkpoint.write(prefix)
    timestamp = time.time()
    # If this is an overwritten checkpoint we were previously tracking, delete
    # and reinsert it to make sure it goes to the end of the queue.
    if save_path in self._maybe_delete:
      del self._maybe_delete[save_path]
    self._maybe_delete[save_path] = timestamp
    self._latest_checkpoint = save_path
    self._sweep()
    self._record_state()
    return save_path
开发者ID:AnishShah,项目名称:tensorflow,代码行数:55,代码来源:checkpoint_management.py


示例9: _call_func

  def _call_func(self, args, kwargs):
    try:
      vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
      trainable_at_start = len(
          ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
      if self._variables_created:
        result = self._func(*args, **kwargs)
      else:
        # The first time we run, restore variables if necessary (via
        # Checkpointable).
        with variable_scope.variable_creator_scope(
            self._checkpointable_custom_creator):
          result = self._func(*args, **kwargs)

      if self._variables_created:
        # Variables were previously created, implying this is not the first
        # time the template has been called. Check to make sure that no new
        # trainable variables were created this time around.
        trainable_variables = ops.get_collection(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        # If a variable that we intend to train is created as a side effect
        # of creating a template, then that is almost certainly an error.
        if trainable_at_start != len(trainable_variables):
          raise ValueError("Trainable variable created when calling a template "
                           "after the first time, perhaps you used tf.Variable "
                           "when you meant tf.get_variable: %s" %
                           (trainable_variables[trainable_at_start:],))

        # Non-trainable tracking variables are a legitimate reason why a new
        # variable would be created, but it is a relatively advanced use-case,
        # so log it.
        variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
        if vars_at_start != len(variables):
          logging.info("New variables created when calling a template after "
                       "the first time, perhaps you used tf.Variable when you "
                       "meant tf.get_variable: %s",
                       variables[vars_at_start:])
      else:
        self._variables_created = True
      return result
    except Exception as exc:
      # Reraise the exception, but append the original definition to the
      # trace.
      args = exc.args
      if not args:
        arg0 = ""
      else:
        arg0 = args[0]
      trace = "".join(_skip_common_stack_elements(self._stacktrace,
                                                  traceback.format_stack()))
      arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
      new_args = [arg0]
      new_args.extend(args[1:])
      exc.args = tuple(new_args)
      raise
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:55,代码来源:template.py


示例10: scope

  def scope(self):
    """Context manager setting a variable creator and `self` as current."""
    if distribution_strategy_context.has_distribution_strategy():
      raise RuntimeError("Must not nest DistributionStrategy scopes.")

    def creator(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      return next_creator(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:11,代码来源:distribute.py


示例11: testSharedVariable

  def testSharedVariable(self):

    shared_variable_store = {}
    num_devices = 3
    creator_fns = []
    for i in range(num_devices):
      creator_fn = shared_variable_creator.make_fn(shared_variable_store, i)
      creator_fns.append(creator_fn)

    with variable_scope.variable_creator_scope(creator_fns[0]):
      v0 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[1]):
      v1 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[2]):
      v2 = variable_scope.variable(1.0, name="foo")

    # v1 and v2 should be same as v0
    self.assertIs(v1, v0)
    self.assertIs(v2, v0)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:21,代码来源:shared_variable_creator_test.py


示例12: model_fn

    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:12,代码来源:mirrored_strategy_test.py


示例13: scope

  def scope(self):
    """Context manager setting a variable creator and `self` as current."""
    if has_distribution_strategy():
      raise RuntimeError("Must not nest DistributionStrategy scopes.")

    def creator(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      if kwargs.pop("tower_local_reduce_method", None) is not None:
        kwargs["trainable"] = False
      return next_creator(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator))
开发者ID:kimr843,项目名称:tensorflow,代码行数:13,代码来源:distribute.py


示例14: one_host_numpy_dataset

def one_host_numpy_dataset(numpy_input, colocate_with, session):
  """Create a dataset on `colocate_with` from `numpy_input`."""
  def create_colocated_variable(next_creator, *args, **kwargs):
    kwargs["colocate_with"] = colocate_with
    return next_creator(*args, **kwargs)

  numpy_flat = nest.flatten(numpy_input)
  with variable_scope.variable_creator_scope(create_colocated_variable):
    vars_flat = tuple(variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
                                              trainable=False)
                      for i in numpy_flat)
  for v, i in zip(vars_flat, numpy_flat):
    init_var_from_numpy(v, i, session)
  vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
  return dataset_ops.Dataset.from_tensor_slices(vars_nested)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:15,代码来源:numpy_dataset.py


示例15: testScopeVarCreatorNestingError

  def testScopeVarCreatorNestingError(self):

    def creator(next_creator, **kwargs):
      return next_creator(**kwargs)

    _assert_in_default_state(self)
    dist = _TestStrategy()
    scope = dist.scope()
    scope.__enter__()
    self.assertIs(dist, ds_context.get_strategy())
    with variable_scope.variable_creator_scope(creator):
      with self.assertRaisesRegexp(RuntimeError,
                                   "Variable creator scope nesting error"):
        scope.__exit__(None, None, None)
    scope.__exit__(None, None, None)
    _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:distribute_lib_test.py


示例16: global_step

  def global_step(self):
    if self._global_step is None:
      # Get the default create_global_step utility to actually call
      # self.add_variable, by setting a custom creator.
      def _owned_variable_as_creator(
          next_creator, initial_value, **kwargs):
        def _creator_as_getter(initializer, **kwargs):
          return next_creator(initial_value=initializer, **kwargs)
        return self.add_variable(
            getter=_creator_as_getter, initializer=initial_value, shape=[],
            **kwargs)

      with variable_scope.variable_creator_scope(
          _owned_variable_as_creator):
        self._global_step = training_util.create_global_step()
    return self._global_step
开发者ID:japrogramer,项目名称:tensorflow,代码行数:16,代码来源:checkpointable_test.py


示例17: parameter_server_scope

def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks):
  """Strategy to use parameter servers in eager.

  Creates SharedVariable objects for variables created in this scope. These
  SharedVariable objects will be placed round-robin on the parameter servers
  specified by the ps_job_name and num_ps_tasks arguments.

  To use parameter servers you need only to wrap your model initialization in
  this scope:

  ```
  with tf.contrib.eager.parameter_server_scope(
      is_chief, ps_job_name, num_ps_tasks):
    my_model = tf.keras.Sequential([...])  # Or
    input = tf.keras.Input(...)
    ....
    my_model = tf.keras.Model(input, output)
  my_model.compile(...)
  # or other usages of the model.
  ```

  Args:
    is_chief: Boolean. Whether this worker is responsible for initializing
      variables.
    ps_job_name: The name of the ps job in this cluster.
    num_ps_tasks: The number of ps tasks to use.

  Yields:
    a context manager.
  """
  # Note: capturing in a list to allow assignment.
  ps_index = [0]

  def variable_creator_scope(unused_next_creator, **kwargs):
    kwargs["initialize"] = is_chief
    with ops.device(
        "/job:%s/task:%s" % (ps_job_name, ps_index[0] % num_ps_tasks)):
      ps_index[0] += 1
      v = SharedVariable(**kwargs)
      if not is_chief:
        while not resource_variable_ops.var_is_initialized_op(v.handle):
          time.sleep(10)
      return v

  with variable_scope.variable_creator_scope(variable_creator_scope):
    yield
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:46,代码来源:parameter_server.py


示例18: testVariableCreatorScope

  def testVariableCreatorScope(self):
    created_variables = []
    captured_variables = []

    @def_function.function
    def f():
      if not created_variables:
        created_variables.append(variables.Variable(1.))
      return created_variables[0] + 1.

    def capture_creator(next_creator, **kwargs):
      created = next_creator(**kwargs)
      captured_variables.append(created)
      return created

    with variable_scope.variable_creator_scope(capture_creator):
      f()
    self.assertEqual(created_variables, captured_variables)
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:18,代码来源:def_function_test.py


示例19: colocate_vars_with

  def colocate_vars_with(self, colocate_with_variable):
    """Scope that controls which devices variables will be created on.

    No operations should be added to the graph inside this scope, it
    should only be used when creating variables (some implementations
    work by changing variable creation, others work by using a
    tf.colocate_with() scope).

    This may only be used inside `self.scope()`.

    Example usage:

    ```
    with distribution_strategy.scope():
      var1 = tf.get_variable(...)
      with distribution_strategy.colocate_vars_with(v1):
        # var2 and var3 will be created on the same device(s) as var1
        var2 = tf.get_variable(...)
        var3 = tf.get_variable(...)

      def fn(v1, v2, v3):
        # operates on v1 from var1, v2 from var2, and v3 from var3

      # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
      distribution_strategy.update(v1, fn, v2, v3)
    ```

    Args:
      colocate_with_variable: A created in `self.scope()`. Variables created
        while in the returned context manager will be on the same set of
        devices as `colocate_with_variable`.

    Returns:
      A context manager.
    """
    def create_colocated_variable(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["use_resource"] = True
      kwargs["colocate_with"] = colocate_with_variable
      return next_creator(*args, **kwargs)

    _require_distribution_strategy_scope(self)
    return variable_scope.variable_creator_scope(create_colocated_variable)
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:43,代码来源:distribute.py


示例20: run

 def run(self):
   # pylint: disable=protected-access
   self.graph._variable_creator_stack = self._variable_creator_stack
   self.should_run.wait()
   self.should_run.clear()
   try:
     if self.coord.should_stop():
       return
     with self.coord.stop_on_exception(), \
         context.context()._mode(self.context_mode), \
         context.context().device_policy(self.context_device_policy), \
         _enter_graph(self.graph), \
         MirroredTowerContext(self.distribution, self.tower_id), \
         ops.device(self.device), \
         ops.name_scope(self._name_scope), \
         variable_scope.variable_scope(
             self._captured_var_scope, reuse=self.tower_id > 0), \
         variable_scope.variable_creator_scope(self.variable_creator_fn):
       self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
       self.done = True
   finally:
     self.has_paused.set()
开发者ID:Jordan1237,项目名称:tensorflow,代码行数:22,代码来源:mirrored_strategy.py



注:本文中的tensorflow.python.ops.variable_scope.variable_creator_scope函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python variable_scope.variable_op_scope函数代码示例发布时间:2022-05-27
下一篇:
Python variable_scope.variable函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap