⬡ Hub
Skip to content

PyTorch: Transfer Learning and Fine-tuning

Transfer learning is a machine learning technique where a model trained on one task is re-purposed for a second related task. Fine-tuning takes this a step further by continuing the training of a pre-trained model on a new, typically smaller, dataset for a specific task. This approach is highly effective in deep learning, especially when you have limited data for your target task, as it leverages the features learned by the pre-trained model from a very large dataset.

Why Transfer Learning?

  1. Reduced Training Time: Pre-trained models have already learned general features (e.g., edge detectors, texture detectors in images) from massive datasets (like ImageNet), saving a huge amount of training time.
  2. Better Performance with Less Data: By using pre-trained weights, your model can achieve higher accuracy even with a relatively small target dataset, as it benefits from the extensive knowledge encoded in the pre-trained model.
  3. Overfitting Prevention: For small datasets, training a deep network from scratch often leads to overfitting. Transfer learning acts as a form of regularization.

Common Strategies for Transfer Learning in PyTorch

  1. **Feature Extractor (Fixed Feature Extractor):

    • Initialize a pre-trained model.
    • Freeze the weights of most of the pre-trained layers (e.g., convolutional base).
    • Replace the final classification layer(s) with new layers tailored to your specific task (e.g., number of output classes).
    • Only train the new, randomly initialized layers. This is faster and generally good for smaller datasets.
  2. **Fine-tuning (All Layers or Partial Layers):

    • Initialize a pre-trained model.
    • Replace the final classification layer(s) (if necessary).
    • Unfreeze some or all layers of the pre-trained model.
    • Train the entire model (or a larger portion) with a very small learning rate, as the existing weights are already good. This is suitable for larger datasets or when the new task is significantly different from the original task.

Example: Transfer Learning with a Pre-trained ResNet

Let's fine-tune a pre-trained ResNet-18 model for a simple image classification task (e.g., distinguishing between two custom classes).

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

# --- 1. Device Configuration ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 2. Data Loading and Preprocessing ---
# Assume you have a dataset structured like:
# data/
# └─ train/
# │   └─ class_a/
# │   │   └─ img1.jpg
# │   │   └─ img2.jpg
# │   └─ class_b/
# │       └─ img3.jpg
# │       └─ img4.jpg
# └─ val/
#     └─ class_a/
#     │   └─ img5.jpg
#     └─ class_b/
#         └─ img6.jpg

# For demonstration, let's use a dummy dataset (e.g., cifar-10 for simplicity
# or you can create a small dummy folder structure to simulate your data).
# If you want to use your own data, replace this with ImageFolder pointing to your paths.

# Example: Use CIFAR-10 as a stand-in for a custom dataset
# This will be treated as if we only care about 10 classes
# For a real custom dataset, replace this block with ImageFolder.
# Example:
# data_dir = 'path/to/your/custom/dataset'
# image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224), # ResNet expects 224x224 input
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet normalization
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load CIFAR-10, we'll pretend it's a smaller custom dataset
print("Loading CIFAR-10 dataset (as a stand-in for custom data)...")
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transforms['train'])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transforms['val']) # Using test as val

# Let's reduce the size of the CIFAR-10 to simulate a smaller custom dataset
# For demonstration, take a subset
indices_train = torch.randperm(len(train_dataset))[:5000] # Take 5000 samples
indices_test = torch.randperm(len(test_dataset))[:1000]   # Take 1000 samples
train_dataset_subset = torch.utils.data.Subset(train_dataset, indices_train)
test_dataset_subset = torch.utils.data.Subset(test_dataset, indices_test)


dataloaders = {
    'train': DataLoader(train_dataset_subset, batch_size=32, shuffle=True, num_workers=2),
    'val': DataLoader(test_dataset_subset, batch_size=32, shuffle=False, num_workers=2),
}
dataset_sizes = {'train': len(train_dataset_subset), 'val': len(test_dataset_subset)}
class_names = train_dataset.classes # For CIFAR-10, otherwise from image_datasets['train'].classes
num_classes = len(class_names)

print(f"Training dataset size: {dataset_sizes['train']}")
print(f"Validation dataset size: {dataset_sizes['val']}")
print(f"Number of classes: {num_classes}")

# --- 3. Load Pre-trained Model (ResNet-18) ---
print("\nLoading pre-trained ResNet-18 model...")
model_ft = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) # Use latest recommended weights

# --- 4. Freeze all parameters (Feature Extractor strategy) ---
# for param in model_ft.parameters():
#     param.requires_grad = False

# --- 5. Modify the final layer for our new classification task ---
# ResNet-18's last layer is a fully connected layer (fc)
num_ftrs = model_ft.fc.in_features # Get the number of input features to the last layer
model_ft.fc = nn.Linear(num_ftrs, num_classes) # Replace with a new linear layer for our specific number of classes

# Move the model to the device
model_ft = model_ft.to(device)

# --- 6. Define Loss Function and Optimizer ---
# If freezing layers, only optimize parameters that are not frozen.
# Otherwise, optimize all parameters.
# In our case, we only have one new layer (model_ft.fc) whose params require grad.
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001)

# You might use a learning rate scheduler for fine-tuning
# exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

# --- 7. Training and Evaluation Loop ---
print("\n--- Starting Training (Fine-tuning) ---")

best_acc = 0.0
best_model_wts = model_ft.state_dict()
num_epochs = 5 # Reduced for demonstration

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model_ft.train() # Set model to training mode
        else:
            model_ft.eval()  # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer_ft.zero_grad()

            # Forward
            # Track gradients only if in training phase
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model_ft(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # Backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer_ft.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        # if phase == 'train':
        #     exp_lr_scheduler.step() # Update learning rate

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Deep copy the model if it's the best performing
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = model_ft.state_dict() # Save best model weights

print(f'\nBest val Acc: {best_acc:.4f}')

# Load best model weights
model_ft.load_state_dict(best_model_wts)

# --- 8. Save the Fine-tuned Model ---
torch.save(model_ft.state_dict(), 'fine_tuned_resnet18_cifar10.pth')
print("\nFine-tuned model saved to fine_tuned_resnet18_cifar10.pth")

Considerations for Transfer Learning and Fine-tuning:

  • Dataset Size and Similarity:
    • Small dataset, similar to pre-training dataset: Use the pre-trained model as a fixed feature extractor. Train only the new head.
    • Small dataset, different from pre-training dataset: Freeze lower layers, fine-tune higher layers, and train a new head.
    • Large dataset, similar to pre-training dataset: Fine-tune the entire model with a small learning rate.
    • Large dataset, different from pre-training dataset: Train the entire model from scratch (though pre-trained weights can still provide a good initialization).
  • Learning Rate: Often, a smaller learning rate is used for fine-tuning than for training from scratch. You might use different learning rates for different blocks of layers (e.g., lower for frozen layers, higher for new layers).
  • Regularization: Be mindful of overfitting, especially when unfreezing more layers.
  • Data Augmentation: Essential for making the most out of your (often limited) target dataset.
  • Pre-trained Model Choice: Select a pre-trained model suitable for your data type (e.g., ResNet, VGG for images; BERT, GPT for text).

This document introduces the powerful concepts of transfer learning and fine-tuning in PyTorch, demonstrating how to leverage pre-trained models to achieve high performance on new tasks with less data and computational effort.