• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python training.create_train_op函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.contrib.training.python.training.training.create_train_op函数的典型用法代码示例。如果您正苦于以下问题:Python create_train_op函数的具体用法?Python create_train_op怎么用?Python create_train_op使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了create_train_op函数的17个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables

  def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
    # First, train only the weights of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      weights, biases = variables_lib.get_variables()

      train_op = training.create_train_op(total_loss, optimizer)
      train_weights = training.create_train_op(
          total_loss, optimizer, variables_to_train=[weights])
      train_biases = training.create_train_op(
          total_loss, optimizer, variables_to_train=[biases])

      with session_lib.Session() as sess:
        # Initialize the variables.
        sess.run(variables_lib2.global_variables_initializer())

        # Get the intial weights and biases values.
        weights_values, biases_values = sess.run([weights, biases])
        self.assertGreater(np.linalg.norm(weights_values), 0)
        self.assertAlmostEqual(np.linalg.norm(biases_values), 0)

        # Update weights and biases.
        loss = sess.run(train_op)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the weights and biases have been updated.
        self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
        self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)

        weights_values, biases_values = new_weights, new_biases

        # Update only weights.
        loss = sess.run(train_weights)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the weights have been updated, but biases have not.
        self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
        self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
        weights_values = new_weights

        # Update only biases.
        loss = sess.run(train_biases)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the biases have been updated, but weights have not.
        self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
        self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:52,代码来源:training_test.py


示例2: _train_model

  def _train_model(self, checkpoint_dir, num_steps):
    """Trains a simple classification model.

    Note that the data has been configured such that after around 300 steps,
    the model has memorized the dataset (e.g. we can expect %100 accuracy).

    Args:
      checkpoint_dir: The directory where the checkpoint is written to.
      num_steps: The number of steps to train for.
    """
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = logistic_classifier(tf_inputs)
      loss = loss_ops.log_loss(tf_predictions, tf_labels)

      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(loss, optimizer)

      loss = training.train(
          train_op,
          checkpoint_dir,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)])
开发者ID:Immexxx,项目名称:tensorflow,代码行数:25,代码来源:evaluation_test.py


示例3: testEmptyUpdateOps

  def testEmptyUpdateOps(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer, update_ops=[])

      moving_mean = variables_lib.get_variables_by_name('moving_mean')[0]
      moving_variance = variables_lib.get_variables_by_name('moving_variance')[
          0]

      with session_lib.Session() as sess:
        # Initialize all variables
        sess.run(variables_lib2.global_variables_initializer())
        mean, variance = sess.run([moving_mean, moving_variance])
        # After initialization moving_mean == 0 and moving_variance == 1.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)

        for _ in range(10):
          sess.run([train_op])
        mean = moving_mean.eval()
        variance = moving_variance.eval()

        # Since we skip update_ops the moving_vars are not updated.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:33,代码来源:training_test.py


示例4: gen_train_op

 def gen_train_op():
   with ops.name_scope('generator_train'):
     return training.create_train_op(
         total_loss=gan_loss.generator_loss,
         optimizer=generator_optimizer,
         variables_to_train=gan_model.generator_variables,
         update_ops=gen_update_ops)
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:7,代码来源:tpu_gan_estimator_impl.py


示例5: testNoneGlobalStep

  def testNoneGlobalStep(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(
          total_loss, optimizer, global_step=None)

      global_step = variables_lib.get_or_create_global_step()

      with session_lib.Session() as sess:
        # Initialize all variables
        sess.run(variables_lib2.global_variables_initializer())

        for _ in range(10):
          sess.run([train_op])
        global_step = global_step.eval()
        # Since train_op don't use global_step it shouldn't change.
        self.assertAllClose(global_step, 0)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:25,代码来源:training_test.py


示例6: testResumeTrainAchievesRoughlyTheSameLoss

  def testResumeTrainAchievesRoughlyTheSameLoss(self):
    number_of_steps = [300, 1, 5]
    logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')

    for i in range(len(number_of_steps)):
      with ops.Graph().as_default():
        random_seed.set_random_seed(i)
        tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
        tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

        tf_predictions = logistic_classifier(tf_inputs)
        loss_ops.log_loss(tf_predictions, tf_labels)
        total_loss = loss_ops.get_total_loss()

        optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

        train_op = training.create_train_op(total_loss, optimizer)

        saver = saver_lib.Saver()

        loss = training.train(
            train_op,
            logdir,
            hooks=[
                basic_session_run_hooks.StopAtStepHook(
                    num_steps=number_of_steps[i]),
                basic_session_run_hooks.CheckpointSaverHook(
                    logdir, save_steps=50, saver=saver),
            ])
        self.assertIsNotNone(loss)
        self.assertLess(loss, .015)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:31,代码来源:training_test.py


示例7: dis_train_op

 def dis_train_op():
   with ops.name_scope('discriminator_train'):
     return training.create_train_op(
         total_loss=gan_loss.discriminator_loss,
         optimizer=discriminator_optimizer,
         variables_to_train=gan_model.discriminator_variables,
         update_ops=dis_update_ops)
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:7,代码来源:tpu_gan_estimator_impl.py


示例8: testTrainOpInCollection

  def testTrainOpInCollection(self):
    with ops.Graph().as_default():
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss = losses.log_loss(tf_labels, tf_predictions)
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(loss, optimizer)

      # Make sure the training op was recorded in the proper collection
      self.assertTrue(train_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:12,代码来源:training_test.py


示例9: testTrainWithNoInitAssignCanAchieveZeroLoss

  def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
    g = ops.Graph()
    with g.as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()

      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer)

      loss = training.train(
          train_op,
          self._logdir,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)])
      self.assertLess(loss, .1)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:20,代码来源:training_test.py


示例10: testCanAchieveZeroLoss

  def testCanAchieveZeroLoss(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = logistic_classifier(tf_inputs)
      losses.log_loss(tf_labels, tf_predictions)
      total_loss = losses.get_total_loss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(total_loss, optimizer)

      loss = training.train(
          train_op,
          None,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)],
          save_summaries_steps=None,
          save_checkpoint_secs=None)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .015)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:20,代码来源:training_test.py


示例11: testCanAchieveZeroLoss

  def testCanAchieveZeroLoss(self):
    logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')

    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = logistic_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()

      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer)

      loss = training.train(
          train_op,
          logdir,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)])
      self.assertIsNotNone(loss)
      self.assertLess(loss, .015)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:22,代码来源:training_test.py


示例12: testGlobalStepIsIncrementedByDefault

  def testGlobalStepIsIncrementedByDefault(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss = losses.log_loss(tf_labels, tf_predictions)
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = training.create_train_op(loss, optimizer)

      global_step = variables_lib.get_or_create_global_step()

      with self.test_session() as session:
        # Initialize all variables
        session.run(variables_lib2.global_variables_initializer())

        for _ in range(10):
          session.run(train_op)

        # After 10 updates global_step should be 10.
        self.assertAllClose(global_step.eval(), 10)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:22,代码来源:training_test.py


示例13: create_train_op

  def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
    tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

    tf_predictions = logistic_classifier(tf_inputs)
    loss_ops.log_loss(tf_predictions, tf_labels)
    total_loss = loss_ops.get_total_loss()

    optimizer = gradient_descent.GradientDescentOptimizer(
        learning_rate=learning_rate)

    def transform_grads_fn(grads):
      if gradient_multiplier != 1.0:
        variables = variables_lib2.trainable_variables()
        gradient_multipliers = {var: gradient_multiplier for var in variables}

        with ops.name_scope('multiply_grads'):
          return training.multiply_gradients(grads, gradient_multipliers)
      else:
        return grads

    return training.create_train_op(
        total_loss, optimizer, transform_grads_fn=transform_grads_fn)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:23,代码来源:training_test.py


示例14: testUseUpdateOps

  def testUseUpdateOps(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      expected_mean = np.mean(self._inputs, axis=(0))
      expected_var = np.var(self._inputs, axis=(0))

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss = losses.log_loss(tf_labels, tf_predictions)
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(loss, optimizer)

      moving_mean = variables_lib.get_variables_by_name('moving_mean')[0]
      moving_variance = variables_lib.get_variables_by_name('moving_variance')[
          0]

      with self.test_session() as session:
        # Initialize all variables
        session.run(variables_lib2.global_variables_initializer())
        mean, variance = session.run([moving_mean, moving_variance])
        # After initialization moving_mean == 0 and moving_variance == 1.
        self.assertAllClose(mean, [0] * 4)
        self.assertAllClose(variance, [1] * 4)

        for _ in range(10):
          session.run(train_op)

        mean = moving_mean.eval()
        variance = moving_variance.eval()
        # After 10 updates with decay 0.1 moving_mean == expected_mean and
        # moving_variance == expected_var.
        self.assertAllClose(mean, expected_mean)
        self.assertAllClose(variance, expected_var)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:36,代码来源:training_test.py


示例15: create_train_op

def create_train_op(total_loss,
                    optimizer,
                    global_step=_USE_GLOBAL_STEP,
                    update_ops=None,
                    variables_to_train=None,
                    clip_gradient_norm=0,
                    summarize_gradients=False,
                    gate_gradients=tf_optimizer.Optimizer.GATE_OP,
                    aggregation_method=None,
                    colocate_gradients_with_ops=False,
                    gradient_multipliers=None,
                    check_numerics=True):
    """Creates an `Operation` that evaluates the gradients and returns the loss.

    Args:
      total_loss: A `Tensor` representing the total loss.
      optimizer: A tf.Optimizer to use for computing the gradients.
      global_step: A `Tensor` representing the global step variable. If left as
        `_USE_GLOBAL_STEP`, then slim.variables.global_step() is used.
      update_ops: An optional list of updates to execute. If `update_ops` is
        `None`, then the update ops are set to the contents of the
        `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
        it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
        a warning will be displayed.
      variables_to_train: an optional list of variables to train. If None, it will
        default to all tf.trainable_variables().
      clip_gradient_norm: If greater than 0 then the gradients would be clipped
        by it.
      summarize_gradients: Whether or not add summaries for each gradient.
      gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: Whether or not to try colocating the gradients
        with the ops that generated them.
      gradient_multipliers: A dictionary of either `Variables` or `Variable` op
        names to the coefficient by which the associated gradient should be
        scaled.
      check_numerics: Whether or not we apply check_numerics.

    Returns:
      A `Tensor` that when evaluated, computes the gradients and returns the total
        loss value.
    """

    def transform_grads_fn(grads):
        if gradient_multipliers:
            with ops.name_scope('multiply_grads'):
                grads = multiply_gradients(grads, gradient_multipliers)

        # Clip gradients.
        if clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                grads = clip_gradient_norms(grads, clip_gradient_norm)
        return grads

    return training.create_train_op(
        total_loss=total_loss,
        optimizer=optimizer,
        global_step=global_step,
        update_ops=update_ops,
        variables_to_train=variables_to_train,
        transform_grads_fn=transform_grads_fn,
        summarize_gradients=summarize_gradients,
        gate_gradients=gate_gradients,
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops,
        check_numerics=check_numerics)
开发者ID:astorfi,项目名称:tensorflow,代码行数:67,代码来源:learning.py


示例16: gan_train_ops

def gan_train_ops(
    model,
    loss,
    generator_optimizer,
    discriminator_optimizer,
    check_for_unused_update_ops=True,
    # Optional args to pass directly to the `create_train_op`.
    **kwargs):
  """Returns GAN train ops.

  The highest-level call in TFGAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.

  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.

  Returns:
    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    be used to train a generator/discriminator pair.
  """
  if isinstance(model, namedtuples.CycleGANModel):
    saved_params = locals()
    saved_params.pop('model', None)
    saved_params.pop('loss', None)
    kwargs = saved_params.pop('kwargs', {})
    saved_params.update(kwargs)
    with ops.name_scope('cyclegan_x2y_train'):
      train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
                                    **saved_params)
    with ops.name_scope('cyclegan_y2x_train'):
      train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
                                    **saved_params)
    return namedtuples.GANTrainOps(
        (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
        (train_ops_x2y.discriminator_train_op,
         train_ops_y2x.discriminator_train_op),
        training_util.get_or_create_global_step().assign_add(1))

  # Create global step increment op.
  global_step = training_util.get_or_create_global_step()
  global_step_inc = global_step.assign_add(1)

  # Get generator and discriminator update ops. We split them so that update
  # ops aren't accidentally run multiple times. For now, throw an error if
  # there are update ops that aren't associated with either the generator or
  # the discriminator. Might modify the `kwargs` dictionary.
  gen_update_ops, dis_update_ops = _get_update_ops(
      kwargs, model.generator_scope.name, model.discriminator_scope.name,
      check_for_unused_update_ops)

  generator_global_step = None
  if isinstance(generator_optimizer,
                sync_replicas_optimizer.SyncReplicasOptimizer):
    # TODO(joelshor): Figure out a way to get this work without including the
    # dummy global step in the checkpoint.
    # WARNING: Making this variable a local variable causes sync replicas to
    # hang forever.
    generator_global_step = variable_scope.get_variable(
        'dummy_global_step_generator',
        shape=[],
        dtype=global_step.dtype.base_dtype,
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    gen_update_ops += [generator_global_step.assign(global_step)]
  with ops.name_scope('generator_train'):
    gen_train_op = training.create_train_op(
        total_loss=loss.generator_loss,
        optimizer=generator_optimizer,
        variables_to_train=model.generator_variables,
        global_step=generator_global_step,
        update_ops=gen_update_ops,
        **kwargs)

  discriminator_global_step = None
  if isinstance(discriminator_optimizer,
                sync_replicas_optimizer.SyncReplicasOptimizer):
    # See comment above `generator_global_step`.
    discriminator_global_step = variable_scope.get_variable(
        'dummy_global_step_discriminator',
        shape=[],
        dtype=global_step.dtype.base_dtype,
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    dis_update_ops += [discriminator_global_step.assign(global_step)]
  with ops.name_scope('discriminator_train'):
    disc_train_op = training.create_train_op(
        total_loss=loss.discriminator_loss,
        optimizer=discriminator_optimizer,
        variables_to_train=model.discriminator_variables,
#.........这里部分代码省略.........
开发者ID:andrewharp,项目名称:tensorflow,代码行数:101,代码来源:train.py


示例17: testTrainAllVarsHasLowerLossThanTrainSubsetOfVars

  def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
    logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
    if gfile.Exists(logdir):  # For running on jenkins.
      gfile.DeleteRecursively(logdir)

    # First, train only the weights of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      weights = variables_lib.get_variables_by_name('weights')

      train_op = training.create_train_op(
          total_loss, optimizer, variables_to_train=weights)

      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=200),
          ])
      self.assertGreater(loss, .015)
      self.assertLess(loss, .05)

    # Next, train the biases of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      biases = variables_lib.get_variables_by_name('biases')

      train_op = training.create_train_op(
          total_loss, optimizer, variables_to_train=biases)

      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=300),
          ])
      self.assertGreater(loss, .015)
      self.assertLess(loss, .05)

    # Finally, train both weights and bias to get lower loss.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer)
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=400),
          ])
      self.assertIsNotNone(loss)
      self.assertLess(loss, .015)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:67,代码来源:training_test.py



注:本文中的tensorflow.contrib.training.python.training.training.create_train_op函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python training.train函数代码示例发布时间:2022-05-27
下一篇:
Python hparam.parse_values函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap