• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python checkpoint_management.latest_checkpoint函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.training.checkpoint_management.latest_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python latest_checkpoint函数的具体用法?Python latest_checkpoint怎么用?Python latest_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了latest_checkpoint函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testRecoverSession

  def testRecoverSession(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      sess.run(v.initializer)
      self.assertEquals(1, sess.run(v))
      saver.save(sess,
                 os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
    self._test_recovered_variable(
        checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
            checkpoint_dir))
    # Cannot set both checkpoint_dir and checkpoint_filename_with_path.
    with self.assertRaises(ValueError):
      self._test_recovered_variable(
          checkpoint_dir=checkpoint_dir,
          checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
              checkpoint_dir))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:31,代码来源:session_manager_test.py


示例2: _new_layer_weight_loading_test_template

  def _new_layer_weight_loading_test_template(
      self, first_model_fn, second_model_fn, restore_init_fn):
    with self.cached_session() as session:
      model = first_model_fn()
      temp_dir = self.get_temp_dir()
      prefix = os.path.join(temp_dir, 'ckpt')

      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
      executing_eagerly = context.executing_eagerly()
      ref_y_tensor = model(x)
      if not executing_eagerly:
        session.run([v.initializer for v in model.variables])
      ref_y = self.evaluate(ref_y_tensor)
      model.save_weights(prefix)
      self.assertEqual(
          prefix,
          checkpoint_management.latest_checkpoint(temp_dir))
      for v in model.variables:
        self.evaluate(
            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))

      self.addCleanup(shutil.rmtree, temp_dir)

      second_model = second_model_fn()
      second_model.load_weights(prefix)
      second_model(x)
      self.evaluate(restore_init_fn(second_model))
      second_model.save_weights(prefix)
      # Check that the second model's checkpoint loads into the original model
      model.load_weights(prefix)
      y = self.evaluate(model(x))
      self.assertAllClose(ref_y, y)
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:32,代码来源:hdf5_format_test.py


示例3: testUsageGraph

 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root = util.Checkpoint(
             optimizer=optimizer, model=model,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             model(input_value),
             global_step=root.global_step)
         checkpoint_path = checkpoint_management.latest_checkpoint(
             checkpoint_directory)
         with self.session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
开发者ID:jackd,项目名称:tensorflow,代码行数:35,代码来源:checkpointable_utils_test.py


示例4: _restore_or_save_initial_ckpt

  def _restore_or_save_initial_ckpt(self, session):
    # Ideally this should be run in after_create_session but is not for the
    # following reason:
    # Currently there is no way of enforcing an order of running the
    # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
    # is run *after* this hook. That is troublesome because
    # 1. If a checkpoint exists and this hook restores it, the initializer hook
    #    will override it.
    # 2. If no checkpoint exists, this hook will try to save an uninitialized
    #    iterator which will result in an exception.
    #
    # As a temporary fix we enter the following implicit contract between this
    # hook and the _DatasetInitializerHook.
    # 1. The _DatasetInitializerHook initializes the iterator in the call to
    #    after_create_session.
    # 2. This hook saves the iterator on the first call to `before_run()`, which
    #    is guaranteed to happen after `after_create_session()` of all hooks
    #    have been run.

    # Check if there is an existing checkpoint. If so, restore from it.
    # pylint: disable=protected-access
    latest_checkpoint_path = checkpoint_management.latest_checkpoint(
        self._checkpoint_saver_hook._checkpoint_dir,
        latest_filename=self._latest_filename)
    if latest_checkpoint_path:
      self._checkpoint_saver_hook._get_saver().restore(session,
                                                       latest_checkpoint_path)
    else:
      # The checkpoint saved here is the state at step "global_step".
      # Note: We do not save the GraphDef or MetaGraphDef here.
      global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
      self._checkpoint_saver_hook._save(session, global_step)
      self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:33,代码来源:iterator_ops.py


示例5: export_fn

  def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
    """Exports the given Estimator as a SavedModel.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the most recent checkpoint found within the model directory is chosen.
      eval_result: placehold args matching the call signature of ExportStrategy.

    Returns:
      The string path to the exported directory.
    """
    if not checkpoint_path:
      # TODO(b/67425018): switch to
      #    checkpoint_path = estimator.latest_checkpoint()
      #  as soon as contrib is cleaned up and we can thus be sure that
      #  estimator is a tf.estimator.Estimator and not a
      #  tf.contrib.learn.Estimator
      checkpoint_path = checkpoint_management.latest_checkpoint(
          estimator.model_dir)
    export_checkpoint_path, export_eval_result = best_model_selector.update(
        checkpoint_path, eval_result)

    if export_checkpoint_path and export_eval_result is not None:
      checkpoint_base = os.path.basename(export_checkpoint_path)
      export_dir = os.path.join(export_dir_base, checkpoint_base)
      return best_model_export_strategy.export(
          estimator, export_dir, export_checkpoint_path, export_eval_result)
    else:
      return ''
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:32,代码来源:saved_model_export_utils.py


示例6: testAgnosticUsage

 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()), test_util.device(use_gpu=True):
       model = MyModel()
       optimizer = adam.AdamOptimizer(0.001)
       root = util.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = checkpoint_management.latest_checkpoint(
           checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(model, input_value),
           global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
开发者ID:jackd,项目名称:tensorflow,代码行数:32,代码来源:checkpointable_utils_test.py


示例7: wait_for_new_checkpoint

def wait_for_new_checkpoint(checkpoint_dir,
                            last_checkpoint=None,
                            seconds_to_sleep=1,
                            timeout=None):
  """Waits until a new checkpoint file is found.

  Args:
    checkpoint_dir: The directory in which checkpoints are saved.
    last_checkpoint: The last checkpoint path used or `None` if we're expecting
      a checkpoint for the first time.
    seconds_to_sleep: The number of seconds to sleep for before looking for a
      new checkpoint.
    timeout: The maximum number of seconds to wait. If left as `None`, then the
      process will wait indefinitely.

  Returns:
    a new checkpoint path, or None if the timeout was reached.
  """
  logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
  stop_time = time.time() + timeout if timeout is not None else None
  while True:
    checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
    if checkpoint_path is None or checkpoint_path == last_checkpoint:
      if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
        return None
      time.sleep(seconds_to_sleep)
    else:
      logging.info('Found new checkpoint at %s', checkpoint_path)
      return checkpoint_path
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:29,代码来源:evaluation.py


示例8: testGraphDistributionStrategy

  def testGraphDistributionStrategy(self):
    self.skipTest("b/121381184")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      return optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      with ops.Graph().as_default():
        strategy = mirrored_strategy.MirroredStrategy()
        with strategy.scope():
          model = MyModel()
          optimizer = adam.AdamOptimizer(0.001)
          root = checkpointable_utils.Checkpoint(
              optimizer=optimizer, model=model,
              optimizer_step=training_util.get_or_create_global_step())
          status = root.restore(checkpoint_management.latest_checkpoint(
              checkpoint_directory))
          train_op = strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
          with self.session() as session:
            if training_continuation > 0:
              status.assert_consumed()
            status.initialize_or_restore()
            for _ in range(num_training_steps):
              session.run(train_op)
            root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:34,代码来源:util_with_v1_optimizers_test.py


示例9: testEagerTPUDistributionStrategy

  def testEagerTPUDistributionStrategy(self):
    self.skipTest("b/121387144")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      strategy = tpu_strategy.TPUStrategy()
      with strategy.scope():
        model = Subclassed()
        optimizer = adam_v1.AdamOptimizer(0.001)
        root = checkpointable_utils.Checkpoint(
            optimizer=optimizer, model=model,
            optimizer_step=training_util.get_or_create_global_step())
        root.restore(checkpoint_management.latest_checkpoint(
            checkpoint_directory))

        for _ in range(num_training_steps):
          strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
        root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
开发者ID:AndreasGocht,项目名称:tensorflow,代码行数:29,代码来源:checkpointing_test.py


示例10: _read_vars

 def _read_vars(self, model_dir):
   """Returns (global_step, latest_feature)."""
   with ops.Graph().as_default() as g:
     ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
     meta_filename = ckpt_path + '.meta'
     saver_lib.import_meta_graph(meta_filename)
     saver = saver_lib.Saver()
     with self.test_session(graph=g) as sess:
       saver.restore(sess, ckpt_path)
       return sess.run(ops.get_collection('my_vars'))
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:10,代码来源:iterator_ops_test.py


示例11: testRestoreInReconstructedIterator

 def testRestoreInReconstructedIterator(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
   dataset = Dataset.range(10)
   for i in range(5):
     iterator = datasets.Iterator(dataset)
     checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
     checkpoint.restore(checkpoint_management.latest_checkpoint(
         checkpoint_directory))
     for j in range(2):
       self.assertEqual(i * 2 + j, iterator.get_next().numpy())
     checkpoint.save(file_prefix=checkpoint_prefix)
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:datasets_test.py


示例12: testRestoreInReconstructedIteratorInitializable

 def testRestoreInReconstructedIteratorInitializable(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   dataset = dataset_ops.Dataset.range(10)
   iterator = dataset.make_initializable_iterator()
   get_next = iterator.get_next()
   checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
   for i in range(5):
     with self.cached_session() as sess:
       checkpoint.restore(checkpoint_management.latest_checkpoint(
           checkpoint_directory)).initialize_or_restore(sess)
       for j in range(2):
         self.assertEqual(i * 2 + j, self.evaluate(get_next))
       checkpoint.save(file_prefix=checkpoint_prefix)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:14,代码来源:iterator_ops_test.py


示例13: testRestoreInReconstructedIteratorInitializable

 def testRestoreInReconstructedIteratorInitializable(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   dataset = dataset_ops.Dataset.range(10)
   iterator = iter(dataset) if context.executing_eagerly(
   ) else dataset_ops.make_initializable_iterator(dataset)
   get_next = iterator.get_next
   checkpoint = trackable_utils.Checkpoint(iterator=iterator)
   for i in range(5):
     checkpoint.restore(
         checkpoint_management.latest_checkpoint(
             checkpoint_directory)).initialize_or_restore()
     for j in range(2):
       self.assertEqual(i * 2 + j, self.evaluate(get_next()))
     checkpoint.save(file_prefix=checkpoint_prefix)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:15,代码来源:iterator_checkpoint_test.py


示例14: testNameCollision

  def testNameCollision(self):
    # Make sure we have a clean directory to work in.
    with self.tempDir() as tempdir:
      # Jump to that directory until this test is done.
      with self.tempWorkingDir(tempdir):
        # Save training snapshots to a relative path.
        traindir = "train/"
        os.mkdir(traindir)
        # Collides with the default name of the checkpoint state file.
        filepath = os.path.join(traindir, "checkpoint")

        with self.test_session() as sess:
          unused_a = variables.Variable(0.0)  # So that Saver saves something.
          variables.global_variables_initializer().run()

          # Should fail.
          saver = saver_module.Saver(sharded=False)
          with self.assertRaisesRegexp(ValueError, "collides with"):
            saver.save(sess, filepath)

          # Succeeds: the file will be named "checkpoint-<step>".
          saver.save(sess, filepath, global_step=1)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))

          # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
          saver = saver_module.Saver(sharded=True)
          saver.save(sess, filepath)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))

          # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
          saver = saver_module.Saver(sharded=True)
          saver.save(sess, filepath, global_step=1)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))
开发者ID:clsung,项目名称:tensorflow,代码行数:36,代码来源:checkpoint_management_test.py


示例15: __init__

  def __init__(self,
               estimator,
               prediction_input_fn,
               input_alternative_key=None,
               output_alternative_key=None,
               graph=None,
               config=None):
    """Initialize a `ContribEstimatorPredictor`.

    Args:
      estimator: an instance of `tf.contrib.learn.Estimator`.
      prediction_input_fn: a function that takes no arguments and returns an
        instance of `InputFnOps`.
      input_alternative_key: Optional. Specify the input alternative used for
        prediction.
      output_alternative_key: Specify the output alternative used for
        prediction. Not needed for single-headed models but required for
        multi-headed models.
      graph: Optional. The Tensorflow `graph` in which prediction should be
        done.
      config: `ConfigProto` proto used to configure the session.
    """
    self._graph = graph or ops.Graph()
    with self._graph.as_default():
      input_fn_ops = prediction_input_fn()
      # pylint: disable=protected-access
      model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
      # pylint: enable=protected-access
      checkpoint_path = checkpoint_management.latest_checkpoint(
          estimator.model_dir)
      self._session = monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              config=config,
              checkpoint_filename_with_path=checkpoint_path))

    input_alternative_key = (
        input_alternative_key or
        saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY)
    input_alternatives, _ = saved_model_export_utils.get_input_alternatives(
        input_fn_ops)
    self._feed_tensors = input_alternatives[input_alternative_key]

    (output_alternatives,
     output_alternative_key) = saved_model_export_utils.get_output_alternatives(
         model_fn_ops, output_alternative_key)
    _, fetch_tensors = output_alternatives[output_alternative_key]
    self._fetch_tensors = fetch_tensors
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:47,代码来源:contrib_estimator_predictor.py


示例16: _GetGraphDef

  def _GetGraphDef(self, use_trt, max_batch_size, model_dir):
    """Get the frozen mnist GraphDef.

    Args:
      use_trt: whether use TF-TRT to convert the graph.
      max_batch_size: the max batch size to apply during TF-TRT conversion.
      model_dir: the model directory to load the checkpoints.

    Returns:
      The frozen mnist GraphDef.
    """
    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      with graph.device('/GPU:0'):
        x = array_ops.placeholder(
            shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME)
        self._BuildGraph(x)
      # Load weights
      mnist_saver = saver.Saver()
      checkpoint_file = latest_checkpoint(model_dir)
      mnist_saver.restore(sess, checkpoint_file)
      # Freeze
      graph_def = graph_util.convert_variables_to_constants(
          sess, sess.graph_def, output_node_names=[OUTPUT_NODE_NAME])
    # Convert with TF-TRT
    if use_trt:
      logging.info('Number of nodes before TF-TRT conversion: %d',
                   len(graph_def.node))
      converter = trt_convert.TrtGraphConverter(
          input_graph_def=graph_def,
          nodes_blacklist=[OUTPUT_NODE_NAME],
          max_batch_size=max_batch_size,
          precision_mode='INT8',
          # There is a 2GB GPU memory limit for each test, so we set
          # max_workspace_size_bytes to 256MB to leave enough room for TF
          # runtime to allocate GPU memory.
          max_workspace_size_bytes=1 << 28,
          minimum_segment_size=2,
          use_calibration=False,
          use_function_backup=False)
      graph_def = converter.convert()
      logging.info('Number of nodes after TF-TRT conversion: %d',
                   len(graph_def.node))
      num_engines = len(
          [1 for n in graph_def.node if str(n.op) == 'TRTEngineOp'])
      self.assertEqual(1, num_engines)
    return graph_def
开发者ID:kylin9872,项目名称:tensorflow,代码行数:47,代码来源:quantization_mnist_test.py


示例17: testCheckpointExists

  def testCheckpointExists(self):
    for sharded in (False, True):
      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
        with self.session(graph=ops_lib.Graph()) as sess:
          unused_v = variables.Variable(1.0, name="v")
          variables.global_variables_initializer().run()
          saver = saver_module.Saver(sharded=sharded, write_version=version)

          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
          self.assertFalse(
              checkpoint_management.checkpoint_exists(path))  # Not saved yet.

          ckpt_prefix = saver.save(sess, path)
          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))

          ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir)
          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
开发者ID:clsung,项目名称:tensorflow,代码行数:17,代码来源:checkpoint_management_test.py


示例18: testRelativePath

  def testRelativePath(self):
    # Make sure we have a clean directory to work in.
    with self.tempDir() as tempdir:

      # Jump to that directory until this test is done.
      with self.tempWorkingDir(tempdir):

        # Save training snapshots to a relative path.
        traindir = "train/"
        os.mkdir(traindir)

        filename = "snapshot"
        filepath = os.path.join(traindir, filename)

        with self.test_session() as sess:
          # Build a simple graph.
          v0 = variables.Variable(0.0)
          inc = v0.assign_add(1.0)

          save = saver_module.Saver({"v0": v0})

          # Record a short training history.
          variables.global_variables_initializer().run()
          save.save(sess, filepath, global_step=0)
          inc.eval()
          save.save(sess, filepath, global_step=1)
          inc.eval()
          save.save(sess, filepath, global_step=2)

        with self.test_session() as sess:
          # Build a new graph with different initialization.
          v0 = variables.Variable(-1.0)

          # Create a new saver.
          save = saver_module.Saver({"v0": v0})
          variables.global_variables_initializer().run()

          # Get the most recent checkpoint name from the training history file.
          name = checkpoint_management.latest_checkpoint(traindir)
          self.assertIsNotNone(name)

          # Restore "v0" from that checkpoint.
          save.restore(sess, name)
          self.assertEqual(v0.eval(), 2.0)
开发者ID:clsung,项目名称:tensorflow,代码行数:44,代码来源:checkpoint_management_test.py


示例19: testTrainSpinn

  def testTrainSpinn(self):
    """Test with fake toy SNLI data and GloVe vectors."""

    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    dev_data = data.SnliData(fake_train_file, word2index)
    test_data = data.SnliData(fake_train_file, word2index)

    # 2. Create a fake config.
    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"))

    # 3. Test training of a SPINN model.
    trainer = spinn.train_or_infer_spinn(
        embed, word2index, train_data, dev_data, test_data, config)

    # 4. Load train loss values from the summary files and verify that they
    #    decrease with training.
    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
    events = summary_test_util.events_from_file(summary_file)
    train_losses = [event.summary.value[0].simple_value for event in events
                    if event.summary.value
                    and event.summary.value[0].tag == "train/loss"]
    self.assertEqual(config.epochs, len(train_losses))

    # 5. Verify that checkpoints exist and contains all the expected variables.
    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
    object_graph = checkpointable_utils.object_metadata(
        checkpoint_management.latest_checkpoint(config.logdir))
    ckpt_variable_names = set()
    for node in object_graph.nodes:
      for attribute in node.attributes:
        ckpt_variable_names.add(attribute.full_name)
    self.assertIn("global_step", ckpt_variable_names)
    for v in trainer.variables:
      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
      self.assertIn(variable_name, ckpt_variable_names)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:44,代码来源:spinn_test.py


示例20: after_save

  def after_save(self, session, global_step_value):
    """Evaluates and exports the model after a checkpoint is created."""
    # Load and cache the path of the most recent checkpoint to avoid duplicate
    # searches on GCS.
    logging.info("Checking for checkpoint in %s", self._model_dir)
    latest_path = checkpoint_management.latest_checkpoint(self._model_dir)

    if not latest_path:
      logging.warning("Skipping evaluation and export since model has not been "
                      "saved yet.")
    elif latest_path == self._latest_path:
      logging.warning("Skipping evaluation due to same latest checkpoint %s.",
                      latest_path)
    else:
      self._latest_path = latest_path
      self._eval_result = self._eval_fn(
          name="intermediate_export", checkpoint_path=latest_path)
      self._export_results = self._export_fn(
          self._eval_result, checkpoint_path=latest_path)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:19,代码来源:experiment.py



注:本文中的tensorflow.python.training.checkpoint_management.latest_checkpoint函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap