Stacking transformer blocks
Learn to stack 12 transformer blocks with embeddings and final normalization to create the complete GPT-2 model.
In this step, you’ll create the body of the GPT-2 model (the MaxGPT2Model module)
as a sequence of transformer blocks (GPT2Block) plus LayerNorm. And because
the model body receives raw token IDs during inference, you’ll also have to first
convert the token IDs into token embeddings that are suitable for processing by the
transformer blocks and the rest of the neural network.
The model processes input in four stages: convert token IDs to embeddings, add position information, pass through 12 transformer blocks sequentially, and normalize the final output. Each transformer block refines the representation, building up from surface patterns in early layers to semantic understanding in later layers.
GPT-2 uses 12 layers because this depth allows the model to learn complex patterns while remaining trainable. Fewer layers would limit the model’s capacity. More layers would increase training difficulty without proportional gains in quality for a 117M parameter model.
Understanding the components
The complete model has four main components:
Token embeddings (wte): Maps each token ID to a 768-dimensional vector
using a lookup table with 50,257 entries (one per vocabulary token).
Position embeddings (wpe): Maps each position (0 to 1,023) to a
768-dimensional vector. These are added to token embeddings so the model knows
token order.
Transformer blocks (h): 12 identical blocks stacked using MAX’s
Sequential
module. Sequential applies blocks in order, passing each block’s output to the
next.
Final layer norm (ln_f): Normalizes the output after all blocks. This
stabilizes the representation before the language model head (added in Step 11)
projects to vocabulary logits.
Understanding the forward pass
The forward method processes token IDs through the model:
First, create position indices using
Tensor.arange.
Generate positions [0, 1, 2, …, seq_length-1] matching the input’s dtype and
device. This ensures compatibility when adding to embeddings.
Next, look up embeddings. Get token embeddings with self.wte(input_ids) and
position embeddings with self.wpe(position_indices). Add them together
element-wise, as both are shape [batch, seq_length, 768].
Then, pass through the transformer blocks with self.h(x). The
Sequential
module applies all 12 transformer blocks in order, each refining the representation.
Finally, normalize the output with self.ln_f(x) and return the result. The
output shape matches the input: [batch, seq_length, 768].
You’ll use the following MAX operations to complete this task:
Module composition:
Sequential(*modules): Chains transformer blocks in sequence
Embeddings:
Embedding(num_embeddings, dim): Token and position embeddings
Position generation:
Tensor.arange(seq_length, dtype, device): Creates position indices
Implementing the model
You’ll create the MaxGPT2Model class by composing embedding layers, transformer
blocks, and layer normalization. The class builds on all the components from
previous steps.
First, import the required modules. You’ll need Tensor for position indices,
Embedding, Module, and Sequential from MAX’s neural network module, plus
the previously implemented GPT2Config, LayerNorm, and GPT2Block.
In the __init__ method, create the four components:
- Token embeddings:
Embedding(config.vocab_size, dim=config.n_embd)stored asself.wte - Position embeddings:
Embedding(config.n_positions, dim=config.n_embd)stored asself.wpe - Transformer blocks:
Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))stored asself.h - Final layer norm:
LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)stored asself.ln_f
The Sequential module takes a generator expression that creates 12 identical
GPT2Block instances. The * unpacks them as arguments to Sequential.
In the forward method, implement the four-stage processing:
- Get the sequence length from
input_ids.shape - Create position indices:
Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device) - Look up embeddings and add them:
x = self.wte(input_ids) + self.wpe(position_indices) - Apply transformer blocks:
x = self.h(x) - Apply final normalization:
x = self.ln_f(x) - Return
x
The position indices must match the input’s dtype and device to ensure the tensors are compatible for addition.
Implementation (step_07.py):
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Step 07: Stacking Transformer Blocks
Stack multiple transformer blocks with embeddings to create
the complete GPT-2 model architecture.
Tasks:
1. Import Tensor, Embedding, Module, Sequential, and previous components
2. Create token and position embeddings
3. Stack n_layer transformer blocks using Sequential
4. Create final layer normalization
5. Implement forward pass: embeddings -> blocks -> layer norm
Run: pixi run s10
"""
# TODO: Import required modules
# Hint: You'll need Tensor from max.tensor
# Hint: You'll need Embedding, Module, Sequential from max.nn
from max.tensor import Tensor
from step_01 import GPT2Config
class MaxGPT2Model(Module):
"""Complete GPT-2 transformer model."""
def __init__(self, config: GPT2Config) -> None:
"""Initialize GPT-2 model.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
# TODO: Create token embeddings
# Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
self.wte = None
# TODO: Create position embeddings
# Hint: Use Embedding(config.n_positions, dim=config.n_embd)
self.wpe = None
# TODO: Stack transformer blocks
# Hint: Use Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
# This creates config.n_layer blocks (12 for GPT-2 base)
self.h = None
# TODO: Create final layer normalization
# Hint: Use LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.ln_f = None
def forward(self, input_ids: Tensor) -> Tensor:
"""Forward pass through the transformer.
Args:
input_ids: Token IDs, shape [batch, seq_length]
Returns:
Hidden states, shape [batch, seq_length, n_embd]
"""
# TODO: Get batch size and sequence length
# Hint: batch_size, seq_length = input_ids.shape
pass
# TODO: Get token embeddings
# Hint: tok_embeds = self.wte(input_ids)
pass
# TODO: Get position embeddings
# Hint: Create position indices with Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
# Hint: pos_embeds = self.wpe(position_indices)
pass
# TODO: Combine embeddings
# Hint: x = tok_embeds + pos_embeds
pass
# TODO: Apply transformer blocks
# Hint: x = self.h(x)
pass
# TODO: Apply final layer norm
# Hint: x = self.ln_f(x)
pass
# TODO: Return the output
return None
Validation
Run pixi run s07 to verify your implementation.
Show solution
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Solution for Step 07: Stacking Transformer Blocks
This module stacks multiple transformer blocks and adds embeddings
to create the complete GPT-2 transformer architecture.
"""
from max.nn import Embedding, Module, Sequential
from max.tensor import Tensor
from step_01 import GPT2Config
from step_05 import LayerNorm
from step_06 import GPT2Block
class MaxGPT2Model(Module):
"""Complete GPT-2 transformer model matching HuggingFace structure.
Architecture:
1. Token embeddings + position embeddings
2. Stack of n_layer transformer blocks
3. Final layer normalization
"""
def __init__(self, config: GPT2Config) -> None:
"""Initialize GPT-2 model.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
# Token embeddings (vocabulary -> embeddings)
self.wte = Embedding(config.vocab_size, dim=config.n_embd)
# Position embeddings (positions -> embeddings)
self.wpe = Embedding(config.n_positions, dim=config.n_embd)
# Stack of transformer blocks
self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
# Final layer normalization
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(self, input_ids: Tensor) -> Tensor:
"""Forward pass through the transformer.
Args:
input_ids: Token IDs, shape [batch, seq_length]
Returns:
Hidden states, shape [batch, seq_length, n_embd]
"""
_, seq_length = input_ids.shape
# Get token embeddings
tok_embeds = self.wte(input_ids)
# Get position embeddings
pos_embeds = self.wpe(
Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
)
# Combine embeddings
x = tok_embeds + pos_embeds
# Apply transformer blocks
x = self.h(x)
# Final layer norm
x = self.ln_f(x)
return x
Next: In Step 08, you’ll add the language modeling head that projects hidden states to vocabulary logits for text generation.