Multi-head attention

Implement scaled dot-product attention with multiple heads, enabling the model to attend to different representation subspaces.

GPT2MultiHeadAttention runs 12 attention operations in parallel. Instead of computing attention once over the full 768-dimensional space, it splits the dimensions into 12 heads of 64 dimensions each. Each head independently learns to focus on different patterns—syntactic structure, semantic similarity, positional relationships, and so on.

GPT-2 uses 12 heads with 768-dimensional embeddings, giving each head 768 ÷ 12 = 64 dimensions. The Q, K, V tensors are reshaped to split the embedding across heads, attention is computed for all heads in parallel via broadcasting, then the outputs are concatenated back. The whole computation happens in a single efficient sequence of tensor operations.

Head splitting and merging

Splitting transforms from [batch, seq_length, 768] to [batch, 12, seq_length, 64]. First reshape to add the head dimension: [batch, seq_length, 12, 64], then transpose to move heads before the sequence dimension: [batch, 12, seq_length, 64]. Now each of the 12 heads operates independently on its 64-dimensional subspace.

Merging reverses the process: transpose back to [batch, seq_length, 12, 64], then reshape to flatten the head dimension: [batch, seq_length, 768]. This concatenates all head outputs back into the original dimension.

Scaled dot-product attention

With shape [batch, num_heads, seq_length, head_dim], computing attention for all heads simultaneously is just a matrix multiplication across the last two dimensions. The scaling factor 1 / sqrt(head_dim) prevents the dot products from growing too large as head dimension increases, which would push softmax into regions with very small gradients.

The causal mask from Section 3 is added to the attention scores before softmax, masking out future positions.

After the output projection (c_proj), the model can mix information across heads—combining the different perspectives each head learned.

The layer names c_attn (combined Q/K/V projection) and c_proj (output projection) match Hugging Face’s GPT-2 implementation for weight loading.

The code

class GPT2MultiHeadAttention(Module):  # type: ignore[type-arg]
    """Exact HuggingFace GPT-2 attention structure"""

    def __init__(self, config: GPT2Config) -> None:
        self.embed_dim = config.n_embd
        self.num_heads = config.n_head
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim

        self.c_attn = Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
        self.c_proj = Linear(self.embed_dim, self.embed_dim, bias=True)

    def _attn(
        self,
        query: Tensor | TensorValue,
        key: Tensor | TensorValue,
        value: Tensor | TensorValue,
    ) -> Tensor | TensorValue:
        attn_weights = query @ key.transpose(-1, -2)

        # Scale attention weights
        attn_weights = attn_weights / math.sqrt(int(value.shape[-1]))

        # Apply causal mask
        seq_len = query.shape[-2]
        mask = causal_mask(seq_len, 0, dtype=query.dtype, device=query.device)
        attn_weights = attn_weights + mask

        attn_weights = F.softmax(attn_weights)
        attn_output = attn_weights @ value

        return attn_output

    def _split_heads(
        self, tensor: Tensor | TensorValue, num_heads: int, attn_head_size: int
    ) -> Tensor | TensorValue:
        """Split the last dimension into (num_heads, head_size)"""
        new_shape = list(tensor.shape[:-1]) + [num_heads, attn_head_size]
        tensor = tensor.reshape(new_shape)
        return tensor.transpose(-3, -2)  # (batch, head, seq_length, head_features)

    def _merge_heads(
        self, tensor: Tensor | TensorValue, num_heads: int, attn_head_size: int
    ) -> Tensor | TensorValue:
        """Merge attention heads back"""
        tensor = tensor.transpose(-3, -2)
        new_shape = list(tensor.shape[:-2]) + [num_heads * attn_head_size]
        return tensor.reshape(new_shape)

    def forward(self, hidden_states: Tensor) -> Tensor:
        split_result = F.split(
            self.c_attn(hidden_states),
            [self.split_size, self.split_size, self.split_size],
            axis=2,
        )
        query = cast(Tensor | TensorValue, split_result[0])
        key = cast(Tensor | TensorValue, split_result[1])
        value = cast(Tensor | TensorValue, split_result[2])

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        attn_output = self._attn(query, key, value)
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.c_proj(cast(Tensor, attn_output))

        return cast(Tensor, attn_output)


F.split divides the combined Q/K/V projection into three equal tensors along the last axis. The cast calls are needed because MAX’s type system requires explicit casts at certain boundaries between Tensor and TensorValue.

Next: Section 5 implements layer normalization, which normalizes activations before each sublayer in the transformer block.