Using Weighted Random Sampler in PyTorch

Sometimes there are scenarios where you have way lesser number of samples for some of the classes where as other classes have lots of samples. In such a scenario, you don’t want a training batch to be contain samples just from a few of the classes with lots of samples. Ideally, a training batch should contain represent a good spread of the dataset. In PyTorch this can be achieved using a weighted random sampler.

In this short post, I will walk you through the process of creating a random weighted sampler in PyTorch.

To start off, lets assume you have a dataset with images grouped in folders based on their class. We can use a ImageFolder to create a dataset from it.

from torchvision import datasets

dataset = datasets.ImageFolder(root=data_dir, transform=image_transforms)

Also, lets split this dataset into training and validation sets,

dataset_size = dataset.__len__()
train_count = int(dataset_size * 0.7)
val_count = dataset_size - train_count
train_dataset, valid_dataset = random_split(dataset, [train_count, val_count])

Note: I am taking a ImageFolder and training/validation splits just to emulate a real world example. You can work with any pytorch dataset.

We will be using a weighted random sampler just for the training set. For validation set, we don’t care about balancing a batch. Now that we have the train_dataset, you need to define the weights for each class which would be inversely proportional to the number of samples for each class.

First, lets find the number of samples for each class.

import numpy as np 

y_train_indices = train_dataset.indices

y_train = [dataset.targets[i] for i in y_train_indices]

class_sample_count = np.array(
    [len(np.where(y_train == t)[0]) for t in np.unique(y_train)])

Next, we need to find the weights for each class.

weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in y_train])
samples_weight = torch.from_numpy(samples_weight)

Now, that we have the weights for each of the classes, we can define a sampler.

sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))

Finally, we can use the sampler, while defining the Dataloader.

train_dataloader = DataLoader(train_dataset, batch_size=4, sampler=sampler)

Thats it for this post. I hope you found this post useful.

Vivek Maskara
Vivek Maskara
GRA at The Luminosity Lab, ASU | Ex Senior Software Engineer, Zeta | Volunteer, Wikimedia Foundation

Seeking Summer 2021 internships. Check out my Resume.

Related