Building a Transformer from Scratch in PyTorch

Building a Transformer from Scratch in PyTorch: A Complete Technical Guide

Transformers have revolutionized machine learning, powering everything from ChatGPT to advanced computer vision models. Yet many practitioners use transformers as black boxes without understanding their inner workings. If you want to build production LLMs or fine-tune foundation models effectively, understanding transformers at the architectural level is essential. This guide walks you through implementing a transformer from scratch in PyTorch, breaking down each component so you can see exactly how these powerful models work.

This tutorial is part of our free 64-notebook LLM engineering course, designed to take you from foundational concepts to production-ready implementations. Let’s dive into the architecture that powers modern AI.

Why Build Transformers from Scratch?

Before jumping into code, it’s worth asking: why build transformers when frameworks like Hugging Face provide pre-built implementations? There are several compelling reasons:

  • Deep understanding: Building from scratch reveals assumptions and design choices that shape model behavior
  • Debugging and optimization: When something goes wrong in production, you need to understand what’s happening at each layer
  • Custom architectures: Many applications require transformer variants tailored to specific problems
  • Interview and research: If you’re interviewing at AI labs or publishing research, this knowledge is foundational

Now, let’s start with the core building block: attention.

Understanding the Attention Mechanism

The attention mechanism is the heart of transformers. Rather than processing sequences sequentially (like RNNs), transformers compute relationships between all positions simultaneously. This “attention” mechanism lets the model decide which parts of the input are relevant when processing each token.

Scaled Dot-Product Attention

The fundamental attention operation is called scaled dot-product attention. Here’s how it works mathematically:

Attention(Q, K, V) = softmax(QK^T / \u221Ad_k) V

Where Q (query), K (key), and V (value) are linear transformations of the input.

Here’s the PyTorch implementation:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.scale = d_k ** 0.5

    def forward(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)
        return output, weights

Multi-Head Attention: Why One Head Isn’t Enough

A single attention head can only focus on one pattern at a time. Multi-head attention splits the representation into multiple “heads”, computes attention independently, and concatenates the results.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.attention = ScaledDotProductAttention(self.d_k)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.shape[0]
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        if mask is not None:
            mask = mask.unsqueeze(1)
        attn_output, attn_weights = self.attention(Q, K, V, mask)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)
        return output, attn_weights

The Transformer Block: Putting Components Together

A transformer block combines multi-head self-attention, a feed-forward network, layer normalization, and residual connections.

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output, _ = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

Notice the pattern: LayerNorm(x + Sublayer(x)). The residual connection ensures gradients can flow directly backward without passing through non-linearities multiple times.

Building a Complete Transformer Model

A full transformer stacks multiple blocks, adds positional encoding, and includes embedding layers:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=1000):
        super().__init__()
        position = torch.arange(0, max_seq_length).unsqueeze(1).float()
        dim_indices = torch.arange(0, d_model, 2).float()
        angle_rates = 1 / torch.pow(10000, dim_indices / d_model)
        pe = torch.zeros(max_seq_length, d_model)
        pe[:, 0::2] = torch.sin(position * angle_rates)
        pe[:, 1::2] = torch.cos(position * angle_rates)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff,
                 max_seq_length=1000, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

Key Architectural Insights

Why transformers work so well:

  • Parallelization: Unlike RNNs, transformers process entire sequences in parallel, making training dramatically faster.
  • Long-range dependencies: Every token can directly attend to every other token, regardless of distance.
  • Interpretability: Attention weights show which positions contribute to each output.
  • Scalability: The architecture scales smoothly to billions of parameters.

Causal Masking for Autoregressive Generation

For language models like GPT, causal masking prevents each position from attending to future positions:

def create_causal_mask(seq_length, device):
    mask = torch.tril(torch.ones(seq_length, seq_length, device=device))
    return mask.unsqueeze(0).unsqueeze(0)

Putting It All Together: A Complete Example

config = {
    'vocab_size': 10000,
    'd_model': 512,
    'num_heads': 8,
    'num_layers': 6,
    'd_ff': 2048,
    'max_seq_length': 1024,
    'dropout': 0.1
}
model = TransformerEncoder(**config)
input_ids = torch.randint(0, config['vocab_size'], (32, 256))
output = model(input_ids)  # Shape: (32, 256, 512)

Key Takeaways

  • Scaled dot-product attention: The core mechanism implemented as softmax(QK^T / sqrt(d_k))V
  • Multi-head attention: Multiple heads attend to different types of relationships simultaneously
  • Transformer blocks: Attention + feed-forward + layer norm + residual connections
  • Positional encoding: Sinusoidal encodings inject sequence position information
  • Causal masking: Prevents attention to future positions for autoregressive models

Advanced Topics

Production transformers often include: Flash Attention for faster computation, Grouped Query Attention (GQA) for reduced memory, Rotary Position Embeddings (RoPE) for better long-sequence handling, and quantization for efficient deployment. These are covered in depth in our full course materials.

Next Steps: Explore the Full Implementation

Want to dive deeper? This post covers concepts from Module 4 of our free 64-notebook LLM engineering course. The course includes complete, production-ready implementations.

Explore the full curriculum and run the notebooks yourself:

GitHub Repository: Module 4: Transformer Architecture

Common Questions

Why is attention “scaled”?

Scaling by 1/sqrt(d_k) prevents dot products from becoming too large, which would cause softmax to produce near one-hot distributions with tiny gradients.

Why use multiple heads instead of one big head?

Multiple heads allow the model to attend to different parts of the representation space simultaneously, like having multiple experts each focusing on different patterns.

What’s the purpose of layer normalization?

Layer normalization stabilizes training by normalizing activations to have mean 0 and variance 1, preventing activation magnitudes from exploding or vanishing deep in the network.

Conclusion

You now understand the complete architecture of a transformer. The true power comes from understanding not just how transformers work, but why each design choice matters.

Ready to go deeper? Clone the repository, run the notebooks, and start building transformers yourself. The best way to understand these models is to implement them, experiment, and see your own implementations work.

Share your love

Leave a Reply

Your email address will not be published. Required fields are marked *