PyTorch: Datasets and DataLoaders
Efficiently handling and loading data is crucial for training machine learning models, especially with large datasets. PyTorch provides two primary abstractions for this: torch.utils.data.Dataset and torch.utils.data.DataLoader.
Key Concepts:
Dataset: An abstract class representing a dataset. All custom datasets should inherit fromDatasetand override the following methods:__len__: Returns the size of the dataset.__getitem__: Returns a sample from the dataset at the given index.
DataLoader: Wraps an iterableDatasetto provide easy access to samples. It handles features like batching, shuffling, and multi-process data loading.
Why use Dataset and DataLoader?
- Batching: Models are typically trained in mini-batches, not one sample at a time, for computational efficiency.
DataLoaderautomates this. - Shuffling: Randomizing the order of data samples helps prevent the model from learning the order of the training data and improves generalization.
DataLoadercan shuffle data at the start of each epoch. - Parallel Data Loading: For large datasets, loading data from disk can be a bottleneck.
DataLoadercan use multiple worker processes (num_workers) to load data in parallel, speeding up training. - Transformation: Data can be preprocessed (e.g., resizing images, normalizing pixel values) on the fly during loading.
Example: Custom Dataset and DataLoader
Let's create a simple custom dataset for a regression task and then use DataLoader to iterate over it.
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
# 1. Define a Custom Dataset
class CustomRegressionDataset(Dataset):
def __init__(self, num_samples=100, transform=None):
self.num_samples = num_samples
# Generate synthetic data: y = 2x + 1 + noise
self.X = torch.randn(num_samples, 1) * 10 # Features
self.y = 2 * self.X + 1 + torch.randn(num_samples, 1) * 2 # Targets with noise
self.transform = transform
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
sample = {'feature': self.X[idx], 'target': self.y[idx]}
if self.transform:
sample = self.transform(sample)
return sample
# Optional: Define a simple transform
class ToTensor:
def __call__(self, sample):
feature, target = sample['feature'], sample['target']
return {'feature': torch.tensor(feature, dtype=torch.float32),
'target': torch.tensor(target, dtype=torch.float32)}
# 2. Instantiate the Dataset
# dataset = CustomRegressionDataset(transform=ToTensor()) # If transform is needed
dataset = CustomRegressionDataset()
print(f"Dataset size: {len(dataset)} samples")
print(f"First sample (raw): {dataset[0]}")
# 3. Create a DataLoader
batch_size = 16
shuffle_data = True
num_workers = 0 # Set to >0 for multi-process data loading, but needs care on Windows
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle_data, num_workers=num_workers)
print(f"\nDataLoader will yield batches of size: {batch_size}")
# 4. Iterate over the DataLoader
print("\nIterating through first 2 batches:")
for i, batch in enumerate(dataloader):
features = batch['feature']
targets = batch['target']
print(f"Batch {i+1}: Features shape {features.shape}, Targets shape {targets.shape}")
if i == 1: # Print first two batches
break
Example: Using torchvision.datasets and DataLoader for Image Data
For common datasets like MNIST, CIFAR-10, PyTorch provides pre-built datasets in torchvision.datasets.
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 1. Define Transforms
# Compose multiple transforms together
transform = transforms.Compose([
transforms.ToTensor(), # Converts PIL Image or NumPy array to Tensor, and scales to [0,1]
transforms.Normalize((0.5,), (0.5,)) # Normalize pixel values to mean 0.5, std 0.5
])
# 2. Load the Dataset
# MNIST dataset from torchvision
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
print(f"\nMNIST Training dataset size: {len(train_dataset)}")
print(f"MNIST Test dataset size: {len(test_dataset)}")
# Get a sample and check its shape
image, label = train_dataset[0]
print(f"Shape of first image: {image.shape}") # (C, H, W)
print(f"Label of first image: {label}")
# 3. Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 4. Iterate through a batch from the DataLoader
print("\nIterating through a batch from MNIST DataLoader:")
for images, labels in train_loader:
print(f"Batch images shape: {images.shape}") # (Batch_size, Channels, Height, Width)
print(f"Batch labels shape: {labels.shape}") # (Batch_size)
break # Just show the first batch
Further Topics:
- Custom
collate_fnfor flexible batching - Handling imbalanced datasets
- Data augmentation techniques (
torchvision.transforms) - Distributed data loading
- Memory pinning for faster GPU transfer
Understanding Dataset and DataLoader is fundamental for building robust and scalable data pipelines in PyTorch, ensuring that data is fed to your models efficiently during training and evaluation.