Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.2k views
in Technique[技术] by (71.8m points)

python - Load tensorflow images and create patches

I am using image_dataset_from_directory to load a very large RGB imagery dataset from disk into a Dataset. For example,

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,
    label_mode=None,
    seed=1,
    subset='training',
    validation_split=0.1)

The Dataset has, say, 100000 images grouped into batches of size 32 yielding a tf.data.Dataset with spec (batch=32, width=256, height=256, channels=3)

I would like to extract patches from the images to create a new tf.data.Dataset with image spatial dimensions of, say, 64x64.

Therefore, I would like to create a new Dataset with 400000 patches still in batches of 32 with a tf.data.Dataset with spec (batch=32, width=64, height=64, channels=3)

I've looked at the window method and the extract_patches function but it's not clear from the documentation how to use them to create a new Dataset I need to start training on the patches. The window seems to be geared toward 1D tensors and the extract_patches seems to work with arrays and not with Datasets.

Any suggestions on how to accomplish this?

UPDATE:

Just to clarify my needs. I am trying to avoid manually creating the patches on disk. One, that would be untenable disk wise. Two, the patch size is not fixed. The experiments will be conducted over several patch sizes. So, I do not want to manually perform the patch creation either on disk or manually load the images in memory and perform the patching. I would prefer to have tensorflow handle the patch creation as part of the pipeline workflow to minimize disk and memory usage.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

What you're looking for is tf.image.extract_patches. Here's an example:

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

data = tfds.load('mnist', split='test', as_supervised=True)

get_patches = lambda x, y: (tf.reshape(
    tf.image.extract_patches(
        images=tf.expand_dims(x, 0),
        sizes=[1, 14, 14, 1],
        strides=[1, 14, 14, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'), (4, 14, 14, 1)), y)

data = data.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(data))
for index, image in enumerate(images):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()

enter image description here


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...