Skip to main content

Building Models with nn.Module

Create neural networks by subclassing nn.Module — layers, parameters, and model composition

~50 min
Listen to this lesson

Building Models with nn.Module

Every neural network in PyTorch is built by subclassing torch.nn.Module. This base class provides the scaffolding for defining layers, managing parameters, moving to GPU, saving/loading weights, and more.

The nn.Module Pattern

The pattern is always the same: 1. Subclass nn.Module 2. Define layers in __init__ 3. Define the forward pass in forward

python
1import torch
2import torch.nn as nn
3
4class SimpleClassifier(nn.Module):
5    def __init__(self, input_size, hidden_size, num_classes):
6        super().__init__()  # Always call super().__init__()!
7
8        # Define layers as attributes
9        self.fc1 = nn.Linear(input_size, hidden_size)
10        self.relu = nn.ReLU()
11        self.dropout = nn.Dropout(0.2)
12        self.fc2 = nn.Linear(hidden_size, num_classes)
13
14    def forward(self, x):
15        """Define how data flows through the network."""
16        x = self.fc1(x)
17        x = self.relu(x)
18        x = self.dropout(x)
19        x = self.fc2(x)
20        return x
21
22# Create an instance
23model = SimpleClassifier(input_size=784, hidden_size=256, num_classes=10)
24
25# Forward pass (just call the model like a function!)
26dummy_input = torch.randn(32, 784)  # batch of 32, 784 features
27output = model(dummy_input)          # calls model.forward() internally
28print(output.shape)                  # torch.Size([32, 10])

Never call forward() directly

Always use model(input) rather than model.forward(input). Calling the model directly invokes __call__, which runs registered hooks, handles gradient tracking, and then calls forward(). Calling forward() directly skips all of that.

Built-in Layers

PyTorch provides dozens of pre-built layers in torch.nn:

python
1import torch.nn as nn
2
3# --- Fully Connected ---
4linear = nn.Linear(in_features=128, out_features=64)  # weight: (64, 128), bias: (64,)
5
6# --- Convolutional ---
7conv2d = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
8# Input: (batch, 3, H, W) -> Output: (batch, 16, H, W) with padding=1
9
10# --- Normalization ---
11batch_norm = nn.BatchNorm2d(16)        # For conv layers with 16 channels
12layer_norm = nn.LayerNorm(64)          # For FC layers with 64 features
13
14# --- Regularization ---
15dropout = nn.Dropout(p=0.5)           # Randomly zeros 50% of elements during training
16dropout2d = nn.Dropout2d(p=0.25)      # Drops entire channels (for conv layers)
17
18# --- Embeddings ---
19embedding = nn.Embedding(num_embeddings=10000, embedding_dim=256)
20# Maps integer indices to dense vectors: input (batch, seq_len) -> (batch, seq_len, 256)
21
22# --- Activations ---
23relu = nn.ReLU()
24gelu = nn.GELU()
25sigmoid = nn.Sigmoid()
26softmax = nn.Softmax(dim=-1)
27
28# --- Pooling ---
29max_pool = nn.MaxPool2d(kernel_size=2, stride=2)  # Halves spatial dimensions
30avg_pool = nn.AdaptiveAvgPool2d((1, 1))            # Global average pooling

Parameter Inspection

nn.Module automatically tracks all parameters (learnable weights):

python
1import torch
2import torch.nn as nn
3
4model = SimpleClassifier(784, 256, 10)
5
6# Count total parameters
7total_params = sum(p.numel() for p in model.parameters())
8trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
9print(f"Total params: {total_params:,}")        # 203,530
10print(f"Trainable params: {trainable_params:,}")  # 203,530
11
12# Inspect named parameters
13for name, param in model.named_parameters():
14    print(f"{name}: shape={param.shape}, requires_grad={param.requires_grad}")
15# fc1.weight: shape=torch.Size([256, 784]), requires_grad=True
16# fc1.bias: shape=torch.Size([256]), requires_grad=True
17# fc2.weight: shape=torch.Size([10, 256]), requires_grad=True
18# fc2.bias: shape=torch.Size([10]), requires_grad=True
19
20# List all sub-modules
21for name, module in model.named_modules():
22    print(f"{name}: {module.__class__.__name__}")

Custom Layers

You can create reusable building blocks as custom nn.Module subclasses:

python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class ResidualBlock(nn.Module):
6    """A block with a skip connection."""
7    def __init__(self, dim):
8        super().__init__()
9        self.net = nn.Sequential(
10            nn.Linear(dim, dim),
11            nn.ReLU(),
12            nn.Linear(dim, dim),
13        )
14        self.norm = nn.LayerNorm(dim)
15
16    def forward(self, x):
17        # Skip connection: add input to output
18        return self.norm(x + self.net(x))
19
20class Swish(nn.Module):
21    """Custom activation: x * sigmoid(x)"""
22    def forward(self, x):
23        return x * torch.sigmoid(x)

Saving and Loading Models

PyTorch uses state_dict — a Python dictionary mapping parameter names to tensors:

python
1import torch
2import torch.nn as nn
3
4model = SimpleClassifier(784, 256, 10)
5
6# --- Save model weights ---
7torch.save(model.state_dict(), "model_weights.pth")
8
9# --- Load model weights ---
10loaded_model = SimpleClassifier(784, 256, 10)  # Must create the architecture first!
11loaded_model.load_state_dict(torch.load("model_weights.pth"))
12loaded_model.eval()  # Set to evaluation mode
13
14# --- Save entire model (not recommended for production) ---
15torch.save(model, "full_model.pth")
16loaded = torch.load("full_model.pth")
17
18# --- Save checkpoint (weights + optimizer + epoch) ---
19checkpoint = {
20    "epoch": 10,
21    "model_state_dict": model.state_dict(),
22    "optimizer_state_dict": optimizer.state_dict(),
23    "loss": 0.42,
24}
25torch.save(checkpoint, "checkpoint.pth")

state_dict is the standard

Always save model.state_dict() rather than the entire model. Saving the entire model uses Python's pickle, which is fragile — it breaks if you rename files, move classes, or change the code structure. state_dict is just a dictionary of tensors and always works.

Model Composition

PyTorch provides containers for building models from parts:

python
1import torch.nn as nn
2
3# --- nn.Sequential: a chain of layers ---
4model = nn.Sequential(
5    nn.Linear(784, 256),
6    nn.ReLU(),
7    nn.Dropout(0.2),
8    nn.Linear(256, 128),
9    nn.ReLU(),
10    nn.Linear(128, 10),
11)
12output = model(torch.randn(32, 784))  # Just chain them all
13
14# --- nn.ModuleList: a list of modules (for dynamic architectures) ---
15class DynamicMLP(nn.Module):
16    def __init__(self, layer_sizes):
17        super().__init__()
18        self.layers = nn.ModuleList([
19            nn.Linear(layer_sizes[i], layer_sizes[i+1])
20            for i in range(len(layer_sizes) - 1)
21        ])
22        self.relu = nn.ReLU()
23
24    def forward(self, x):
25        for layer in self.layers[:-1]:
26            x = self.relu(layer(x))
27        x = self.layers[-1](x)  # No activation on last layer
28        return x
29
30# --- nn.ModuleDict: named modules for branching ---
31class MultiHeadModel(nn.Module):
32    def __init__(self, input_dim):
33        super().__init__()
34        self.backbone = nn.Linear(input_dim, 128)
35        self.heads = nn.ModuleDict({
36            "classification": nn.Linear(128, 10),
37            "regression": nn.Linear(128, 1),
38        })
39
40    def forward(self, x, task="classification"):
41        features = torch.relu(self.backbone(x))
42        return self.heads[task](features)

Always use nn.ModuleList, never plain Python lists

If you store layers in a regular Python list (self.layers = [nn.Linear(...), ...]), PyTorch will NOT register them as sub-modules. Their parameters will not appear in model.parameters(), they won't be moved when you call model.to(device), and they won't be saved in state_dict. Always use nn.ModuleList or nn.ModuleDict.