• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python mnist.training函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.examples.tutorials.mnist.mnist.training函数的典型用法代码示例。如果您正苦于以下问题:Python training函数的具体用法?Python training怎么用?Python training使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了training函数的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: run_training

def run_training():
    """Train MNIST for a number of steps."""

    # Tell TensorFlow that the model will be built into the default Graph.
    with tf.Graph().as_default():
        # Input images and labels.
        images, labels = inputs(train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs)

        # Build a Graph that computes predictions from the inference model.
        logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)

        # Add to the Graph the loss calculation.
        loss = mnist.loss(logits, labels)

        # Add to the Graph operations that train the model.
        train_op = mnist.training(loss, FLAGS.learning_rate)

        # The op for initializing the variables.
        init_op = tf.initialize_all_variables()

        # Create a session for running operations in the Graph.
        sess = tf.Session()

        # Initialize the variables (the trained variables and the
        # epoch counter).
        sess.run(init_op)

        # Start input enqueue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            step = 0
            while not coord.should_stop():
                start_time = time.time()

                # Run one step of the model.  The return values are
                # the activations from the `train_op` (which is
                # discarded) and the `loss` op.  To inspect the values
                # of your ops or variables, you may include them in
                # the list passed to sess.run() and the value tensors
                # will be returned in the tuple from the call.
                _, loss_value = sess.run([train_op, loss])

                duration = time.time() - start_time

                # Print an overview fairly often.
                if step % 100 == 0:
                    print("Step %d: loss = %.2f (%.3f sec)" % (step, loss_value, duration))
                step += 1
        except tf.errors.OutOfRangeError:
            print("Done training for %d epochs, %d steps." % (FLAGS.num_epochs, step))
        finally:
            # When done, ask the threads to stop.
            coord.request_stop()

        # Wait for threads to finish.
        coord.join(threads)
        sess.close()
开发者ID:MingxuanChen,项目名称:tensorflow,代码行数:59,代码来源:fully_connected_reader.py


示例2: run_training

def run_training():
  """Train MNIST for a number of steps."""

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Input images and labels.
    image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size,
                               num_epochs=FLAGS.num_epochs)

    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(image_batch,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # Add to the Graph the loss calculation.
    loss = mnist.loss(logits, label_batch)

    # Add to the Graph operations that train the model.
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # The op for initializing the variables.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # Create a session for running operations in the Graph.
    with tf.Session() as sess:
      # Initialize the variables (the trained variables and the
      # epoch counter).
      sess.run(init_op)
      try:
        step = 0
        while True: #train until OutOfRangeError
          start_time = time.time()

          # Run one step of the model.  The return values are
          # the activations from the `train_op` (which is
          # discarded) and the `loss` op.  To inspect the values
          # of your ops or variables, you may include them in
          # the list passed to sess.run() and the value tensors
          # will be returned in the tuple from the call.
          _, loss_value = sess.run([train_op, loss])

          duration = time.time() - start_time

          # Print an overview fairly often.
          if step % 100 == 0:
            print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
                                                     duration))
          step += 1
      except tf.errors.OutOfRangeError:
        print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:51,代码来源:fully_connected_reader.py


示例3: run_training

def run_training():
  # Get the sets of images and labels for training, validation, and
  # test on MNIST.
  data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    loss = mnist.loss(logits, labels_placeholder)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver(tf.all_variables())

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Run the Op to initialize the variables.
    init = tf.initialize_all_variables()
    sess.run(init)

    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      print('...no checkpoint found...')

    # Evaluate against the test set.
    print('Test Data Eval:')
    do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)
开发者ID:j-pong,项目名称:tensorflow_test,代码行数:44,代码来源:mnist_eval.py


示例4: run_training

def run_training():
  """Train MNIST for a number of steps."""
  # Get the sets of images and labels for training, validation, and
  # test on MNIST.
  data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    loss = mnist.loss(logits, labels_placeholder)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Run the Op to initialize the variables.
    init = tf.initialize_all_variables()
    sess.run(init)

    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                            graph_def=sess.graph_def)

    # And then after everything is built, start the training loop.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # Fill a feed dictionary with the actual set of images and labels
      # for this particular training step.
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # Run one step of the model.  The return values are the activations
      # from the `train_op` (which is discarded) and the `loss` Op.  To
      # inspect the values of your Ops or variables, you may include them
      # in the list passed to sess.run() and the value tensors will be
      # returned in the tuple from the call.
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      duration = time.time() - start_time

      # Write the summaries and print an overview fairly often.
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        summary_str = sess.run(summary_op, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)

      # Save a checkpoint and evaluate the model periodically.
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        saver.save(sess, FLAGS.train_dir, global_step=step)
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # Evaluate against the validation set.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # Evaluate against the test set.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)
开发者ID:PseudoAj,项目名称:MyTensorFlow,代码行数:95,代码来源:fully_connected_feed.py


示例5: run_training

def run_training():
  """Train MNIST for a number of epochs."""
  # Get the sets of images and labels for training, validation, and
  # test on MNIST.
  data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    with tf.name_scope('input'):
      # Input data, pin to CPU because rest of pipeline is CPU-only
      with tf.device('/cpu:0'):
        input_images = tf.constant(data_sets.train.images)
        input_labels = tf.constant(data_sets.train.labels)

      image, label = tf.train.slice_input_producer(
          [input_images, input_labels], num_epochs=FLAGS.num_epochs)
      label = tf.cast(label, tf.int32)
      images, labels = tf.train.batch(
          [image, label], batch_size=FLAGS.batch_size)

    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    loss = mnist.loss(logits, labels)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create the op for initializing variables.
    init_op = tf.initialize_all_variables()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Run the Op to initialize the variables.
    sess.run(init_op)

    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    # Start input enqueue threads.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # And then after everything is built, start the training loop.
    try:
      step = 0
      while not coord.should_stop():
        start_time = time.time()

        # Run one step of the model.
        _, loss_value = sess.run([train_op, loss])

        duration = time.time() - start_time

        # Write the summaries and print an overview fairly often.
        if step % 100 == 0:
          # Print status to stdout.
          print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
                                                     duration))
          # Update the events file.
          summary_str = sess.run(summary_op)
          summary_writer.add_summary(summary_str, step)
          step += 1

        # Save a checkpoint periodically.
        if (step + 1) % 1000 == 0:
          print('Saving')
          saver.save(sess, FLAGS.train_dir, global_step=step)

        step += 1
    except tf.errors.OutOfRangeError:
      print('Saving')
      saver.save(sess, FLAGS.train_dir, global_step=step)
      print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
    finally:
      # When done, ask the threads to stop.
      coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)
    sess.close()
开发者ID:0-T-0,项目名称:tensorflow,代码行数:92,代码来源:fully_connected_preloaded.py


示例6: run_training

def run_training():
  """Train MNIST for a number of steps."""
  # Get the sets of images and labels for training, validation, and
  # test on MNIST.
  train_dir = tempfile.mkdtemp()
  data_sets = input_data.read_data_sets(train_dir, FLAGS.fake_data)

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
    images_placeholder, labels_placeholder = placeholder_inputs()

    # Build a Graph that computes predictions from the inference model.
    logits, clustering_loss, kmeans_training_op = inference(images_placeholder,
                                                            FLAGS.num_clusters,
                                                            FLAGS.hidden1,
                                                            FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    loss = mnist.loss(logits, labels_placeholder)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = tf.group(mnist.training(loss, FLAGS.learning_rate),
                        kmeans_training_op)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # Add the variable initializer Op.
    init = tf.initialize_all_variables()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    feed_dict = fill_feed_dict(data_sets.train,
                               images_placeholder,
                               labels_placeholder,
                               batch_size=5000)
    # Run the Op to initialize the variables.
    sess.run(init, feed_dict=feed_dict)

    # Start the training loop.
    max_test_prec = 0
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # Fill a feed dictionary with the actual set of images and labels
      # for this particular training step.
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder,
                                 FLAGS.batch_size)

      # Run one step of the model.
      _, loss_value, clustering_loss_value = sess.run([train_op,
                                                       loss,
                                                       clustering_loss],
                                                      feed_dict=feed_dict)

      duration = time.time() - start_time
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f, clustering_loss = %.2f (%.3f sec)' % (
            step, loss_value, clustering_loss_value, duration))

      # Save a checkpoint and evaluate the model periodically.
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # Evaluate against the validation set.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # Evaluate against the test set.
        print('Test Data Eval:')
        test_prec = do_eval(sess,
                            eval_correct,
                            images_placeholder,
                            labels_placeholder,
                            data_sets.test)
        max_test_prec = max(max_test_prec, test_prec)
    return max_test_prec
开发者ID:2020zyc,项目名称:tensorflow,代码行数:90,代码来源:mnist.py


示例7: run_training

def run_training():
  """Train MNIST for a number of steps."""

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Input images and labels.
    images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
                            num_epochs=FLAGS.num_epochs)

    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(images,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # Add to the Graph the loss calculation.
    loss = mnist.loss(logits, labels)

    # Add to the Graph operations that train the model.
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # The op for initializing the variables.
    init_op = tf.group(tf.initialize_all_variables(),
                       tf.initialize_local_variables())

    # Create a session for running operations in the Graph.
    sess = tf.Session()

    # Initialize the variables (the trained variables and the
    # epoch counter).
    sess.run(init_op)

    # Start input enqueue threads.
    print("Queue runners: %s" %([qr.name for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)]))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    # waiting for queue to get loaded
    time.sleep(15)
    run_metadata = tf.RunMetadata()

    try:
      step = 0
      while not coord.should_stop():
        start_time = time.time()

        # Run one step of the model.  The return values are
        # the activations from the `train_op` (which is
        # discarded) and the `loss` op.  To inspect the values
        # of your ops or variables, you may include them in
        # the list passed to sess.run() and the value tensors
        # will be returned in the tuple from the call.
        if step == 500:
            _, loss_value = sess.run([train_op, loss],
                                     options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                                     run_metadata=run_metadata)
            with open("run_metadata.pbtxt", "w") as out:
              out.write(str(run_metadata))
              
            from tensorflow.python.client import timeline
            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            trace_file = open('timeline.reader-1thread.json', 'w')
            trace_file.write(trace.generate_chrome_trace_format())
        else:
            _, loss_value = sess.run([train_op, loss])

        duration = time.time() - start_time

        # Print an overview fairly often.
        if step % 100 == 0:
          print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
                                                     duration))
        step += 1
    except tf.errors.OutOfRangeError:
      print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
    finally:
      # When done, ask the threads to stop.
      coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)
    sess.close()
开发者ID:yaroslavvb,项目名称:stuff,代码行数:81,代码来源:fully_connected_reader.py


示例8:

                             'Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
                             'for unit testing.')

## Download data and unpack
## data_sets is a custom DataSet data type
data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)

## Initialize graph and start drawing on it
with tf.Graph().as_default():
    ## Prepare inputs and placeholders
    images_placeholder = tf.placeholder(tf.float32, shape=(FLAGS.batch_size,
                                                            mnist.IMAGE_PIXELS))
    labels_placeholder = tf.placeholder(tf.int32, shape=(FLAGS.batch_size))

    ## mnist.inference() builds feed-forward portion of graph
    ## It takes the images placeholder and two integers, each representing the
    ## number of neurons for the respective hidden layers and returns logits
    logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2)
    loss = mnist.loss(logits, labels_placeholder)
    train_op = mnist.training(loss, FLAGS.learning_rate)
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    ## Initialize variables, run session, and write summary writer data
    summary_op = tf.merge_all_summaries()
    init = tf.initialize_all_variables()
    sess = tf.Session()
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
    sess.run(init)
开发者ID:cannonja,项目名称:tensorflow,代码行数:30,代码来源:build_ffgraph.py


示例9: run_training

def run_training():
  """Train MNIST for a number of steps."""
  # Get the sets of images and labels for training, validation, and
  # test on MNIST. If input_path is specified, download the data from GCS to
  # the folder expected by read_data_sets.
  data_dir = tempfile.mkdtemp()
  if FLAGS.input_path:
    files = [os.path.join(FLAGS.input_path, file_name)
             for file_name in INPUT_FILES]
    subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files +
                          [data_dir])
  data_sets = input_data.read_data_sets(data_dir, FLAGS.fake_data)

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    loss = mnist.loss(logits, labels_placeholder)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # Build the summary operation based on the TF collection of Summaries.
    # Remove this if once Tensorflow 0.12 is standard.
    try:
      summary_op = tf.contrib.deprecated.merge_all_summaries()
    except AttributeError:
      summary_op = tf.merge_all_summaries()

    # Add the variable initializer Op.
    # Remove this if once Tensorflow 0.12 is standard.
    try:
      init = tf.global_variables_initializer()
    except AttributeError:
      init = tf.initialize_all_variables()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Instantiate a SummaryWriter to output summaries and the Graph.
    # Remove this if once Tensorflow 0.12 is standard.
    try:
      summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
    except AttributeError:
      summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    # And then after everything is built:

    # Run the Op to initialize the variables.
    sess.run(init)

    # Start the training loop.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # Fill a feed dictionary with the actual set of images and labels
      # for this particular training step.
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # Run one step of the model.  The return values are the activations
      # from the `train_op` (which is discarded) and the `loss` Op.  To
      # inspect the values of your Ops or variables, you may include them
      # in the list passed to sess.run() and the value tensors will be
      # returned in the tuple from the call.
      _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

      duration = time.time() - start_time

      # Write the summaries and print an overview fairly often.
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        summary_str = sess.run(summary_op, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      # Save a checkpoint and evaluate the model periodically.
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
        saver.save(sess, checkpoint_file, global_step=step)
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
#.........这里部分代码省略.........
开发者ID:cottrell,项目名称:notebooks,代码行数:101,代码来源:task.py


示例10: run_training

def run_training():
  """Train MNIST for a number of steps."""
  # Get the sets of images and labels for training, validation, and
  # test on MNIST.
  data_sets = input_data.read_data_sets(tempfile.mkdtemp(), FLAGS.fake_data)

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels and mark as input.
    placeholders = placeholder_inputs()
    keys_placeholder, images_placeholder, labels_placeholder = placeholders
    inputs = {'key': keys_placeholder.name, 'image': images_placeholder.name}
    tf.add_to_collection('inputs', json.dumps(inputs))

    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    loss = mnist.loss(logits, labels_placeholder)

    # To be able to extract the id, we need to add the identity function.
    keys = tf.identity(keys_placeholder)

    # The prediction will be the index in logits with the highest score.
    # We also use a softmax operation to produce a probability distribution
    # over all possible digits.
    prediction = tf.argmax(logits, 1)
    scores = tf.nn.softmax(logits)

    # Mark the outputs.
    outputs = {'key': keys.name,
               'prediction': prediction.name,
               'scores': scores.name}
    tf.add_to_collection('outputs', json.dumps(outputs))

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Add the variable initializer Op.
    init = tf.initialize_all_variables()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    # And then after everything is built:

    # Run the Op to initialize the variables.
    sess.run(init)

    # Start the training loop.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # Fill a feed dictionary with the actual set of images and labels
      # for this particular training step.
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # Run one step of the model.  The return values are the activations
      # from the `train_op` (which is discarded) and the `loss` Op.  To
      # inspect the values of your Ops or variables, you may include them
      # in the list passed to sess.run() and the value tensors will be
      # returned in the tuple from the call.
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      duration = time.time() - start_time

      # Write the summaries and print an overview fairly often.
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        summary_str = sess.run(summary_op, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      # Save a checkpoint and evaluate the model periodically.
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
        saver.save(sess, checkpoint_file, global_step=step)
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
#.........这里部分代码省略.........
开发者ID:obulpathi,项目名称:cloud,代码行数:101,代码来源:task.py


示例11: run_training

def run_training():
    with tf.Graph().as_default():
        # train data and run valid after each epoch, so nb_epochs=1
        images, labels = inputs(train=True, batch_size=cfg.FLAGS.batch_size, nb_epochs=cfg.FLAGS.nb_epochs)
        logits = mnist.inference(images, cfg.FLAGS.hidden1, cfg.FLAGS.hidden2)
        loss = mnist.loss(logits, labels)

        train_op = mnist.training(loss, cfg.FLAGS.learning_rate)

        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        sess = tf.Session()
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        data_sets = mnist_datasets.read_data_sets(cfg.FLAGS.train_dir,
                                                  dtype=tf.uint8,
                                                  reshape=False,
                                                  validation_size=cfg.FLAGS.validation_size)

        nb_train_samples = data_sets.train.num_examples
        # print('training samples: {}; batch_size: {}'.format(nb_train_samples, cfg.FLAGS.batch_size))
        # .. 55000 and 100

        # prepare validation data in terms of tf.constant
        image_valid_np = data_sets.validation.images.reshape((cfg.FLAGS.validation_size, mnist.IMAGE_PIXELS))
        label_valid_np = data_sets.validation.labels        # shape (5000,)
        # to fit the batch size
        idx_valid = np.random.choice(cfg.FLAGS.validation_size, cfg.FLAGS.batch_size, replace=False)
        image_valid_np = image_valid_np[idx_valid, :]
        image_valid_np = image_valid_np * (1. / 255) - 0.5      # remember to preprocessing
        label_valid_np = label_valid_np[idx_valid]

        step = 0
        epoch_idx = 0
        try:
            start_time = time.time()
            while not coord.should_stop():
                _, loss_value = sess.run([train_op, loss])
                step += 1
                if step >= nb_train_samples // cfg.FLAGS.batch_size:
                    epoch_idx += 1
                    end_time = time.time()
                    duration = end_time - start_time
                    print('Training Epoch {}, Step {}: loss = {:.02f} ({:.03f} sec)'
                          .format(epoch_idx, step, loss_value, duration))
                    start_time = end_time   # re-timing
                    step = 0                # reset step counter
                    # derive loss on validation dataset
                    loss_valid_value = sess.run(loss, feed_dict={images: image_valid_np, labels: label_valid_np})
                    print('Validation Epoch {}: loss = {:.02f}'
                          .format(epoch_idx, loss_valid_value))
        except tf.errors.OutOfRangeError:
            print('Done training for epoch {}, {} steps'.format(epoch_idx, step))
        finally:
            coord.request_stop()



        # # restart runner for validation data
        # coord = tf.train.Coordinator()
        # threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        #
        # step = 0
        # try:
        #     start_time = time.time()
        #     while not coord.should_stop():
        #         loss_value_valid = sess.run(loss_valid)
        #         step += 1
        # except tf.errors.OutOfRangeError:
        #     print('Done validation for epoch {}, {} steps'.format(epoch_idx, step))
        # finally:
        #     coord.request_stop()
        #     duration = time.time() - start_time
        #     print('Validation: Epoch {}, Step {}: loss = {:.02f} ({:.03f} sec)'
        #           .format(epoch_idx, step, loss_value_valid, duration))

        coord.join(threads)
        sess.close()
开发者ID:jamescfli,项目名称:PythonTest,代码行数:81,代码来源:read_tfrecords.py



注:本文中的tensorflow.examples.tutorials.mnist.mnist.training函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python interpreter.Interpreter类代码示例发布时间:2022-05-27
下一篇:
Python mnist.loss函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap