Stacking transformer blocks

Learn to stack 12 transformer blocks with embeddings and final normalization to create the complete GPT-2 model.

In this step, you’ll create the body of the GPT-2 model (the MaxGPT2Model module) as a sequence of transformer blocks (GPT2Block) plus LayerNorm. And because the model body receives raw token IDs during inference, you’ll also have to first convert the token IDs into token embeddings that are suitable for processing by the transformer blocks and the rest of the neural network.

The model processes input in four stages: convert token IDs to embeddings, add position information, pass through 12 transformer blocks sequentially, and normalize the final output. Each transformer block refines the representation, building up from surface patterns in early layers to semantic understanding in later layers.

GPT-2 uses 12 layers because this depth allows the model to learn complex patterns while remaining trainable. Fewer layers would limit the model’s capacity. More layers would increase training difficulty without proportional gains in quality for a 117M parameter model.

Understanding the components

The complete model has four main components:

Token embeddings (wte): Maps each token ID to a 768-dimensional vector using a lookup table with 50,257 entries (one per vocabulary token).

Position embeddings (wpe): Maps each position (0 to 1,023) to a 768-dimensional vector. These are added to token embeddings so the model knows token order.

Transformer blocks (h): 12 identical blocks stacked using MAX’s Sequential module. Sequential applies blocks in order, passing each block’s output to the next.

Final layer norm (ln_f): Normalizes the output after all blocks. This stabilizes the representation before the language model head (added in Step 11) projects to vocabulary logits.

Understanding the forward pass

The forward method processes token IDs through the model:

First, create position indices using Tensor.arange. Generate positions [0, 1, 2, …, seq_length-1] matching the input’s dtype and device. This ensures compatibility when adding to embeddings.

Next, look up embeddings. Get token embeddings with self.wte(input_ids) and position embeddings with self.wpe(position_indices). Add them together element-wise, as both are shape [batch, seq_length, 768].

Then, pass through the transformer blocks with self.h(x). The Sequential module applies all 12 transformer blocks in order, each refining the representation.

Finally, normalize the output with self.ln_f(x) and return the result. The output shape matches the input: [batch, seq_length, 768].

MAX operations

You’ll use the following MAX operations to complete this task:

Module composition:

Embeddings:

Position generation:

Implementing the model

You’ll create the MaxGPT2Model class by composing embedding layers, transformer blocks, and layer normalization. The class builds on all the components from previous steps.

First, import the required modules. You’ll need Tensor for position indices, Embedding, Module, and Sequential from MAX’s neural network module, plus the previously implemented GPT2Config, LayerNorm, and GPT2Block.

In the __init__ method, create the four components:

  • Token embeddings: Embedding(config.vocab_size, dim=config.n_embd) stored as self.wte
  • Position embeddings: Embedding(config.n_positions, dim=config.n_embd) stored as self.wpe
  • Transformer blocks: Sequential(*(GPT2Block(config) for _ in range(config.n_layer))) stored as self.h
  • Final layer norm: LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) stored as self.ln_f

The Sequential module takes a generator expression that creates 12 identical GPT2Block instances. The * unpacks them as arguments to Sequential.

In the forward method, implement the four-stage processing:

  1. Get the sequence length from input_ids.shape
  2. Create position indices: Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
  3. Look up embeddings and add them: x = self.wte(input_ids) + self.wpe(position_indices)
  4. Apply transformer blocks: x = self.h(x)
  5. Apply final normalization: x = self.ln_f(x)
  6. Return x

The position indices must match the input’s dtype and device to ensure the tensors are compatible for addition.

Implementation (step_07.py):

# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Step 07: Stacking Transformer Blocks

Stack multiple transformer blocks with embeddings to create
the complete GPT-2 model architecture.

Tasks:
1. Import Tensor, Embedding, Module, Sequential, and previous components
2. Create token and position embeddings
3. Stack n_layer transformer blocks using Sequential
4. Create final layer normalization
5. Implement forward pass: embeddings -> blocks -> layer norm

Run: pixi run s10
"""

# TODO: Import required modules
# Hint: You'll need Tensor from max.tensor
# Hint: You'll need Embedding, Module, Sequential from max.nn

from max.tensor import Tensor
from step_01 import GPT2Config


class MaxGPT2Model(Module):
    """Complete GPT-2 transformer model."""

    def __init__(self, config: GPT2Config) -> None:
        """Initialize GPT-2 model.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        # TODO: Create token embeddings
        # Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
        self.wte = None

        # TODO: Create position embeddings
        # Hint: Use Embedding(config.n_positions, dim=config.n_embd)
        self.wpe = None

        # TODO: Stack transformer blocks
        # Hint: Use Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
        # This creates config.n_layer blocks (12 for GPT-2 base)
        self.h = None

        # TODO: Create final layer normalization
        # Hint: Use LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.ln_f = None

    def forward(self, input_ids: Tensor) -> Tensor:
        """Forward pass through the transformer.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Hidden states, shape [batch, seq_length, n_embd]
        """
        # TODO: Get batch size and sequence length
        # Hint: batch_size, seq_length = input_ids.shape
        pass

        # TODO: Get token embeddings
        # Hint: tok_embeds = self.wte(input_ids)
        pass

        # TODO: Get position embeddings
        # Hint: Create position indices with Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
        # Hint: pos_embeds = self.wpe(position_indices)
        pass

        # TODO: Combine embeddings
        # Hint: x = tok_embeds + pos_embeds
        pass

        # TODO: Apply transformer blocks
        # Hint: x = self.h(x)
        pass

        # TODO: Apply final layer norm
        # Hint: x = self.ln_f(x)
        pass

        # TODO: Return the output
        return None

Validation

Run pixi run s07 to verify your implementation.

Show solution
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Solution for Step 07: Stacking Transformer Blocks

This module stacks multiple transformer blocks and adds embeddings
to create the complete GPT-2 transformer architecture.
"""

from max.nn import Embedding, Module, Sequential
from max.tensor import Tensor
from step_01 import GPT2Config
from step_05 import LayerNorm
from step_06 import GPT2Block


class MaxGPT2Model(Module):
    """Complete GPT-2 transformer model matching HuggingFace structure.

    Architecture:
    1. Token embeddings + position embeddings
    2. Stack of n_layer transformer blocks
    3. Final layer normalization
    """

    def __init__(self, config: GPT2Config) -> None:
        """Initialize GPT-2 model.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        # Token embeddings (vocabulary -> embeddings)
        self.wte = Embedding(config.vocab_size, dim=config.n_embd)
        # Position embeddings (positions -> embeddings)
        self.wpe = Embedding(config.n_positions, dim=config.n_embd)

        # Stack of transformer blocks
        self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))

        # Final layer normalization
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

    def forward(self, input_ids: Tensor) -> Tensor:
        """Forward pass through the transformer.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Hidden states, shape [batch, seq_length, n_embd]
        """
        _, seq_length = input_ids.shape

        # Get token embeddings
        tok_embeds = self.wte(input_ids)

        # Get position embeddings
        pos_embeds = self.wpe(
            Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
        )

        # Combine embeddings
        x = tok_embeds + pos_embeds

        # Apply transformer blocks
        x = self.h(x)

        # Final layer norm
        x = self.ln_f(x)

        return x

Next: In Step 08, you’ll add the language modeling head that projects hidden states to vocabulary logits for text generation.