Skip to main content

Transformers & Attention

Understand the self-attention mechanism that powers modern NLP — from the attention formula to multi-head attention to full transformer architectures like BERT and GPT.

~60 min
Listen to this lesson

Transformers & Attention

The Transformer architecture, introduced in the landmark 2017 paper *"Attention Is All You Need"*, replaced RNNs and LSTMs as the dominant architecture in NLP. Its core innovation — self-attention — allows every token to directly attend to every other token in the sequence, regardless of distance.

Why Attention?

Consider the sentence:

> *"The animal didn't cross the street because it was too tired."*

What does "it" refer to? The animal. But how does a model figure this out?

  • An RNN processes tokens left-to-right. By the time it reaches "it", the representation of "animal" has been compressed through many hidden states and may be diluted.
  • Self-attention lets "it" directly look at "animal" (and every other word) in a single step, computing a relevance score for each pair.
  • Self-Attention in One Sentence

    Self-attention computes a weighted sum of all token representations, where the weights are learned based on how relevant each token is to the current token being processed.

    Scaled Dot-Product Attention

    The attention mechanism uses three matrices derived from the input:

  • Q (Query) — "What am I looking for?"
  • K (Key) — "What do I contain?"
  • V (Value) — "What information do I provide?"
  • The formula:

    $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

    Step by step: 1. Compute QK^T — dot product between every query and every key (measures similarity) 2. Divide by sqrt(d_k) — prevents dot products from growing too large as dimensions increase 3. Apply softmax — convert raw scores into a probability distribution (weights sum to 1) 4. Multiply by V — weighted sum of value vectors

    The scaling by sqrt(d_k) is critical. Without it, when d_k is large (e.g., 512), dot products can be very large, pushing softmax into regions with near-zero gradients, making training unstable.

    python
    1import numpy as np
    2
    3def scaled_dot_product_attention(Q, K, V, mask=None):
    4    """
    5    Compute scaled dot-product attention.
    6
    7    Args:
    8        Q: Query matrix (seq_len_q, d_k)
    9        K: Key matrix (seq_len_k, d_k)
    10        V: Value matrix (seq_len_k, d_v)
    11        mask: Optional mask to prevent attending to certain positions
    12    Returns:
    13        output: Weighted sum of values (seq_len_q, d_v)
    14        attention_weights: Attention weight matrix (seq_len_q, seq_len_k)
    15    """
    16    d_k = Q.shape[-1]
    17
    18    # Step 1: QK^T — similarity scores
    19    scores = Q @ K.T  # (seq_len_q, seq_len_k)
    20
    21    # Step 2: Scale by sqrt(d_k)
    22    scores = scores / np.sqrt(d_k)
    23
    24    # Step 3: Optional masking (set masked positions to -inf before softmax)
    25    if mask is not None:
    26        scores = np.where(mask == 0, -1e9, scores)
    27
    28    # Step 4: Softmax to get attention weights
    29    exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    30    attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
    31
    32    # Step 5: Weighted sum of values
    33    output = attention_weights @ V  # (seq_len_q, d_v)
    34
    35    return output, attention_weights
    36
    37# --- Example ---
    38np.random.seed(42)
    39seq_len, d_k, d_v = 4, 8, 8
    40
    41Q = np.random.randn(seq_len, d_k)
    42K = np.random.randn(seq_len, d_k)
    43V = np.random.randn(seq_len, d_v)
    44
    45output, weights = scaled_dot_product_attention(Q, K, V)
    46print("Attention weights (each row sums to 1):")
    47print(weights.round(3))
    48print("\nRow sums:", weights.sum(axis=-1).round(3))
    49print("Output shape:", output.shape)

    Multi-Head Attention

    A single attention head can only focus on one type of relationship at a time. Multi-head attention runs several attention heads in parallel, each learning to attend to different things:

  • Head 1 might learn syntactic relationships (subject → verb)
  • Head 2 might learn semantic similarity (synonyms)
  • Head 3 might learn positional patterns (nearby words)
  • The outputs from all heads are concatenated and linearly projected:

    $$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O$$

    where each head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

    python
    1import tensorflow as tf
    2from tensorflow.keras import layers
    3import numpy as np
    4
    5class MultiHeadAttention(layers.Layer):
    6    def __init__(self, d_model, num_heads):
    7        super().__init__()
    8        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
    9
    10        self.num_heads = num_heads
    11        self.d_model = d_model
    12        self.depth = d_model // num_heads  # d_k per head
    13
    14        # Linear projections for Q, K, V and output
    15        self.wq = layers.Dense(d_model)
    16        self.wk = layers.Dense(d_model)
    17        self.wv = layers.Dense(d_model)
    18        self.wo = layers.Dense(d_model)
    19
    20    def split_heads(self, x, batch_size):
    21        """Reshape (batch, seq_len, d_model) → (batch, num_heads, seq_len, depth)"""
    22        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    23        return tf.transpose(x, perm=[0, 2, 1, 3])
    24
    25    def call(self, q, k, v, mask=None):
    26        batch_size = tf.shape(q)[0]
    27
    28        # Linear projections
    29        q = self.wq(q)  # (batch, seq_len, d_model)
    30        k = self.wk(k)
    31        v = self.wv(v)
    32
    33        # Split into multiple heads
    34        q = self.split_heads(q, batch_size)  # (batch, num_heads, seq_len_q, depth)
    35        k = self.split_heads(k, batch_size)
    36        v = self.split_heads(v, batch_size)
    37
    38        # Scaled dot-product attention (per head)
    39        d_k = tf.cast(self.depth, tf.float32)
    40        scores = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(d_k)
    41
    42        if mask is not None:
    43            scores += (mask * -1e9)
    44
    45        weights = tf.nn.softmax(scores, axis=-1)
    46        attn_output = tf.matmul(weights, v)  # (batch, num_heads, seq_len_q, depth)
    47
    48        # Concatenate heads
    49        attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
    50        concat = tf.reshape(attn_output, (batch_size, -1, self.d_model))
    51
    52        # Final linear projection
    53        return self.wo(concat)
    54
    55# --- Test ---
    56mha = MultiHeadAttention(d_model=128, num_heads=8)
    57x = tf.random.normal((2, 10, 128))  # (batch=2, seq_len=10, d_model=128)
    58output = mha(x, x, x)  # Self-attention: Q=K=V=x
    59print("Input shape:", x.shape)
    60print("Output shape:", output.shape)  # Same shape: (2, 10, 128)

    Transformer Variants

    The original Transformer has both an encoder and decoder. Modern models often use only one half:

    ModelArchitectureTraining ObjectiveBest For
    BERTEncoder-onlyMasked language modeling (fill in blanks)Classification, NER, QA
    GPTDecoder-onlyAutoregressive (predict next token)Text generation, chat
    T5Encoder-decoderText-to-text (every task is seq2seq)Translation, summarization
    ViTEncoder-onlyImage classification (patches as tokens)Computer vision
    BERT sees the full context (bidirectional) but can't generate text naturally. GPT generates text autoregressively (left-to-right) but only sees past context during training. T5 frames every NLP task as "given this input text, produce this output text".

    Positional Encoding

    Unlike RNNs, Transformers process all tokens in parallel — they have no built-in notion of order. Positional encodings are added to the input embeddings to tell the model where each token sits in the sequence. The original paper used sinusoidal functions; modern models often use learned positional embeddings. Without positional encoding, "dog bites man" and "man bites dog" would produce identical representations.

    Using Hugging Face Transformers

    The Hugging Face transformers library makes it trivial to use pre-trained models:

    python
    1from transformers import pipeline
    2
    3# --- Sentiment Analysis ---
    4classifier = pipeline("sentiment-analysis")
    5result = classifier("I absolutely loved this movie! The acting was superb.")
    6print(result)
    7# [{'label': 'POSITIVE', 'score': 0.9998}]
    8
    9# --- Summarization ---
    10summarizer = pipeline("summarization")
    11article = """
    12    The Transformer architecture has revolutionized natural language processing.
    13    Introduced in 2017, it replaced recurrent neural networks with self-attention
    14    mechanisms that can process all tokens in parallel. This led to models like
    15    BERT, GPT, and T5 that achieve state-of-the-art results on virtually every
    16    NLP benchmark. The key innovation is the ability of each token to directly
    17    attend to every other token, capturing long-range dependencies efficiently.
    18"""
    19summary = summarizer(article, max_length=50, min_length=20)
    20print(summary[0]["summary_text"])
    21
    22# --- Zero-shot classification (no fine-tuning needed!) ---
    23zero_shot = pipeline("zero-shot-classification")
    24result = zero_shot(
    25    "The new iPhone has an incredible camera and battery life.",
    26    candidate_labels=["technology", "sports", "politics", "food"],
    27)
    28print(f"Label: {result['labels'][0]}, Score: {result['scores'][0]:.4f}")

    Fine-tuning BERT for Classification

    When a pre-trained model doesn't perfectly fit your task, you can fine-tune it on your specific dataset:

    python
    1from transformers import (
    2    AutoTokenizer,
    3    AutoModelForSequenceClassification,
    4    TrainingArguments,
    5    Trainer,
    6)
    7from datasets import load_dataset
    8import numpy as np
    9
    10# Load dataset
    11dataset = load_dataset("imdb")
    12
    13# Load pre-trained BERT tokenizer and model
    14model_name = "bert-base-uncased"
    15tokenizer = AutoTokenizer.from_pretrained(model_name)
    16model = AutoModelForSequenceClassification.from_pretrained(
    17    model_name, num_labels=2
    18)
    19
    20# Tokenize the dataset
    21def tokenize_function(examples):
    22    return tokenizer(
    23        examples["text"],
    24        padding="max_length",
    25        truncation=True,
    26        max_length=256,
    27    )
    28
    29tokenized_datasets = dataset.map(tokenize_function, batched=True)
    30
    31# Use a small subset for demonstration
    32small_train = tokenized_datasets["train"].shuffle(seed=42).select(range(2000))
    33small_test = tokenized_datasets["test"].shuffle(seed=42).select(range(500))
    34
    35# Training arguments
    36training_args = TrainingArguments(
    37    output_dir="./results",
    38    num_train_epochs=3,
    39    per_device_train_batch_size=16,
    40    per_device_eval_batch_size=16,
    41    eval_strategy="epoch",
    42    learning_rate=2e-5,           # Small LR for fine-tuning!
    43    weight_decay=0.01,
    44)
    45
    46def compute_metrics(eval_pred):
    47    logits, labels = eval_pred
    48    predictions = np.argmax(logits, axis=-1)
    49    accuracy = (predictions == labels).mean()
    50    return {"accuracy": accuracy}
    51
    52# Train!
    53trainer = Trainer(
    54    model=model,
    55    args=training_args,
    56    train_dataset=small_train,
    57    eval_dataset=small_test,
    58    compute_metrics=compute_metrics,
    59)
    60
    61trainer.train()
    62# Fine-tuned BERT typically achieves 92-94% on IMDB

    Fine-tuning Tips

    Use a very small learning rate (2e-5 to 5e-5) when fine-tuning pre-trained models — large learning rates will destroy the pre-trained weights. Train for only 2-4 epochs to avoid overfitting. If your dataset is small (< 1000 examples), consider few-shot prompting with GPT instead of fine-tuning.