Model Optimization for Edge Deployment
Deploying ML models to edge devices — mobile phones, IoT sensors, embedded systems — requires models that are small, fast, and energy-efficient. A state-of-the-art model that runs on a GPU cluster is useless if it cannot run on a phone in real time. This lesson covers the four pillars of model optimization: quantization, pruning, knowledge distillation, and efficient architecture design.
The Edge AI Challenge
Quantization
Quantization reduces the precision of model weights and activations from 32-bit floating point (FP32) to lower-bit representations (INT8, INT4, or even binary). This reduces model size, memory usage, and inference latency.
Why Quantization Works
Neural networks are robust to noise — small perturbations in weight values have minimal impact on output quality. Quantization exploits this by mapping the continuous FP32 range to a discrete set of lower-precision values.
Quantization Types
| Type | Description | Accuracy Loss | Speed Gain |
|---|---|---|---|
| FP32 (baseline) | Standard 32-bit float | None | 1x |
| FP16 / BF16 | Half precision | Minimal | ~2x |
| INT8 | 8-bit integer | Small (1-2%) | ~2-4x |
| INT4 | 4-bit integer | Moderate (2-5%) | ~4-8x |
| Binary/Ternary | 1-2 bit | Significant | ~10-32x |
Dynamic vs Static Quantization
1import torch
2import torch.nn as nn
3import torch.quantization as quant
4import time
5
6# --- Define a simple model ---
7class SimpleClassifier(nn.Module):
8 def __init__(self, input_dim=784, hidden=256, output=10):
9 super().__init__()
10 self.fc1 = nn.Linear(input_dim, hidden)
11 self.relu = nn.ReLU()
12 self.fc2 = nn.Linear(hidden, hidden)
13 self.relu2 = nn.ReLU()
14 self.fc3 = nn.Linear(hidden, output)
15
16 def forward(self, x):
17 x = self.relu(self.fc1(x))
18 x = self.relu2(self.fc2(x))
19 return self.fc3(x)
20
21model = SimpleClassifier()
22model.eval()
23
24# --- Dynamic Quantization (easiest) ---
25quantized_model = torch.quantization.quantize_dynamic(
26 model,
27 {nn.Linear}, # Quantize all Linear layers
28 dtype=torch.qint8, # Use INT8
29)
30
31# --- Compare model sizes ---
32def get_model_size(model):
33 """Get model size in MB."""
34 torch.save(model.state_dict(), "/tmp/model.pt")
35 import os
36 size = os.path.getsize("/tmp/model.pt") / 1e6
37 os.remove("/tmp/model.pt")
38 return size
39
40original_size = get_model_size(model)
41quantized_size = get_model_size(quantized_model)
42
43print(f"Original model size: {original_size:.2f} MB")
44print(f"Quantized model size: {quantized_size:.2f} MB")
45print(f"Compression ratio: {original_size / quantized_size:.1f}x")
46
47# --- Compare inference speed ---
48dummy_input = torch.randn(1, 784)
49n_iters = 1000
50
51start = time.time()
52for _ in range(n_iters):
53 with torch.no_grad():
54 model(dummy_input)
55original_time = (time.time() - start) / n_iters * 1000
56
57start = time.time()
58for _ in range(n_iters):
59 with torch.no_grad():
60 quantized_model(dummy_input)
61quantized_time = (time.time() - start) / n_iters * 1000
62
63print(f"\nOriginal inference: {original_time:.3f} ms")
64print(f"Quantized inference: {quantized_time:.3f} ms")
65print(f"Speedup: {original_time / quantized_time:.1f}x")Pruning
Pruning removes unnecessary weights or neurons from a model, making it smaller and faster without significant accuracy loss.
Unstructured Pruning
Sets individual weights to zero based on magnitude (smallest weights are least important). Creates sparse weight matrices.Structured Pruning
Removes entire filters, channels, or layers rather than individual weights. Produces dense, smaller models that run faster on standard hardware.Knowledge Distillation
Knowledge distillation trains a small "student" model to mimic a large "teacher" model. The student learns from the teacher's soft probability outputs (which contain richer information than hard labels).
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5# --- Knowledge Distillation ---
6
7class TeacherModel(nn.Module):
8 """Large, accurate model (e.g., ResNet-152)."""
9 def __init__(self):
10 super().__init__()
11 self.fc1 = nn.Linear(784, 512)
12 self.fc2 = nn.Linear(512, 256)
13 self.fc3 = nn.Linear(256, 10)
14
15 def forward(self, x):
16 x = F.relu(self.fc1(x))
17 x = F.relu(self.fc2(x))
18 return self.fc3(x)
19
20class StudentModel(nn.Module):
21 """Small, fast model for edge deployment."""
22 def __init__(self):
23 super().__init__()
24 self.fc1 = nn.Linear(784, 64)
25 self.fc2 = nn.Linear(64, 10)
26
27 def forward(self, x):
28 x = F.relu(self.fc1(x))
29 return self.fc2(x)
30
31def distillation_loss(student_logits, teacher_logits, labels,
32 temperature=4.0, alpha=0.7):
33 """Combined distillation and classification loss.
34
35 Args:
36 student_logits: Raw output from student model
37 teacher_logits: Raw output from teacher model
38 labels: True class labels
39 temperature: Softmax temperature (higher = softer distributions)
40 alpha: Weight for distillation loss vs classification loss
41 """
42 # Soft targets from teacher (with temperature)
43 soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
44 soft_student = F.log_softmax(student_logits / temperature, dim=1)
45
46 # KL divergence between student and teacher soft outputs
47 distill_loss = F.kl_div(
48 soft_student, soft_teacher, reduction="batchmean"
49 ) * (temperature ** 2)
50
51 # Standard cross-entropy with true labels
52 hard_loss = F.cross_entropy(student_logits, labels)
53
54 # Weighted combination
55 return alpha * distill_loss + (1 - alpha) * hard_loss
56
57
58# --- Training loop (simplified) ---
59teacher = TeacherModel()
60student = StudentModel()
61optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
62
63# Simulate training
64teacher.eval() # Teacher is frozen
65student.train()
66
67for epoch in range(5):
68 # Simulated batch
69 x = torch.randn(32, 784)
70 labels = torch.randint(0, 10, (32,))
71
72 with torch.no_grad():
73 teacher_logits = teacher(x)
74
75 student_logits = student(x)
76 loss = distillation_loss(student_logits, teacher_logits, labels)
77
78 optimizer.zero_grad()
79 loss.backward()
80 optimizer.step()
81
82 print(f"Epoch {epoch + 1}: loss = {loss.item():.4f}")
83
84# Compare sizes
85teacher_params = sum(p.numel() for p in teacher.parameters())
86student_params = sum(p.numel() for p in student.parameters())
87print(f"\nTeacher parameters: {teacher_params:,}")
88print(f"Student parameters: {student_params:,}")
89print(f"Compression: {teacher_params / student_params:.1f}x")Neural Architecture Search (NAS) for Efficiency
Rather than manually designing efficient architectures, NAS automates the search for architectures that optimize for both accuracy and efficiency:
The key insight: the best architecture depends on the target hardware. A model optimized for a GPU is very different from one optimized for a phone CPU or a microcontroller.