PyTorch: Distributed Training
As models and datasets grow in size, training on a single device (e.g., one GPU) becomes impractical or impossible due to memory limitations and long training times. PyTorch provides robust tools for distributed training, allowing you to train models across multiple GPUs, multiple machines, or even multiple nodes in a cluster.
Key Concepts:
-
Data Parallelism:
- The most common form of distributed training.
- Each worker (e.g., GPU) gets a subset of the data.
- Each worker has a complete copy of the model.
- Forward pass and backward pass are done independently on each worker.
- Gradients are then collected and averaged across all workers to update the model parameters.
- PyTorch offers
torch.nn.DataParallel(DP) andtorch.nn.parallel.DistributedDataParallel(DDP). DDP is generally preferred for its better performance, scalability, and flexibility.
-
Model Parallelism:
- Different layers of a single model are placed on different devices.
- Suitable when a single model is too large to fit into one GPU's memory.
- Data is typically processed sequentially through the model, moving between devices.
- More complex to implement than data parallelism and often requires careful partitioning of the model.
-
torch.distributedpackage:- The core package for communication primitives in distributed environments.
- Provides functionalities like
init_process_group,send,recv,all_reduce,broadcast,gather,scatter, etc.
-
Process Groups:
- A set of processes that can communicate with each other.
init_process_groupinitializes the process group.- Each process is identified by a unique rank, and the total number of processes is
world_size.
DistributedDataParallel (DDP) - The Recommended Approach
DistributedDataParallel (DDP) is PyTorch's preferred approach for data-parallel training. It overcomes performance bottlenecks of DataParallel by distributing the model across multiple GPUs (or machines) and replicating it on each worker. Critically, it performs gradient synchronization more efficiently, often asynchronously, and requires each GPU to be handled by a separate process.
Example: Setting up DDP on a Single Machine with Multiple GPUs
This example demonstrates how to set up DDP on a single machine with, for instance, 2 GPUs. Each GPU will be managed by a separate Python process. We'll use torch.multiprocessing.spawn to launch these processes.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms
import os
# --- 1. Define a Simple Model (for demonstration) ---
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc = nn.Linear(320, num_classes) # Adjusted for MNIST 28x28 input
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 320) # Flatten
x = self.fc(x)
return x
# --- 2. Main Training Function for Each Process ---
def train_ddp_model(rank, world_size):
# a. Initialize the process group
# 'env://' uses environment variables for configuration (MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE)
# backend='nccl' is recommended for GPU training
print(f"Rank {rank} initializing process group...")
os.environ['MASTER_ADDR'] = 'localhost' # For single machine
os.environ['MASTER_PORT'] = '12355' # Choose a free port
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
torch.distributed.init_process_group(backend='nccl', rank=rank, world_size=world_size)
print(f"Rank {rank} process group initialized.")
# b. Setup Device
torch.cuda.set_device(rank) # Set the current GPU device for this process
device = torch.device(f"cuda:{rank}")
print(f"Rank {rank} using device: {device}")
# c. Data Loading with DistributedSampler
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data_ddp', train=True, download=True, transform=transform)
# DistributedSampler ensures each process gets a unique subset of the data
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler, num_workers=2)
# d. Model Initialization and DDP Wrapper
model = SimpleCNN(num_classes=10).to(device)
# Wrap the model with DistributedDataParallel
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
print(f"Rank {rank} model wrapped with DDP.")
# e. Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# f. Training Loop
num_epochs = 5
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch) # Important for shuffling to be effective across epochs
model.train()
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 0:
print(f"Rank {rank}, Epoch {epoch+1}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
# Reduce loss from all GPUs to get averaged loss
# This is optional for logging, but useful for overall monitoring
avg_loss = running_loss / len(train_loader)
print(f"Rank {rank}, Epoch {epoch+1} finished. Avg Loss: {avg_loss:.4f}")
print(f"Rank {rank} training complete.")
torch.distributed.destroy_process_group() # Clean up
# --- 3. Entry Point for Launching Processes ---
def main():
# Detect available GPUs
world_size = torch.cuda.device_count()
if world_size == 0:
print("No GPUs found. Running on CPU (DistributedDataParallel is less efficient here).")
world_size = 1 # Run with one process on CPU for demonstration
else:
print(f"Found {world_size} GPUs. Launching {world_size} processes for DDP.")
# Use torch.multiprocessing.spawn to launch a process for each GPU
# 'spawn' creates new Python processes, ensuring clean state for each.
mp.spawn(train_ddp_model,
args=(world_size,),
nprocs=world_size,
join=True) # join=True waits for all processes to finish
if __name__ == '__main__':
main()
Explanation of DDP components:
init_process_group: Initializes the distributed environment.backend:ncclfor GPU,gloofor CPU.rank: Unique identifier for each process within the group (0 toworld_size-1).world_size: Total number of processes participating in the training.
DistributedSampler: Ensures that each process in the distributed setup receives a unique and non-overlapping subset of the dataset. This is crucial for data parallelism.torch.cuda.set_device(rank): Each process specifically tells PyTorch which GPU it should use.nn.parallel.DistributedDataParallel(model, device_ids=[rank]): Wraps your model. DDP handles the synchronization of gradients between processes after the backward pass. It makes sure that each worker's model replica remains consistent.train_sampler.set_epoch(epoch): Important forDistributedSamplerto ensure proper shuffling each epoch across all processes.
Considerations for Multi-Node DDP:
For training across multiple machines, you typically use a launcher utility (e.g., torch.distributed.launch or torchrun) and set up environment variables like MASTER_ADDR and MASTER_PORT to point to a central coordinating process.
Further Topics:
- Gradient Accumulation: Effectively increasing batch size without needing more memory by accumulating gradients over several mini-batches before optimizing.
- Mixed Precision Training: Using
torch.cuda.ampto train models with a mix of float16 and float32 data types for faster training and reduced memory usage on compatible GPUs. - Custom Communication Backends: For advanced use cases.
- Model Parallelism Implementation: Manually splitting model layers across devices.
- Checkpointing in DDP: Ensuring that only the rank 0 process saves the model to avoid race conditions and redundant saving.
Distributed training is a complex but essential skill for scaling deep learning applications. DDP offers a powerful and relatively straightforward way to achieve data parallelism efficiently in PyTorch.