Skip to main content

PyTorch Lightning & Modern Tooling

Reduce boilerplate with Lightning, scale to multiple GPUs, and export models with ONNX

~45 min
Listen to this lesson

PyTorch Lightning & Modern Tooling

Raw PyTorch gives you full control, but the training loop boilerplate can become repetitive and error-prone. PyTorch Lightning is a lightweight wrapper that handles the engineering while you focus on the science.

Why Lightning?

Lightning eliminates boilerplate for:

  • Training/validation/test loops
  • Device management (CPU, GPU, multi-GPU, TPU)
  • Mixed precision training
  • Gradient accumulation and clipping
  • Logging (TensorBoard, W&B, CSV)
  • Checkpointing and early stopping
  • Distributed training across multiple GPUs/nodes
  • You still write pure PyTorch — Lightning just organizes it.

    LightningModule

    Instead of a raw training loop, you define a LightningModule that encapsulates model + training logic:

    python
    1import torch
    2import torch.nn as nn
    3import torch.nn.functional as F
    4import lightning as L
    5from torchmetrics import Accuracy
    6
    7class MNISTClassifier(L.LightningModule):
    8    def __init__(self, hidden_size=256, lr=1e-3):
    9        super().__init__()
    10        self.save_hyperparameters()  # Saves all __init__ args to self.hparams
    11
    12        # Model architecture (same as before)
    13        self.model = nn.Sequential(
    14            nn.Flatten(),
    15            nn.Linear(784, hidden_size),
    16            nn.ReLU(),
    17            nn.Dropout(0.2),
    18            nn.Linear(hidden_size, 10),
    19        )
    20
    21        # Metrics
    22        self.train_acc = Accuracy(task="multiclass", num_classes=10)
    23        self.val_acc = Accuracy(task="multiclass", num_classes=10)
    24
    25    def forward(self, x):
    26        """Used for inference: model(x)"""
    27        return self.model(x)
    28
    29    def training_step(self, batch, batch_idx):
    30        """Replaces the inner training loop."""
    31        x, y = batch
    32        logits = self(x)
    33        loss = F.cross_entropy(logits, y)
    34
    35        # Log metrics (automatically handles accumulation and logging)
    36        preds = logits.argmax(dim=1)
    37        self.train_acc(preds, y)
    38        self.log("train_loss", loss, prog_bar=True)
    39        self.log("train_acc", self.train_acc, prog_bar=True)
    40
    41        return loss  # Lightning handles backward() and optimizer.step()
    42
    43    def validation_step(self, batch, batch_idx):
    44        """Replaces the validation loop."""
    45        x, y = batch
    46        logits = self(x)
    47        loss = F.cross_entropy(logits, y)
    48
    49        preds = logits.argmax(dim=1)
    50        self.val_acc(preds, y)
    51        self.log("val_loss", loss, prog_bar=True)
    52        self.log("val_acc", self.val_acc, prog_bar=True)
    53
    54    def configure_optimizers(self):
    55        """Define optimizer and (optional) scheduler."""
    56        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    57        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    58        return [optimizer], [scheduler]

    What LightningModule Replaces

    training_step replaces the inner training loop body (forward + loss + logging). validation_step replaces the inner validation loop body. configure_optimizers replaces manual optimizer and scheduler creation. Lightning automatically handles: zero_grad, backward, optimizer.step, scheduler.step, model.train()/eval(), torch.no_grad(), device transfers, and metric aggregation.

    The Trainer

    The Trainer handles everything outside the model: hardware, logging, callbacks, and distributed training.

    python
    1import lightning as L
    2from lightning.pytorch.callbacks import (
    3    ModelCheckpoint,
    4    EarlyStopping,
    5    LearningRateMonitor,
    6    RichProgressBar,
    7)
    8from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
    9
    10# --- Callbacks ---
    11checkpoint_cb = ModelCheckpoint(
    12    monitor="val_loss",
    13    mode="min",
    14    save_top_k=3,               # Keep top 3 models
    15    filename="{epoch}-{val_loss:.2f}",
    16)
    17
    18early_stop_cb = EarlyStopping(
    19    monitor="val_loss",
    20    patience=5,                 # Stop if val_loss doesn't improve for 5 epochs
    21    mode="min",
    22)
    23
    24lr_monitor = LearningRateMonitor(logging_interval="epoch")
    25
    26# --- Logger ---
    27tb_logger = TensorBoardLogger("logs/", name="mnist")
    28# wandb_logger = WandbLogger(project="mnist", name="run-01")
    29
    30# --- Create Trainer ---
    31trainer = L.Trainer(
    32    max_epochs=20,
    33    accelerator="auto",         # Automatically use GPU/MPS/CPU
    34    devices="auto",             # Use all available GPUs
    35    precision="16-mixed",       # Mixed precision training
    36    callbacks=[checkpoint_cb, early_stop_cb, lr_monitor],
    37    logger=tb_logger,
    38    gradient_clip_val=1.0,      # Gradient clipping
    39    accumulate_grad_batches=4,  # Simulate 4x larger batch size
    40    log_every_n_steps=10,
    41)
    42
    43# --- Train ---
    44from torch.utils.data import DataLoader
    45from torchvision import datasets, transforms
    46
    47transform = transforms.Compose([
    48    transforms.ToTensor(),
    49    transforms.Normalize((0.1307,), (0.3081,)),
    50])
    51
    52train_ds = datasets.MNIST("./data", train=True, download=True, transform=transform)
    53val_ds = datasets.MNIST("./data", train=False, transform=transform)
    54
    55train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4)
    56val_loader = DataLoader(val_ds, batch_size=256, num_workers=4)
    57
    58model = MNISTClassifier(hidden_size=256, lr=1e-3)
    59trainer.fit(model, train_loader, val_loader)
    60
    61# --- Test ---
    62trainer.test(model, val_loader)
    63
    64# --- Load best checkpoint ---
    65best_model = MNISTClassifier.load_from_checkpoint(checkpoint_cb.best_model_path)

    LightningDataModule

    For production, wrap your data pipeline in a LightningDataModule. It encapsulates prepare_data() (downloads), setup() (splits), and train/val/test_dataloader() methods. This keeps data logic separate from model logic and makes your code more reusable.

    Lightning Fabric

    If you want some Lightning benefits without the full LightningModule structure, Fabric is a lighter-weight alternative:

    python
    1import lightning as L
    2import torch
    3import torch.nn as nn
    4
    5# Fabric: keeps your raw training loop but handles hardware
    6fabric = L.Fabric(accelerator="auto", devices="auto", precision="16-mixed")
    7fabric.launch()
    8
    9model = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))
    10optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    11
    12# Fabric wraps model, optimizer, and dataloaders
    13model, optimizer = fabric.setup(model, optimizer)
    14train_loader = fabric.setup_dataloaders(train_loader)
    15
    16# Your normal training loop — but it works on any hardware!
    17for epoch in range(10):
    18    model.train()
    19    for batch in train_loader:
    20        images, labels = batch
    21        optimizer.zero_grad()
    22        output = model(images)
    23        loss = nn.functional.cross_entropy(output, labels)
    24        fabric.backward(loss)   # Use fabric.backward instead of loss.backward
    25        optimizer.step()

    TorchMetrics

    Lightning integrates with TorchMetrics for metric computation that works correctly across distributed training:

    python
    1from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC
    2from torchmetrics import MetricCollection
    3
    4# Individual metrics
    5accuracy = Accuracy(task="multiclass", num_classes=10)
    6precision = Precision(task="multiclass", num_classes=10, average="macro")
    7recall = Recall(task="multiclass", num_classes=10, average="macro")
    8f1 = F1Score(task="multiclass", num_classes=10, average="macro")
    9
    10# MetricCollection: compute all at once
    11metrics = MetricCollection({
    12    "accuracy": Accuracy(task="multiclass", num_classes=10),
    13    "precision": Precision(task="multiclass", num_classes=10, average="macro"),
    14    "recall": Recall(task="multiclass", num_classes=10, average="macro"),
    15    "f1": F1Score(task="multiclass", num_classes=10, average="macro"),
    16})
    17
    18# Usage in LightningModule
    19class MyModel(L.LightningModule):
    20    def __init__(self):
    21        super().__init__()
    22        self.train_metrics = metrics.clone(prefix="train_")
    23        self.val_metrics = metrics.clone(prefix="val_")
    24
    25    def training_step(self, batch, batch_idx):
    26        x, y = batch
    27        logits = self(x)
    28        loss = F.cross_entropy(logits, y)
    29        self.train_metrics(logits.argmax(1), y)
    30        self.log_dict(self.train_metrics, prog_bar=True)
    31        return loss

    ONNX Export

    Export your PyTorch model to ONNX (Open Neural Network Exchange) for deployment in other runtimes:

    python
    1import torch
    2import torch.onnx
    3
    4model = MNISTClassifier(hidden_size=256)
    5model.eval()
    6
    7# Create dummy input with the correct shape
    8dummy_input = torch.randn(1, 1, 28, 28)
    9
    10# Export to ONNX
    11torch.onnx.export(
    12    model,                       # Model to export
    13    dummy_input,                 # Example input
    14    "mnist_model.onnx",         # Output file
    15    export_params=True,          # Store trained weights
    16    opset_version=17,            # ONNX opset version
    17    input_names=["image"],       # Input tensor name
    18    output_names=["logits"],     # Output tensor name
    19    dynamic_axes={               # Variable-length axes (for batching)
    20        "image": {0: "batch_size"},
    21        "logits": {0: "batch_size"},
    22    },
    23)
    24
    25# Verify the exported model
    26import onnx
    27onnx_model = onnx.load("mnist_model.onnx")
    28onnx.checker.check_model(onnx_model)
    29print("ONNX model is valid!")
    30
    31# Run inference with ONNX Runtime
    32import onnxruntime as ort
    33session = ort.InferenceSession("mnist_model.onnx")
    34input_data = dummy_input.numpy()
    35result = session.run(None, {"image": input_data})
    36print(f"ONNX output shape: {result[0].shape}")

    Why Export to ONNX?

    ONNX is a universal model format supported by many inference runtimes (ONNX Runtime, TensorRT, CoreML, OpenVINO). Exporting to ONNX lets you: (1) deploy models without a PyTorch dependency, (2) run on specialized hardware accelerators, (3) use language-specific runtimes (C++, Java, C#), (4) get significant inference speedups from optimized runtimes. ONNX Runtime alone can give 2-3x speedup over native PyTorch inference.

    Integration with W&B (Weights & Biases)

    Weights & Biases provides experiment tracking, visualization, and collaboration:

    python
    1# pip install wandb
    2import wandb
    3from lightning.pytorch.loggers import WandbLogger
    4
    5# Option 1: Lightning integration
    6wandb_logger = WandbLogger(
    7    project="my-project",
    8    name="experiment-01",
    9    log_model=True,         # Log model checkpoints to W&B
    10)
    11
    12trainer = L.Trainer(
    13    logger=wandb_logger,
    14    max_epochs=20,
    15)
    16
    17# Option 2: Raw PyTorch integration
    18wandb.init(project="my-project", config={"lr": 1e-3, "epochs": 20})
    19for epoch in range(20):
    20    train_loss = train_one_epoch(...)
    21    val_loss = evaluate(...)
    22    wandb.log({
    23        "train_loss": train_loss,
    24        "val_loss": val_loss,
    25        "epoch": epoch,
    26    })
    27wandb.finish()