Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
996 views
in Technique[技术] by (71.8m points)

tensorflow - parallelising tf.data.Dataset.from_generator

I have a non trivial input pipeline that from_generator is perfect for...

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
                                        (tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

Where complex_img_label_generator dynamically generates images and returns a numpy array representing a (H, W, 3) image and a simple string label. The processing not something I can represent as reading from files and tf.image operations.

My question is about how to parallise the generator? How do I have N of these generators running in their own threads.

One thought was to use dataset.map with num_parallel_calls to handle the threading; but the map operates on tensors... Another thought was to create multiple generators each with it's own prefetch and somehow join them, but I can't see how I'd join N generator streams?

Any canonical examples I could follow?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Turns out I can use Dataset.map if I make the generator super lightweight (only generating meta data) and then move the actual heavy lighting into a stateless function. This way I can parallelise just the heavy lifting part with .map using a py_func.

Works; but feels a tad clumsy... Would be great to be able to just add num_parallel_calls to from_generator :)

def pure_numpy_and_pil_complex_calculation(metadata, label):
  # some complex pil and numpy work nothing to do with tf
  ...

dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                         output_types=(tf.string,   # metadata
                                                       tf.string))  # label

def wrapped_complex_calulation(metadata, label):
  return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                    inp = (metadata, label),
                    Tout = (tf.uint8,    # (H,W,3) img
                            tf.string))  # label
dataset = dataset.map(wrapped_complex_calulation,
                      num_parallel_calls=8)

dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...