• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python ops.get_collection函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.ops.get_collection函数的典型用法代码示例。如果您正苦于以下问题:Python get_collection函数的具体用法?Python get_collection怎么用?Python get_collection使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_collection函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: __init__

    def __init__(self,
                 checkpoint_dir,
                 display_steps=100,
                 maximum_train_steps=None,
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            display_steps: A python integer, display every N steps.
            maximum_train_steps: A python integer, the maximum training steps.
            do_summary: Whether to save summaries when display.
            is_chief: Whether this is the chief process.do_summary:
        """

        tf.logging.info("Create DisplayHook.")
        self._checkpoint_dir = checkpoint_dir
        # display steps
        self._display_steps = display_steps
        self._maximum_train_steps = maximum_train_steps
        self._do_summary = do_summary
        self._is_chief = is_chief  # not used now

        # display values
        global_step = training_util.get_global_step()
        display_keys = ops.get_collection(Constants.DISPLAY_KEY_COLLECTION_NAME)
        display_values = ops.get_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME)
        self._display_args = dict(zip(display_keys, display_values))
        self._display_args["global_step"] = global_step
        # timer & summary writer
        self._timer = None
        self._logging_timer = None
        self._summary_writer = None
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:34,代码来源:hooks.py


示例2: testAddWeight

    def testAddWeight(self):
        with self.test_session():
            layer = base_layers._Layer(name="my_layer")

            # Test basic variable creation.
            variable = layer._add_variable("my_var", [2, 2], initializer=init_ops.zeros_initializer)
            self.assertEqual(variable.name, "my_var:0")
            self.assertListEqual(layer.variables, [variable])
            self.assertListEqual(layer.trainable_variables, [variable])
            self.assertListEqual(layer.non_trainable_variables, [])
            self.assertListEqual(layer.variables, ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))

            # Test non-trainable variable creation.
            # layer._add_variable should work even outside `build` and `call`.
            variable_2 = layer._add_variable(
                "non_trainable_var", [2, 2], initializer=init_ops.zeros_initializer, trainable=False
            )
            self.assertListEqual(layer.variables, [variable, variable_2])
            self.assertListEqual(layer.trainable_variables, [variable])
            self.assertListEqual(layer.non_trainable_variables, [variable_2])
            self.assertEqual(len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)

            # Test with regularizer.
            regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
            variable = layer._add_variable(
                "reg_var", [2, 2], initializer=init_ops.zeros_initializer, regularizer=regularizer
            )
            self.assertEqual(len(layer.losses), 1)
开发者ID:BloodD,项目名称:tensorflow,代码行数:28,代码来源:base_test.py


示例3: after_create_session

  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    """Does first run which shows the eval metrics before training."""
    if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS):
      raise ValueError(
          'InMemoryEvaluator does not support saveables other than global '
          'variables.')
    self._var_name_to_train_var = {
        v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    }
    var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set(
        self._var_name_to_train_var.keys())
    # Filter training var names that are not exist in evaluation
    self._var_name_to_train_var = {
        v_name: self._var_name_to_train_var[v_name]
        for v_name in var_names_to_transfer
    }
    # Filter eval var names that are not exist in training
    self._var_name_to_eval_var = {
        v_name: self._var_name_to_eval_var[v_name]
        for v_name in var_names_to_transfer
    }

    with self._graph.as_default():
      self._var_feed_op = control_flow_ops.group([
          state_ops.assign(self._var_name_to_eval_var[v_name],
                           self._var_name_to_placeholder[v_name])
          for v_name in var_names_to_transfer
      ])

    self._evaluate(session)
开发者ID:ChristinaEricka,项目名称:tensorflow,代码行数:30,代码来源:hooks.py


示例4: testAddWeight

  def testAddWeight(self):
    layer = base_layers.Layer(name='my_layer')

    # Test basic variable creation.
    variable = layer.add_variable(
        'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(variable.name, 'my_layer/my_var:0')
    self.assertListEqual(layer.variables, [variable])
    self.assertListEqual(layer.trainable_variables, [variable])
    self.assertListEqual(layer.non_trainable_variables, [])
    self.assertListEqual(layer.variables,
                         ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))

    # Test non-trainable variable creation.
    # layer.add_variable should work even outside `build` and `call`.
    variable_2 = layer.add_variable(
        'non_trainable_var', [2, 2],
        initializer=init_ops.zeros_initializer(),
        trainable=False)
    self.assertListEqual(layer.variables, [variable, variable_2])
    self.assertListEqual(layer.trainable_variables, [variable])
    self.assertListEqual(layer.non_trainable_variables, [variable_2])
    self.assertEqual(
        len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)

    if context.in_graph_mode():
      # regularizers only supported in GRAPH mode.
      regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
      variable = layer.add_variable(
          'reg_var', [2, 2],
          initializer=init_ops.zeros_initializer(),
          regularizer=regularizer)
      self.assertEqual(len(layer.losses), 1)
开发者ID:keveman,项目名称:tensorflow,代码行数:33,代码来源:base_test.py


示例5: testVariableCollections

 def testVariableCollections(self):
   with self.test_session():
     a = variables_lib2.variable('a', [], collections=['A', 'C'])
     b = variables_lib2.variable('b', [], collections=['B', 'C'])
     self.assertEquals(a, ops.get_collection('A')[0])
     self.assertEquals(b, ops.get_collection('B')[0])
     self.assertListEqual([a, b], ops.get_collection('C'))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:7,代码来源:variables_test.py


示例6: _train_op_fn

  def _train_op_fn(loss):
    """Returns the op to optimize the loss."""
    train_ops = []
    global_step = training_util.get_global_step()
    if dnn_logits is not None:
      train_ops.append(
          dnn_optimizer.minimize(
              loss,
              var_list=ops.get_collection(
                  ops.GraphKeys.TRAINABLE_VARIABLES,
                  scope=dnn_parent_scope)))
    if linear_logits is not None:
      train_ops.append(
          linear_optimizer.minimize(
              loss,
              var_list=ops.get_collection(
                  ops.GraphKeys.TRAINABLE_VARIABLES,
                  scope=linear_parent_scope)))

    train_op = control_flow_ops.group(*train_ops)
    with ops.control_dependencies([train_op]):
      with ops.colocate_with(global_step):
        return state_ops.assign_add(global_step, 1)

    return head.create_estimator_spec(
        features=features,
        mode=mode,
        labels=labels,
        train_op_fn=_train_op_fn,
        logits=logits)
开发者ID:m-colombo,项目名称:tensorflow,代码行数:30,代码来源:dnn_linear_combined.py


示例7: testSaveAsText

  def testSaveAsText(self):
    export_dir = self._get_export_dir("test_astext")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Restore the graph with tag "bar", whose variables were not saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:30,代码来源:saved_model_test.py


示例8: testTrainOpGroup

  def testTrainOpGroup(self):
    export_dir = self._get_export_dir("test_train_op_group")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      sess.run(variables.global_variables_initializer())
      train_op = control_flow_ops.group()

      sess.run(train_op)
      # TODO(karmel): remove explicit call when in the public method.
      builder._add_train_op(train_op)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      self.assertIsInstance(
          ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:28,代码来源:saved_model_test.py


示例9: testCustomMainOp

  def testCustomMainOp(self):
    export_dir = self._get_export_dir("test_main_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      # Initialize another variable `v3` to 42.
      v3 = variables.Variable(42, name="v3")
      ops.add_to_collection("v", v3)

      # Set up an assignment op to be run as part of the main_op.
      with ops.control_dependencies([main_op.main_op()]):
        add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
        custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))

      sess.run(custom_main_op)
      builder.add_meta_graph_and_variables(
          sess, ["foo"], main_op=custom_main_op)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      # Evaluates to the sum of the first two variables and assigned as part of
      # the main_op, following a restore.
      self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:34,代码来源:saved_model_test.py


示例10: testLegacyInitOp

  def testLegacyInitOp(self):
    export_dir = self._get_export_dir("test_legacy_init_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      # Initialize another variable `v3` to 42.
      v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
      ops.add_to_collection("v", v3)

      # Set up an assignment op to be run as part of the legacy_init_op.
      assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
      legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")

      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], legacy_init_op=legacy_init_op)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      # Evaluates to the sum of the first two variables and assigned as part of
      # the legacy_init_op, following a restore.
      self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:33,代码来源:saved_model_test.py


示例11: test_example

  def test_example(self):
    with self.test_session() as session:
      for tower_id in range(3):
        self.create_tower_metrics(tower_id)

      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      session.run(
          replicate_model_fn._reduce_metric_variables(number_of_towers=3))

      # 1st tower = 1.3, 2.3,  [3.3, 3.5, 3.7]
      # 2nd tower = 2.6, 4.6,  [6.6, 7.0, 7.4]
      # 3rd tower = 3.9, 6.9,  [9.9, 10.5, 11.1]
      # Reduced =   7.8, 13.8, [19.8, 21.0, 22.2]
      # Towers are accumulated in the first tower.
      local_metrics = session.run(
          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

      self.assertNear(7.8, local_metrics[0], 0.01)
      self.assertNear(13.8, local_metrics[1], 0.01)
      self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
      self.assertNear(0.0, local_metrics[3], 0.01)
      self.assertNear(0.0, local_metrics[4], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
      self.assertNear(0.0, local_metrics[6], 0.01)
      self.assertNear(0.0, local_metrics[7], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:29,代码来源:replicate_model_fn_test.py


示例12: testMultipleConvMaskAdded

  def testMultipleConvMaskAdded(self):
    number_of_layers = 5

    kernel_size = 3
    base_depth = 4
    depth_step = 7

    input_tensor = array_ops.ones((8, self.height, self.width, base_depth))

    top_layer = input_tensor

    for ix in range(number_of_layers):
      top_layer = layers.masked_conv2d(top_layer, base_depth +
                                       (ix + 1) * depth_step, kernel_size)

    masks = ops.get_collection(core_layers.MASK_COLLECTION)
    self.assertEqual(len(masks), number_of_layers)
    for ix in range(number_of_layers):
      self.assertListEqual(masks[ix].get_shape().as_list(), [
          kernel_size, kernel_size, base_depth + ix * depth_step,
          base_depth + (ix + 1) * depth_step
      ])

    masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
    self.assertEqual(len(masked_weight), number_of_layers)
    for ix in range(number_of_layers):
      self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
          kernel_size, kernel_size, base_depth + ix * depth_step,
          base_depth + (ix + 1) * depth_step
      ])
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:30,代码来源:layers_test.py


示例13: testTags

  def testTags(self):
    export_dir = os.path.join(test.get_temp_dir(), "test_tags")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    # - a single tag (from predefined constants).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])

    # Graph that updates the single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    # - a single tag (from predefined constants).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      builder.add_meta_graph([tag_constants.SERVING])

    # Graph that updates the single variable. SavedModel is invoked:
    # - to add the model (weights are not updated).
    # - multiple custom tags.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 44)
      builder.add_meta_graph(["foo", "bar"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with a single predefined tag whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, [tag_constants.TRAINING], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Restore the graph with a single predefined tag whose variables were not
    # saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, [tag_constants.SERVING], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Restore the graph with multiple tags. Provide duplicate tags to test set
    # semantics.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo", "bar", "foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Try restoring a graph with a non-existent tag. This should yield a runtime
    # error.
    with self.test_session(graph=ops.Graph()) as sess:
      self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
                        export_dir)

    # Try restoring a graph where a subset of the tags match. Since tag matching
    # for meta graph defs follows "all" semantics, this should yield a runtime
    # error.
    with self.test_session(graph=ops.Graph()) as sess:
      self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
                        export_dir)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:60,代码来源:saved_model_test.py


示例14: test_reduce_is_idempotent

  def test_reduce_is_idempotent(self):
    with self.test_session() as session:
      for tower_id in range(3):
        self.create_tower_metrics(tower_id)

      session.run(
          variables.variables_initializer(
              ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))

      for _ in range(20):
        session.run(
            replicate_model_fn._reduce_metric_variables(number_of_towers=3))

      local_metrics = session.run(
          ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))

      self.assertNear(7.8, local_metrics[0], 0.01)
      self.assertNear(13.8, local_metrics[1], 0.01)
      self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
      self.assertNear(0.0, local_metrics[3], 0.01)
      self.assertNear(0.0, local_metrics[4], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
      self.assertNear(0.0, local_metrics[6], 0.01)
      self.assertNear(0.0, local_metrics[7], 0.01)
      self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:replicate_model_fn_test.py


示例15: testCreateBN

  def testCreateBN(self):
    # Call layer.
    bn = normalization_layers.BatchNormalization(axis=1)
    inputs = random_ops.random_uniform((5, 4, 3), seed=1)
    training = array_ops.placeholder(dtype='bool')
    outputs = bn.apply(inputs, training=training)

    # Verify shape.
    self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3])

    # Verify layer attributes.
    self.assertEqual(len(bn.updates), 2)
    self.assertEqual(len(bn.variables), 4)
    self.assertEqual(len(bn.trainable_variables), 2)
    self.assertEqual(len(bn.non_trainable_variables), 2)

    # Test that updates were created and added to UPDATE_OPS.
    self.assertEqual(len(bn.updates), 2)
    self.assertListEqual(
        ops.get_collection(ops.GraphKeys.UPDATE_OPS), bn.updates)

    # Test that weights were created and added to TRAINABLE_VARIABLES.
    self.assertListEqual(
        ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
        bn.trainable_variables)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:25,代码来源:normalization_test.py


示例16: _make_training_op

  def _make_training_op(training_loss):
    """Training op for the DNN linear combined model."""
    train_ops = []
    if dnn_logits is not None:
      train_ops.append(
          optimizers.optimize_loss(
              loss=training_loss,
              global_step=contrib_variables.get_global_step(),
              learning_rate=_DNN_LEARNING_RATE,
              optimizer=_get_optimizer(dnn_optimizer),
              gradient_multipliers=_extract_embedding_lr_multipliers(  # pylint: disable=protected-access
                  embedding_lr_multipliers, dnn_parent_scope,
                  dnn_input_scope.name),
              clip_gradients=gradient_clip_norm,
              variables=ops.get_collection(dnn_parent_scope),
              name=dnn_parent_scope,
              # Empty summaries, because head already logs "loss" summary.
              summaries=[]))
    if linear_logits is not None:
      train_ops.append(
          optimizers.optimize_loss(
              loss=training_loss,
              global_step=contrib_variables.get_global_step(),
              learning_rate=_linear_learning_rate(len(linear_feature_columns)),
              optimizer=_get_optimizer(linear_optimizer),
              clip_gradients=gradient_clip_norm,
              variables=ops.get_collection(linear_parent_scope),
              name=linear_parent_scope,
              # Empty summaries, because head already logs "loss" summary.
              summaries=[]))

    return control_flow_ops.group(*train_ops)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:32,代码来源:dnn_linear_combined.py


示例17: _testKLPenaltyBoth

  def _testKLPenaltyBoth(self, layer_class):
    def _make_normal(dtype, *args):  # pylint: disable=unused-argument
      return normal_lib.Normal(
          loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.))
    with self.test_session():
      layer = layer_class(
          filters=2,
          kernel_size=3,
          bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(),
          bias_prior_fn=_make_normal)
      if layer_class == prob_layers_lib.Conv1DVariational:
        inputs = random_ops.random_uniform([2, 3, 1], seed=1)
      elif layer_class == prob_layers_lib.Conv2DVariational:
        inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1)
      elif layer_class == prob_layers_lib.Conv3DVariational:
        inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1)

      # No keys.
      losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
      self.assertEqual(len(losses), 0)
      self.assertListEqual(layer.losses, losses)

      _ = layer(inputs)

      # Yes keys.
      losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
      self.assertEqual(len(losses), 2)
      self.assertListEqual(layer.losses, losses)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:28,代码来源:layers_conv_variational_test.py


示例18: testTrainOpAfterVariables

  def testTrainOpAfterVariables(self):
    export_dir = self._get_export_dir("test_train_op_after_variables")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(sess, ["pre_foo"])

      train_op = state_ops.assign_add(v1, v2)
      sess.run(train_op)
      # TODO(karmel): remove explicit call when in the public method.
      builder._add_train_op(train_op)
      builder.add_meta_graph(["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertIsInstance(
          ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["pre_foo"], export_dir)
      self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:31,代码来源:saved_model_test.py


示例19: testClearExtraneousSavers

  def testClearExtraneousSavers(self):
    export_dir = os.path.join(test.get_temp_dir(),
                              "test_clear_extraneous_savers")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Create a variable and a Saver.
    with ops.Graph().as_default() as graph:
      with session.Session(
          target="",
          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
        self._init_and_validate_variable(sess, "v", 42)

        # Add two Savers, which should be removed in
        # add_meta_graph_and_variables() in favor of the locally added one.
        saver1 = tf_saver.Saver()
        graph.add_to_collection(ops.GraphKeys.SAVERS, saver1)
        saver2 = tf_saver.Saver()
        graph.add_to_collection(ops.GraphKeys.SAVERS, saver2)

        # Confirm there are two SaverDefs.
        savers = graph.get_collection(ops.GraphKeys.SAVERS)
        self.assertEqual(2, len(savers))

        # Confirm there are two Save and two Restore ops.
        save_op_names = set([x.name for x in graph.get_operations()
                             if x.type == "SaveV2"])
        self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]),
                            save_op_names)

        restore_op_names = set([x.name for x in graph.get_operations()
                                if x.type == "RestoreV2"])
        self.assertSetEqual(set(["save/RestoreV2", "save_1/RestoreV2"]),
                            restore_op_names)

        # The SavedModel builder adds its own Saver' for a total of three.
        builder.add_meta_graph_and_variables(
            sess, [tag_constants.TRAINING], clear_devices=True)

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph.
    with ops.Graph().as_default() as graph:
      with self.test_session(graph=graph) as sess:
        loader.load(sess, [tag_constants.TRAINING], export_dir)
        self.assertEqual(
            42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

        # Confirm that the reloaded graph has only one SaverDef.
        savers = ops.get_collection(ops.GraphKeys.SAVERS)
        self.assertEqual(1, len(savers))

        # The reloaded graph should have exactly one Save and one Restore op.
        save_op_names = set([x.name for x in graph.get_operations()
                             if x.type == "SaveV2"])
        self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names)
        restore_op_names = set([x.name for x in graph.get_operations()
                                if x.type == "RestoreV2"])
        self.assertSetEqual(set(["save_2/RestoreV2"]), restore_op_names)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py


示例20: testSignatureDefs

  def testSignatureDefs(self):
    export_dir = self._get_export_dir("test_signature_defs")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable and a single entry in the signature def map.
    # SavedModel is invoked to add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      # Build and populate an empty SignatureDef for testing.
      foo_signature = signature_def_utils.build_signature_def(dict(),
                                                              dict(), "foo")
      builder.add_meta_graph_and_variables(
          sess, ["foo"], signature_def_map={"foo_key": foo_signature})

    # Graph with the same single variable and multiple entries in the signature
    # def map. No weights are saved by SavedModel.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      # Build and populate a different SignatureDef for testing.
      bar_signature = signature_def_utils.build_signature_def(dict(),
                                                              dict(), "bar")
      # Also, build a different SignatureDef corresponding to "foo_key" defined
      # in the previous graph.
      foo_new_signature = signature_def_utils.build_signature_def(dict(),
                                                                  dict(),
                                                                  "foo_new")
      builder.add_meta_graph(
          ["bar"],
          signature_def_map={
              "bar_key": bar_signature,
              "foo_key": foo_new_signature
          })

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo". The single entry in the SignatureDef map
    # corresponding to "foo_key" should exist.
    with self.test_session(graph=ops.Graph()) as sess:
      foo_graph = loader.load(sess, ["foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

      foo_signature = foo_graph.signature_def
      self.assertEqual(len(foo_signature), 1)
      self.assertEqual("foo", foo_signature["foo_key"].method_name)

    # Restore the graph with tag "bar". The SignatureDef map should have two
    # entries. One corresponding to "bar_key" and another corresponding to the
    # new value of "foo_key".
    with self.test_session(graph=ops.Graph()) as sess:
      bar_graph = loader.load(sess, ["bar"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

      bar_signature = bar_graph.signature_def
      self.assertEqual(len(bar_signature), 2)
      self.assertEqual("bar", bar_signature["bar_key"].method_name)
      self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py



注:本文中的tensorflow.python.framework.ops.get_collection函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python ops.get_collection_proto_type函数代码示例发布时间:2022-05-27
下一篇:
Python ops.enable_eager_execution函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap