Skip to main content

Q-Learning & Deep Q-Networks

From Q-tables to deep neural network approximators

~55 min
Listen to this lesson

Q-Learning & Deep Q-Networks

Q-Learning is one of the most important RL algorithms. It learns the Q-function (action-value function) directly, without needing a model of the environment.

Q-Table Learning

For environments with small, discrete state and action spaces, we can represent Q(s, a) as a table:

StateLeftRight
s00.51.2
s12.10.3
s20.03.5

The Q-Learning Update Rule

After taking action *a* in state *s*, observing reward *r* and next state *s'*:

**Q(s, a) <- Q(s, a) + alpha * [r + gamma * max_a' Q(s', a') - Q(s, a)]

Where:

  • alpha is the learning rate
  • gamma is the discount factor
  • r + gamma * max Q(s', a') is the TD target
  • r + gamma * max Q(s', a') - Q(s, a)** is the TD error
  • Off-Policy vs On-Policy

    Q-Learning is **off-policy**: it learns about the optimal policy (max Q) while following an exploratory policy (epsilon-greedy). This is powerful because the agent can learn the optimal strategy even while exploring. SARSA is the on-policy counterpart, updating with the action actually taken instead of the max.
    python
    1import numpy as np
    2
    3class QLearningAgent:
    4    """Tabular Q-Learning agent."""
    5
    6    def __init__(self, n_states, n_actions, lr=0.1, gamma=0.99, epsilon=1.0):
    7        self.q_table = np.zeros((n_states, n_actions))
    8        self.lr = lr
    9        self.gamma = gamma
    10        self.epsilon = epsilon
    11
    12    def select_action(self, state):
    13        if np.random.random() < self.epsilon:
    14            return np.random.randint(self.q_table.shape[1])
    15        return np.argmax(self.q_table[state])
    16
    17    def update(self, state, action, reward, next_state, done):
    18        """Apply the Q-learning update rule."""
    19        if done:
    20            td_target = reward
    21        else:
    22            td_target = reward + self.gamma * np.max(self.q_table[next_state])
    23
    24        td_error = td_target - self.q_table[state, action]
    25        self.q_table[state, action] += self.lr * td_error
    26
    27# Example: Simple environment with 16 states and 4 actions
    28agent = QLearningAgent(n_states=16, n_actions=4)
    29print(f"Q-table shape: {agent.q_table.shape}")
    30print(f"Initial Q-values for state 0: {agent.q_table[0]}")

    From Q-Tables to Deep Q-Networks (DQN)

    Q-tables don't scale. If your state is an image (e.g., Atari game frames at 210x160 pixels), the table would need an astronomically large number of rows. DQN replaces the table with a neural network that approximates Q(s, a).

    DQN Architecture

  • Input: State (e.g., stacked game frames)
  • Network: Convolutional or fully-connected layers
  • Output: Q-value for each possible action
  • Key DQN Innovations

    1. Experience Replay Instead of learning from consecutive experiences (which are correlated), store transitions in a replay buffer and sample random mini-batches. This breaks correlations and improves data efficiency.

    2. Target Network Use a separate, slowly updated copy of the Q-network for computing TD targets. This stabilizes training by preventing the targets from shifting rapidly.

    python
    1import numpy as np
    2from collections import deque
    3import random
    4
    5class ReplayBuffer:
    6    """Experience replay buffer for DQN."""
    7
    8    def __init__(self, capacity=10000):
    9        self.buffer = deque(maxlen=capacity)
    10
    11    def push(self, state, action, reward, next_state, done):
    12        self.buffer.append((state, action, reward, next_state, done))
    13
    14    def sample(self, batch_size=32):
    15        batch = random.sample(self.buffer, batch_size)
    16        states, actions, rewards, next_states, dones = zip(*batch)
    17        return (
    18            np.array(states),
    19            np.array(actions),
    20            np.array(rewards, dtype=np.float32),
    21            np.array(next_states),
    22            np.array(dones, dtype=np.float32),
    23        )
    24
    25    def __len__(self):
    26        return len(self.buffer)
    27
    28# Demo
    29buffer = ReplayBuffer(capacity=5000)
    30for i in range(100):
    31    state = np.random.randn(4)
    32    action = np.random.randint(2)
    33    reward = np.random.randn()
    34    next_state = np.random.randn(4)
    35    done = i % 20 == 0
    36    buffer.push(state, action, reward, next_state, done)
    37
    38states, actions, rewards, next_states, dones = buffer.sample(8)
    39print(f"Sampled batch - states shape: {states.shape}, actions: {actions}")

    Double DQN

    Standard DQN tends to overestimate Q-values because it uses the max operator for both selecting and evaluating actions. Double DQN fixes this by decoupling selection and evaluation:

    Standard DQN target: r + gamma * max_a' Q_target(s', a')

    Double DQN target: r + gamma * Q_target(s', argmax_a' Q_online(s', a'))

    The online network selects the best action, but the target network evaluates it. This significantly reduces overestimation and improves performance.

    DQN Timeline

    2013: DeepMind introduces DQN, playing Atari from pixels. 2015: Nature paper adds target networks and achieves human-level play on many games. 2016: Double DQN, Dueling DQN, Prioritized Experience Replay further improve performance. These innovations laid the groundwork for modern deep RL.
    python
    1import numpy as np
    2
    3# Pseudocode for DQN training loop (PyTorch-style)
    4"""
    5import torch
    6import torch.nn as nn
    7
    8class DQN(nn.Module):
    9    def __init__(self, state_dim, action_dim):
    10        super().__init__()
    11        self.network = nn.Sequential(
    12            nn.Linear(state_dim, 128),
    13            nn.ReLU(),
    14            nn.Linear(128, 128),
    15            nn.ReLU(),
    16            nn.Linear(128, action_dim),
    17        )
    18
    19    def forward(self, x):
    20        return self.network(x)
    21
    22# Training loop pseudocode
    23online_net = DQN(state_dim=4, action_dim=2)
    24target_net = DQN(state_dim=4, action_dim=2)
    25target_net.load_state_dict(online_net.state_dict())
    26
    27optimizer = torch.optim.Adam(online_net.parameters(), lr=1e-3)
    28buffer = ReplayBuffer(capacity=10000)
    29
    30for episode in range(1000):
    31    state = env.reset()
    32    done = False
    33
    34    while not done:
    35        # Epsilon-greedy action selection using online_net
    36        action = select_action(state, online_net, epsilon)
    37        next_state, reward, done, _ = env.step(action)
    38        buffer.push(state, action, reward, next_state, done)
    39
    40        if len(buffer) >= batch_size:
    41            # Sample batch and compute loss
    42            states, actions, rewards, next_states, dones = buffer.sample(32)
    43
    44            # Double DQN target
    45            with torch.no_grad():
    46                best_actions = online_net(next_states).argmax(dim=1)
    47                target_q = target_net(next_states).gather(1, best_actions.unsqueeze(1))
    48                targets = rewards + gamma * target_q.squeeze() * (1 - dones)
    49
    50            current_q = online_net(states).gather(1, actions.unsqueeze(1)).squeeze()
    51            loss = nn.MSELoss()(current_q, targets)
    52
    53            optimizer.zero_grad()
    54            loss.backward()
    55            optimizer.step()
    56
    57        # Update target network periodically
    58        if step % target_update_freq == 0:
    59            target_net.load_state_dict(online_net.state_dict())
    60
    61        state = next_state
    62"""
    63print("DQN architecture and training loop defined (pseudocode)")
    64print("Key components: online net, target net, replay buffer, Double DQN targets")