Step 10: Stacking transformer blocks

Learn to stack 12 transformer blocks with embeddings and final normalization to create the complete GPT-2 model.

Building the complete model

In this step, you’ll create the GPT2Model class - the complete transformer that takes token IDs as input and outputs contextualized representations. This class combines embeddings, 12 stacked transformer blocks, and final layer normalization.

The model processes input in four stages: convert token IDs to embeddings, add position information, pass through 12 transformer blocks sequentially, and normalize the final output. Each transformer block refines the representation, building up from surface patterns in early layers to semantic understanding in later layers.

GPT-2 uses 12 layers because this depth allows the model to learn complex patterns while remaining trainable. Fewer layers would limit the model’s capacity. More layers would increase training difficulty without proportional gains in quality for a 117M parameter model.

Understanding the components

The complete model has four main components:

Token embeddings (wte): Maps each token ID to a 768-dimensional vector using a lookup table with 50,257 entries (one per vocabulary token).

Position embeddings (wpe): Maps each position (0 to 1,023) to a 768-dimensional vector. These are added to token embeddings so the model knows token order.

Transformer blocks (h): 12 identical blocks stacked using MAX’s Sequential module. Sequential applies blocks in order, passing each block’s output to the next.

Final layer norm (ln_f): Normalizes the output after all blocks. This stabilizes the representation before the language model head (added in Step 11) projects to vocabulary logits.

Understanding the forward pass

The forward method processes token IDs through the model:

First, create position indices using Tensor.arange. Generate positions [0, 1, 2, …, seq_length-1] matching the input’s dtype and device. This ensures compatibility when adding to embeddings.

Next, look up embeddings. Get token embeddings with self.wte(input_ids) and position embeddings with self.wpe(position_indices). Add them together element-wise, as both are shape [batch, seq_length, 768].

Then, pass through the transformer blocks with self.h(x). Sequential applies all 12 blocks in order, each refining the representation.

Finally, normalize the output with self.ln_f(x) and return the result. The output shape matches the input: [batch, seq_length, 768].

MAX operations

You’ll use the following MAX operations to complete this task:

Module composition:

Embeddings:

Position generation:

Implementing the model

You’ll create the GPT2Model class by composing embedding layers, transformer blocks, and layer normalization. The class builds on all the components from previous steps.

First, import the required modules. You’ll need Tensor for position indices, Embedding, Module, and Sequential from MAX’s neural network module, plus the previously implemented GPT2Config, LayerNorm, and GPT2Block.

In the __init__ method, create the four components:

  • Token embeddings: Embedding(config.vocab_size, dim=config.n_embd) stored as self.wte
  • Position embeddings: Embedding(config.n_positions, dim=config.n_embd) stored as self.wpe
  • Transformer blocks: Sequential(*(GPT2Block(config) for _ in range(config.n_layer))) stored as self.h
  • Final layer norm: LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) stored as self.ln_f

The Sequential module takes a generator expression that creates 12 identical GPT2Block instances. The * unpacks them as arguments to Sequential.

In the forward method, implement the four-stage processing:

  1. Get the sequence length from input_ids.shape
  2. Create position indices: Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
  3. Look up embeddings and add them: x = self.wte(input_ids) + self.wpe(position_indices)
  4. Apply transformer blocks: x = self.h(x)
  5. Apply final normalization: x = self.ln_f(x)
  6. Return x

The position indices must match the input’s dtype and device to ensure the tensors are compatible for addition.

Implementation (step_10.py):

"""
Step 10: Stacking Transformer Blocks

Stack multiple transformer blocks with embeddings to create
the complete GPT-2 model architecture.

Tasks:
1. Import Tensor, Embedding, Module, Sequential, and previous components
2. Create token and position embeddings
3. Stack n_layer transformer blocks using Sequential
4. Create final layer normalization
5. Implement forward pass: embeddings -> blocks -> layer norm

Run: pixi run s10
"""

# TODO: Import required modules
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need Embedding, Module, Sequential from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import LayerNorm from solutions.solution_08
# Hint: Import GPT2Block from solutions.solution_09


class GPT2Model(Module):
    """Complete GPT-2 transformer model."""

    def __init__(self, config: GPT2Config):
        """Initialize GPT-2 model.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        # TODO: Create token embeddings
        # Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
        self.wte = None

        # TODO: Create position embeddings
        # Hint: Use Embedding(config.n_positions, dim=config.n_embd)
        self.wpe = None

        # TODO: Stack transformer blocks
        # Hint: Use Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
        # This creates config.n_layer blocks (12 for GPT-2 base)
        self.h = None

        # TODO: Create final layer normalization
        # Hint: Use LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.ln_f = None

    def __call__(self, input_ids):
        """Forward pass through the transformer.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Hidden states, shape [batch, seq_length, n_embd]
        """
        # TODO: Get batch size and sequence length
        # Hint: batch_size, seq_length = input_ids.shape
        pass

        # TODO: Get token embeddings
        # Hint: tok_embeds = self.wte(input_ids)
        pass

        # TODO: Get position embeddings
        # Hint: Create position indices with Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
        # Hint: pos_embeds = self.wpe(position_indices)
        pass

        # TODO: Combine embeddings
        # Hint: x = tok_embeds + pos_embeds
        pass

        # TODO: Apply transformer blocks
        # Hint: x = self.h(x)
        pass

        # TODO: Apply final layer norm
        # Hint: x = self.ln_f(x)
        pass

        # TODO: Return the output
        return None

Validation

Run pixi run s10 to verify your implementation.

Show solution
"""
Solution for Step 10: Stacking Transformer Blocks

This module stacks multiple transformer blocks and adds embeddings
to create the complete GPT-2 transformer architecture.
"""

from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding, Module, Sequential

from solutions.solution_01 import GPT2Config
from solutions.solution_08 import LayerNorm
from solutions.solution_09 import GPT2Block


class GPT2Model(Module):
    """Complete GPT-2 transformer model matching HuggingFace structure.

    Architecture:
    1. Token embeddings + position embeddings
    2. Stack of n_layer transformer blocks
    3. Final layer normalization
    """

    def __init__(self, config: GPT2Config):
        """Initialize GPT-2 model.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        # Token embeddings (vocabulary -> embeddings)
        self.wte = Embedding(config.vocab_size, dim=config.n_embd)
        # Position embeddings (positions -> embeddings)
        self.wpe = Embedding(config.n_positions, dim=config.n_embd)

        # Stack of transformer blocks
        self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))

        # Final layer normalization
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

    def __call__(self, input_ids):
        """Forward pass through the transformer.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Hidden states, shape [batch, seq_length, n_embd]
        """
        batch_size, seq_length = input_ids.shape

        # Get token embeddings
        tok_embeds = self.wte(input_ids)

        # Get position embeddings
        pos_embeds = self.wpe(
            Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
        )

        # Combine embeddings
        x = tok_embeds + pos_embeds

        # Apply transformer blocks
        x = self.h(x)

        # Final layer norm
        x = self.ln_f(x)

        return x

Next: In Step 11, you’ll add the language modeling head that projects hidden states to vocabulary logits for text generation.