Step 07: Multi-head attention
Learn to use multi-head attention, enabling the model to attend to different representation subspaces.
Building multi-head attention
In this step, you’ll implement the GPT2MultiHeadAttention class that runs 12
attention operations in parallel. Instead of computing attention once over the
full 768-dimensional space, you split the dimensions into 12 heads of 64
dimensions each. Each head learns to focus on different patterns.
GPT-2 uses 12 heads with 768-dimensional embeddings, giving each head 768 ÷ 12 = 64 dimensions. The Q, K, V tensors are reshaped to split the embedding dimension across heads, attention is computed for all heads in parallel, then the outputs are concatenated back together. This happens in a single efficient operation using tensor reshaping and broadcasting.
Multiple heads let the model learn complementary attention strategies. Different heads can specialize in different relationships, such as one that might attend to adjacent tokens, another to syntactic patterns, and another to semantic similarity. This increases the model’s capacity without dramatically increasing computation.
Understanding the architecture
Multi-head attention splits the embedding dimension, computes attention independently for each head, then merges the results. This requires careful tensor reshaping to organize the computation efficiently.
Head splitting: Transform from [batch, seq_length, 768] to
[batch, 12, seq_length, 64]. First reshape to add the head dimension:
[batch, seq_length, 12, 64]. Then transpose to move heads before sequence:
[batch, 12, seq_length, 64]. Now each of the 12 heads operates independently
on its 64-dimensional subspace.
Parallel attention: With shape [batch, num_heads, seq_length, head_dim],
you can compute attention for all heads simultaneously. The matrix
multiplication Q @ K^T operates on the last two dimensions
[seq_length, head_dim] @ [head_dim, seq_length], broadcasting across the batch
and head dimensions. All 12 heads computed in a single efficient operation.
Head merging: Reverse the splitting to go from
[batch, 12, seq_length, 64] back to [batch, seq_length, 768]. First
transpose to [batch, seq_length, 12, 64], then reshape to flatten the head
dimension: [batch, seq_length, 768]. This concatenates all head outputs back
into the original dimension.
Output projection (c_proj): After merging heads, apply a learned linear
transformation that maps [batch, seq_length, 768] to
[batch, seq_length, 768]. This lets the model mix information across heads,
combining the different perspectives each head learned.
The layer names c_attn (combined Q/K/V projection) and c_proj (output
projection) match Hugging Face’s GPT-2 implementation. This naming is essential
for loading pretrained weights.
You’ll use the following MAX operations to complete this task:
Linear layers:
Linear(in_features, out_features, bias=True): Q/K/V and output projections
Tensor operations:
tensor.reshape(new_shape): Splits or merges head dimensiontensor.transpose(axis1, axis2): Rearranges dimensions for parallel attentionF.split(tensor, split_sizes, axis): Divides Q/K/V from combined projection
Implementing multi-head attention
You’ll create the GPT2MultiHeadAttention class with helper methods for splitting and merging heads. The implementation builds on the attention mechanism from Step 02, extending it to work with multiple heads in parallel.
First, import the required modules. You’ll need math for scaling, functional as F for operations, Tensor for type hints, device and dtype utilities, and Linear and Module from MAX’s neural network module. You’ll also reuse the causal_mask function from Step 02.
In the __init__ method, create the projection layers and store configuration:
- Combined Q/K/V projection:
Linear(embed_dim, 3 * embed_dim, bias=True)stored asself.c_attn - Output projection:
Linear(embed_dim, embed_dim, bias=True)stored asself.c_proj - Store
self.num_heads(12) andself.head_dim(64) from config - Calculate
self.split_sizefor splitting Q, K, V later
Implement _split_heads to reshape for parallel attention:
- Calculate new shape by replacing the last dimension:
tensor.shape[:-1] + [num_heads, attn_head_size] - Reshape to add the head dimension:
tensor.reshape(new_shape) - Transpose to move heads to position 1:
tensor.transpose(-3, -2) - Returns shape
[batch, num_heads, seq_length, head_size]
Implement _merge_heads to concatenate head outputs:
- Transpose to move heads back:
tensor.transpose(-3, -2) - Calculate flattened shape:
tensor.shape[:-2] + [num_heads * attn_head_size] - Reshape to merge heads:
tensor.reshape(new_shape) - Returns shape
[batch, seq_length, n_embd]
Implement _attn to compute scaled dot-product attention for all heads:
- Compute attention scores:
query @ key.transpose(-2, -1) - Scale by square root of head dimension
- Apply causal mask to prevent attending to future positions
- Apply softmax to get attention weights
- Multiply weights by values:
attn_weights @ value
In the forward method, orchestrate the complete multi-head attention:
- Project to Q/K/V:
qkv = self.c_attn(hidden_states) - Split into separate tensors:
F.split(qkv, [self.split_size, self.split_size, self.split_size], axis=-1) - Split heads for each:
query = self._split_heads(query, self.num_heads, self.head_dim)(repeat for key, value) - Compute attention:
attn_output = self._attn(query, key, value) - Merge heads:
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - Final projection:
return self.c_proj(attn_output)
Implementation (step_07.py):
"""
Step 07: Multi-head Attention
Implement multi-head attention that splits Q/K/V into multiple heads,
computes attention in parallel for each head, and merges the results.
Tasks:
1. Import required modules (math, F, Tensor, Linear, Module, etc.)
2. Create c_attn and c_proj linear layers
3. Implement _split_heads: reshape and transpose to add head dimension
4. Implement _merge_heads: transpose and reshape to remove head dimension
5. Implement _attn: compute attention for all heads in parallel
6. Implement forward pass: project -> split -> attend -> merge -> project
Run: pixi run s07
"""
# TODO: Import required modules
# Hint: You'll need math for scaling
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor, Device, DType from max.experimental.tensor and max.driver
# Hint: You'll need Dim, DimLike from max.graph
# Hint: You'll also need Linear and Module from max.nn.module_v3
from solutions.solution_01 import GPT2Config
# TODO: Copy causal_mask function from solution_02.py
# This is the same function you implemented in Step 02
class GPT2MultiHeadAttention(Module):
"""Multi-head attention for GPT-2."""
def __init__(self, config: GPT2Config):
super().__init__()
self.embed_dim = config.n_embd
self.num_heads = config.n_head
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
# TODO: Create combined Q/K/V projection
# Hint: Use Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
self.c_attn = None
# TODO: Create output projection
# Hint: Use Linear(self.embed_dim, self.embed_dim, bias=True)
self.c_proj = None
def _split_heads(self, tensor, num_heads, attn_head_size):
"""Split the last dimension into (num_heads, head_size).
Args:
tensor: Input tensor, shape [batch, seq_length, n_embd]
num_heads: Number of attention heads
attn_head_size: Dimension of each head
Returns:
Tensor with shape [batch, num_heads, seq_length, head_size]
"""
# TODO: Add head dimension
# Hint: new_shape = tensor.shape[:-1] + [num_heads, attn_head_size]
# Hint: tensor = tensor.reshape(new_shape)
pass
# TODO: Move heads dimension to position 1
# Hint: return tensor.transpose(-3, -2)
return None
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""Merge attention heads back to original shape.
Args:
tensor: Input tensor, shape [batch, num_heads, seq_length, head_size]
num_heads: Number of attention heads
attn_head_size: Dimension of each head
Returns:
Tensor with shape [batch, seq_length, n_embd]
"""
# TODO: Move heads dimension back
# Hint: tensor = tensor.transpose(-3, -2)
pass
# TODO: Flatten head dimensions
# Hint: new_shape = tensor.shape[:-2] + [num_heads * attn_head_size]
# Hint: return tensor.reshape(new_shape)
return None
def _attn(self, query, key, value):
"""Compute attention for all heads in parallel.
Args:
query: Query tensor, shape [batch, num_heads, seq_length, head_size]
key: Key tensor, shape [batch, num_heads, seq_length, head_size]
value: Value tensor, shape [batch, num_heads, seq_length, head_size]
Returns:
Attention output, shape [batch, num_heads, seq_length, head_size]
"""
# TODO: Implement attention computation
# The same 5-step process: scores, scale, mask, softmax, weighted sum
# Hint: Compute attention scores: query @ key.transpose(-1, -2)
# Hint: Scale by sqrt(head_dim): attn_weights / math.sqrt(head_dim)
# Hint: Apply causal mask using causal_mask function
# Hint: Apply softmax: F.softmax(attn_weights)
# Hint: Weighted sum: attn_weights @ value
return None
def __call__(self, hidden_states):
"""Apply multi-head attention.
Args:
hidden_states: Input tensor, shape [batch, seq_length, n_embd]
Returns:
Attention output, shape [batch, seq_length, n_embd]
"""
# TODO: Project to Q, K, V
# Hint: qkv = self.c_attn(hidden_states)
# Hint: query, key, value = F.split(qkv, [self.split_size, self.split_size, self.split_size], axis=-1)
pass
# TODO: Split into multiple heads
# Hint: query = self._split_heads(query, self.num_heads, self.head_dim)
# Hint: key = self._split_heads(key, self.num_heads, self.head_dim)
# Hint: value = self._split_heads(value, self.num_heads, self.head_dim)
pass
# TODO: Apply attention
# Hint: attn_output = self._attn(query, key, value)
pass
# TODO: Merge heads back
# Hint: attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
pass
# TODO: Output projection
# Hint: attn_output = self.c_proj(attn_output)
# Hint: return attn_output
return None
Validation
Run pixi run s07 to verify your implementation.
Show solution
"""
Solution for Step 07: Multi-head Attention
This module implements multi-head attention, which allows the model to jointly
attend to information from different representation subspaces at different positions.
"""
import math
from max.driver import Device
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import Dim, DimLike
from max.nn.module_v3 import Linear, Module
from solutions.solution_01 import GPT2Config
@F.functional
def causal_mask(
sequence_length: DimLike,
num_tokens: DimLike,
*,
dtype: DType,
device: Device,
):
"""Create a causal attention mask."""
n = Dim(sequence_length) + num_tokens
mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
mask = F.broadcast_to(mask, shape=(sequence_length, n))
return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)
class GPT2MultiHeadAttention(Module):
"""Multi-head attention for GPT-2, matching HuggingFace structure."""
def __init__(self, config: GPT2Config):
"""Initialize multi-head attention.
Args:
config: GPT2Config containing n_embd and n_head
"""
super().__init__()
self.embed_dim = config.n_embd
self.num_heads = config.n_head
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
# Combined Q/K/V projection
self.c_attn = Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
# Output projection
self.c_proj = Linear(self.embed_dim, self.embed_dim, bias=True)
def _split_heads(self, tensor, num_heads, attn_head_size):
"""Split the last dimension into (num_heads, head_size).
Transforms shape from [batch, seq_length, n_embd]
to [batch, num_heads, seq_length, head_size]
Args:
tensor: Input tensor, shape [batch, seq_length, n_embd]
num_heads: Number of attention heads
attn_head_size: Dimension of each head
Returns:
Tensor with shape [batch, num_heads, seq_length, head_size]
"""
# Add head dimension: [batch, seq_length, n_embd] -> [batch, seq_length, num_heads, head_size]
new_shape = tensor.shape[:-1] + [num_heads, attn_head_size]
tensor = tensor.reshape(new_shape)
# Move heads dimension: [batch, seq_length, num_heads, head_size] -> [batch, num_heads, seq_length, head_size]
return tensor.transpose(-3, -2)
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""Merge attention heads back to original shape.
Transforms shape from [batch, num_heads, seq_length, head_size]
to [batch, seq_length, n_embd]
Args:
tensor: Input tensor, shape [batch, num_heads, seq_length, head_size]
num_heads: Number of attention heads
attn_head_size: Dimension of each head
Returns:
Tensor with shape [batch, seq_length, n_embd]
"""
# Move heads dimension back: [batch, num_heads, seq_length, head_size] -> [batch, seq_length, num_heads, head_size]
tensor = tensor.transpose(-3, -2)
# Flatten head dimensions: [batch, seq_length, num_heads, head_size] -> [batch, seq_length, n_embd]
new_shape = tensor.shape[:-2] + [num_heads * attn_head_size]
return tensor.reshape(new_shape)
def _attn(self, query, key, value):
"""Compute attention for all heads in parallel.
Args:
query: Query tensor, shape [batch, num_heads, seq_length, head_size]
key: Key tensor, shape [batch, num_heads, seq_length, head_size]
value: Value tensor, shape [batch, num_heads, seq_length, head_size]
Returns:
Attention output, shape [batch, num_heads, seq_length, head_size]
"""
# Compute attention scores
attn_weights = query @ key.transpose(-1, -2)
# Scale attention weights
attn_weights = attn_weights / math.sqrt(int(value.shape[-1]))
# Apply causal mask
seq_len = query.shape[-2]
mask = causal_mask(seq_len, 0, dtype=query.dtype, device=query.device)
attn_weights = attn_weights + mask
# Softmax and weighted sum
attn_weights = F.softmax(attn_weights)
attn_output = attn_weights @ value
return attn_output
def __call__(self, hidden_states):
"""Apply multi-head attention.
Args:
hidden_states: Input tensor, shape [batch, seq_length, n_embd]
Returns:
Attention output, shape [batch, seq_length, n_embd]
"""
# Project to Q, K, V
qkv = self.c_attn(hidden_states)
query, key, value = F.split(
qkv, [self.split_size, self.split_size, self.split_size], axis=-1
)
# Split into multiple heads
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# Apply attention
attn_output = self._attn(query, key, value)
# Merge heads back
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
# Output projection
attn_output = self.c_proj(attn_output)
return attn_output
Next: In Step 08, you’ll implement residual connections and layer normalization to enable training deep transformer networks.