I'm confused about the distributed training process in tensorflow.
I think the tensorflow feed a batch_size of data to a worker and then the worker update the ps server,is this right?
But when training , I noticed that the step number in the log may strange.
If I have only 2 workers , I thinks the right process should be some thing like
[worker1] step 0 xxxxxxx
[worker2] step 100 xxxxxxx
[worker1] step 200 xxxxxxx
[worker2] step 300 xxxxxxx
.....
every worker should print different step to log.
Actually , the log are as below:
[worker1] step 0 xxxxxxx
[worker2] step 100 xxxxxxx
[worker1] step 100 xxxxxxx
[worker2] step 200 xxxxxxx
[worker1] step 300 xxxxxxx
...
Why the worker1 dosn't print step 200?
I am confused about the job assign.
How the tensorflow do the distribution training?
the chief worker split data to batch_size , then give a batch to a worker then update the ps server?
OR, every worker will run whole data ,and update the ps server ?
If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)
```
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
# Read TFRecords files for training
filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(FLAGS.train),
num_epochs=epoch_number)
serialized_example = read_and_decode(filename_queue)
batch_serialized_example = tf.train.shuffle_batch(
[serialized_example],
batch_size=batch_size,
num_threads=thread_number,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
features = tf.parse_example(
batch_serialized_example,
features={
"label": tf.FixedLenFeature([], tf.float32),
"ids": tf.VarLenFeature(tf.int64),
"values": tf.VarLenFeature(tf.float32),
})
batch_labels = features["label"]
batch_ids = features["ids"]
batch_values = features["values"]
# Read TFRecords file for validatioin
validate_filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(FLAGS.eval),
num_epochs=epoch_number)
validate_serialized_example = read_and_decode(validate_filename_queue)
validate_batch_serialized_example = tf.train.shuffle_batch(
[validate_serialized_example],
batch_size=validate_batch_size,
num_threads=thread_number,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
validate_features = tf.parse_example(
validate_batch_serialized_example,
features={
"label": tf.FixedLenFeature([], tf.float32),
"ids": tf.VarLenFeature(tf.int64),
"values": tf.VarLenFeature(tf.float32),
})
validate_batch_labels = features["label"]
validate_batch_ids = features["ids"]
validate_batch_values = features["values"]
logits = inference(batch_ids, batch_values)
batch_labels = tf.to_int64(batch_labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits,
batch_labels)
loss = tf.reduce_mean(cross_entropy, name='loss')
print("Use the optimizer: {}".format(FLAGS.optimizer))
optimizer = tf.train.FtrlOptimizer(learning_rate)
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
# Initialize saver and summary
steps_to_validate = FLAGS.steps_to_validate
init_op = tf.initialize_all_variables()
saver = tf.train.Saver(max_to_keep = 2)
keys_placeholder = tf.placeholder("float")
keys = tf.identity(keys_placeholder)
tf.add_to_collection("inputs", json.dumps({'key': keys_placeholder.name}))
tf.add_to_collection("outputs", json.dumps({'key': keys.name,
'softmax': inference_softmax.name,
'prediction': inference_op.name}))
summary_op = tf.merge_all_summaries()
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
logdir="./train_process/",
init_op=init_op,
summary_op=summary_op,
saver=saver,
global_step=global_step,
save_model_secs=60)
# Create session to run graph
with sv.managed_session(server.target) as sess:
while not sv.should_stop():
# Get coordinator and run queues to read data
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
start_time = datetime.datetime.now()
try:
while not coord.should_stop():
_, loss_value, step = sess.run([train_op, loss, global_step])
if step % steps_to_validate == 0:
accuracy_value, auc_value, summary_value = sess.run(
[accuracy, auc_op, summary_op])
end_time = datetime.datetime.now()
print("[{}] Task: {}, Step: {}, loss: {}, accuracy: {}, auc: {}".format(
end_time - start_time,
FLAGS.task_index,
step, loss_value, accuracy_value,
auc_value))
start_time = end_time
except tf.errors.OutOfRangeError:
print("Done training after reading all data")
finally:
coord.request_stop()
print("coord stopped")
# Wait for threads to exit
coord.join(threads)
```
Logs or other output that would be helpful
(If logs are large, please upload as attachment or provide link).
```
[0:00:17.115814] Task: 0, Step: 74600, loss: 0.303285002708, accuracy: 0.910000026226, auc: 0.946377456188
[0:00:03.804889] Task: 1, Step: 74700, loss: 0.287385582924, accuracy: 0.879999995232, auc: 0.946395516396
[0:00:03.778589] Task: 0, Step: 74800, loss: 0.247096762061, accuracy: 0.860000014305, auc: 0.946370542049
[0:00:03.772320] Task: 1, Step: 74900, loss: 0.264987647533, accuracy: 0.899999976158, auc: 0.946406364441
[0:00:03.795459] Task: 0, Step: 75000, loss: 0.228719010949, accuracy: 0.899999976158, auc: 0.946437120438
[0:00:01.902293] Task: 1, Step: 75000, loss: 0.217391207814, accuracy: 0.910000026226, auc: 0.946473121643
[0:00:01.942055] Task: 1, Step: 75100, loss: 0.284583866596, accuracy: 0.889999985695, auc: 0.946496844292
[0:00:03.860608] Task: 0, Step: 75200, loss: 0.273199081421, accuracy: 0.850000023842, auc: 0.946503221989
[0:00:03.800881] Task: 1, Step: 75300, loss: 0.189931258559, accuracy: 0.930000007153, auc: 0.946559965611
```
See Question&Answers more detail:
os