• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python ops.add_to_collection函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.ops.add_to_collection函数的典型用法代码示例。如果您正苦于以下问题:Python add_to_collection函数的具体用法?Python add_to_collection怎么用?Python add_to_collection使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了add_to_collection函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _compute_weighted_loss

def _compute_weighted_loss(losses, weight):
    """Computes the weighted loss.

  Args:
    losses: A tensor of size [batch_size, d1, ... dN].
    weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.

  Returns:
    A scalar `Tensor` that returns the weighted loss.

  Raises:
    ValueError: If the weight shape is not compatible with the losses shape or
      if the number of dimensions (rank) of either losses or weight is missing.
  """
    losses = math_ops.to_float(losses)
    weight = math_ops.to_float(ops.convert_to_tensor(weight))

    if losses.get_shape().ndims is None:
        raise ValueError("losses.get_shape().ndims cannot be None")
    if weight.get_shape().ndims is None:
        raise ValueError("weight.get_shape().ndims cannot be None")

    total_loss = _scale_losses(losses, weight)
    num_present = _num_present(losses, weight)
    mean_loss = _safe_mean(total_loss, num_present)
    ops.add_to_collection(ops.GraphKeys.LOSSES, mean_loss)
    return mean_loss
开发者ID:passiweinberger,项目名称:tensorflow,代码行数:27,代码来源:loss_ops.py


示例2: fertile_stats_variable

def fertile_stats_variable(params, stats_config, name, container=None):
  r"""Creates a stats object and returns a handle to it.

  Args:
    params: A TensorForestParams object.
    stats_config: A `Tensor` of type `string`. Serialized proto of the stats.
    name: A name for the variable.
    container: An optional `string`. Defaults to `""`.

  Returns:
    A `Tensor` of type mutable `string`. The handle to the stats.
  """
  with ops.name_scope(name, "FertileStatsVariable") as name:
    fertile_stats_var = FertileStatsVariable(params, stats_config, name,
                                             container)
    resource_handle = fertile_stats_var.resource_handle
    create_op = fertile_stats_var.initializer
    is_initialized_op = fertile_stats_var.is_initialized()
    # Adds the variable to the savable list.
    saveable = (
        fertile_stats_var._gather_saveables_for_checkpoint()[  # pylint: disable=protected-access
            "fertile_stats_variable"](name=resource_handle.name))
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    resources.register_resource(resource_handle, create_op, is_initialized_op)
    return resource_handle
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:25,代码来源:stats_ops.py


示例3: _get_default_variable_store

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store
开发者ID:2php,项目名称:tensorflow,代码行数:7,代码来源:variable_scope.py


示例4: _get_or_create_global_step_read

def _get_or_create_global_step_read(graph=None):
  """Gets or creates global step read tensor in graph.

  Args:
    graph: The graph in which to create the global step read tensor. If missing,
      use default graph.

  Returns:
    Global step read tensor if there is global_step_tensor else return None.
  """
  graph = graph or ops.get_default_graph()
  global_step_read_tensor = _get_global_step_read(graph)
  if global_step_read_tensor is not None:
    return global_step_read_tensor
  global_step_tensor = get_global_step(graph)
  if global_step_tensor is None:
    return None
  # add 'zero' so that it will create a copy of variable as Tensor.
  with graph.as_default() as g, g.name_scope(None):
    # using initialized_value to ensure that global_step is initialized before
    # this run. This is needed for example Estimator makes all model_fn build
    # under global_step_read_tensor dependency.
    global_step_value = global_step_tensor.initialized_value() if isinstance(
        global_step_tensor, variables.Variable) else global_step_tensor
    global_step_read_tensor = global_step_value + 0
    ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)
  return _get_global_step_read(graph)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:27,代码来源:training_util.py


示例5: _maybe_add_main_op

  def _maybe_add_main_op(self, main_op):
    """Adds main op to the SavedModel.

    Args:
      main_op: Main op to run as part of graph initialization. If None, no
        main op will be added to the graph.

    Raises:
      TypeError: if main op is provided but is not of type `Operation`.
      ValueError: if the Graph already contains an init op.
    """
    if main_op is None:
      return

    if not isinstance(main_op, ops.Operation):
      raise TypeError("main_op needs to be an Operation: %r" % main_op)

    # Validate that no other init ops have been added to this graph already.
    # We check main_op and legacy_init_op for thoroughness and explicitness.
    for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
      if ops.get_collection(init_op_key):
        raise ValueError(
            "Graph already contains one or more main ops under the "
            "collection {}.".format(init_op_key))

    ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:26,代码来源:builder_impl.py


示例6: initialize_from

    def initialize_from(self, keys, values, name=None):
        """Initialize the table with the provided keys and values tensors.

    Construct an initializer object from keys and value tensors.

    Args:
      keys: The tensor for the keys.
      values: The tensor for the values.
      name: Optional name for the op.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
        if name is None:
            name = "%s_initialize_table" % self.name
        with ops.op_scope([keys, values], None, name):
            keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys")
            values = ops.convert_to_tensor(values, dtype=self.value_dtype, name="values")

        init_op = gen_data_flow_ops._initialize_table(self.table_ref, keys, values, name=name)
        ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
        return init_op
开发者ID:swapnilashtekar,项目名称:tensorflow,代码行数:26,代码来源:data_flow_ops.py


示例7: __init__

  def __init__(self, iterator_resource, initializer, output_types,
               output_shapes, output_classes):
    """Creates a new iterator from the given iterator resource.

    Note: Most users will not call this initializer directly, and will
    instead use `Dataset.make_initializable_iterator()` or
    `Dataset.make_one_shot_iterator()`.

    Args:
      iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
        iterator.
      initializer: A `tf.Operation` that should be run to initialize this
        iterator.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this iterator.
      output_shapes: A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this iterator.
      output_classes: A nested structure of Python `type` objects corresponding
        to each component of an element of this iterator.
    """
    self._iterator_resource = iterator_resource
    self._initializer = initializer

    if (output_types is None or output_shapes is None
        or output_classes is None):
      raise ValueError("If `structure` is not specified, all of "
                       "`output_types`, `output_shapes`, and `output_classes`"
                       " must be specified.")
    self._structure = structure_lib.convert_legacy_structure(
        output_types, output_shapes, output_classes)

    self._string_handle = gen_dataset_ops.iterator_to_string_handle(
        self._iterator_resource)
    self._get_next_call_count = 0
    ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:35,代码来源:iterator_ops.py


示例8: apply_regularization

def apply_regularization(regularizer, weights_list=None):
  """Returns the summed penalty by applying `regularizer` to the `weights_list`.

  Adding a regularization penalty over the layer weights and embedding weights
  can help prevent overfitting the training data. Regularization over layer
  biases is less common/useful, but assuming proper data preprocessing/mean
  subtraction, it usually shouldn't hurt much either.

  Args:
    regularizer: A function that takes a single `Tensor` argument and returns
      a scalar `Tensor` output.
    weights_list: List of weights `Tensors` or `Variables` to apply
      `regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if
      `None`.

  Returns:
    A scalar representing the overall regularization penalty.

  Raises:
    ValueError: If `regularizer` does not return a scalar output.
  """
  if not weights_list:
    weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS)
  with ops.op_scope(weights_list, 'get_regularization_penalty') as scope:
    penalties = [regularizer(w) for w in weights_list]
    for p in penalties:
      if p.get_shape().ndims != 0:
        raise ValueError('regularizer must return a scalar Tensor instead of a '
                         'Tensor with rank %d.' % p.get_shape().ndims)

    summed_penalty = math_ops.add_n(penalties, name=scope)
    ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty)
    return summed_penalty
开发者ID:KezhiLi,项目名称:DropNeuron,代码行数:33,代码来源:regularizers.py


示例9: testKeepNodes

  def testKeepNodes(self):
    g = ops.Graph()
    with g.as_default():
      a1 = variables.VariableV1(
          1.0)  # Must be preserved since it's in the collection 'variables'.
      a2 = constant_op.constant(0, shape=[50, 50], name='keep')
      ops.add_to_collection('a2', a2)  # Explicitly add to collection.
      with g._attr_scope(
          {'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}):
        a3 = constant_op.constant(0, name='keep2')
      b = constant_op.constant(1, shape=[100, 10])
      c = constant_op.constant(0, shape=[10, 30])
      d = math_ops.matmul(b, c)
      ops.add_to_collection('train_op', d)  # d is the fetch node.

    # Optimize the graph.
    mg = meta_graph.create_meta_graph_def(graph=g)
    config = config_pb2.ConfigProto()
    rewriter_config = config.graph_options.rewrite_options
    rewriter_config.min_graph_nodes = -1
    optimized_graph = tf_optimizer.OptimizeGraph(config, mg)

    # Check that the nodes referenced in various collections have been preserved
    optimized_graph_nodes = [node.name for node in optimized_graph.node]
    expected_nodes = [
        d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value',
        'Variable/Assign'
    ]
    self.assertEqual(len(optimized_graph_nodes), len(expected_nodes))
    self.assertAllInSet(optimized_graph_nodes, expected_nodes)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:30,代码来源:tf_optimizer_test.py


示例10: _CreateParamsSavable

def _CreateParamsSavable(params,
                         model,
                         base_variable_scope=None,
                         name="params_canonical"):
  """Create a RNNParamsSaveable for the weight and bias parameters.

  Args:
    params: a Variable for weight and bias parameters.
    model: a CudnnRNN model.
    base_variable_scope: a string, prefix of names of saved variables.
    name: a string, name of the RNNParamsSaveable object.
  Returns:
    a RNNParamsSaveable object.
  """
  if model._rnn_mode == CUDNN_LSTM:
    fn = cudnn_rnn_ops.CudnnLSTMSaveable
  elif model._rnn_mode == CUDNN_GRU:
    fn = cudnn_rnn_ops.CudnnGRUSaveable
  elif model._rnn_mode == CUDNN_RNN_TANH:
    fn = cudnn_rnn_ops.CudnnRNNTanhSaveable
  elif model._rnn_mode == CUDNN_RNN_RELU:
    fn = cudnn_rnn_ops.CudnnRNNReluSaveable
  params_saveable = fn(
      params,
      model.num_layers,
      model.num_units,
      model.input_size,
      model.input_mode,
      model.direction,
      scope=base_variable_scope,
      name=name)
  ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
  return params_saveable
开发者ID:1000sprites,项目名称:tensorflow,代码行数:33,代码来源:cudnn_rnn_ops_test.py


示例11: _train_model

  def _train_model(self, input_fn, hooks):
    all_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = training.create_global_step(g)
      with ops.device('/cpu:0'):
        features, labels = input_fn()
      estimator_spec = self._call_model_fn(features, labels,
                                           model_fn_lib.ModeKeys.TRAIN)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      all_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      all_hooks.extend(hooks)
      all_hooks.extend(estimator_spec.training_hooks)

      if not (estimator_spec.scaffold.saver or
              ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(ops.GraphKeys.SAVERS,
                              training.Saver(
                                  sharded=True,
                                  max_to_keep=self._config.keep_checkpoint_max,
                                  defer_build=True))

      chief_hooks = []
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        saver_hook_exists = any([
            isinstance(h, training.CheckpointSaverHook)
            for h in (all_hooks + chief_hooks +
                      estimator_spec.training_chief_hooks)
        ])
        if not saver_hook_exists:
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=estimator_spec.scaffold)
          ]
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=estimator_spec.scaffold,
          hooks=all_hooks,
          chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks,
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss
开发者ID:LugarkPirog,项目名称:tensorflow,代码行数:60,代码来源:estimator.py


示例12: initialize

  def initialize(self, table):
    """Initializes the given `table` with `keys` and `values` tensors.

    Args:
      table: The table to initialize.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
    _check_table_dtypes(table, self._keys.dtype, self._values.dtype)
    with ops.name_scope(
        self._name, values=(table.table_ref, self._keys,
                            self._values)) as scope:
      if context.executing_eagerly():
        # Ensure a unique name when eager execution is enabled to avoid spurious
        # sharing issues.
        scope += str(ops.uid())
      init_op = gen_lookup_ops.initialize_table_v2(
          table.table_ref, self._keys, self._values, name=scope)
    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    return init_op
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:25,代码来源:lookup_ops.py


示例13: _add_iterator_ops_to_collection

 def _add_iterator_ops_to_collection(self, init_op, get_next):
   ops.add_to_collection("iterator_ops", init_op)
   # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
   # do not support tuples we flatten the tensors and restore the shape in
   # `_get_iterator_ops_from_collection`.
   for el in nest.flatten(get_next):
     ops.add_to_collection("iterator_ops", el)
开发者ID:dyoung418,项目名称:tensorflow,代码行数:7,代码来源:dataset_serialization_test_base.py


示例14: tree_variable

def tree_variable(params, tree_config, stats_handle, name, container=None):
  r"""Creates a tree model and returns a handle to it.

  Args:
    params: A TensorForestParams object.
    tree_config: A `Tensor` of type `string`. Serialized proto of the tree.
    stats_handle: Resource handle to the stats object.
    name: A name for the variable.
    container: An optional `string`. Defaults to `""`.

  Returns:
    A `Tensor` of type mutable `string`. The handle to the tree.
  """
  with ops.name_scope(name, "TreeVariable") as name:
    resource_handle = gen_model_ops.decision_tree_resource_handle_op(
        container, shared_name=name, name=name)

    create_op = gen_model_ops.create_tree_variable(
        resource_handle,
        tree_config,
        params=params.serialized_params_proto)
    is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle)
    # Adds the variable to the savable list.
    saveable = TreeVariableSavable(params, resource_handle, stats_handle,
                                   create_op,
                                   resource_handle.name)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    resources.register_resource(resource_handle, create_op, is_initialized_op)
    return resource_handle
开发者ID:1000sprites,项目名称:tensorflow,代码行数:29,代码来源:model_ops.py


示例15: testCustomSaveable

  def testCustomSaveable(self):
    export_dir = self._get_export_dir("custom_saveable")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with session.Session(
        graph=ops.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      # CheckpointedOp is a key-value table that can be saved across sessions.
      # The table register itself in SAVEABLE_OBJECTS collection.
      v1 = saver_test_utils.CheckpointedOp(name="v1")
      variables.global_variables_initializer().run()
      v1.insert("k1", 3.0).run()
      # Once the table is restored, we can access it through this reference.
      ops.add_to_collection("table_ref", v1.table_ref)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with session.Session(
        graph=ops.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      loader.load(sess, ["foo"], export_dir)
      # Instantiate a wrapper object from the checkpointed reference.
      v1 = saver_test_utils.CheckpointedOp(
          name="v1", table_ref=ops.get_collection("table_ref")[0])
      self.assertEqual(b"k1", v1.keys().eval())
      self.assertEqual(3.0, v1.values().eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:28,代码来源:saved_model_test.py


示例16: __init__

  def __init__(self, iterator_resource, initializer, output_types,
               output_shapes, output_classes):
    """Creates a new iterator from the given iterator resource.

    Note: Most users will not call this initializer directly, and will
    instead use `Dataset.make_initializable_iterator()` or
    `Dataset.make_one_shot_iterator()`.

    Args:
      iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
        iterator.
      initializer: A `tf.Operation` that should be run to initialize this
        iterator.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this iterator.
      output_shapes: A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this iterator.
      output_classes: A nested structure of Python `type` objects corresponding
        to each component of an element of this iterator.
    """
    self._iterator_resource = iterator_resource
    self._initializer = initializer
    self._output_classes = output_classes
    self._output_types = output_types
    self._output_shapes = output_shapes
    self._string_handle = gen_dataset_ops.iterator_to_string_handle(
        self._iterator_resource)
    self._get_next_call_count = 0
    ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:29,代码来源:iterator_ops.py


示例17: summary_writer_function

def summary_writer_function(name, tensor, function, family=None):
  """Helper function to write summaries.

  Args:
    name: name of the summary
    tensor: main tensor to form the summary
    function: function taking a tag and a scope which writes the summary
    family: optional, the summary's family

  Returns:
    The result of writing the summary.
  """
  def record():
    with summary_op_util.summary_scope(
        name, family, values=[tensor]) as (tag, scope):
      with ops.control_dependencies([function(tag, scope)]):
        return constant_op.constant(True)

  if context.context().summary_writer_resource is None:
    return control_flow_ops.no_op()
  with ops.device("cpu:0"):
    op = utils.smart_cond(
        should_record_summaries(), record, _nothing, name="")
    ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
  return op
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:25,代码来源:summary_ops.py


示例18: tree_ensemble_variable

def tree_ensemble_variable(stamp_token,
                           tree_ensemble_config,
                           name,
                           container=None):
  r"""Creates a tree ensemble model and returns a handle to it.

  Args:
    stamp_token: The initial stamp token value for the ensemble resource.
    tree_ensemble_config: A `Tensor` of type `string`.
      Serialized proto of the tree ensemble.
    name: A name for the ensemble variable.
    container: An optional `string`. Defaults to `""`.

  Returns:
    A `Tensor` of type mutable `string`. The handle to the tree ensemble.
  """
  with ops.name_scope(name, "TreeEnsembleVariable") as name:
    resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op(
        container, shared_name=name, name=name)
    create_op = gen_model_ops.create_tree_ensemble_variable(
        resource_handle, stamp_token, tree_ensemble_config)
    is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op(
        resource_handle)
    # Adds the variable to the savable list.
    saveable = TreeEnsembleVariableSavable(resource_handle, create_op,
                                           resource_handle.name)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    resources.register_resource(resource_handle, create_op, is_initialized_op)
    return resource_handle
开发者ID:1000sprites,项目名称:tensorflow,代码行数:29,代码来源:model_ops.py


示例19: initialize

  def initialize(self, table):
    """Initializes the given `table` with `keys` and `values` tensors.

    Args:
      table: The table to initialize.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
    _check_table_dtypes(table, self._keys.dtype, self._values.dtype)
    with ops.name_scope(
        self._name, values=(table.resource_handle, self._keys,
                            self._values)) as scope:
      if context.executing_eagerly():
        # Ensure a unique name when eager execution is enabled to avoid spurious
        # sharing issues.
        scope += str(ops.uid())
      if fwd_compat.forward_compatible(2018, 9, 19):
        init_op = gen_lookup_ops.lookup_table_import_v2(
            table.resource_handle, self._keys, self._values, name=scope)
      else:
        # To maintain forward compatibiltiy, use the old implementation.
        init_op = gen_lookup_ops.initialize_table_v2(
            table.resource_handle, self._keys, self._values, name=scope)
    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    return init_op
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:lookup_ops.py


示例20: test_train_worker_monitor

 def test_train_worker_monitor(self):
   # We need to explicitly set device due to check on non-chief workers
   # requiring all variables to have a device assigned.
   with tf.Graph().as_default() as g, g.device('/cpu:0'):
     global_step = tf.contrib.framework.create_global_step(g)
     train_op = tf.assign_add(global_step, 1)
     loss_op = tf.constant(2.0)
     tf.scalar_summary('loss', loss_op)
     # Add explicit "local" init op to initialize all variables
     # as there's no chief to init here.
     init_op = variables.initialize_all_variables()
     ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op)
     # Create worker monitors where one should be active on the worker
     # and the other chief exclusive.
     chief_exclusive_monitor = _BaseMonitorWrapper(False)
     all_workers_monitor = _BaseMonitorWrapper(True)
     with self.test_session(g):
       loss = learn.graph_actions.train(
           g, output_dir=self._output_dir,
           global_step_tensor=global_step,
           train_op=train_op, loss_op=loss_op,
           supervisor_is_chief=False, steps=1,
           monitors=[chief_exclusive_monitor, all_workers_monitor])
     self.assertEqual(2.0, loss)
     self.assertTrue(not chief_exclusive_monitor.is_active and
                     all_workers_monitor.is_active,
                     'Only non-chief runnable monitor must have been active.')
     self.assertTrue(not chief_exclusive_monitor.has_step and
                     all_workers_monitor.has_step,
                     'Only non-chief runnable monitor must have a step.')
开发者ID:MostafaGazar,项目名称:tensorflow,代码行数:30,代码来源:graph_actions_test.py



注:本文中的tensorflow.python.framework.ops.add_to_collection函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python ops.add_to_collections函数代码示例发布时间:2022-05-27
下一篇:
Python ops._name_from_scope_name函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap