Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
283 views
in Technique[技术] by (71.8m points)

python - How to shuffle batches with ImageDataGenerator?

I'm using ImageDataGenerator with flow_from_dataframe to load a dataset.

Using flow_from_dataframe with shuffle=True shuffles the images in the dataset.

I want to shuffle the batches. If I have 12 images and batch_size=3, then I have 4 batches:

batch1 = [image1, image2, image3]
batch2 = [image4, image5, image6]
batch3 = [image7, image8, image9]
batch4 = [image10, image11, image12]

I want to shuffle the batches without shuffling the images in each batch, so that I get for example:

batch2 = [image4, image5, image6]
batch1 = [image1, image2, image3]
batch4 = [image10, image11, image12]
batch3 = [image7, image8, image9]

Is that possible with ImageDataGenerator and flow_from_dataframe? Is there a preprocessing function I can use?

question from:https://stackoverflow.com/questions/65942274/how-to-shuffle-batches-with-imagedatagenerator

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

Consider using the tf.data.Dataset API. You can perform the batching operation before the shuffling.

import tensorflow as tf

file_names = [f'image_{i}' for i in range(1, 10)]

ds = tf.data.Dataset.from_tensor_slices(file_names).batch(3).shuffle(3)

for _ in range(3):
    for batch in ds:
        print(batch.numpy())
    print()
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']
[b'image_1' b'image_2' b'image_3']

[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']

[b'image_1' b'image_2' b'image_3']
[b'image_4' b'image_5' b'image_6']
[b'image_7' b'image_8' b'image_9']

Then, you can use a mapping operation to load the images from the file names:

def read_image(file_name):
  image = tf.io.read_file(file_name)
  image = tf.image.decode_image(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
  label = tf.strings.split(file_path, os.sep)[0]
  label = tf.cast(tf.equal(label, class_categories), tf.int32)
  return image, label

ds = ds.map(read_image)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...