Step 07: Multi-head attention

Learn to use multi-head attention, enabling the model to attend to different representation subspaces.

Building multi-head attention

In this step, you’ll implement the GPT2MultiHeadAttention class that runs 12 attention operations in parallel. Instead of computing attention once over the full 768-dimensional space, you split the dimensions into 12 heads of 64 dimensions each. Each head learns to focus on different patterns.

GPT-2 uses 12 heads with 768-dimensional embeddings, giving each head 768 ÷ 12 = 64 dimensions. The Q, K, V tensors are reshaped to split the embedding dimension across heads, attention is computed for all heads in parallel, then the outputs are concatenated back together. This happens in a single efficient operation using tensor reshaping and broadcasting.

Multiple heads let the model learn complementary attention strategies. Different heads can specialize in different relationships, such as one that might attend to adjacent tokens, another to syntactic patterns, and another to semantic similarity. This increases the model’s capacity without dramatically increasing computation.

Understanding the architecture

Multi-head attention splits the embedding dimension, computes attention independently for each head, then merges the results. This requires careful tensor reshaping to organize the computation efficiently.

Head splitting: Transform from [batch, seq_length, 768] to [batch, 12, seq_length, 64]. First reshape to add the head dimension: [batch, seq_length, 12, 64]. Then transpose to move heads before sequence: [batch, 12, seq_length, 64]. Now each of the 12 heads operates independently on its 64-dimensional subspace.

Parallel attention: With shape [batch, num_heads, seq_length, head_dim], you can compute attention for all heads simultaneously. The matrix multiplication Q @ K^T operates on the last two dimensions [seq_length, head_dim] @ [head_dim, seq_length], broadcasting across the batch and head dimensions. All 12 heads computed in a single efficient operation.

Head merging: Reverse the splitting to go from [batch, 12, seq_length, 64] back to [batch, seq_length, 768]. First transpose to [batch, seq_length, 12, 64], then reshape to flatten the head dimension: [batch, seq_length, 768]. This concatenates all head outputs back into the original dimension.

Output projection (c_proj): After merging heads, apply a learned linear transformation that maps [batch, seq_length, 768] to [batch, seq_length, 768]. This lets the model mix information across heads, combining the different perspectives each head learned.

The layer names c_attn (combined Q/K/V projection) and c_proj (output projection) match Hugging Face’s GPT-2 implementation. This naming is essential for loading pretrained weights.

MAX operations

You’ll use the following MAX operations to complete this task:

Linear layers:

Tensor operations:

  • tensor.reshape(new_shape): Splits or merges head dimension
  • tensor.transpose(axis1, axis2): Rearranges dimensions for parallel attention
  • F.split(tensor, split_sizes, axis): Divides Q/K/V from combined projection

Implementing multi-head attention

You’ll create the GPT2MultiHeadAttention class with helper methods for splitting and merging heads. The implementation builds on the attention mechanism from Step 02, extending it to work with multiple heads in parallel.

First, import the required modules. You’ll need math for scaling, functional as F for operations, Tensor for type hints, device and dtype utilities, and Linear and Module from MAX’s neural network module. You’ll also reuse the causal_mask function from Step 02.

In the __init__ method, create the projection layers and store configuration:

  • Combined Q/K/V projection: Linear(embed_dim, 3 * embed_dim, bias=True) stored as self.c_attn
  • Output projection: Linear(embed_dim, embed_dim, bias=True) stored as self.c_proj
  • Store self.num_heads (12) and self.head_dim (64) from config
  • Calculate self.split_size for splitting Q, K, V later

Implement _split_heads to reshape for parallel attention:

  • Calculate new shape by replacing the last dimension: tensor.shape[:-1] + [num_heads, attn_head_size]
  • Reshape to add the head dimension: tensor.reshape(new_shape)
  • Transpose to move heads to position 1: tensor.transpose(-3, -2)
  • Returns shape [batch, num_heads, seq_length, head_size]

Implement _merge_heads to concatenate head outputs:

  • Transpose to move heads back: tensor.transpose(-3, -2)
  • Calculate flattened shape: tensor.shape[:-2] + [num_heads * attn_head_size]
  • Reshape to merge heads: tensor.reshape(new_shape)
  • Returns shape [batch, seq_length, n_embd]

Implement _attn to compute scaled dot-product attention for all heads:

  • Compute attention scores: query @ key.transpose(-2, -1)
  • Scale by square root of head dimension
  • Apply causal mask to prevent attending to future positions
  • Apply softmax to get attention weights
  • Multiply weights by values: attn_weights @ value

In the forward method, orchestrate the complete multi-head attention:

  • Project to Q/K/V: qkv = self.c_attn(hidden_states)
  • Split into separate tensors: F.split(qkv, [self.split_size, self.split_size, self.split_size], axis=-1)
  • Split heads for each: query = self._split_heads(query, self.num_heads, self.head_dim) (repeat for key, value)
  • Compute attention: attn_output = self._attn(query, key, value)
  • Merge heads: attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  • Final projection: return self.c_proj(attn_output)

Implementation (step_07.py):

"""
Step 07: Multi-head Attention

Implement multi-head attention that splits Q/K/V into multiple heads,
computes attention in parallel for each head, and merges the results.

Tasks:
1. Import required modules (math, F, Tensor, Linear, Module, etc.)
2. Create c_attn and c_proj linear layers
3. Implement _split_heads: reshape and transpose to add head dimension
4. Implement _merge_heads: transpose and reshape to remove head dimension
5. Implement _attn: compute attention for all heads in parallel
6. Implement forward pass: project -> split -> attend -> merge -> project

Run: pixi run s07
"""

# TODO: Import required modules
# Hint: You'll need math for scaling
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor, Device, DType from max.experimental.tensor and max.driver
# Hint: You'll need Dim, DimLike from max.graph
# Hint: You'll also need Linear and Module from max.nn.module_v3

from solutions.solution_01 import GPT2Config


# TODO: Copy causal_mask function from solution_02.py
# This is the same function you implemented in Step 02


class GPT2MultiHeadAttention(Module):
    """Multi-head attention for GPT-2."""

    def __init__(self, config: GPT2Config):
        super().__init__()

        self.embed_dim = config.n_embd
        self.num_heads = config.n_head
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim

        # TODO: Create combined Q/K/V projection
        # Hint: Use Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
        self.c_attn = None

        # TODO: Create output projection
        # Hint: Use Linear(self.embed_dim, self.embed_dim, bias=True)
        self.c_proj = None

    def _split_heads(self, tensor, num_heads, attn_head_size):
        """Split the last dimension into (num_heads, head_size).

        Args:
            tensor: Input tensor, shape [batch, seq_length, n_embd]
            num_heads: Number of attention heads
            attn_head_size: Dimension of each head

        Returns:
            Tensor with shape [batch, num_heads, seq_length, head_size]
        """
        # TODO: Add head dimension
        # Hint: new_shape = tensor.shape[:-1] + [num_heads, attn_head_size]
        # Hint: tensor = tensor.reshape(new_shape)
        pass

        # TODO: Move heads dimension to position 1
        # Hint: return tensor.transpose(-3, -2)
        return None

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """Merge attention heads back to original shape.

        Args:
            tensor: Input tensor, shape [batch, num_heads, seq_length, head_size]
            num_heads: Number of attention heads
            attn_head_size: Dimension of each head

        Returns:
            Tensor with shape [batch, seq_length, n_embd]
        """
        # TODO: Move heads dimension back
        # Hint: tensor = tensor.transpose(-3, -2)
        pass

        # TODO: Flatten head dimensions
        # Hint: new_shape = tensor.shape[:-2] + [num_heads * attn_head_size]
        # Hint: return tensor.reshape(new_shape)
        return None

    def _attn(self, query, key, value):
        """Compute attention for all heads in parallel.

        Args:
            query: Query tensor, shape [batch, num_heads, seq_length, head_size]
            key: Key tensor, shape [batch, num_heads, seq_length, head_size]
            value: Value tensor, shape [batch, num_heads, seq_length, head_size]

        Returns:
            Attention output, shape [batch, num_heads, seq_length, head_size]
        """
        # TODO: Implement attention computation
        # The same 5-step process: scores, scale, mask, softmax, weighted sum
        # Hint: Compute attention scores: query @ key.transpose(-1, -2)
        # Hint: Scale by sqrt(head_dim): attn_weights / math.sqrt(head_dim)
        # Hint: Apply causal mask using causal_mask function
        # Hint: Apply softmax: F.softmax(attn_weights)
        # Hint: Weighted sum: attn_weights @ value
        return None

    def __call__(self, hidden_states):
        """Apply multi-head attention.

        Args:
            hidden_states: Input tensor, shape [batch, seq_length, n_embd]

        Returns:
            Attention output, shape [batch, seq_length, n_embd]
        """
        # TODO: Project to Q, K, V
        # Hint: qkv = self.c_attn(hidden_states)
        # Hint: query, key, value = F.split(qkv, [self.split_size, self.split_size, self.split_size], axis=-1)
        pass

        # TODO: Split into multiple heads
        # Hint: query = self._split_heads(query, self.num_heads, self.head_dim)
        # Hint: key = self._split_heads(key, self.num_heads, self.head_dim)
        # Hint: value = self._split_heads(value, self.num_heads, self.head_dim)
        pass

        # TODO: Apply attention
        # Hint: attn_output = self._attn(query, key, value)
        pass

        # TODO: Merge heads back
        # Hint: attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        pass

        # TODO: Output projection
        # Hint: attn_output = self.c_proj(attn_output)
        # Hint: return attn_output
        return None

Validation

Run pixi run s07 to verify your implementation.

Show solution
"""
Solution for Step 07: Multi-head Attention

This module implements multi-head attention, which allows the model to jointly
attend to information from different representation subspaces at different positions.
"""

import math

from max.driver import Device
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import Dim, DimLike
from max.nn.module_v3 import Linear, Module

from solutions.solution_01 import GPT2Config


@F.functional
def causal_mask(
    sequence_length: DimLike,
    num_tokens: DimLike,
    *,
    dtype: DType,
    device: Device,
):
    """Create a causal attention mask."""
    n = Dim(sequence_length) + num_tokens
    mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
    mask = F.broadcast_to(mask, shape=(sequence_length, n))
    return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)


class GPT2MultiHeadAttention(Module):
    """Multi-head attention for GPT-2, matching HuggingFace structure."""

    def __init__(self, config: GPT2Config):
        """Initialize multi-head attention.

        Args:
            config: GPT2Config containing n_embd and n_head
        """
        super().__init__()

        self.embed_dim = config.n_embd
        self.num_heads = config.n_head
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim

        # Combined Q/K/V projection
        self.c_attn = Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
        # Output projection
        self.c_proj = Linear(self.embed_dim, self.embed_dim, bias=True)

    def _split_heads(self, tensor, num_heads, attn_head_size):
        """Split the last dimension into (num_heads, head_size).

        Transforms shape from [batch, seq_length, n_embd]
        to [batch, num_heads, seq_length, head_size]

        Args:
            tensor: Input tensor, shape [batch, seq_length, n_embd]
            num_heads: Number of attention heads
            attn_head_size: Dimension of each head

        Returns:
            Tensor with shape [batch, num_heads, seq_length, head_size]
        """
        # Add head dimension: [batch, seq_length, n_embd] -> [batch, seq_length, num_heads, head_size]
        new_shape = tensor.shape[:-1] + [num_heads, attn_head_size]
        tensor = tensor.reshape(new_shape)
        # Move heads dimension: [batch, seq_length, num_heads, head_size] -> [batch, num_heads, seq_length, head_size]
        return tensor.transpose(-3, -2)

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """Merge attention heads back to original shape.

        Transforms shape from [batch, num_heads, seq_length, head_size]
        to [batch, seq_length, n_embd]

        Args:
            tensor: Input tensor, shape [batch, num_heads, seq_length, head_size]
            num_heads: Number of attention heads
            attn_head_size: Dimension of each head

        Returns:
            Tensor with shape [batch, seq_length, n_embd]
        """
        # Move heads dimension back: [batch, num_heads, seq_length, head_size] -> [batch, seq_length, num_heads, head_size]
        tensor = tensor.transpose(-3, -2)
        # Flatten head dimensions: [batch, seq_length, num_heads, head_size] -> [batch, seq_length, n_embd]
        new_shape = tensor.shape[:-2] + [num_heads * attn_head_size]
        return tensor.reshape(new_shape)

    def _attn(self, query, key, value):
        """Compute attention for all heads in parallel.

        Args:
            query: Query tensor, shape [batch, num_heads, seq_length, head_size]
            key: Key tensor, shape [batch, num_heads, seq_length, head_size]
            value: Value tensor, shape [batch, num_heads, seq_length, head_size]

        Returns:
            Attention output, shape [batch, num_heads, seq_length, head_size]
        """
        # Compute attention scores
        attn_weights = query @ key.transpose(-1, -2)

        # Scale attention weights
        attn_weights = attn_weights / math.sqrt(int(value.shape[-1]))

        # Apply causal mask
        seq_len = query.shape[-2]
        mask = causal_mask(seq_len, 0, dtype=query.dtype, device=query.device)
        attn_weights = attn_weights + mask

        # Softmax and weighted sum
        attn_weights = F.softmax(attn_weights)
        attn_output = attn_weights @ value

        return attn_output

    def __call__(self, hidden_states):
        """Apply multi-head attention.

        Args:
            hidden_states: Input tensor, shape [batch, seq_length, n_embd]

        Returns:
            Attention output, shape [batch, seq_length, n_embd]
        """
        # Project to Q, K, V
        qkv = self.c_attn(hidden_states)
        query, key, value = F.split(
            qkv, [self.split_size, self.split_size, self.split_size], axis=-1
        )

        # Split into multiple heads
        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        # Apply attention
        attn_output = self._attn(query, key, value)

        # Merge heads back
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)

        # Output projection
        attn_output = self.c_proj(attn_output)

        return attn_output

Next: In Step 08, you’ll implement residual connections and layer normalization to enable training deep transformer networks.