Skip to main content

Image Segmentation

Semantic, instance, and panoptic segmentation with U-Net, Mask R-CNN, and SAM

~50 min
Listen to this lesson

Image Segmentation

While object detection draws bounding boxes around objects, segmentation provides pixel-level understanding of the image.

Types of Segmentation

Semantic Segmentation

Assigns a class label to every pixel in the image. All pixels belonging to "car" get the same label, regardless of which individual car they belong to.
  • Does NOT distinguish between different instances of the same class
  • Use case: autonomous driving (road vs sidewalk vs building), medical imaging
  • Instance Segmentation

    Detects each object instance and provides a pixel mask for each one. Two different cars get different masks.
  • Combines object detection with segmentation
  • Use case: counting objects, robotics, photo editing
  • Panoptic Segmentation

    Combines semantic and instance segmentation:
  • "Stuff" classes (sky, road, grass) get semantic segmentation
  • "Thing" classes (car, person, dog) get instance segmentation
  • Every pixel gets a label — no pixel is left unclassified
  • U-Net Architecture

    U-Net is the foundational architecture for segmentation, originally developed for biomedical image segmentation (2015).

    Encoder (Contracting Path)

  • Standard CNN backbone (e.g., ResNet) that progressively downsamples
  • Each level extracts features at a different scale
  • Captures what is in the image (semantic information)
  • Decoder (Expanding Path)

  • Uses transposed convolutions (or bilinear upsampling) to increase spatial resolution
  • Gradually recovers the full-resolution output
  • Captures where things are (spatial information)
  • Skip Connections

    The critical innovation: feature maps from the encoder are concatenated with the corresponding decoder feature maps. This provides:
  • High-resolution spatial details from early encoder layers
  • Rich semantic information from deep encoder layers
  • Sharp, precise segmentation boundaries
  • U-Net Skip Connections vs ResNet Skip Connections

    Both architectures use skip connections, but differently: - **ResNet**: Skip connections ADD the input to the output (element-wise addition) within the same resolution level. Purpose: ease gradient flow. - **U-Net**: Skip connections CONCATENATE encoder features with decoder features across different parts of the network. Purpose: combine spatial detail from the encoder with semantic information from the decoder. U-Net skips bridge across the encoder-decoder gap. ResNet skips bridge across layers at the same resolution.

    Mask R-CNN

    Mask R-CNN extends Faster R-CNN by adding a segmentation branch: 1. Backbone CNN extracts features (typically ResNet + FPN) 2. Region Proposal Network generates proposals 3. RoI Align (improved RoI Pooling without quantization) extracts features 4. Three parallel heads: - Classification head (what class?) - Box regression head (where exactly?) - Mask head (pixel-level mask for each instance)

    Key insight: The mask head is a small FCN (Fully Convolutional Network) applied to each RoI independently, predicting a binary mask per class.

    Segment Anything Model (SAM)

    Meta's SAM (2023) is a foundation model for segmentation:

  • Trained on 1 billion masks from 11 million images (SA-1B dataset)
  • Promptable: Accepts points, boxes, or text as input prompts
  • Zero-shot: Segments objects it has never seen before
  • Architecture: ViT encoder + prompt encoder + lightweight mask decoder
  • SAM 2 (2024)

    Extends SAM to video with temporal consistency:
  • Tracks and segments objects across video frames
  • Memory mechanism maintains object identity over time
  • Loss Functions for Segmentation

    Pixel-wise Cross-Entropy Loss

    Standard classification loss applied to each pixel independently: $$L = -\frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C} y_{i,c} \log(p_{i,c})$$
  • Simple and effective
  • Problem: doesn't handle class imbalance well (e.g., small tumors in large images)
  • Dice Loss

    Based on the Dice coefficient (F1 score for sets): $$\text{Dice} = \frac{2|A \cap B|}{|A| + |B|} = \frac{2\sum p_i g_i}{\sum p_i + \sum g_i}$$ $$L_{\text{dice}} = 1 - \text{Dice}$$
  • Handles class imbalance naturally
  • Measures overlap between predicted and ground truth masks
  • Common in medical imaging
  • Focal Loss

    Downweights easy examples to focus learning on hard ones: $$L_{\text{focal}} = -\alpha_t (1 - p_t)^\gamma \log(p_t)$$
  • \u03b3 (gamma) controls how much to downweight easy examples (typically \u03b3 = 2)
  • Very effective for extreme class imbalance
  • Evaluation Metrics

    MetricDescription
    Pixel Accuracy% of pixels correctly classified (misleading with imbalance)
    IoU (per class)Intersection / Union for each class
    mIoUMean IoU across all classes — the standard metric
    Dice Score2 * intersection / (pred + GT) — equivalent to F1
    Boundary F1F1 score computed only on boundary pixels

    python
    1# ==============================================================
    2# Simple U-Net implementation in PyTorch
    3# ==============================================================
    4import torch
    5import torch.nn as nn
    6
    7class DoubleConv(nn.Module):
    8    """Two 3x3 convolutions with BatchNorm and ReLU."""
    9    def __init__(self, in_ch, out_ch):
    10        super().__init__()
    11        self.block = nn.Sequential(
    12            nn.Conv2d(in_ch, out_ch, 3, padding=1),
    13            nn.BatchNorm2d(out_ch),
    14            nn.ReLU(inplace=True),
    15            nn.Conv2d(out_ch, out_ch, 3, padding=1),
    16            nn.BatchNorm2d(out_ch),
    17            nn.ReLU(inplace=True),
    18        )
    19
    20    def forward(self, x):
    21        return self.block(x)
    22
    23
    24class UNet(nn.Module):
    25    def __init__(self, in_channels=3, num_classes=1):
    26        super().__init__()
    27        # Encoder
    28        self.enc1 = DoubleConv(in_channels, 64)
    29        self.enc2 = DoubleConv(64, 128)
    30        self.enc3 = DoubleConv(128, 256)
    31        self.enc4 = DoubleConv(256, 512)
    32
    33        # Bottleneck
    34        self.bottleneck = DoubleConv(512, 1024)
    35
    36        # Decoder
    37        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
    38        self.dec4 = DoubleConv(1024, 512)  # 512 from up + 512 from skip
    39        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
    40        self.dec3 = DoubleConv(512, 256)
    41        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
    42        self.dec2 = DoubleConv(256, 128)
    43        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
    44        self.dec1 = DoubleConv(128, 64)
    45
    46        # Output
    47        self.out_conv = nn.Conv2d(64, num_classes, 1)
    48        self.pool = nn.MaxPool2d(2)
    49
    50    def forward(self, x):
    51        # Encoder
    52        e1 = self.enc1(x)              # [B, 64, H, W]
    53        e2 = self.enc2(self.pool(e1))   # [B, 128, H/2, W/2]
    54        e3 = self.enc3(self.pool(e2))   # [B, 256, H/4, W/4]
    55        e4 = self.enc4(self.pool(e3))   # [B, 512, H/8, W/8]
    56
    57        # Bottleneck
    58        b = self.bottleneck(self.pool(e4))  # [B, 1024, H/16, W/16]
    59
    60        # Decoder with skip connections (concatenation)
    61        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
    62        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
    63        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
    64        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
    65
    66        return self.out_conv(d1)
    67
    68# Test the architecture
    69model = UNet(in_channels=3, num_classes=21)  # 21 classes for Pascal VOC
    70x = torch.randn(2, 3, 256, 256)
    71out = model(x)
    72print(f"Input shape:  {x.shape}")      # [2, 3, 256, 256]
    73print(f"Output shape: {out.shape}")     # [2, 21, 256, 256]
    74print(f"Parameters:   {sum(p.numel() for p in model.parameters()):,}")
    python
    1# ==============================================================
    2# Dice Loss and combined loss implementation
    3# ==============================================================
    4import torch
    5import torch.nn as nn
    6import torch.nn.functional as F
    7
    8class DiceLoss(nn.Module):
    9    """Dice Loss for binary or multiclass segmentation."""
    10    def __init__(self, smooth=1.0):
    11        super().__init__()
    12        self.smooth = smooth
    13
    14    def forward(self, pred, target):
    15        # pred: [B, C, H, W] (logits)
    16        # target: [B, H, W] (class indices)
    17        pred = F.softmax(pred, dim=1)
    18        num_classes = pred.shape[1]
    19
    20        # One-hot encode target
    21        target_onehot = F.one_hot(target, num_classes)  # [B, H, W, C]
    22        target_onehot = target_onehot.permute(0, 3, 1, 2).float()  # [B, C, H, W]
    23
    24        # Compute dice per class
    25        intersection = (pred * target_onehot).sum(dim=(2, 3))
    26        union = pred.sum(dim=(2, 3)) + target_onehot.sum(dim=(2, 3))
    27
    28        dice = (2 * intersection + self.smooth) / (union + self.smooth)
    29        return 1 - dice.mean()
    30
    31
    32class CombinedLoss(nn.Module):
    33    """Combine Cross-Entropy and Dice Loss."""
    34    def __init__(self, ce_weight=0.5, dice_weight=0.5):
    35        super().__init__()
    36        self.ce = nn.CrossEntropyLoss()
    37        self.dice = DiceLoss()
    38        self.ce_weight = ce_weight
    39        self.dice_weight = dice_weight
    40
    41    def forward(self, pred, target):
    42        return (self.ce_weight * self.ce(pred, target) +
    43                self.dice_weight * self.dice(pred, target))
    44
    45# Usage
    46criterion = CombinedLoss(ce_weight=0.5, dice_weight=0.5)
    47pred = torch.randn(4, 21, 256, 256)  # predictions for 4 images, 21 classes
    48target = torch.randint(0, 21, (4, 256, 256))  # ground truth labels
    49loss = criterion(pred, target)
    50print(f"Combined loss: {loss.item():.4f}")
    python
    1# ==============================================================
    2# Using Segment Anything Model (SAM)
    3# pip install segment-anything
    4# ==============================================================
    5from segment_anything import sam_model_registry, SamPredictor
    6import numpy as np
    7import cv2
    8import matplotlib.pyplot as plt
    9
    10# Load SAM model
    11sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
    12sam.to("cuda" if torch.cuda.is_available() else "cpu")
    13predictor = SamPredictor(sam)
    14
    15# Load and set image
    16image = cv2.imread("example.jpg")
    17image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    18predictor.set_image(image_rgb)
    19
    20# Segment with a point prompt
    21# (x, y) point and label (1 = foreground, 0 = background)
    22input_point = np.array([[500, 375]])  # click location
    23input_label = np.array([1])           # foreground
    24
    25masks, scores, logits = predictor.predict(
    26    point_coords=input_point,
    27    point_labels=input_label,
    28    multimask_output=True,  # returns 3 masks at different granularities
    29)
    30
    31# Visualize the best mask
    32best_idx = np.argmax(scores)
    33best_mask = masks[best_idx]
    34
    35fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    36axes[0].imshow(image_rgb)
    37axes[0].set_title("Original Image")
    38
    39axes[1].imshow(image_rgb)
    40axes[1].imshow(best_mask, alpha=0.5, cmap="jet")
    41axes[1].scatter(*input_point[0], c="red", s=200, marker="*")
    42axes[1].set_title(f"Best Mask (score: {scores[best_idx]:.3f})")
    43
    44# Segment with a box prompt
    45input_box = np.array([100, 100, 600, 500])  # [x1, y1, x2, y2]
    46masks_box, scores_box, _ = predictor.predict(box=input_box, multimask_output=False)
    47
    48axes[2].imshow(image_rgb)
    49axes[2].imshow(masks_box[0], alpha=0.5, cmap="jet")
    50axes[2].set_title("Box-Prompted Segmentation")
    51
    52for ax in axes:
    53    ax.axis("off")
    54plt.tight_layout()
    55plt.show()

    mIoU: The Standard Segmentation Metric

    mIoU (mean Intersection over Union) is computed as: 1. For each class, compute IoU = TP / (TP + FP + FN) 2. Average IoU across all classes Important: mIoU treats all classes equally regardless of pixel count. A rare class (e.g., bicycle) matters as much as a common class (e.g., road). This makes it a fair metric but can be dominated by hard-to-segment classes.