• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python state_ops.assign_add函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.state_ops.assign_add函数的典型用法代码示例。如果您正苦于以下问题:Python assign_add函数的具体用法?Python assign_add怎么用?Python assign_add使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了assign_add函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks

  def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
    with session.Session(config=no_rewrite_session_config()) as sess:
      v_1 = variables.Variable(50.0, name="v_1")
      v_2 = variables.Variable(-50.0, name="v_2")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run(variables.global_variables_initializer())

      # Disable the sending of traceback and source code.
      sess = grpc_wrapper.TensorBoardDebugWrapperSession(
          sess, self._debug_server_url_1, send_traceback_and_source_code=False)

      for i in xrange(4):
        self._server_1.clear_data()

        if i == 0:
          self._server_1.request_watch(
              "delta_1", 0, "DebugIdentity", breakpoint=True)

        output = sess.run([inc_v_1, inc_v_2])
        self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

        # No op traceback or source code should have been received by the debug
        # server due to the disabling above.
        with self.assertRaisesRegexp(
            ValueError, r"Op .*delta_1.* does not exist"):
          self.assertTrue(self._server_1.query_op_traceback("delta_1"))
        with self.assertRaisesRegexp(
            ValueError, r".* has not received any source file"):
          self._server_1.query_source_file_line(__file__, 1)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:33,代码来源:session_debug_grpc_test.py


示例2: _Update_global_variables

 def _Update_global_variables():
   local_vars = [v for g, v in grads_and_vars if g is not None]
   global_center_vars = [self._global_map[var] for var in local_vars]
   local_center_vars = [self._local_map[var] for var in local_vars]
   local_center_vars_update = []
   for lvar, var in zip(local_center_vars, global_center_vars):
     local_center_vars_update.append(lvar.assign(var))
   update_ops = []
   differences = []
   with ops.control_dependencies(local_center_vars_update):
     for v, lv in zip(local_vars, local_center_vars):
       with ops.device(v.device):
         differences.append(math_ops.subtract(v, lv))
     for lvar, diff in zip(local_vars, differences):
       with ops.device(lvar.device):
         update_ops.append(
             state_ops.assign_sub(lvar,
                                  math_ops.multiply(self._moving_rate,
                                                    diff)))
     for var, diff in zip(global_center_vars, differences):
       with ops.device(var.device):
         update_ops.append(
             state_ops.assign_add(var,
                                  math_ops.multiply(self._moving_rate,
                                                    diff)))
     if global_step:
       with ops.colocate_with(global_step):
         update_ops.append(state_ops.assign_add(global_step, 1))
   variable_update = control_flow_ops.group(*(update_ops))
   return variable_update
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:30,代码来源:elastic_average_optimizer.py


示例3: test_train_max_steps_is_not_incremental

  def test_train_max_steps_is_not_incremental(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions.train(
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)

    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions.train(
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=15)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(15, step)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:graph_actions_test.py


示例4: _dense_moving_average

 def _dense_moving_average(self, x_tm1, b_t, name, beta=.9):
   """
   Creates a moving average for a dense variable.
   
   Inputs:
     x_tm1: the associated parameter (e.g. a weight matrix)
     b_t: the value to accumulate (e.g. the gradient)
     name: a string to use to retrieve it later (e.g. 'm')
     beta: the decay factor (defaults to .9)
   Outputs:
     a_t: the average after moving
     t: the internal timestep (used to correct initialization bias)
   """
   
   a_tm1 = self.get_slot(x_tm1, '%s' % name)
   tm1 = self.get_slot(x_tm1, '%s/tm1' % name)
   t = state_ops.assign_add(tm1, 1, use_locking = self._use_locking)
   if beta < 1:
     beta_t = ops.convert_to_tensor(beta, name='%s/decay' % name)
     beta_t = beta_t * (1-beta**tm1) / (1-beta**t)
   else:
     beta_t = tm1 / t
   a_t = state_ops.assign(a_tm1, beta_t*a_tm1, use_locking=self._use_locking)
   a_t = state_ops.assign_add(a_t, (1-beta_t)*b_t, use_locking=self._use_locking)
   return a_t, t
开发者ID:tdozat,项目名称:Optimization,代码行数:25,代码来源:optimizers.py


示例5: model_fn_diff_modes

def model_fn_diff_modes(features, labels, mode):
  _, _ = features, labels
  v = variables.Variable(21, name='some_var')
  train_op = None
  loss = constant_op.constant(104)
  if mode == model_fn_lib.ModeKeys.TRAIN:
    loss = constant_op.constant(105)
    predictions = constant_op.constant([501])
    train_op = control_flow_ops.group(
        state_ops.assign_add(training.get_global_step(), 1),
        state_ops.assign_add(v, 3))
  elif mode == model_fn_lib.ModeKeys.EVAL:
    loss = constant_op.constant(106)
    predictions = constant_op.constant([502])
  else:
    loss = constant_op.constant(107)
    predictions = constant_op.constant([503])
  return model_fn_lib.EstimatorSpec(
      mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          'abs_err': metrics_lib.mean_absolute_error(
              constant_op.constant(0), predictions)},
      predictions=predictions)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:25,代码来源:saved_model_estimator_test.py


示例6: test_train_skip_train_if_max_step_already_saved

  def test_train_skip_train_if_max_step_already_saved(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)

    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:graph_actions_test.py


示例7: update_state

  def update_state(self, values, sample_weight=None):
    """Accumulates statistics for computing the mean.

    For example, if `values` is [1, 3, 5, 7] then the mean is 4. If
    the `sample_weight` is specified as [1, 1, 0, 0] then the mean would be 2.

    Args:
      values: Per-example value.
      sample_weight: Optional weighting of each example. Defaults to 1.
    """
    values = math_ops.cast(values, self._dtype)
    if sample_weight is None:
      num_values = math_ops.cast(array_ops.size(values), self._dtype)
    else:
      sample_weight = math_ops.cast(sample_weight, self._dtype)

      # Update dimensions of weights to match with values.
      values, _, sample_weight = _squeeze_or_expand_dimensions(
          values, None, sample_weight)
      sample_weight = weights_broadcast_ops.broadcast_weights(
          sample_weight, values)
      num_values = math_ops.reduce_sum(sample_weight)
      values = math_ops.multiply(values, sample_weight)
    values = math_ops.reduce_sum(values)

    # Update state variables
    state_ops.assign_add(self.total, values)
    state_ops.assign_add(self.count, num_values)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:28,代码来源:metrics.py


示例8: testMultiEvalStepIncrements

  def testMultiEvalStepIncrements(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')

    # Train a model for a single step to get a checkpoint.
    self._train_model(checkpoint_dir, num_steps=1)
    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)

    # Create the model so we have something to restore.
    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    logistic_classifier(inputs)

    num_evals = 6

    my_var = local_variable(0.0, name='MyVar')
    # In eval ops, we also increase the eval step one more time.
    eval_ops = [state_ops.assign_add(my_var, 1.0),
                state_ops.assign_add(
                    evaluation._get_or_create_eval_step(), 1, use_locking=True)]
    expect_eval_update_counts = num_evals // 2

    final_ops = array_ops.identity(my_var)

    final_ops_values = evaluation._evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=eval_ops,
        final_ops={'value': final_ops},
        hooks=[evaluation._StopAfterNEvalsHook(num_evals),])
    self.assertEqual(final_ops_values['value'], expect_eval_update_counts)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:28,代码来源:evaluation_test.py


示例9: testClusterSpecPropagationThreeServers2Graphs

  def testClusterSpecPropagationThreeServers2Graphs(self):
    """Boots 3 servers, creates 2 sessions, ensures appropriate operations.

    We create 2 clusterspecs:
     1. server2 as the master, server1 as a worker
     2. server2 as the master, server3 as a worker

    We ensure that variables on the workers are independent.
    """
    server1 = server_lib.Server.create_local_server()
    server2 = server_lib.Server.create_local_server()
    server3 = server_lib.Server.create_local_server()
    cluster_def1 = cluster_pb2.ClusterDef()
    job1 = cluster_def1.job.add()
    job1.name = 'worker1'
    job1.tasks[0] = server2.target[len('grpc://'):]
    job1.tasks[1] = server1.target[len('grpc://'):]

    cluster_def2 = cluster_pb2.ClusterDef()
    job2 = cluster_def2.job.add()
    job2.name = 'worker2'
    job2.tasks[0] = server2.target[len('grpc://'):]
    job2.tasks[1] = server3.target[len('grpc://'):]

    config1 = config_pb2.ConfigProto(cluster_def=cluster_def1)
    config2 = config_pb2.ConfigProto(cluster_def=cluster_def2)

    with ops.Graph().as_default() as g1:
      with ops.device('/job:worker1/task:1'):
        var1 = variables.Variable(array_ops.zeros([2]), name='var1')
        update_op1 = state_ops.assign_add(
            var1, array_ops.ones([2]), name='var1_assign_add')
        init1 = variables.global_variables_initializer()

    with ops.Graph().as_default() as g2:
      with ops.device('/job:worker2/task:1'):
        var2 = variables.Variable(array_ops.zeros([2]), name='var2')
        update_op2 = state_ops.assign_add(
            var2, array_ops.ones([2]), name='var2_assign_add')
        init2 = variables.global_variables_initializer()

    sess1 = session.Session(server2.target, graph=g1, config=config1)
    sess2 = session.Session(server2.target, graph=g2, config=config2)

    init1.run(session=sess1)
    init2.run(session=sess2)

    expected_zeros = np.zeros([2])
    expected_ones = np.ones([2])

    self.assertAllEqual(expected_zeros, sess1.run(var1))
    self.assertAllEqual(expected_zeros, sess2.run(var2))

    self.assertAllEqual(expected_ones, sess1.run(update_op1))
    self.assertAllEqual(expected_ones, sess1.run(var1))
    self.assertAllEqual(expected_zeros, sess2.run(var2))
    self.assertAllEqual(expected_ones, sess2.run(update_op2))
    self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1))
    self.assertAllEqual(expected_ones, sess2.run(var2))
    self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:60,代码来源:session_clusterspec_prop_test.py


示例10: adder

 def adder(x, y):
   state_ops.assign_add(step, 1)
   summary_ops.generic('x', x)
   summary_ops.generic('y', y)
   sum_ = x + y
   summary_ops.generic('sum', sum_)
   return sum_
开发者ID:AnishShah,项目名称:tensorflow,代码行数:7,代码来源:summary_ops_test.py


示例11: testToggleBreakpointsWorks

  def testToggleBreakpointsWorks(self):
    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v_1 = variables.VariableV1(50.0, name="v_1")
      v_2 = variables.VariableV1(-50.0, name="v_2")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run([v_1.initializer, v_2.initializer])

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity(gated_grpc=true)"],
          debug_urls=[self._debug_server_url_1])

      for i in xrange(4):
        self._server_1.clear_data()

        if i in (0, 2):
          # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
          self._server_1.request_watch(
              "delta_1", 0, "DebugIdentity", breakpoint=True)
          self._server_1.request_watch(
              "delta_2", 0, "DebugIdentity", breakpoint=True)
        else:
          # Disable the breakpoint in runs 1 and 3.
          self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
          self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")

        output = sess.run([inc_v_1, inc_v_2],
                          options=run_options, run_metadata=run_metadata)
        self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

        if i in (0, 2):
          # During runs 0 and 2, the server should have received the published
          # debug tensor delta:0:DebugIdentity. The breakpoint should have been
          # unblocked by EventReply reponses from the server.
          self.assertAllClose(
              [5.0],
              self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
          self.assertAllClose(
              [-5.0],
              self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
          # After the runs, the server should have properly registered the
          # breakpoints due to the request_unwatch calls.
          self.assertSetEqual({("delta_1", 0, "DebugIdentity"),
                               ("delta_2", 0, "DebugIdentity")},
                              self._server_1.breakpoints)
        else:
          # After the end of runs 1 and 3, the server has received the requests
          # to disable the breakpoint at delta:0:DebugIdentity.
          self.assertSetEqual(set(), self._server_1.breakpoints)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:57,代码来源:session_debug_grpc_test.py


示例12: _model_fn_with_incremental_loss

 def _model_fn_with_incremental_loss(features, labels, mode):
   _, _ = features, labels
   local_weight = variables.Variable(
       0., name='local_weight', collections=[ops.GraphKeys.LOCAL_VARIABLES])
   # Loss will be 2, 4, 6, ...
   loss = 2 * state_ops.assign_add(local_weight, 1.)
   return model_fn_lib.EstimatorSpec(
       mode,
       loss=loss,
       train_op=state_ops.assign_add(training.get_global_step(), 1))
开发者ID:Immexxx,项目名称:tensorflow,代码行数:10,代码来源:estimator_test.py


示例13: model_fn

 def model_fn(features, labels, mode):
   _, _ = features, labels
   v = variables.Variable(21, name='some_var')
   scaffold = monitored_session.Scaffold(
       local_init_op=state_ops.assign_add(v, -3).op
   )
   return model_fn_lib.EstimatorSpec(
       mode,
       scaffold=scaffold,
       train_op=state_ops.assign_add(training.get_global_step(), 1),
       loss=array_ops.identity(v))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:11,代码来源:saved_model_estimator_test.py


示例14: setUp

  def setUp(self):
    self.session_root = tempfile.mkdtemp()

    self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v")
    self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
    self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
    self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
    self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v")

    self.sess = session.Session()
    self.sess.run(self.v.initializer)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:11,代码来源:disk_usage_test.py


示例15: test_get_updates_for

  def test_get_updates_for(self):
    a = keras.layers.Input(shape=(1,))
    dense_layer = keras.layers.Dense(1)
    dense_layer.build((None, 1))
    update_1 = state_ops.assign_add(dense_layer.kernel, a)
    update_2 = state_ops.assign_add(dense_layer.kernel, [[1.]])
    dense_layer.add_update(update_1, inputs=a)
    dense_layer.add_update(update_2, inputs=None)

    self.assertListEqual(dense_layer.get_updates_for(a), [update_1])
    self.assertListEqual(dense_layer.get_updates_for(None), [update_2])
开发者ID:japrogramer,项目名称:tensorflow,代码行数:11,代码来源:topology_test.py


示例16: setUp

  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
    tensor = state_ops.assign_add(var, 1.0)
    self.summary_op = summary_lib.scalar('my_summary', tensor)

    with variable_scope.variable_scope('foo', use_resource=True):
      global_step = variables.get_or_create_global_step()
    self.train_op = state_ops.assign_add(global_step, 1)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:13,代码来源:basic_session_run_hooks_test.py


示例17: setUp

  def setUp(self):
    self.session_root = tempfile.mkdtemp()

    self.v = variables.VariableV1(10.0, dtype=dtypes.float32, name="v")
    self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
    self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
    self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
    self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v")

    self.ph = array_ops.placeholder(dtypes.float32, shape=(), name="ph")
    self.inc_w_ph = state_ops.assign_add(self.v, self.ph, name="inc_w_ph")

    self.sess = session.Session()
    self.sess.run(self.v.initializer)
开发者ID:aritratony,项目名称:tensorflow,代码行数:14,代码来源:dumping_wrapper_test.py


示例18: testTensorBoardDebuggerWrapperToggleBreakpointsWorks

  def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
    with session.Session(config=no_rewrite_session_config()) as sess:
      v_1 = variables.Variable(50.0, name="v_1")
      v_2 = variables.Variable(-50.0, name="v_2")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run([v_1.initializer, v_2.initializer])

      # The TensorBoardDebugWrapperSession should add a DebugIdentity debug op
      # with attribute gated_grpc=True for every tensor in the graph.
      sess = grpc_wrapper.TensorBoardDebugWrapperSession(
          sess, self._debug_server_url_1)

      for i in xrange(4):
        self._server_1.clear_data()

        if i in (0, 2):
          # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
          self._server_1.request_watch(
              "delta_1", 0, "DebugIdentity", breakpoint=True)
          self._server_1.request_watch(
              "delta_2", 0, "DebugIdentity", breakpoint=True)
        else:
          # Disable the breakpoint in runs 1 and 3.
          self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
          self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")

        output = sess.run([inc_v_1, inc_v_2])
        self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

        if i in (0, 2):
          # During runs 0 and 2, the server should have received the published
          # debug tensor delta:0:DebugIdentity. The breakpoint should have been
          # unblocked by EventReply reponses from the server.
          self.assertAllClose(
              [5.0],
              self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
          self.assertAllClose(
              [-5.0],
              self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
          # After the runs, the server should have properly registered the
          # breakpoints.
        else:
          # After the end of runs 1 and 3, the server has received the requests
          # to disable the breakpoint at delta:0:DebugIdentity.
          self.assertSetEqual(set(), self._server_1.breakpoints)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:49,代码来源:session_debug_grpc_test.py


示例19: _createGraph

  def _createGraph(self):
    """Create graph for testing.

    Returns:
      Python Graph object.
    """
    with ops.Graph().as_default() as graph:
      with ops.device("/job:worker/task:0/cpu:0"):
        self.a = variables.VariableV1(10.0, name="a")
        self.b = variables.VariableV1(100.0, name="b")
        self.inc_a = state_ops.assign_add(self.a, 2.0, name="inc_a")
        self.dec_b = state_ops.assign_add(self.b, -5.0, name="dec_b")
        self.p = math_ops.multiply(self.inc_a, self.dec_b, name="p")
        self.q = math_ops.negative(self.p, name="q")
    return graph
开发者ID:perfmjs,项目名称:tensorflow,代码行数:15,代码来源:dist_session_debug_grpc_test.py


示例20: _minimize

  def _minimize(loss, global_step):
    trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
    testcase.assertItemsEqual(
        expected_var_names,
        [var.name for var in trainable_vars])

    # Verify loss. We can't check the value directly, so we add an assert op.
    testcase.assertEquals(0, loss.shape.ndims)
    if expected_loss is None:
      return state_ops.assign_add(global_step, 1).op
    assert_loss = _assert_close(
        math_ops.to_float(expected_loss, name='expected'), loss,
        name='assert_loss')
    with ops.control_dependencies((assert_loss,)):
      return state_ops.assign_add(global_step, 1).op
开发者ID:cameronphchen,项目名称:tensorflow,代码行数:15,代码来源:dnn_test.py



注:本文中的tensorflow.python.ops.state_ops.assign_add函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python state_ops.assign_sub函数代码示例发布时间:2022-05-27
下一篇:
Python state_ops.assign函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap