⬡ Hub
Skip to content

PyTorch nn.Module: Building Neural Networks

In PyTorch, the torch.nn module is the core component for building neural networks. The nn.Module class is the base class for all neural network modules. Your models should subclass this class.

Key Concepts of nn.Module:

  • Layers: Each layer (e.g., convolution, linear, activation) in your neural network is typically an instance of nn.Module.
  • Parameters: nn.Module automatically tracks all parameters (weights and biases) defined within its submodules. These parameters are registered as torch.nn.Parameter objects.
  • forward() method: Every nn.Module subclass must override the forward() method. This method defines how the input data is processed through the layers to produce an output.
  • Hooks: Allows you to register functions that will be executed before or after forward/backward passes.
  • Moving to Device: Provides methods like .to(device) to easily move all model parameters and buffers to a GPU or CPU.

Building a Simple Feedforward Neural Network

Let's construct a simple neural network for classification.

Example: Basic Classification Network

import torch
import torch.nn as nn
import torch.nn.functional as F

# 1. Define the Neural Network Class
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__() # Call the parent constructor
        self.fc1 = nn.Linear(input_size, hidden_size)  # First fully connected layer
        self.relu = nn.ReLU()                         # Activation function
        self.fc2 = nn.Linear(hidden_size, num_classes) # Second fully connected layer (output layer)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 2. Instantiate the Model
input_size = 784  # For example, if input is flattened 28x28 images
hidden_size = 128
num_classes = 10  # For example, 10 digit classification (0-9)

model = SimpleNN(input_size, hidden_size, num_classes)
print("Model Architecture:")
print(model)

# 3. Inspect Model Parameters
print("\nModel Parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}, Size: {param.size()}")

# 4. Perform a Forward Pass (dummy data)
dummy_input = torch.randn(64, input_size) # Batch size of 64
output = model(dummy_input)
print(f"\nOutput shape: {output.shape}") # Expected: (batch_size, num_classes)

Advanced Network Architectures

Example: Convolutional Neural Network (CNN)

For tasks like image classification, CNNs are highly effective.

import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        # Convolutional Layer 1
        # Input: (batch_size, 1, 28, 28) for grayscale images like MNIST
        # Output: (batch_size, 32, 26, 26) after conv, (batch_size, 32, 13, 13) after pooling
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=0)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Convolutional Layer 2
        # Input: (batch_size, 32, 13, 13)
        # Output: (batch_size, 64, 11, 11) after conv, (batch_size, 64, 5, 5) after pooling
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=0)

        # Fully Connected Layer 1
        # The input features for fc1 depend on the output size of the last pooling layer.
        # For a 28x28 input image:
        # After conv1: (28 - 3 + 1) = 26
        # After pool1: 26 / 2 = 13
        # After conv2: (13 - 3 + 1) = 11
        # After pool2: 11 / 2 = 5 (floor division)
        # So, 64 channels * 5 * 5 features
        self.fc1 = nn.Linear(64 * 5 * 5, 128)

        # Output Layer
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        # -> conv1 -> ReLU -> pool
        x = self.pool(F.relu(self.conv1(x)))
        # -> conv2 -> ReLU -> pool
        x = self.pool(F.relu(self.conv2(x)))
        # Flatten the output for the fully connected layers
        x = x.view(-1, 64 * 5 * 5) # -1 infers batch size
        # -> fc1 -> ReLU
        x = F.relu(self.fc1(x))
        # -> fc2 (output)
        x = self.fc2(x)
        return x

# Instantiate the CNN model
cnn_model = ConvNet(num_classes=10)
print("\nCNN Model Architecture:")
print(cnn_model)

# Dummy input for CNN (e.g., 1 batch, 1 channel, 28x28 image)
dummy_cnn_input = torch.randn(1, 1, 28, 28)
cnn_output = cnn_model(dummy_cnn_input)
print(f"\nCNN Output shape: {cnn_output.shape}")

Common nn.Module Layers:

  • nn.Linear: Applies a linear transformation to the incoming data (fully connected layer).
  • nn.Conv1d, nn.Conv2d, nn.Conv3d: Convolutional layers for 1D, 2D, and 3D data.
  • nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d: Max pooling layers.
  • nn.ReLU, nn.Sigmoid, nn.Tanh, nn.Softmax: Activation functions.
  • nn.Dropout: Randomly sets a fraction of input units to 0 at each update during training.
  • nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d: Batch Normalization layers.
  • nn.LSTM, nn.GRU, nn.RNN: Recurrent neural network layers.
  • nn.Embedding: A simple lookup table that stores fixed-size embeddings.

Saving and Loading Models

PyTorch models can be saved and loaded efficiently. The recommended way to save a model is to save only the model's state_dict (which contains the learned parameters).

# Save the state_dict
torch.save(model.state_dict(), 'simple_nn_model.pth')

# Load the state_dict
loaded_model = SimpleNN(input_size, hidden_size, num_classes) # Re-instantiate the model
loaded_model.load_state_dict(torch.load('simple_nn_model.pth'))
loaded_model.eval() # Set the model to evaluation mode (important for dropout, batchnorm)

print("\nLoaded Model:")
print(loaded_model)

Further Topics:

  • Loss Functions (nn.CrossEntropyLoss, nn.MSELoss)
  • Optimizers (torch.optim.Adam, torch.optim.SGD)
  • DataLoaders and Datasets (torch.utils.data)
  • Training Loops and Evaluation
  • Transfer Learning
  • Custom Layers and Modules

This document provides a foundational understanding of nn.Module and how to build neural networks in PyTorch. The next steps involve understanding how to train these models effectively.