• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python variables.local_variables_initializer函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.variables.local_variables_initializer函数的典型用法代码示例。如果您正苦于以下问题:Python local_variables_initializer函数的具体用法?Python local_variables_initializer怎么用?Python local_variables_initializer使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了local_variables_initializer函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _get_train_op_and_ensemble

  def _get_train_op_and_ensemble(self, head, config, is_classification,
                                 train_in_memory):
    """Calls bt_model_fn() and returns the train_op and ensemble_serialzed."""
    features, labels = _make_train_input_fn(is_classification)()
    estimator_spec = boosted_trees._bt_model_fn(  # pylint:disable=protected-access
        features=features,
        labels=labels,
        mode=model_fn.ModeKeys.TRAIN,
        head=head,
        feature_columns=self._feature_columns,
        tree_hparams=self._tree_hparams,
        example_id_column_name=EXAMPLE_ID_COLUMN,
        n_batches_per_layer=1,
        config=config,
        train_in_memory=train_in_memory)
    resources.initialize_resources(resources.shared_resources()).run()
    variables.global_variables_initializer().run()
    variables.local_variables_initializer().run()

    # Gets the train_op and serialized proto of the ensemble.
    shared_resources = resources.shared_resources()
    self.assertEqual(1, len(shared_resources))
    train_op = estimator_spec.train_op
    with ops.control_dependencies([train_op]):
      _, ensemble_serialized = (
          gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
              shared_resources[0].handle))
    return train_op, ensemble_serialized
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:28,代码来源:boosted_trees_test.py


示例2: test_sync_replicas

  def test_sync_replicas(self, create_gan_model_fn, create_global_step):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    num_trainable_vars = len(variables_lib.get_trainable_variables())

    if create_global_step:
      gstep = variable_scope.get_variable(
          'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False)
      ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
    self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
    # No new trainable variables should have been added.
    self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars)

    # Sync hooks should be populated in the GANTrainOps.
    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(
          hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
    sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
    d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

    # Check that update op is run properly.
    global_step = training_util.get_or_create_global_step()
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      variables.local_variables_initializer().run()

      g_opt.chief_init_op.run()
      d_opt.chief_init_op.run()

      gstep_before = global_step.eval()

      # Start required queue runner for SyncReplicasOptimizer.
      coord = coordinator.Coordinator()
      g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
      d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)

      g_sync_init_op.run()
      d_sync_init_op.run()

      train_ops.generator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      train_ops.discriminator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      coord.request_stop()
      coord.join(g_threads + d_threads)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:58,代码来源:train_test.py


示例3: testAccuracy

 def testAccuracy(self):
   predictions = constant_op.constant([0, 1, 3, 6, 5, 2, 7, 6, 4, 9])
   targets = constant_op.constant([0, 1, 4, 6, 5, 1, 7, 5, 4, 8])
   accuracy_op, update_op = eval_metrics._accuracy(predictions, targets)
   with self.test_session():
     variables.local_variables_initializer().run()
     # need to call in order to run the accuracy_op internal operations because
     # it is a streaming function
     update_op.eval()
     self.assertNear(0.6, accuracy_op.eval(), 0.0001)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:10,代码来源:eval_metrics_test.py


示例4: test_empty_labels_and_scores_gives_nan_auc

 def test_empty_labels_and_scores_gives_nan_auc(self):
   with self.cached_session():
     labels = constant_op.constant([], shape=[0], dtype=dtypes.bool)
     scores = constant_op.constant([], shape=[0], dtype=dtypes.float32)
     score_range = [0, 1.]
     auc, update_op = histogram_ops.auc_using_histogram(labels, scores,
                                                        score_range)
     variables.local_variables_initializer().run()
     update_op.run()
     self.assertTrue(np.isnan(auc.eval()))
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:10,代码来源:histogram_ops_test.py


示例5: _check_auc

  def _check_auc(self,
                 nbins=100,
                 desired_auc=0.75,
                 score_range=None,
                 num_records=50,
                 frac_true=0.5,
                 atol=0.05,
                 num_updates=10):
    """Check auc accuracy against synthetic data.

    Args:
      nbins:  nbins arg from contrib.metrics.auc_using_histogram.
      desired_auc:  Number in [0, 1].  The desired auc for synthetic data.
      score_range:  2-tuple, (low, high), giving the range of the resultant
        scores.  Defaults to [0, 1.].
      num_records:  Positive integer.  The number of records to return.
      frac_true:  Number in (0, 1).  Expected fraction of resultant labels that
        will be True.  This is just in expectation...more or less may actually
        be True.
      atol:  Absolute tolerance for final AUC estimate.
      num_updates:  Update internal histograms this many times, each with a new
        batch of synthetic data, before computing final AUC.

    Raises:
      AssertionError: If resultant AUC is not within atol of theoretical AUC
        from synthetic data.
    """
    score_range = [0, 1.] or score_range
    with self.cached_session():
      labels = array_ops.placeholder(dtypes.bool, shape=[num_records])
      scores = array_ops.placeholder(dtypes.float32, shape=[num_records])
      auc, update_op = histogram_ops.auc_using_histogram(
          labels, scores, score_range, nbins=nbins)
      variables.local_variables_initializer().run()
      # Updates, then extract auc.
      for _ in range(num_updates):
        labels_a, scores_a = synthetic_data(desired_auc, score_range,
                                            num_records, self.rng, frac_true)
        update_op.run(feed_dict={labels: labels_a, scores: scores_a})
      labels_a, scores_a = synthetic_data(desired_auc, score_range, num_records,
                                          self.rng, frac_true)
      # Fetch current auc, and verify that fetching again doesn't change it.
      auc_eval = auc.eval()
      self.assertAlmostEqual(auc_eval, auc.eval(), places=5)

    msg = ('nbins: %s, desired_auc: %s, score_range: %s, '
           'num_records: %s, frac_true: %s, num_updates: %s') % (nbins,
                                                                 desired_auc,
                                                                 score_range,
                                                                 num_records,
                                                                 frac_true,
                                                                 num_updates)
    np.testing.assert_allclose(desired_auc, auc_eval, atol=atol, err_msg=msg)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:53,代码来源:histogram_ops_test.py


示例6: testMetricsCollection

  def testMetricsCollection(self):

    def _enqueue_vector(sess, queue, values, shape=None):
      if not shape:
        shape = (1, len(values))
      dtype = queue.dtypes[0]
      sess.run(
          queue.enqueue(constant_op.constant(
              values, dtype=dtype, shape=shape)))

    meta_graph_filename = os.path.join(
        _TestDir("metrics_export"), "meta_graph.pb")

    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      values_queue = data_flow_ops.FIFOQueue(
          4, dtypes.float32, shapes=(1, 2))
      _enqueue_vector(sess, values_queue, [0, 1])
      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
      _enqueue_vector(sess, values_queue, [6.5, 0])
      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
      values = values_queue.dequeue()

      _, update_op = metrics.mean(values)

      initializer = variables.local_variables_initializer()
      self.evaluate(initializer)
      self.evaluate(update_op)

    meta_graph.export_scoped_meta_graph(
        filename=meta_graph_filename, graph=graph)

    # Verifies that importing a meta_graph with LOCAL_VARIABLES collection
    # works correctly.
    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      meta_graph.import_scoped_meta_graph(meta_graph_filename)
      initializer = variables.local_variables_initializer()
      self.evaluate(initializer)

    # Verifies that importing an old meta_graph where "local_variables"
    # collection is of node_list type works, but cannot build initializer
    # with the collection.
    graph = ops.Graph()
    with self.session(graph=graph) as sess:
      meta_graph.import_scoped_meta_graph(
          test.test_src_dir_path(
              "python/framework/testdata/metrics_export_meta_graph.pb"))
      self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)),
                       2)
      with self.assertRaisesRegexp(
          AttributeError, "'Tensor' object has no attribute 'initializer'"):
        initializer = variables.local_variables_initializer()
开发者ID:aeverall,项目名称:tensorflow,代码行数:53,代码来源:meta_graph_test.py


示例7: testTop2

 def testTop2(self):
   top_2_fn = eval_metrics._top_k_generator(2)
   probabilities = constant_op.constant([[0.1, 0.2, 0.3], [0.4, 0.7, 0.5],
                                         [0.9, 0.8, 0.2], [0.6, 0.4, 0.8]])
   targets = constant_op.constant([[0], [2], [1], [1]])
   in_top_2_op, update_op = top_2_fn(probabilities, targets)
   with self.test_session():
     # initializes internal accuracy vars
     variables.local_variables_initializer().run()
     # need to call in order to run the in_top_2_op internal operations because
     # it is a streaming function
     update_op.eval()
     self.assertNear(0.5, in_top_2_op.eval(), 0.0001)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:13,代码来源:eval_metrics_test.py


示例8: testR2

 def testR2(self):
   scores = constant_op.constant(
       [1.2, 3.9, 2.1, 0.9, 2.2, 0.1, 6.0, 4.0, 0.9])
   targets = constant_op.constant(
       [1.0, 4.3, 2.6, 0.5, 1.1, 0.7, 5.1, 3.4, 1.8])
   r2_op, update_op = eval_metrics._r2(scores, targets)
   with self.test_session():
     # initializes internal accuracy vars
     variables.local_variables_initializer().run()
     # need to call in order to run the r2_op internal operations because
     # it is a streaming function
     update_op.eval()
     self.assertNear(0.813583, r2_op.eval(), 0.0001)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:13,代码来源:eval_metrics_test.py


示例9: begin

 def begin(self):
   self._local_init_op = variables.local_variables_initializer()
   self._global_init_op = None
   if self._is_chief:
     self._global_init_op = variables.global_variables_initializer()
     self._chief_init_op = self._ma_optimizer._chief_init_op  # pylint: disable=protected-access
   self._variable_init_op = self._ma_optimizer.get_init_op()
开发者ID:jinxin0924,项目名称:tensorflow,代码行数:7,代码来源:model_average_optimizer.py


示例10: export

  def export(self, last_checkpoint, output_dir):
    """Builds a prediction graph and xports the model.

    Args:
      last_checkpoint: Path to the latest checkpoint file from training.
      output_dir: Path to the folder to be used to output the model.
    """
    logging.info('Exporting prediction graph to %s', output_dir)
    with tf.Session(graph=tf.Graph()) as sess:
      # Build and save prediction meta graph and trained variable values.
      inputs, outputs = self.build_prediction_graph()
      signature_def_map = {
        'serving_default': signature_def_utils.predict_signature_def(inputs, outputs)
      }
      init_op = tf.global_variables_initializer()
      sess.run(init_op)
      self.restore_from_checkpoint(sess, self.inception_checkpoint_file,
                                   last_checkpoint)
      init_op_serving = control_flow_ops.group(
          variables.local_variables_initializer(),
          tf.tables_initializer())

      builder = saved_model_builder.SavedModelBuilder(output_dir)
      builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING],
          signature_def_map=signature_def_map,
          legacy_init_op=init_op_serving)
      builder.save(False)
开发者ID:googledatalab,项目名称:pydatalab,代码行数:28,代码来源:_model.py


示例11: _test_metric

  def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
    with ops.Graph().as_default(), distribution.scope():
      iterator = distribution.distribute_dataset(
          dataset_fn).make_one_shot_iterator()
      value, update = distribution.call_for_each_tower(
          metric_fn, iterator.get_next())
      update = distribution.group(update)
      self.evaluate(variables.local_variables_initializer())
      # TODO(josh11b): Once we switch to using a global batch size for input,
      # replace "distribution.num_towers" with "1".
      batches_per_update = distribution.num_towers

      # Update variables using the first `num_towers` batches.
      self.evaluate(update)
      self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
                          0.001, msg="After first update")

      # Update variables using the second `num_towers` batches.
      self.evaluate(update)
      self.assertAllClose(expected_fn(2 * batches_per_update),
                          self.evaluate(value),
                          0.001,
                          msg="After second update")

      if batches_per_update == 1:  # Consume 4 input batches
        self.evaluate(update)
        self.assertAllClose(expected_fn(3 * batches_per_update),
                            self.evaluate(value),
                            0.001,
                            msg="After third update")
        self.evaluate(update)
        self.assertAllClose(expected_fn(4 * batches_per_update),
                            self.evaluate(value),
                            0.001,
                            msg="After fourth update")
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:35,代码来源:metrics_v1_test.py


示例12: _random_window_input_fn_test_template

 def _random_window_input_fn_test_template(
     self, time_series_reader, window_size, batch_size, num_features,
     discard_out_of_order=False):
   input_fn = input_pipeline.RandomWindowInputFn(
       time_series_reader=time_series_reader,
       window_size=window_size, batch_size=batch_size)
   result, _ = input_fn()
   init_op = variables.local_variables_initializer()
   with self.cached_session() as session:
     coordinator = coordinator_lib.Coordinator()
     queue_runner_impl.start_queue_runners(session, coord=coordinator)
     session.run(init_op)
     features = session.run(result)
     coordinator.request_stop()
     coordinator.join()
   self.assertAllEqual([batch_size, window_size],
                       features[TrainEvalFeatures.TIMES].shape)
   for window_position in range(window_size - 1):
     for batch_position in range(batch_size):
       # Checks that all times are contiguous
       self.assertEqual(
           features[TrainEvalFeatures.TIMES][batch_position,
                                             window_position + 1],
           features[TrainEvalFeatures.TIMES][batch_position,
                                             window_position] + 1)
   self.assertAllEqual([batch_size, window_size, num_features],
                       features[TrainEvalFeatures.VALUES].shape)
   self.assertEqual("int64", features[TrainEvalFeatures.TIMES].dtype)
   for feature_number in range(num_features):
     self.assertAllEqual(
         features[TrainEvalFeatures.TIMES] * 2. + feature_number,
         features[TrainEvalFeatures.VALUES][:, :, feature_number])
   return features
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:33,代码来源:input_pipeline_test.py


示例13: _all_window_input_fn_test_template

 def _all_window_input_fn_test_template(
     self, time_series_reader, num_samples, window_size,
     original_numpy_features=None):
   input_fn = test_utils.AllWindowInputFn(
       time_series_reader=time_series_reader,
       window_size=window_size)
   features, _ = input_fn()
   init_op = variables.local_variables_initializer()
   with self.cached_session() as session:
     coordinator = coordinator_lib.Coordinator()
     queue_runner_impl.start_queue_runners(session, coord=coordinator)
     session.run(init_op)
     chunked_times, chunked_values = session.run(
         [features[TrainEvalFeatures.TIMES],
          features[TrainEvalFeatures.VALUES]])
     coordinator.request_stop()
     coordinator.join()
   self.assertAllEqual([num_samples - window_size + 1, window_size],
                       chunked_times.shape)
   if original_numpy_features is not None:
     original_times = original_numpy_features[TrainEvalFeatures.TIMES]
     original_values = original_numpy_features[TrainEvalFeatures.VALUES]
     self.assertAllEqual(original_times, numpy.unique(chunked_times))
     self.assertAllEqual(original_values[chunked_times],
                         chunked_values)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:25,代码来源:input_pipeline_test.py


示例14: testWhileLoop

  def testWhileLoop(self):
    with self.cached_session():
      r_ = rate.Rate()

      def body(value, denom, i, ret_rate):
        i += 1
        ret_rate = r_(value, denom)
        with ops.control_dependencies([ret_rate]):
          value = math_ops.add(value, 2)
          denom = math_ops.add(denom, 1)
        return [value, denom, i, ret_rate]

      def condition(v, d, i, r):
        del v, d, r  # unused vars by condition
        return math_ops.less(i, 100)

      i = constant_op.constant(0)
      value = constant_op.constant([1], dtype=dtypes.float64)
      denom = constant_op.constant([1], dtype=dtypes.float64)
      ret_rate = r_(value, denom)
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(variables.local_variables_initializer())
      loop = control_flow_ops.while_loop(condition, body,
                                         [value, denom, i, ret_rate])
      self.assertEqual([[2]], self.evaluate(loop[3]))
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:25,代码来源:rate_test.py


示例15: testLargeCase

  def testLargeCase(self):
    shape = [32, 512, 256, 1]
    predictions = random_ops.random_uniform(
        shape, 0.0, 1.0, dtype=dtypes_lib.float32)
    labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5)

    result, update_op = metric_ops.precision_recall_at_equal_thresholds(
        labels=labels, predictions=predictions, num_thresholds=201)
    # Run many updates, enough to cause highly inaccurate values if the
    # code used float32 for accumulation.
    num_updates = 71

    with self.test_session() as sess:
      sess.run(variables.local_variables_initializer())
      for _ in xrange(num_updates):
        sess.run(update_op)

      prdata = sess.run(result)

      # Since we use random values, we won't know the tp/fp/tn/fn values, but
      # tp and fp at threshold 0 should be the total number of positive and
      # negative labels, hence their sum should be total number of pixels.
      expected_value = 1.0 * np.product(shape) * num_updates
      got_value = prdata.tp[0] + prdata.fp[0]
      # They should be at least within 1.
      self.assertNear(got_value, expected_value, 1.0)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:26,代码来源:metric_ops_large_test.py


示例16: test_batch_text_lines

  def test_batch_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("A\nB\nC\nD\nE\n")

    batch_size = 3
    queue_capacity = 10
    name = "my_batch"

    with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = graph_io.read_batch_examples(
          [filename],
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          queue_capacity=queue_capacity,
          read_batch_size=10,
          name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
      self.assertAllEqual(session.run(inputs), [b"D", b"E"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
开发者ID:willdzeng,项目名称:tensorflow,代码行数:31,代码来源:graph_io_test.py


示例17: testFinalOpsOnEvaluationLoop

  def testFinalOpsOnEvaluationLoop(self):
    value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
                                                        self._labels)
    init_op = control_flow_ops.group(variables.global_variables_initializer(),
                                     variables.local_variables_initializer())
    # Create Checkpoint and log directories
    chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/')
    gfile.MakeDirs(chkpt_dir)
    logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
    gfile.MakeDirs(logdir)

    # Save initialized variables to checkpoint directory
    saver = saver_lib.Saver()
    with self.test_session() as sess:
      init_op.run()
      saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))

    # Now, run the evaluation loop:
    accuracy_value = evaluation.evaluation_loop(
        '',
        chkpt_dir,
        logdir,
        eval_op=update_op,
        final_op=value_op,
        max_number_of_evaluations=1)
    self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:evaluation_test.py


示例18: test_keyed_features_filter

  def test_keyed_features_filter(self):
    gfile.Glob = self._orig_glob
    lines = [
        '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}',
        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}',
        '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}',
        '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}',
        '{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}',
        '{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}'
    ]
    filename = self._create_temp_file("\n".join(lines))

    batch_size = 2
    queue_capacity = 4
    name = "my_batch"
    features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)}

    def filter_fn(keys, examples_json):
      del keys
      serialized = parsing_ops.decode_json_example(examples_json)
      examples = parsing_ops.parse_example(serialized, features)
      return math_ops.less(examples["age"], 2)

    with ops.Graph().as_default() as g, self.session(graph=g) as session:
      keys, inputs = graph_io._read_keyed_batch_examples_helper(
          filename,
          batch_size,
          reader=io_ops.TextLineReader,
          randomize_input=False,
          num_epochs=1,
          read_batch_size=batch_size,
          queue_capacity=queue_capacity,
          filter_fn=filter_fn,
          name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(variables.local_variables_initializer())

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(session, coord=coord)
      # First batch of two filtered examples.
      out_keys, out_vals = session.run((keys, inputs))
      self.assertAllEqual(
          [filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"],
          out_keys)
      self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")],
                          out_vals)

      # Second batch will only have one filtered example as that's the only
      # remaining example that satisfies the filtering criterion.
      out_keys, out_vals = session.run((keys, inputs))
      self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys)
      self.assertAllEqual([lines[3].encode("utf-8")], out_vals)

      # Exhausted input.
      with self.assertRaises(errors.OutOfRangeError):
        session.run((keys, inputs))

      coord.request_stop()
      coord.join(threads)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:60,代码来源:graph_io_test.py


示例19: test_example

  def test_example(self):
    with self.test_session() as session:
      tower_losses = map(self.create_constant_loss, [2, 4, 6])
      tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
      tower_specs = [
          self.create_estimator_spec(l, m)
          for l, m in zip(tower_losses, tower_metrics)
      ]
      session.run(variables.local_variables_initializer())

      estimator_spec = replicate_model_fn._eval_spec(
          tower_specs, aggregation_device='/device:GPU:0')

      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
      auc, b = estimator_spec.eval_metric_ops['auc']

      self.assertEqual('/device:CPU:0', accuracy.device)
      self.assertEqual('/device:CPU:0', auc.device)

      session.run([a, b])
      accuracy, auc = session.run([accuracy, auc])

      self.assertNear((12 - 2) / 12, accuracy, 0.01)
      self.assertEqual(0, auc)
      self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:replicate_model_fn_test.py


示例20: test_handles_single_tower

  def test_handles_single_tower(self):
    with self.test_session() as session:
      tower_losses = map(self.create_constant_loss, [5])
      tower_metrics = map(self.create_eval_metrics, [0.2])
      tower_specs = [
          self.create_estimator_spec(l, m)
          for l, m in zip(tower_losses, tower_metrics)
      ]
      session.run(variables.local_variables_initializer())

      estimator_spec = replicate_model_fn._eval_spec(
          tower_specs, aggregation_device='/device:GPU:0')

      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
      auc, b = estimator_spec.eval_metric_ops['auc']

      self.assertEqual('/device:CPU:0', accuracy.device)
      self.assertEqual('/device:CPU:0', auc.device)

      session.run([a, b])
      accuracy = session.run(accuracy)
      auc = session.run(auc)

      self.assertNear((4 - 1) / 4, accuracy, 0.01)
      self.assertEqual(0, auc)
      self.assertEqual(5, session.run(estimator_spec.loss))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:26,代码来源:replicate_model_fn_test.py



注:本文中的tensorflow.python.ops.variables.local_variables_initializer函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap