Skip to main content

Training Loops & Data Loading

Datasets, DataLoaders, the training loop pattern, schedulers, and mixed precision

~55 min
Listen to this lesson

Training Loops & Data Loading

Training a neural network in PyTorch means writing an explicit training loop. Unlike higher-level frameworks that hide the loop, PyTorch gives you full control. This is both its greatest strength (flexibility) and a common source of bugs (forgetting to zero gradients, forgetting to call .eval(), etc.).

Dataset and DataLoader

PyTorch's data pipeline has two core abstractions:

  • Dataset: Holds your data and defines how to access individual samples.
  • DataLoader: Wraps a Dataset and provides batching, shuffling, and parallel loading.
  • python
    1import torch
    2from torch.utils.data import Dataset, DataLoader
    3
    4# --- Built-in datasets (torchvision, torchaudio, torchtext) ---
    5from torchvision import datasets, transforms
    6
    7transform = transforms.Compose([
    8    transforms.ToTensor(),                          # PIL Image -> Tensor
    9    transforms.Normalize((0.1307,), (0.3081,)),     # MNIST mean and std
    10])
    11
    12train_dataset = datasets.MNIST(
    13    root="./data", train=True, download=True, transform=transform
    14)
    15test_dataset = datasets.MNIST(
    16    root="./data", train=False, download=True, transform=transform
    17)
    18
    19# --- DataLoader: batching, shuffling, parallel workers ---
    20train_loader = DataLoader(
    21    train_dataset,
    22    batch_size=64,
    23    shuffle=True,        # Shuffle every epoch (important for training!)
    24    num_workers=4,       # Parallel data loading processes
    25    pin_memory=True,     # Faster CPU->GPU transfer
    26    drop_last=True,      # Drop incomplete last batch
    27)
    28
    29# Iterate over batches
    30for batch_idx, (images, labels) in enumerate(train_loader):
    31    print(f"Batch {batch_idx}: images={images.shape}, labels={labels.shape}")
    32    # images: (64, 1, 28, 28), labels: (64,)
    33    if batch_idx == 2:
    34        break

    Custom Dataset

    For your own data, subclass Dataset and implement three methods:

    python
    1import torch
    2from torch.utils.data import Dataset
    3import pandas as pd
    4
    5class CSVDataset(Dataset):
    6    """Custom dataset that reads from a CSV file."""
    7
    8    def __init__(self, csv_path, target_column, transform=None):
    9        self.data = pd.read_csv(csv_path)
    10        self.target_column = target_column
    11        self.transform = transform
    12
    13        # Separate features and target
    14        self.features = self.data.drop(columns=[target_column]).values
    15        self.targets = self.data[target_column].values
    16
    17    def __len__(self):
    18        """Return the total number of samples."""
    19        return len(self.data)
    20
    21    def __getitem__(self, idx):
    22        """Return a single sample (features, target) at the given index."""
    23        x = torch.tensor(self.features[idx], dtype=torch.float32)
    24        y = torch.tensor(self.targets[idx], dtype=torch.long)
    25
    26        if self.transform:
    27            x = self.transform(x)
    28
    29        return x, y
    30
    31# Usage
    32dataset = CSVDataset("train.csv", target_column="label")
    33loader = DataLoader(dataset, batch_size=32, shuffle=True)

    Data Loading Performance Tips

    Use num_workers > 0 (try 4 or 8) for parallel data loading. Set pin_memory=True when training on GPU — it speeds up CPU-to-GPU transfer. Use persistent_workers=True (PyTorch 1.7+) to avoid re-spawning workers each epoch. If your dataset fits in RAM, preload everything in __init__ rather than reading from disk in __getitem__.

    The Training Loop

    Here is the canonical PyTorch training loop, annotated step by step:

    python
    1import torch
    2import torch.nn as nn
    3import torch.optim as optim
    4
    5# Setup
    6model = SimpleClassifier(784, 256, 10)
    7device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    8model = model.to(device)
    9
    10criterion = nn.CrossEntropyLoss()
    11optimizer = optim.Adam(model.parameters(), lr=1e-3)
    12
    13# ============= TRAINING LOOP =============
    14num_epochs = 10
    15
    16for epoch in range(num_epochs):
    17    model.train()                          # 1. Set training mode (enables dropout, batchnorm training behavior)
    18    running_loss = 0.0
    19    correct = 0
    20    total = 0
    21
    22    for batch_idx, (inputs, targets) in enumerate(train_loader):
    23        inputs, targets = inputs.to(device), targets.to(device)  # 2. Move data to device
    24
    25        optimizer.zero_grad()              # 3. Zero all parameter gradients (CRITICAL!)
    26
    27        outputs = model(inputs.view(inputs.size(0), -1))  # 4. Forward pass
    28        loss = criterion(outputs, targets)                 # 5. Compute loss
    29
    30        loss.backward()                    # 6. Backward pass (compute gradients)
    31        optimizer.step()                   # 7. Update parameters
    32
    33        # Track metrics
    34        running_loss += loss.item()
    35        _, predicted = outputs.max(1)
    36        total += targets.size(0)
    37        correct += predicted.eq(targets).sum().item()
    38
    39    train_loss = running_loss / len(train_loader)
    40    train_acc = 100.0 * correct / total
    41    print(f"Epoch {epoch+1}/{num_epochs} — Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")

    The Five Steps of a Training Iteration

    Every training step follows this pattern: 1. optimizer.zero_grad() — clear accumulated gradients 2. outputs = model(inputs) — forward pass 3. loss = criterion(outputs, targets) — compute loss 4. loss.backward() — compute gradients via backprop 5. optimizer.step() — update parameters using gradients Forgetting zero_grad() causes gradients to accumulate across batches, leading to incorrect updates.

    Validation Loop

    The validation loop is similar but simpler — no gradients needed:

    python
    1def validate(model, val_loader, criterion, device):
    2    model.eval()                           # Set evaluation mode (disables dropout, uses running stats for batchnorm)
    3    val_loss = 0.0
    4    correct = 0
    5    total = 0
    6
    7    with torch.no_grad():                 # Disable gradient computation
    8        for inputs, targets in val_loader:
    9            inputs, targets = inputs.to(device), targets.to(device)
    10            outputs = model(inputs.view(inputs.size(0), -1))
    11            loss = criterion(outputs, targets)
    12
    13            val_loss += loss.item()
    14            _, predicted = outputs.max(1)
    15            total += targets.size(0)
    16            correct += predicted.eq(targets).sum().item()
    17
    18    avg_loss = val_loss / len(val_loader)
    19    accuracy = 100.0 * correct / total
    20    return avg_loss, accuracy
    21
    22# Call after each training epoch
    23val_loss, val_acc = validate(model, val_loader, criterion, device)
    24print(f"Validation — Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")

    model.train() vs model.eval()

    These are NOT optional! model.train() enables dropout and uses batch statistics for BatchNorm. model.eval() disables dropout and uses running statistics for BatchNorm. Forgetting model.eval() during validation gives unreliable metrics. Forgetting model.train() before the next training epoch disables dropout entirely.

    Learning Rate Schedulers

    Adjusting the learning rate during training often improves convergence:

    python
    1import torch.optim as optim
    2from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, OneCycleLR
    3
    4optimizer = optim.Adam(model.parameters(), lr=1e-3)
    5
    6# --- StepLR: multiply LR by gamma every step_size epochs ---
    7scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
    8# LR: 1e-3 -> 1e-4 (epoch 10) -> 1e-5 (epoch 20)
    9
    10# --- CosineAnnealingLR: smooth cosine decay ---
    11scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
    12# LR follows a cosine curve from 1e-3 to 1e-6 over 50 epochs
    13
    14# --- OneCycleLR: warmup then decay (per-batch scheduler) ---
    15scheduler = OneCycleLR(
    16    optimizer,
    17    max_lr=1e-3,
    18    epochs=num_epochs,
    19    steps_per_epoch=len(train_loader),
    20)
    21
    22# Usage in training loop:
    23for epoch in range(num_epochs):
    24    model.train()
    25    for inputs, targets in train_loader:
    26        optimizer.zero_grad()
    27        outputs = model(inputs)
    28        loss = criterion(outputs, targets)
    29        loss.backward()
    30        optimizer.step()
    31        # For OneCycleLR, step per batch:
    32        # scheduler.step()
    33
    34    # For StepLR / CosineAnnealingLR, step per epoch:
    35    scheduler.step()
    36    print(f"LR: {scheduler.get_last_lr()[0]:.6f}")

    Mixed Precision Training

    Use float16 for faster training with less GPU memory:

    python
    1import torch
    2from torch.cuda.amp import autocast, GradScaler
    3
    4# Mixed precision: compute forward pass in float16, keep master weights in float32
    5scaler = GradScaler()
    6
    7for inputs, targets in train_loader:
    8    inputs, targets = inputs.to(device), targets.to(device)
    9    optimizer.zero_grad()
    10
    11    # Automatic mixed precision: operations inside autocast run in float16
    12    with autocast():
    13        outputs = model(inputs)
    14        loss = criterion(outputs, targets)
    15
    16    # Scale loss to prevent float16 underflow, then backward
    17    scaler.scale(loss).backward()
    18
    19    # Unscale gradients and step optimizer
    20    scaler.step(optimizer)
    21    scaler.update()

    Gradient Clipping

    Prevents exploding gradients, especially in RNNs and Transformers:

    python
    1import torch.nn.utils as utils
    2
    3for inputs, targets in train_loader:
    4    optimizer.zero_grad()
    5    outputs = model(inputs)
    6    loss = criterion(outputs, targets)
    7    loss.backward()
    8
    9    # Clip gradient norms to max_norm (e.g., 1.0)
    10    utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    11
    12    optimizer.step()

    Gradient Clipping Placement

    Gradient clipping must happen AFTER loss.backward() (so gradients exist) and BEFORE optimizer.step() (so the update uses clipped gradients). clip_grad_norm_ clips by the total gradient norm across all parameters, which is generally preferred over clip_grad_value_ which clips each gradient element independently.