Self-Attention from Scratch: The Core of Every LLM

Build scaled dot-product attention, multi-head attention, causal masking, KV cache, and grouped query attention from scratch in PyTorch. The fundamental operation behind GPT-4, LLaMA 3, and every modern language model.

Every modern LLM — GPT-4, LLaMA 3, Claude, Gemini — runs on the same fundamental operation: self-attention. It’s the mechanism that lets each token in a sequence look at every other token and decide what’s relevant. If you want to understand how language models actually work, this is where you start.

We’re going to build self-attention from scratch in PyTorch — not use a library wrapper, not call a magic function. Raw matrix operations. By the end of this post, you’ll understand exactly what happens inside torch.nn.MultiheadAttention and why every design choice exists.

The Query-Key-Value Framework

For each token in a sequence, self-attention creates three vectors:

  • Query (Q): “What am I looking for?”
  • Key (K): “What do I contain?”
  • Value (V): “What information do I give out?”

The attention formula is: Attention(Q, K, V) = softmax(QKT / √dk) × V

Think of it like a search engine. The Query is your search query. The Keys are the index entries. The dot product QKT computes a relevance score between the query and every key. Softmax normalizes these scores into a probability distribution. Then we take a weighted sum of the Values using those probabilities.

Here it is in PyTorch:

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)

    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Step 2: Apply mask (optional)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: Softmax to get weights that sum to 1
    attn_weights = F.softmax(scores, dim=-1)

    # Step 4: Weighted sum of values
    output = torch.matmul(attn_weights, V)

    return output, attn_weights

Four lines of actual math. That’s the entire core of attention. Let’s test it:

torch.manual_seed(42)
Q = torch.randn(1, 1, 6, 4)  # batch=1, heads=1, seq=6, d_k=4
K = torch.randn(1, 1, 6, 4)
V = torch.randn(1, 1, 6, 4)

output, weights = scaled_dot_product_attention(Q, K, V)
print(weights[0, 0].detach().numpy().round(3))

# Output:
# [[0.186 0.062 0.155 0.138 0.183 0.276]
#  [0.122 0.187 0.101 0.236 0.116 0.239]
#  [0.221 0.108 0.146 0.136 0.163 0.226]
#  [0.117 0.184 0.143 0.203 0.127 0.226]
#  [0.238 0.08  0.166 0.108 0.189 0.219]
#  [0.144 0.091 0.107 0.18  0.161 0.317]]

Each row sums to 1.0. Row i tells us how much token i attends to each other token. The output for token i is a weighted combination of all value vectors, weighted by these attention scores.

Why Scale by √dk?

This is the detail most explanations gloss over. When d_k is large, the dot products QKT grow in magnitude — their variance scales with d_k. Large values push softmax into extreme regions where the gradient is essentially zero. Training stalls.

Dividing by √d_k normalizes the variance back to approximately 1, keeping the softmax in a region where gradients flow. The notebook includes a visual comparison across d_k values from 4 to 1024 — without scaling, d_k=1024 produces a one-hot-like distribution where a single token gets 99%+ of the weight. With scaling, the distribution stays smooth regardless of dimension.

Multi-Head Attention: Different Heads Learn Different Patterns

A single attention head can only learn one type of relationship between tokens. Multi-head attention runs several attention operations in parallel — each head with its own learned Q, K, V projections — then concatenates the results.

In practice, different heads specialize. One head might learn syntactic dependencies (subject-verb agreement). Another tracks coreference (what “it” refers to). Another learns positional relationships (attending to the previous token).

The key implementation detail is how we split and merge heads efficiently:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        B, T, D = x.shape

        # Project and split into heads
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # Attention per head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = torch.matmul(attn, V)

        # Merge heads and project
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.W_o(out)

The .view(B, T, n_heads, d_k).transpose(1, 2) reshaping is how we split d_model into parallel heads without copying data. With d_model=512 and 8 heads, each head operates on a 64-dimensional subspace. The total parameter count is 4 × d_model² = 4 × 512² = 1,048,576 — four projection matrices, each mapping d_model to d_model.

Causal Masking: No Peeking at the Future

GPT-style models generate text left-to-right. During training, token i must only attend to tokens 0 through i — it can’t look at what comes next. We enforce this with a causal mask: a lower-triangular matrix of ones.

seq_len = 8
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
# Result: lower triangular matrix
# [[1, 0, 0, 0, 0, 0, 0, 0],
#  [1, 1, 0, 0, 0, 0, 0, 0],
#  [1, 1, 1, 0, 0, 0, 0, 0],
#  ...
#  [1, 1, 1, 1, 1, 1, 1, 1]]

Where the mask is 0, we fill the attention scores with -inf before softmax. Since softmax(-inf) = 0, those positions contribute nothing to the output. Row 3 can attend to positions 0, 1, 2, 3 but not 4, 5, 6, 7.

BERT-style models use full attention (every token sees every other token). Models like Longformer use sliding-window attention for efficiency on long sequences. The notebook includes visual comparisons of all three masking strategies.

KV Cache: Why Inference Isn’t O(T²)

During generation, the model produces one token at a time. Without caching, at step T it would need to recompute K and V for all T tokens — that’s O(T²) total work across all steps. KV caching stores the key and value matrices from previous steps and only computes the new token’s K and V, reducing total work to O(T).

class AttentionWithKVCache(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        B, T, D = x.shape
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # Append to cache
        if kv_cache is not None:
            K = torch.cat([kv_cache[0], K], dim=2)
            V = torch.cat([kv_cache[1], V], dim=2)

        new_cache = (K, V)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        out = torch.matmul(F.softmax(scores, dim=-1), V)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.W_o(out), new_cache

The tradeoff is memory. For LLaMA 3 70B at a 32K context length, the KV cache alone consumes several gigabytes. This is why techniques like Grouped Query Attention (GQA) exist — LLaMA 3 uses 32 query heads but only 8 KV heads, cutting KV cache size by 4x with minimal quality loss.

Grouped Query Attention: What LLaMA 3 Actually Uses

Standard multi-head attention (MHA) gives each head its own K and V projections. GQA shares K/V heads across groups of query heads. Multi-query attention (MQA) is the extreme case where all query heads share a single K/V head.

# KV cache memory per token (bf16) for a 4096-dim model with 32 query heads:
# MHA:  32 KV heads × 128 d_head × 2 (K+V) × 2 bytes = 16,384 B/token
# GQA:   8 KV heads × 128 d_head × 2 (K+V) × 2 bytes =  4,096 B/token  (4x smaller)
# MQA:   1 KV head  × 128 d_head × 2 (K+V) × 2 bytes =    512 B/token  (32x smaller)

The implementation trick is repeat_interleave — we expand the fewer KV heads to match the number of Q heads before computing attention. No additional parameters, just index manipulation. The full GQA implementation is in the notebook.

What to Do Next

The complete notebook has runnable code for everything covered here, plus attention visualization on a real DistilBERT model, KV cache memory profiling charts, and four exercises that will test whether you’ve actually understood the material.

Open the notebook in Google Colab — it runs on a free T4 GPU in about 90 minutes.

Next up in this series: Building a Transformer Block — where we combine self-attention with layer normalization, feed-forward networks, and residual connections to build the complete architecture that powers every modern LLM.

This post is part of TheAiSingularity’s LLM Engineering Course — 64 notebooks, 20 capstone projects, fully open source.

Share your love

Leave a Reply

Your email address will not be published. Required fields are marked *