Step 09: Transformer block
Learn to combine attention, MLP, layer normalization, and residual connections into a complete transformer block.
Building the transformer block
In this step, you’ll build the GPT2Block class. This is a fundamental
repeating unit of GPT-2. Each block combines multi-head attention and a
feed-forward network, with layer normalization and residual connections around
each.
The block processes input through two sequential operations. First, it applies
layer norm, runs multi-head attention, then adds the result back to the input
(residual connection). Second, it applies another layer norm, runs the MLP, and
adds that result back. This pattern is x = x + sublayer(layer_norm(x)), called
pre-normalization.
GPT-2 uses pre-norm because it stabilizes training in deep networks. By normalizing before each sublayer instead of after, gradients flow more smoothly through the network’s 12 stacked blocks.
Understanding the components
The transformer block consists of four components, applied in this order:
First layer norm (ln_1): Normalizes the input before attention. Uses epsilon=1e-5 for numerical stability.
Multi-head attention (attn): The self-attention mechanism from Step 07. Lets each position attend to all previous positions.
Second layer norm (ln_2): Normalizes before the MLP. Same configuration as the first.
Feed-forward network (mlp): The position-wise MLP from Step 04. Expands to 3,072 dimensions internally (4× the embedding size), then projects back to 768.
The block maintains a constant 768-dimensional representation throughout. Input
shape [batch, seq_length, 768] stays the same after each sublayer, which is
essential for stacking 12 blocks together.
Understanding the flow
Each sublayer follows the pre-norm pattern:
- Save the input as
residual - Apply layer normalization to the input
- Process through the sublayer (attention or MLP)
- Add the original
residualback to the output
This happens twice per block, once for attention and once for the MLP. The residual connections let gradients flow directly through the network, preventing vanishing gradients in deep models.
Component names (ln_1, attn, ln_2, mlp) match Hugging Face’s GPT-2
implementation. This matters for loading pretrained weights in later steps.
Implementing the block
You’ll create the GPT2Block class by composing the components from earlier
steps. The block takes GPT2Config and creates four sublayers, then applies
them in sequence with residual connections.
First, import the required modules. You’ll need Module from MAX, plus the
previously implemented components: GPT2Config, GPT2MLP,
GPT2MultiHeadAttention, and LayerNorm.
In the __init__ method, create the four sublayers:
ln_1:LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)attn:GPT2MultiHeadAttention(config)ln_2:LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)mlp:GPT2MLP(4 * config.n_embd, config)
The MLP uses 4 * config.n_embd (3,072 dimensions) as its inner dimension, following the standard transformer ratio.
In the forward method, implement the two sublayer blocks:
Attention block:
- Save
residual = hidden_states - Normalize:
hidden_states = self.ln_1(hidden_states) - Apply attention:
attn_output = self.attn(hidden_states) - Add back:
hidden_states = attn_output + residual
MLP block:
- Save
residual = hidden_states - Normalize:
hidden_states = self.ln_2(hidden_states) - Apply MLP:
feed_forward_hidden_states = self.mlp(hidden_states) - Add back:
hidden_states = residual + feed_forward_hidden_states
Finally, return hidden_states.
Implementation (step_09.py):
"""
Step 09: Transformer Block
Combine multi-head attention, MLP, layer normalization, and residual
connections into a complete transformer block.
Tasks:
1. Import Module and all previous solution components
2. Create ln_1, attn, ln_2, and mlp layers
3. Implement forward pass with pre-norm residual pattern
Run: pixi run s09
"""
# TODO: Import required modules
# Hint: You'll need Module from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import GPT2MLP from solutions.solution_04
# Hint: Import GPT2MultiHeadAttention from solutions.solution_07
# Hint: Import LayerNorm from solutions.solution_08
class GPT2Block(Module):
"""Complete GPT-2 transformer block."""
def __init__(self, config: GPT2Config):
"""Initialize transformer block.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
hidden_size = config.n_embd
inner_dim = (
config.n_inner
if hasattr(config, "n_inner") and config.n_inner is not None
else 4 * hidden_size
)
# TODO: Create first layer norm (before attention)
# Hint: Use LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_1 = None
# TODO: Create multi-head attention
# Hint: Use GPT2MultiHeadAttention(config)
self.attn = None
# TODO: Create second layer norm (before MLP)
# Hint: Use LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_2 = None
# TODO: Create MLP
# Hint: Use GPT2MLP(inner_dim, config)
self.mlp = None
def __call__(self, hidden_states):
"""Apply transformer block.
Args:
hidden_states: Input tensor, shape [batch, seq_length, n_embd]
Returns:
Output tensor, shape [batch, seq_length, n_embd]
"""
# TODO: Attention block with residual connection
# Hint: residual = hidden_states
# Hint: hidden_states = self.ln_1(hidden_states)
# Hint: attn_output = self.attn(hidden_states)
# Hint: hidden_states = attn_output + residual
pass
# TODO: MLP block with residual connection
# Hint: residual = hidden_states
# Hint: hidden_states = self.ln_2(hidden_states)
# Hint: feed_forward_hidden_states = self.mlp(hidden_states)
# Hint: hidden_states = residual + feed_forward_hidden_states
pass
# TODO: Return the output
return None
Validation
Run pixi run s09 to verify your implementation.
Show solution
"""
Solution for Step 09: Transformer Block
This module implements a complete GPT-2 transformer block, combining
multi-head attention, MLP, layer normalization, and residual connections.
"""
from max.nn.module_v3 import Module
from solutions.solution_01 import GPT2Config
from solutions.solution_04 import GPT2MLP
from solutions.solution_07 import GPT2MultiHeadAttention
from solutions.solution_08 import LayerNorm
class GPT2Block(Module):
"""Complete GPT-2 transformer block matching HuggingFace structure.
Architecture (pre-norm):
1. x = x + attention(layer_norm(x))
2. x = x + mlp(layer_norm(x))
"""
def __init__(self, config: GPT2Config):
"""Initialize transformer block.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
hidden_size = config.n_embd
# Inner dimension for MLP (4x hidden size by default)
inner_dim = (
config.n_inner
if hasattr(config, "n_inner") and config.n_inner is not None
else 4 * hidden_size
)
# First layer norm (before attention)
self.ln_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# Multi-head attention
self.attn = GPT2MultiHeadAttention(config)
# Second layer norm (before MLP)
self.ln_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# Feed-forward MLP
self.mlp = GPT2MLP(inner_dim, config)
def __call__(self, hidden_states):
"""Apply transformer block.
Args:
hidden_states: Input tensor, shape [batch, seq_length, n_embd]
Returns:
Output tensor, shape [batch, seq_length, n_embd]
"""
# Attention block with residual connection
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(hidden_states)
hidden_states = attn_output + residual
# MLP block with residual connection
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feed_forward_hidden_states
return hidden_states
Next: In Step 10, you’ll stack 12 transformer blocks together to create the complete GPT-2 model architecture.