Turns out I can use Dataset.map
if I make the generator super lightweight (only generating meta data) and then move the actual heavy lighting into a stateless function. This way I can parallelise just the heavy lifting part with .map
using a py_func
.
Works; but feels a tad clumsy... Would be great to be able to just add num_parallel_calls
to from_generator
:)
def pure_numpy_and_pil_complex_calculation(metadata, label):
# some complex pil and numpy work nothing to do with tf
...
dataset = tf.data.Dataset.from_generator(lightweight_generator,
output_types=(tf.string, # metadata
tf.string)) # label
def wrapped_complex_calulation(metadata, label):
return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
inp = (metadata, label),
Tout = (tf.uint8, # (H,W,3) img
tf.string)) # label
dataset = dataset.map(wrapped_complex_calulation,
num_parallel_calls=8)
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…