Multi-head attention
Implement scaled dot-product attention with multiple heads, enabling the model to attend to different representation subspaces.
GPT2MultiHeadAttention runs 12 attention operations in parallel. Instead of
computing attention once over the full 768-dimensional space, it splits the
dimensions into 12 heads of 64 dimensions each. Each head independently learns
to focus on different patterns—syntactic structure, semantic similarity,
positional relationships, and so on.
GPT-2 uses 12 heads with 768-dimensional embeddings, giving each head 768 ÷ 12 = 64 dimensions. The Q, K, V tensors are reshaped to split the embedding across heads, attention is computed for all heads in parallel via broadcasting, then the outputs are concatenated back. The whole computation happens in a single efficient sequence of tensor operations.
Head splitting and merging
Splitting transforms from [batch, seq_length, 768] to
[batch, 12, seq_length, 64]. First reshape to add the head dimension:
[batch, seq_length, 12, 64], then transpose to move heads before the sequence
dimension: [batch, 12, seq_length, 64]. Now each of the 12 heads operates
independently on its 64-dimensional subspace.
Merging reverses the process: transpose back to
[batch, seq_length, 12, 64], then reshape to flatten the head dimension:
[batch, seq_length, 768]. This concatenates all head outputs back into the
original dimension.
Scaled dot-product attention
With shape [batch, num_heads, seq_length, head_dim], computing attention for
all heads simultaneously is just a matrix multiplication across the last two
dimensions. The scaling factor 1 / sqrt(head_dim) prevents the dot products
from growing too large as head dimension increases, which would push softmax
into regions with very small gradients.
The causal mask from Section 3 is added to the attention scores before softmax, masking out future positions.
After the output projection (c_proj), the model can mix information across
heads—combining the different perspectives each head learned.
The layer names c_attn (combined Q/K/V projection) and c_proj (output
projection) match Hugging Face’s GPT-2 implementation for weight loading.
The code
class GPT2MultiHeadAttention(Module): # type: ignore[type-arg]
"""Exact HuggingFace GPT-2 attention structure"""
def __init__(self, config: GPT2Config) -> None:
self.embed_dim = config.n_embd
self.num_heads = config.n_head
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
self.c_attn = Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
self.c_proj = Linear(self.embed_dim, self.embed_dim, bias=True)
def _attn(
self,
query: Tensor | TensorValue,
key: Tensor | TensorValue,
value: Tensor | TensorValue,
) -> Tensor | TensorValue:
attn_weights = query @ key.transpose(-1, -2)
# Scale attention weights
attn_weights = attn_weights / math.sqrt(int(value.shape[-1]))
# Apply causal mask
seq_len = query.shape[-2]
mask = causal_mask(seq_len, 0, dtype=query.dtype, device=query.device)
attn_weights = attn_weights + mask
attn_weights = F.softmax(attn_weights)
attn_output = attn_weights @ value
return attn_output
def _split_heads(
self, tensor: Tensor | TensorValue, num_heads: int, attn_head_size: int
) -> Tensor | TensorValue:
"""Split the last dimension into (num_heads, head_size)"""
new_shape = list(tensor.shape[:-1]) + [num_heads, attn_head_size]
tensor = tensor.reshape(new_shape)
return tensor.transpose(-3, -2) # (batch, head, seq_length, head_features)
def _merge_heads(
self, tensor: Tensor | TensorValue, num_heads: int, attn_head_size: int
) -> Tensor | TensorValue:
"""Merge attention heads back"""
tensor = tensor.transpose(-3, -2)
new_shape = list(tensor.shape[:-2]) + [num_heads * attn_head_size]
return tensor.reshape(new_shape)
def forward(self, hidden_states: Tensor) -> Tensor:
split_result = F.split(
self.c_attn(hidden_states),
[self.split_size, self.split_size, self.split_size],
axis=2,
)
query = cast(Tensor | TensorValue, split_result[0])
key = cast(Tensor | TensorValue, split_result[1])
value = cast(Tensor | TensorValue, split_result[2])
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
attn_output = self._attn(query, key, value)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(cast(Tensor, attn_output))
return cast(Tensor, attn_output)
F.split divides the combined Q/K/V projection into three equal tensors along
the last axis. The cast calls are needed because MAX’s type system requires
explicit casts at certain boundaries between Tensor and TensorValue.
Next: Section 5 implements layer normalization, which normalizes activations before each sublayer in the transformer block.