Step 09: Transformer block

Learn to combine attention, MLP, layer normalization, and residual connections into a complete transformer block.

Building the transformer block

In this step, you’ll build the GPT2Block class. This is a fundamental repeating unit of GPT-2. Each block combines multi-head attention and a feed-forward network, with layer normalization and residual connections around each.

The block processes input through two sequential operations. First, it applies layer norm, runs multi-head attention, then adds the result back to the input (residual connection). Second, it applies another layer norm, runs the MLP, and adds that result back. This pattern is x = x + sublayer(layer_norm(x)), called pre-normalization.

GPT-2 uses pre-norm because it stabilizes training in deep networks. By normalizing before each sublayer instead of after, gradients flow more smoothly through the network’s 12 stacked blocks.

Understanding the components

The transformer block consists of four components, applied in this order:

First layer norm (ln_1): Normalizes the input before attention. Uses epsilon=1e-5 for numerical stability.

Multi-head attention (attn): The self-attention mechanism from Step 07. Lets each position attend to all previous positions.

Second layer norm (ln_2): Normalizes before the MLP. Same configuration as the first.

Feed-forward network (mlp): The position-wise MLP from Step 04. Expands to 3,072 dimensions internally (4× the embedding size), then projects back to 768.

The block maintains a constant 768-dimensional representation throughout. Input shape [batch, seq_length, 768] stays the same after each sublayer, which is essential for stacking 12 blocks together.

Understanding the flow

Each sublayer follows the pre-norm pattern:

  1. Save the input as residual
  2. Apply layer normalization to the input
  3. Process through the sublayer (attention or MLP)
  4. Add the original residual back to the output

This happens twice per block, once for attention and once for the MLP. The residual connections let gradients flow directly through the network, preventing vanishing gradients in deep models.

Component names (ln_1, attn, ln_2, mlp) match Hugging Face’s GPT-2 implementation. This matters for loading pretrained weights in later steps.

Implementing the block

You’ll create the GPT2Block class by composing the components from earlier steps. The block takes GPT2Config and creates four sublayers, then applies them in sequence with residual connections.

First, import the required modules. You’ll need Module from MAX, plus the previously implemented components: GPT2Config, GPT2MLP, GPT2MultiHeadAttention, and LayerNorm.

In the __init__ method, create the four sublayers:

  • ln_1: LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  • attn: GPT2MultiHeadAttention(config)
  • ln_2: LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  • mlp: GPT2MLP(4 * config.n_embd, config)

The MLP uses 4 * config.n_embd (3,072 dimensions) as its inner dimension, following the standard transformer ratio.

In the forward method, implement the two sublayer blocks:

Attention block:

  1. Save residual = hidden_states
  2. Normalize: hidden_states = self.ln_1(hidden_states)
  3. Apply attention: attn_output = self.attn(hidden_states)
  4. Add back: hidden_states = attn_output + residual

MLP block:

  1. Save residual = hidden_states
  2. Normalize: hidden_states = self.ln_2(hidden_states)
  3. Apply MLP: feed_forward_hidden_states = self.mlp(hidden_states)
  4. Add back: hidden_states = residual + feed_forward_hidden_states

Finally, return hidden_states.

Implementation (step_09.py):

"""
Step 09: Transformer Block

Combine multi-head attention, MLP, layer normalization, and residual
connections into a complete transformer block.

Tasks:
1. Import Module and all previous solution components
2. Create ln_1, attn, ln_2, and mlp layers
3. Implement forward pass with pre-norm residual pattern

Run: pixi run s09
"""

# TODO: Import required modules
# Hint: You'll need Module from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import GPT2MLP from solutions.solution_04
# Hint: Import GPT2MultiHeadAttention from solutions.solution_07
# Hint: Import LayerNorm from solutions.solution_08


class GPT2Block(Module):
    """Complete GPT-2 transformer block."""

    def __init__(self, config: GPT2Config):
        """Initialize transformer block.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        hidden_size = config.n_embd
        inner_dim = (
            config.n_inner
            if hasattr(config, "n_inner") and config.n_inner is not None
            else 4 * hidden_size
        )

        # TODO: Create first layer norm (before attention)
        # Hint: Use LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.ln_1 = None

        # TODO: Create multi-head attention
        # Hint: Use GPT2MultiHeadAttention(config)
        self.attn = None

        # TODO: Create second layer norm (before MLP)
        # Hint: Use LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.ln_2 = None

        # TODO: Create MLP
        # Hint: Use GPT2MLP(inner_dim, config)
        self.mlp = None

    def __call__(self, hidden_states):
        """Apply transformer block.

        Args:
            hidden_states: Input tensor, shape [batch, seq_length, n_embd]

        Returns:
            Output tensor, shape [batch, seq_length, n_embd]
        """
        # TODO: Attention block with residual connection
        # Hint: residual = hidden_states
        # Hint: hidden_states = self.ln_1(hidden_states)
        # Hint: attn_output = self.attn(hidden_states)
        # Hint: hidden_states = attn_output + residual
        pass

        # TODO: MLP block with residual connection
        # Hint: residual = hidden_states
        # Hint: hidden_states = self.ln_2(hidden_states)
        # Hint: feed_forward_hidden_states = self.mlp(hidden_states)
        # Hint: hidden_states = residual + feed_forward_hidden_states
        pass

        # TODO: Return the output
        return None

Validation

Run pixi run s09 to verify your implementation.

Show solution
"""
Solution for Step 09: Transformer Block

This module implements a complete GPT-2 transformer block, combining
multi-head attention, MLP, layer normalization, and residual connections.
"""

from max.nn.module_v3 import Module

from solutions.solution_01 import GPT2Config
from solutions.solution_04 import GPT2MLP
from solutions.solution_07 import GPT2MultiHeadAttention
from solutions.solution_08 import LayerNorm


class GPT2Block(Module):
    """Complete GPT-2 transformer block matching HuggingFace structure.

    Architecture (pre-norm):
    1. x = x + attention(layer_norm(x))
    2. x = x + mlp(layer_norm(x))
    """

    def __init__(self, config: GPT2Config):
        """Initialize transformer block.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        hidden_size = config.n_embd
        # Inner dimension for MLP (4x hidden size by default)
        inner_dim = (
            config.n_inner
            if hasattr(config, "n_inner") and config.n_inner is not None
            else 4 * hidden_size
        )

        # First layer norm (before attention)
        self.ln_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        # Multi-head attention
        self.attn = GPT2MultiHeadAttention(config)
        # Second layer norm (before MLP)
        self.ln_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        # Feed-forward MLP
        self.mlp = GPT2MLP(inner_dim, config)

    def __call__(self, hidden_states):
        """Apply transformer block.

        Args:
            hidden_states: Input tensor, shape [batch, seq_length, n_embd]

        Returns:
            Output tensor, shape [batch, seq_length, n_embd]
        """
        # Attention block with residual connection
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(hidden_states)
        hidden_states = attn_output + residual

        # MLP block with residual connection
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = residual + feed_forward_hidden_states

        return hidden_states

Next: In Step 10, you’ll stack 12 transformer blocks together to create the complete GPT-2 model architecture.