PyTorch: Training and Evaluation Loop
Bringing together the concepts of nn.Module for model definition, autograd for automatic differentiation, optimizers for parameter updates, and Datasets/DataLoaders for data handling, we can construct a complete training and evaluation loop for a PyTorch model.
The training loop typically involves:
1. Iterating over epochs: Multiple passes over the entire dataset.
2. Iterating over batches: Processing data in mini-batches within each epoch.
3. Forward pass: Propagating input through the model to get predictions.
4. Calculate loss: Quantifying the difference between predictions and true labels.
5. Backward pass: Computing gradients using autograd.
6. Optimizer step: Updating model parameters based on gradients.
The evaluation loop is similar but does not involve gradient computation or parameter updates. It often uses torch.no_grad() to save memory and computation.
Example: Complete Training and Evaluation of a Simple CNN on MNIST
Let's train the ConvNet defined in pytorch_nn_module.md on the MNIST dataset using pytorch_datasets_dataloaders.md and the optimizers/loss functions from pytorch_optimizers_losses.md.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# --- 1. Define the Model (reusing ConvNet from pytorch_nn_module.md) ---
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=0)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=0)
# Calculate input features for the first fully connected layer
# For a 28x28 input image:
# After conv1: (28 - 3 + 1) = 26
# After pool1: 26 / 2 = 13
# After conv2: (13 - 3 + 1) = 11
# After pool2: 11 / 2 = 5 (floor division)
self.fc1 = nn.Linear(64 * 5 * 5, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 5 * 5) # Flatten
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# --- 2. Hyperparameters ---
num_epochs = 5
batch_size = 64
learning_rate = 0.001
num_classes = 10 # MNIST has 10 classes (0-9)
# --- 3. Device Configuration ---
# Check if GPU is available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# --- 4. Data Loading and Preprocessing (reusing from pytorch_datasets_dataloaders.md) ---
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # Standard normalization for MNIST
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# --- 5. Instantiate Model, Loss, and Optimizer ---
model = ConvNet(num_classes).to(device) # Move model to the selected device
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# --- 6. Training Loop ---
print("\n--- Starting Training ---")
train_losses = []
train_accuracies = []
for epoch in range(num_epochs):
model.train() # Set model to training mode (important for dropout/batchnorm)
running_loss = 0.0
correct_predictions = 0
total_samples = 0
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0) # Accumulate batch loss
_, predicted = torch.max(outputs.data, 1) # Get predicted class
total_samples += labels.size(0)
correct_predictions += (predicted == labels).sum().item()
epoch_loss = running_loss / total_samples
epoch_accuracy = 100 * correct_predictions / total_samples
train_losses.append(epoch_loss)
train_accuracies.append(epoch_accuracy)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')
# --- 7. Evaluation Loop ---
print("\n--- Starting Evaluation ---")
model.eval() # Set model to evaluation mode (disable dropout/batchnorm updates)
with torch.no_grad(): # Disable gradient computation
correct_predictions = 0
total_samples = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total_samples += labels.size(0)
correct_predictions += (predicted == labels).sum().item()
test_accuracy = 100 * correct_predictions / total_samples
print(f'Test Accuracy of the model on the 10000 test images: {test_accuracy:.2f}%')
# --- 8. Visualize Training Progress (Optional) ---
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.subplot(1, 2, 2)
plt.plot(train_accuracies)
plt.title('Training Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.tight_layout()
plt.show()
# --- 9. Save the trained model ---
torch.save(model.state_dict(), 'mnist_cnn_model.pth')
print("\nModel saved to mnist_cnn_model.pth")
Best Practices and Considerations:
model.train()vsmodel.eval(): Always remember to set your model to the correct mode.model.train()enables features like Dropout and BatchNorm updates, whilemodel.eval()disables them, which is crucial for consistent evaluation results.optimizer.zero_grad(): Clear gradients at the beginning of each training step.torch.no_grad(): Use this context manager during evaluation to save memory and computations.- Device Management: Move your model and data to the appropriate device (CPU/GPU) using
.to(device). - Hyperparameter Tuning: Experiment with
learning_rate,batch_size,num_epochs, and optimizer choices. - Early Stopping: Stop training if validation performance plateaus or degrades to prevent overfitting.
- Learning Rate Schedulers: Dynamically adjust the learning rate during training.
This comprehensive example demonstrates how to implement a full training and evaluation loop in PyTorch. From here, you can explore more complex architectures, advanced training techniques, and deployment strategies.