Building Models with nn.Module
Every neural network in PyTorch is built by subclassing torch.nn.Module. This base class provides the scaffolding for defining layers, managing parameters, moving to GPU, saving/loading weights, and more.
The nn.Module Pattern
The pattern is always the same:
1. Subclass nn.Module
2. Define layers in __init__
3. Define the forward pass in forward
python
1import torch
2import torch.nn as nn
3
4class SimpleClassifier(nn.Module):
5 def __init__(self, input_size, hidden_size, num_classes):
6 super().__init__() # Always call super().__init__()!
7
8 # Define layers as attributes
9 self.fc1 = nn.Linear(input_size, hidden_size)
10 self.relu = nn.ReLU()
11 self.dropout = nn.Dropout(0.2)
12 self.fc2 = nn.Linear(hidden_size, num_classes)
13
14 def forward(self, x):
15 """Define how data flows through the network."""
16 x = self.fc1(x)
17 x = self.relu(x)
18 x = self.dropout(x)
19 x = self.fc2(x)
20 return x
21
22# Create an instance
23model = SimpleClassifier(input_size=784, hidden_size=256, num_classes=10)
24
25# Forward pass (just call the model like a function!)
26dummy_input = torch.randn(32, 784) # batch of 32, 784 features
27output = model(dummy_input) # calls model.forward() internally
28print(output.shape) # torch.Size([32, 10])Never call forward() directly
Always use model(input) rather than model.forward(input). Calling the model directly invokes __call__, which runs registered hooks, handles gradient tracking, and then calls forward(). Calling forward() directly skips all of that.
Built-in Layers
PyTorch provides dozens of pre-built layers in torch.nn:
python
1import torch.nn as nn
2
3# --- Fully Connected ---
4linear = nn.Linear(in_features=128, out_features=64) # weight: (64, 128), bias: (64,)
5
6# --- Convolutional ---
7conv2d = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
8# Input: (batch, 3, H, W) -> Output: (batch, 16, H, W) with padding=1
9
10# --- Normalization ---
11batch_norm = nn.BatchNorm2d(16) # For conv layers with 16 channels
12layer_norm = nn.LayerNorm(64) # For FC layers with 64 features
13
14# --- Regularization ---
15dropout = nn.Dropout(p=0.5) # Randomly zeros 50% of elements during training
16dropout2d = nn.Dropout2d(p=0.25) # Drops entire channels (for conv layers)
17
18# --- Embeddings ---
19embedding = nn.Embedding(num_embeddings=10000, embedding_dim=256)
20# Maps integer indices to dense vectors: input (batch, seq_len) -> (batch, seq_len, 256)
21
22# --- Activations ---
23relu = nn.ReLU()
24gelu = nn.GELU()
25sigmoid = nn.Sigmoid()
26softmax = nn.Softmax(dim=-1)
27
28# --- Pooling ---
29max_pool = nn.MaxPool2d(kernel_size=2, stride=2) # Halves spatial dimensions
30avg_pool = nn.AdaptiveAvgPool2d((1, 1)) # Global average poolingParameter Inspection
nn.Module automatically tracks all parameters (learnable weights):
python
1import torch
2import torch.nn as nn
3
4model = SimpleClassifier(784, 256, 10)
5
6# Count total parameters
7total_params = sum(p.numel() for p in model.parameters())
8trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
9print(f"Total params: {total_params:,}") # 203,530
10print(f"Trainable params: {trainable_params:,}") # 203,530
11
12# Inspect named parameters
13for name, param in model.named_parameters():
14 print(f"{name}: shape={param.shape}, requires_grad={param.requires_grad}")
15# fc1.weight: shape=torch.Size([256, 784]), requires_grad=True
16# fc1.bias: shape=torch.Size([256]), requires_grad=True
17# fc2.weight: shape=torch.Size([10, 256]), requires_grad=True
18# fc2.bias: shape=torch.Size([10]), requires_grad=True
19
20# List all sub-modules
21for name, module in model.named_modules():
22 print(f"{name}: {module.__class__.__name__}")Custom Layers
You can create reusable building blocks as custom nn.Module subclasses:
python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class ResidualBlock(nn.Module):
6 """A block with a skip connection."""
7 def __init__(self, dim):
8 super().__init__()
9 self.net = nn.Sequential(
10 nn.Linear(dim, dim),
11 nn.ReLU(),
12 nn.Linear(dim, dim),
13 )
14 self.norm = nn.LayerNorm(dim)
15
16 def forward(self, x):
17 # Skip connection: add input to output
18 return self.norm(x + self.net(x))
19
20class Swish(nn.Module):
21 """Custom activation: x * sigmoid(x)"""
22 def forward(self, x):
23 return x * torch.sigmoid(x)Saving and Loading Models
PyTorch uses state_dict — a Python dictionary mapping parameter names to tensors:
python
1import torch
2import torch.nn as nn
3
4model = SimpleClassifier(784, 256, 10)
5
6# --- Save model weights ---
7torch.save(model.state_dict(), "model_weights.pth")
8
9# --- Load model weights ---
10loaded_model = SimpleClassifier(784, 256, 10) # Must create the architecture first!
11loaded_model.load_state_dict(torch.load("model_weights.pth"))
12loaded_model.eval() # Set to evaluation mode
13
14# --- Save entire model (not recommended for production) ---
15torch.save(model, "full_model.pth")
16loaded = torch.load("full_model.pth")
17
18# --- Save checkpoint (weights + optimizer + epoch) ---
19checkpoint = {
20 "epoch": 10,
21 "model_state_dict": model.state_dict(),
22 "optimizer_state_dict": optimizer.state_dict(),
23 "loss": 0.42,
24}
25torch.save(checkpoint, "checkpoint.pth")state_dict is the standard
Always save model.state_dict() rather than the entire model. Saving the entire model uses Python's pickle, which is fragile — it breaks if you rename files, move classes, or change the code structure. state_dict is just a dictionary of tensors and always works.
Model Composition
PyTorch provides containers for building models from parts:
python
1import torch.nn as nn
2
3# --- nn.Sequential: a chain of layers ---
4model = nn.Sequential(
5 nn.Linear(784, 256),
6 nn.ReLU(),
7 nn.Dropout(0.2),
8 nn.Linear(256, 128),
9 nn.ReLU(),
10 nn.Linear(128, 10),
11)
12output = model(torch.randn(32, 784)) # Just chain them all
13
14# --- nn.ModuleList: a list of modules (for dynamic architectures) ---
15class DynamicMLP(nn.Module):
16 def __init__(self, layer_sizes):
17 super().__init__()
18 self.layers = nn.ModuleList([
19 nn.Linear(layer_sizes[i], layer_sizes[i+1])
20 for i in range(len(layer_sizes) - 1)
21 ])
22 self.relu = nn.ReLU()
23
24 def forward(self, x):
25 for layer in self.layers[:-1]:
26 x = self.relu(layer(x))
27 x = self.layers[-1](x) # No activation on last layer
28 return x
29
30# --- nn.ModuleDict: named modules for branching ---
31class MultiHeadModel(nn.Module):
32 def __init__(self, input_dim):
33 super().__init__()
34 self.backbone = nn.Linear(input_dim, 128)
35 self.heads = nn.ModuleDict({
36 "classification": nn.Linear(128, 10),
37 "regression": nn.Linear(128, 1),
38 })
39
40 def forward(self, x, task="classification"):
41 features = torch.relu(self.backbone(x))
42 return self.heads[task](features)Always use nn.ModuleList, never plain Python lists
If you store layers in a regular Python list (self.layers = [nn.Linear(...), ...]), PyTorch will NOT register them as sub-modules. Their parameters will not appear in model.parameters(), they won't be moved when you call model.to(device), and they won't be saved in state_dict. Always use nn.ModuleList or nn.ModuleDict.