⬡ Hub
Skip to content

PyTorch: Distributed Training

As models and datasets grow in size, training on a single device (e.g., one GPU) becomes impractical or impossible due to memory limitations and long training times. PyTorch provides robust tools for distributed training, allowing you to train models across multiple GPUs, multiple machines, or even multiple nodes in a cluster.

Key Concepts:

  • Data Parallelism:

    • The most common form of distributed training.
    • Each worker (e.g., GPU) gets a subset of the data.
    • Each worker has a complete copy of the model.
    • Forward pass and backward pass are done independently on each worker.
    • Gradients are then collected and averaged across all workers to update the model parameters.
    • PyTorch offers torch.nn.DataParallel (DP) and torch.nn.parallel.DistributedDataParallel (DDP). DDP is generally preferred for its better performance, scalability, and flexibility.
  • Model Parallelism:

    • Different layers of a single model are placed on different devices.
    • Suitable when a single model is too large to fit into one GPU's memory.
    • Data is typically processed sequentially through the model, moving between devices.
    • More complex to implement than data parallelism and often requires careful partitioning of the model.
  • torch.distributed package:

    • The core package for communication primitives in distributed environments.
    • Provides functionalities like init_process_group, send, recv, all_reduce, broadcast, gather, scatter, etc.
  • Process Groups:

    • A set of processes that can communicate with each other.
    • init_process_group initializes the process group.
    • Each process is identified by a unique rank, and the total number of processes is world_size.

DistributedDataParallel (DDP) is PyTorch's preferred approach for data-parallel training. It overcomes performance bottlenecks of DataParallel by distributing the model across multiple GPUs (or machines) and replicating it on each worker. Critically, it performs gradient synchronization more efficiently, often asynchronously, and requires each GPU to be handled by a separate process.

Example: Setting up DDP on a Single Machine with Multiple GPUs

This example demonstrates how to set up DDP on a single machine with, for instance, 2 GPUs. Each GPU will be managed by a separate Python process. We'll use torch.multiprocessing.spawn to launch these processes.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms
import os

# --- 1. Define a Simple Model (for demonstration) ---
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc = nn.Linear(320, num_classes) # Adjusted for MNIST 28x28 input

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 320) # Flatten
        x = self.fc(x)
        return x

# --- 2. Main Training Function for Each Process ---
def train_ddp_model(rank, world_size):
    # a. Initialize the process group
    #   'env://' uses environment variables for configuration (MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE)
    #   backend='nccl' is recommended for GPU training
    print(f"Rank {rank} initializing process group...")
    os.environ['MASTER_ADDR'] = 'localhost' # For single machine
    os.environ['MASTER_PORT'] = '12355' # Choose a free port
    os.environ['RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    torch.distributed.init_process_group(backend='nccl', rank=rank, world_size=world_size)
    print(f"Rank {rank} process group initialized.")

    # b. Setup Device
    torch.cuda.set_device(rank) # Set the current GPU device for this process
    device = torch.device(f"cuda:{rank}")
    print(f"Rank {rank} using device: {device}")

    # c. Data Loading with DistributedSampler
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = datasets.MNIST(root='./data_ddp', train=True, download=True, transform=transform)

    # DistributedSampler ensures each process gets a unique subset of the data
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler, num_workers=2)

    # d. Model Initialization and DDP Wrapper
    model = SimpleCNN(num_classes=10).to(device)
    # Wrap the model with DistributedDataParallel
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    print(f"Rank {rank} model wrapped with DDP.")

    # e. Loss and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # f. Training Loop
    num_epochs = 5
    for epoch in range(num_epochs):
        train_sampler.set_epoch(epoch) # Important for shuffling to be effective across epochs
        model.train()
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Rank {rank}, Epoch {epoch+1}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

        # Reduce loss from all GPUs to get averaged loss
        # This is optional for logging, but useful for overall monitoring
        avg_loss = running_loss / len(train_loader)
        print(f"Rank {rank}, Epoch {epoch+1} finished. Avg Loss: {avg_loss:.4f}")

    print(f"Rank {rank} training complete.")
    torch.distributed.destroy_process_group() # Clean up

# --- 3. Entry Point for Launching Processes ---
def main():
    # Detect available GPUs
    world_size = torch.cuda.device_count()
    if world_size == 0:
        print("No GPUs found. Running on CPU (DistributedDataParallel is less efficient here).")
        world_size = 1 # Run with one process on CPU for demonstration
    else:
        print(f"Found {world_size} GPUs. Launching {world_size} processes for DDP.")

    # Use torch.multiprocessing.spawn to launch a process for each GPU
    # 'spawn' creates new Python processes, ensuring clean state for each.
    mp.spawn(train_ddp_model,
             args=(world_size,),
             nprocs=world_size,
             join=True) # join=True waits for all processes to finish

if __name__ == '__main__':
    main()

Explanation of DDP components:

  • init_process_group: Initializes the distributed environment.
    • backend: nccl for GPU, gloo for CPU.
    • rank: Unique identifier for each process within the group (0 to world_size-1).
    • world_size: Total number of processes participating in the training.
  • DistributedSampler: Ensures that each process in the distributed setup receives a unique and non-overlapping subset of the dataset. This is crucial for data parallelism.
  • torch.cuda.set_device(rank): Each process specifically tells PyTorch which GPU it should use.
  • nn.parallel.DistributedDataParallel(model, device_ids=[rank]): Wraps your model. DDP handles the synchronization of gradients between processes after the backward pass. It makes sure that each worker's model replica remains consistent.
  • train_sampler.set_epoch(epoch): Important for DistributedSampler to ensure proper shuffling each epoch across all processes.

Considerations for Multi-Node DDP:

For training across multiple machines, you typically use a launcher utility (e.g., torch.distributed.launch or torchrun) and set up environment variables like MASTER_ADDR and MASTER_PORT to point to a central coordinating process.

Further Topics:

  • Gradient Accumulation: Effectively increasing batch size without needing more memory by accumulating gradients over several mini-batches before optimizing.
  • Mixed Precision Training: Using torch.cuda.amp to train models with a mix of float16 and float32 data types for faster training and reduced memory usage on compatible GPUs.
  • Custom Communication Backends: For advanced use cases.
  • Model Parallelism Implementation: Manually splitting model layers across devices.
  • Checkpointing in DDP: Ensuring that only the rank 0 process saves the model to avoid race conditions and redundant saving.

Distributed training is a complex but essential skill for scaling deep learning applications. DDP offers a powerful and relatively straightforward way to achieve data parallelism efficiently in PyTorch.