TLDR;
Default DataLoader
only uses a sampler, not a batch sampler.
You can define a sampler, plus a batch sampler, a batch sampler will override the sampler.
The sampler only yields the sequence of dataset element, not the actual batches (this is handled by the data loader, depending on batch_size
).
To answer your initial question: Working with a sampler on an iterable dataset doesn't seem to be possible cf. Github issue (still open). Also, read the following note on pytorch/dataloader.py
.
Samplers (for map-style datasets):
That aside, if you are switching to a map-style dataset, here are some details on samplers and batch samplers work. You have access to a dataset's underlying data using indices, just like you would with a list (since torch.utils.data.Dataset
implements __getitem__
). In another word, your dataset elements are all dataset[i]
, for i
in [0, len(dataset) - 1]
.
Here is a toy dataset:
class DS(Dataset):
def __getitem__(self, index):
return index
def __len__(self):
return 10
In a general use case you would just give torch.utils.data.DataLoader
the arguments batch_size
and shuffle
. By default, shuffle
is set to false
, which means it will use torch.utils.data.SequentialSampler
. Else (if shuffle
is true
) torch.utils.data.RandomSampler
will be used. The sampler defines how the data loader accesses the dataset (in which order it accesses it).
The above dataset (DS
) has 10 elements. The indices are 0
, 1
, 2
, 3
, 4
, 5
, 6
, 7
, 8
, and 9
. They map to elements 0
, 10
, 20
, 30
, 40
, 50
, 60
, 70
, 80
, and 90
. So with a batch size of 2:
SequentialSampler
: DataLoader(ds, batch_size=2)
(implictly shuffle=False
), identical to DataLoader(ds, batch_size=2, sampler=SequentialSampler(ds))
. The dataloader will deliver tensor([0, 10])
, tensor([20, 30])
, tensor([40, 50])
, tensor([60, 70])
, and tensor([80, 90])
.
RandomSampler
: DataLoader(ds, batch_size=2, shuffle=True)
, identical to DataLoader(ds, batch_size=2, sampler=RandomSampler(ds))
. The dataloader will sample randomly each time you iterate through it. For instance: tensor([50, 40])
, tensor([90, 80])
, tensor([0, 60])
, tensor([10, 20])
, and tensor([30, 70])
. But the sequence will be different if you iterate through the dataloader a second time!
Batch sampler
Providing batch_sampler
will override batch_size
, shuffle
, sampler
, and drop_last
altogether. It is meant to define exactly the batch elements and their content. For instance:
>>> DataLoader(ds, batch_sampler=[[1,2,3], [6,5,4], [7,8], [0,9]])`
Will yield tensor([10, 20, 30])
, tensor([60, 50, 40])
, tensor([70, 80])
, and tensor([ 0, 90])
.
Batch sampling on the class
Let's say I just want to have 2 elements (different or not) of each class in my batch and have to exclude more examples of each class. So ensuring that not 3 examples are inside of the batch.
Let's say you have a dataset with four classes. Here is how I would do it. First, keep track of dataset indices for each class.
class DS(Dataset):
def __init__(self, data):
super(DS, self).__init__()
self.data = data
self.indices = [[] for _ in range(4)]
for i, x in enumerate(data):
if x > 0 and x % 2: self.indices[0].append(i)
if x > 0 and not x % 2: self.indices[1].append(i)
if x < 0 and x % 2: self.indices[2].append(i)
if x < 0 and not x % 2: self.indices[3].append(i)
def classes(self):
return self.indices
def __getitem__(self, index):
return self.data[index]
For example:
>>> ds = DS([1, 6, 7, -5, 10, -6, 8, 6, 1, -3, 9, -21, -13, 11, -2, -4, -21, 4])
Will give:
>>> ds.classes()
[[0, 2, 8, 10, 13], [1, 4, 6, 7, 17], [3, 9, 11, 12, 16], [5, 14, 15]]
Then for the batch sampler, the easiest way is to create a list of class indices that are available, and have as many class index as there are dataset element.
In the dataset defined above, we have 5 items from class 0
, 5 from class 1
, 5 from class 2
, and 3 from class 3
. Therefore we want to construct [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3]
. We will shuffle it. Then, from this list and the dataset classes content (ds.classes()
) we will be able to construct the batches.
class Sampler():
def __init__(self, classes):
self.classes = classes
def __iter__(self):
classes = copy.deepcopy(self.classes)
indices = flatten([[i for _ in range(len(klass))] for i, klass in enumerate(classes)])
random.shuffle(indices)
grouped = zip(*[iter(indices)]*2)
res = []
for a, b in grouped:
res.append((classes[a].pop(), classes[b].pop()))
return iter(res)
Note - deep copying the list is required since we're popping elements from it.
A possible output of this sampler would be:
[(15, 14), (16, 17), (7, 12), (11, 6), (13, 10), (5, 4), (9, 8), (2, 0), (3, 1)]
At this point we can simply use torch.data.utils.DataLoader
:
>>> dl = DataLoader(ds, batch_sampler=sampler(ds.classes()))
Which could yield something like:
[tensor([ 4, -4]), tensor([-21, 11]), tensor([-13, 6]), tensor([9, 1]), tensor([ 8, -21]), tensor([-3, 10]), tensor([ 6, -2]), tensor([-5, 7]), tensor([-6, 1])]
An easier approach
Here is another - easier - approach that will not guarantee to return all elements from the dataset, on average it will...
For each batch, first sample class_per_batch
classes, then sample batch_size
elements from these selected classes (by first sampling a class from that class subset, then sampling from a data point from that class).
class Sampler():
def __init__(self, classes, class_per_batch, batch_size):
self.classes = classes
self.n_batches = sum([len(x) for x in classes]) // batch_size
self.class_per_batch = class_per_batch
self.batch_size = batch_size
def __iter__(self):
classes = random.sample(range(len(self.classes)), self.class_per_batch)
batches = []
for _ in range(self.n_batches):
batch = []
for i in range(self.batch_size):
klass = random.choice(classes)
batch.append(random.choice(self.classes[klass]))
batches.append(batch)
return iter(batches)
You can try it this way:
>>> s = Sampler(ds.classes(), class_per_batch=2, batch_size=4)
>>> list(s)
[[16, 0, 0, 9], [10, 8, 11, 2], [16, 9, 16, 8], [2, 9, 2, 3]]
>>> dl = DataLoader(ds, batch_sampler=s)
>>> list(iter(dl))
[tensor([ -5, -6, -21, -13]), tensor([ -4, -4, -13, -13]), tensor([ -3, -21, -2, -5]), tensor([-3, -5, -4, -6])]