Build an LLM from scratch in MAX

Experimental APIs: The APIs in the experimental package are subject to change. Share feedback on the MAX LLMs forum.

Transformer models power today’s most impactful AI applications, from language models like ChatGPT to code generation tools like GitHub Copilot. Maybe you’ve been asked to adapt one of these models for your team, or you want to understand what’s actually happening when you call an inference API. Either way, building a transformer from scratch is one of the best ways to truly understand how they work.

This guide walks you through implementing GPT-2 using Modular’s MAX framework experimental API. You’ll build each component yourself: embeddings, attention mechanisms, and feed-forward layers. You’ll see how they fit together into a complete language model by completing the sequential coding challenges in the tutorial GitHub repository.

This API is unstable: This tutorial is built on the MAX Experimental API, which we expect to change over time and expand to include new features and functionality. As it evolves, we plan to update the tutorial accordingly. When this API comes out of experimental development, the tutorial content will also enter a more stable state. While in development, this tutorial will be pinned to a major release version.

Why GPT-2?

It’s the architectural foundation for modern language models. LLaMA, Mistral, GPT-4; they’re all built on the same core components you’ll implement here:

  • multi-head attention
  • feed-forward layers
  • layer normalization
  • residual connections

Modern variants add refinements like grouped-query attention or mixture of experts, but the fundamentals remain the same. GPT-2 is complex enough to teach real transformer architecture but simple enough to implement completely and understand deeply. When you grasp how its pieces fit together, you understand how to build any transformer-based model.

Learning by building: This tutorial follows a format popularized by Andrej Karpathy’s educational work and Sebastian Raschka’s hands-on approach. Rather than abstract theory, you’ll implement each component yourself, building intuition through practice.

Why MAX?

Traditional ML development often feels like stitching together tools that weren’t designed to work together. Maybe you write your model in PyTorch, optimize in CUDA, convert to ONNX for deployment, then use separate serving tools. Each handoff introduces complexity.

MAX Framework takes a different approach: everything happens in one unified system. You write code to define your model, load weights, and run inference, all in MAX’s Python API. The MAX Platform handles optimization automatically and you can even use MAX Serve to manage your deployment.

When you build GPT-2 in this guide, you’ll load pretrained weights from Hugging Face, implement the architecture, and run text generation, all in the same environment.

Why coding challenges?

This tutorial emphasizes active problem-solving over passive reading. Each step presents a focused implementation task with:

  1. Clear context: What you’re building and why it matters
  2. Guided implementation: Code structure with specific tasks to complete
  3. Immediate validation: Tests that verify correctness before moving forward
  4. Conceptual grounding: Explanations that connect code to architecture

Rather than presenting complete solutions, this approach helps you develop intuition for when and why to use specific patterns. The skills you build extend beyond GPT-2 to model development more broadly.

You can work through the tutorial sequentially for comprehensive understanding, or skip directly to topics you need. Each step is self-contained enough to be useful independently while building toward a complete implementation.

What you’ll build

This tutorial guides you through building GPT-2 in manageable steps:

StepComponentWhat you’ll learn
1Model configurationDefine architecture hyperparameters matching HuggingFace GPT-2.
2Causal maskingCreate attention masks to prevent looking at future tokens.
3Layer normalizationStabilize activations for effective training.
4GPT-2 MLP (feed-forward network)Build the position-wise feed-forward network with GELU activation.
5Token embeddingsConvert token IDs to continuous vector representations.
6Position embeddingsEncode sequence order information.
7Multi-head attentionExtend to multiple parallel attention heads.
8Residual connections & layer normEnable training deep networks with skip connections.
9Transformer blockCombine attention and MLP into the core building block.
10Stacking transformer blocksCreate the complete 12-layer GPT-2 model.
11Language model headProject hidden states to vocabulary logits.
12Text generationGenerate text autoregressively with temperature sampling.
00Serve your modelRun GPT-2 as an endpoint with MAX Serve…. coming soon.

By the end, you’ll have a complete GPT-2 implementation and practical experience with MAX’s Python API. These are skills you can immediately apply to your own projects.

To install the puzzles, follow the steps in Setup.

How this book works

Each step includes automated tests that verify your implementation before moving forward. This immediate feedback helps you catch issues early and build confidence.

You’ll first need to clone the GitHub repository and navigate to the repository:

git clone https://github.com/modular/max-llm-book
cd max-llm-book

Then download and install pixi:

curl -fsSL https://pixi.sh/install.sh | sh

To validate a step, use the corresponding test command. For example, to test Step 01:

pixi run s01

Initially, tests will fail because the implementation isn’t complete:

✨ Pixi task (s01): python tests/test.step_01.py
Running tests for Step 01: Create Model Configuration...

Results:
❌ dataclass is not imported from dataclasses
❌ GPT2Config does not have the @dataclass decorator
❌ vocab_size is incorrect: expected match with Hugging Face model configuration, got None
# ...

Each failure tells you exactly what to implement.

When your implementation is correct, you’ll see:

✨ Pixi task (s01): python tests/test.step_01.py                                                                         
Running tests for Step 01: Create Model Configuration...

Results:
✅ dataclass is correctly imported from dataclasses
✅ GPT2Config has the @dataclass decorator
✅ vocab_size is correct
# ...

The test output tells you exactly what needs to be fixed, making it easy to iterate until your implementation is correct. Once all checks pass, you’re ready to move on to the next step.

Prerequisites

This tutorial assumes:

  • Basic Python knowledge: Classes, functions, type hints
  • Familiarity with neural networks: What embeddings and layers do (we’ll explain the specifics)
  • Interest in understanding: Curiosity matters more than prior transformer experience

Whether you’re exploring MAX for the first time or deepening your understanding of model architecture, this tutorial provides hands-on experience you can apply to current projects and learning priorities.

Ready to build? Let’s get started with Step 01: Model configuration.

Step 01: Model configuration

Learn to define the GPT-2 model architecture parameters using configuration classes.

Defining the model architecture

Before you can implement GPT-2, you need to define its architecture: the dimensions, layer counts, and structural parameters that determine how the model processes information.

In this step, you’ll create GPT2Config, a class that holds all the architectural decisions for GPT-2. This class describes things like: embedding dimensions, number of transformer layers, and number of attention heads. These parameters define the shape and capacity of your model.

OpenAI trained the original GPT-2 model with specific parameters that you can see in the config.json file on Hugging Face. By using the exact same values, we can access OpenAI’s pretrained weights in subsequent steps.

Understanding the parameters

Looking at the config.json file file, we can see some key information about the model. Each parameter controls a different aspect of the model’s architecture:

  • vocab_size: Size of the token vocabulary (default: 50,257). This seemingly odd number is actually 50,000 Byte Pair Encoding (BPE) tokens + 256 byte-level tokens (fallback for rare characters) + 1 special token.
  • n_positions: Maximum sequence length, also called the context window (default: 1,024). Longer sequences require quadratic memory in attention.
  • n_embd: Embedding dimension, or the size of the hidden states that flow through the model (default: 768). This determines the model’s capacity to represent information.
  • n_layer: Number of transformer blocks stacked vertically (default: 12). More layers allow the model to learn more complex patterns.
  • n_head: Number of attention heads per layer (default: 12). Multiple heads let the model attend to different types of patterns simultaneously.
  • n_inner: Dimension of the MLP intermediate layer (default: 3,072). This is 4x the embedding dimension, a ratio found empirically in the Attention is all you need paper to work well.
  • layer_norm_epsilon: Small constant for numerical stability in layer normalization (default: 1e-5). This prevents division by zero when variance is very small.

These values define the small GPT-2 model. OpenAI released four sizes (small, medium, large, XL), each with different configurations that scale up these parameters. For implementation purposes we will use these parameters.

Implementing the configuration

Now let’s implement this yourself. You’ll create the GPT2Config class using Python’s @dataclass decorator. Dataclasses reduce boilerplate.

Instead of writing __init__ and defining each parameter manually, you just declare the fields with type hints and default values.

First, you’ll need to import the dataclass decorator from the dataclasses module. Then you’ll add the @dataclass decorator to the GPT2Config class definition.

The actual parameter values come from Hugging Face. You can get them in two ways:

  • Option 1: Run pixi run huggingface to access these parameters programmatically from the Hugging Face transformers library.
  • Option 2: Read the values directly from the GPT-2 model card.

Once you have the values, replace each None in the GPT2Config class properties with the correct numbers from the configuration.

Implementation (step_01.py):

"""
Step 01: Model Configuration

Implement the GPT-2 configuration dataclass that stores model hyperparameters.

Tasks:
1. Import dataclass from the dataclasses module
2. Add the @dataclass decorator to the GPT2Config class
3. Fill in the configuration values from HuggingFace GPT-2 model

Run: pixi run s01
"""

# 1. Import dataclass from the dataclasses module

# 2. Add the Python @dataclass decorator to the GPT2Config class


class GPT2Config:
    """GPT-2 configuration matching HuggingFace.

    Attributes:
        vocab_size: Size of the vocabulary.
        n_positions: Maximum sequence length.
        n_embd: Embedding dimension.
        n_layer: Number of transformer layers.
        n_head: Number of attention heads.
        n_inner: Inner dimension of feed-forward network (defaults to 4 * n_embd if None).
        layer_norm_epsilon: Epsilon for layer normalization.
    """

    # 3a. Run `pixi run huggingface` to access the model parameters from the Hugging Face `transformers` library
    # 3b. Alternately, read the values from GPT-2 model card: https://huggingface.co/openai-community/gpt2/blob/main/config.json
    # 4. Replace the None of the GPT2Config properties with the correct values
    vocab_size: int = None
    n_positions: int = None
    n_embd: int = None
    n_layer: int = None
    n_head: int = None
    n_inner: int = None  # Equal to 4 * n_embd
    layer_norm_epsilon: float = None

Validation

Run pixi run s01 to verify your implementation matches the expected configuration.

Show solution
"""
Solution for Step 01: Model Configuration

This module implements the GPT-2 configuration dataclass that stores
hyperparameters matching HuggingFace's GPT-2 model structure.
"""

from dataclasses import dataclass


@dataclass
class GPT2Config:
    """GPT-2 configuration matching HuggingFace.

    Attributes:
        vocab_size: Size of the vocabulary.
        n_positions: Maximum sequence length.
        n_embd: Embedding dimension.
        n_layer: Number of transformer layers.
        n_head: Number of attention heads.
        n_inner: Inner dimension of feed-forward network (defaults to 4 * n_embd if None).
        layer_norm_epsilon: Epsilon for layer normalization.
    """

    vocab_size: int = 50257
    n_positions: int = 1024
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12
    n_inner: int = 3072
    layer_norm_epsilon: float = 1e-5

Next: In Step 02, you’ll implement causal masking to prevent tokens from attending to future positions in autoregressive generation.

Step 02: Causal masking

Learn to create attention masks to prevent the model from seeing future tokens during autoregressive generation.

Implementing causal masking

In this step you’ll implement the causal_mask() function. This creates a mask matrix that prevents the model from seeing future tokens when predicting the next token. The mask sets attention scores to negative infinity (-inf) for future positions. After softmax, these -inf values become zero probability, blocking information flow from later tokens.

GPT-2 generates text one token at a time, left-to-right. During training, causal masking prevents the model from “cheating” by looking ahead at tokens it should be predicting. Without this mask, the model would have access to information it won’t have during actual text generation.

Understanding the mask pattern

The mask creates a lower triangular pattern where each token can only attend to itself and previous tokens:

  • Position 0 attends to: position 0 only
  • Position 1 attends to: positions 0-1
  • Position 2 attends to: positions 0-2
  • And so on…

The mask shape is (sequence_length, sequence_length + num_tokens). This shape is designed for KV cache compatibility during generation. The KV cache stores key and value tensors from previously generated tokens, so you only need to compute attention for new tokens while attending to both new tokens (sequence_length) and cached tokens (num_tokens). This significantly speeds up generation by avoiding recomputation.

MAX operations

You’ll use the following MAX operations to complete this task:

Functional decorator:

  • @F.functional: Converts functions to graph operations for MAX compilation

Tensor operations:

Implementing the mask

You’ll create the causal mask in several steps:

  1. Import required modules:

    • Device from max.driver - specifies hardware device (CPU/GPU)
    • DType from max.dtype - data type specification
    • functional as F from max.experimental - functional operations library
    • Tensor from max.experimental.tensor - tensor operations
    • Dim from graph.dim - dimension handling
  2. Add @F.functional decorator: This converts the function to a MAX graph operation.

  3. Calculate total sequence length: Combine sequence_length and num_tokens using Dim() to determine mask width.

  4. Create constant tensor: Use Tensor.constant(float("-inf"), dtype=dtype, device=device) to create a scalar that will be broadcast.

  5. Broadcast to target shape: Use F.broadcast_to(mask, shape=(sequence_length, n)) to expand the scalar to a 2D matrix.

  6. Apply band part: Use F.band_part(mask, num_lower=None, num_upper=0, exclude=True) to create the lower triangular pattern. This keeps 0s on and below the diagonal, -inf above.

Implementation (step_02.py):

"""
Step 02: Causal Masking

Implement causal attention masking that prevents tokens from attending to future positions.

Tasks:
1. Import functional module (as F) and Tensor from max.experimental
2. Add @F.functional decorator to the causal_mask function
3. Create a constant tensor filled with negative infinity
4. Broadcast the mask to the correct shape (sequence_length, n)
5. Apply band_part to create the lower triangular causal structure

Run: pixi run s02
"""

# 1: Import the required modules from MAX
from max.driver import Device
from max.dtype import DType
# TODO: Import necessary funcional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional

# TODO: Import Tensor object from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor

from max.graph import Dim, DimLike

# 2: Add the @F.functional decorator to make this a MAX functional operation
# TODO: Add the decorator here


def causal_mask(
    sequence_length: DimLike,
    num_tokens: DimLike,
    *,
    dtype: DType,
    device: Device,
):
    """Create a causal mask for autoregressive attention.

    Args:
        sequence_length: Length of the sequence.
        num_tokens: Number of tokens.
        dtype: Data type for the mask.
        device: Device to create the mask on.

    Returns:
        A causal mask tensor.
    """
    # Calculate total sequence length
    n = Dim(sequence_length) + num_tokens

    # 3: Create a constant tensor filled with negative infinity
    # TODO: Use Tensor.constant() with float("-inf"), dtype, and device parameters
    # https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.constant
    # Hint: This creates the base mask value that will block attention to future tokens
    mask = None

    # 4: Broadcast the mask to the correct shape
    # TODO: Use F.broadcast_to() to expand mask to shape (sequence_length, n)
    # https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.broadcast_to
    # Hint: This creates a 2D attention mask matrix
    mask = None

    # 5: Apply band_part to create the causal (lower triangular) structure and return the mask
    # TODO: Use F.band_part() with num_lower=None, num_upper=0, exclude=True
    # https://docs.modular.com/max/api/python/experimental/functional/#max.experimental.functional.band_part
    # Hint: This keeps only the lower triangle, allowing attention to past tokens only
    return None

Validation

Run pixi run s02 to verify your implementation.

Show solution
"""
Solution for Step 02: Causal Masking

This module implements causal attention masking that prevents tokens from
attending to future positions in autoregressive generation.
"""

from max.driver import Device
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import Dim, DimLike


@F.functional
def causal_mask(
    sequence_length: DimLike,
    num_tokens: DimLike,
    *,
    dtype: DType,
    device: Device,
):
    """Create a causal mask for autoregressive attention.

    Args:
        sequence_length: Length of the sequence.
        num_tokens: Number of tokens.
        dtype: Data type for the mask.
        device: Device to create the mask on.

    Returns:
        A causal mask tensor.
    """
    n = Dim(sequence_length) + num_tokens
    mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
    mask = F.broadcast_to(mask, shape=(sequence_length, n))
    return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)

Next: In Step 03, you’ll implement layer normalization to stabilize activations for effective training.

Step 03: Layer normalization

Learn to implement layer normalization for stabilizing neural network training.

Building layer normalization

In this step, you’ll create the LayerNorm class that normalizes activations across the feature dimension. For each input, you compute the mean and variance across all features, normalize by subtracting the mean and dividing by the standard deviation, then apply learned weight and bias parameters to scale and shift the result.

Unlike batch normalization, layer normalization works independently for each example. This makes it ideal for transformers - no dependence on batch size, no tracking running statistics during inference, and consistent behavior between training and generation.

GPT-2 applies layer normalization before the attention and MLP blocks in each of its 12 transformer layers. This pre-normalization pattern stabilizes training in deep networks by keeping activations in a consistent range.

Understanding the operation

Layer normalization normalizes across the feature dimension (the last dimension) independently for each example. It learns two parameters per feature: weight (gamma) for scaling and bias (beta) for shifting.

The normalization follows this formula:

output = weight * (x - mean) / sqrt(variance + epsilon) + bias

The mean and variance are computed across all features in each example. After normalizing to zero mean and unit variance, the learned weight scales the result and the learned bias shifts it. The epsilon value (typically 1e-5) prevents division by zero when variance is very small.

MAX operations

You’ll use the following MAX operations to complete this task:

Modules:

  • Module: The Module class used for eager tensors

Tensor initialization:

Layer normalization:

  • F.layer_norm(): Applies layer normalization with parameters: input, gamma (weight), beta (bias), and epsilon

Implementing layer normalization

You’ll create the LayerNorm class that wraps MAX’s layer normalization function with learnable parameters. The implementation is straightforward - two parameters and a single function call.

First, import the required modules. You’ll need functional as F for the layer norm operation and Tensor for creating parameters.

In the __init__ method, create two learnable parameters:

  • Weight: Tensor.ones([dim]) stored as self.weight - initialized to ones so the initial transformation is identity
  • Bias: Tensor.zeros([dim]) stored as self.bias - initialized to zeros so there’s no initial shift

Store the epsilon value as self.eps for numerical stability.

In the forward method, apply layer normalization with F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). This computes the normalization and applies the learned parameters in one operation.

Implementation (step_03.py):

"""
Step 03: Layer Normalization

Implement layer normalization that normalizes activations for training stability.

Tasks:
1. Import functional module (as F) and Tensor from max.experimental
2. Initialize learnable weight (gamma) and bias (beta) parameters
3. Apply layer normalization using F.layer_norm in the forward pass

Run: pixi run s03
"""

# 1: Import the required modules from MAX
# TODO: Import functional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional

# TODO: Import Tensor from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor

from max.graph import DimLike
from max.nn.module_v3 import Module


class LayerNorm(Module):
    """Layer normalization module.

    Args:
        dim: Dimension to normalize over.
        eps: Epsilon for numerical stability.
    """

    def __init__(self, dim: DimLike, *, eps: float = 1e-5):
        super().__init__()
        self.eps = eps

        # 2: Initialize learnable weight and bias parameters
        # TODO: Create self.weight as a Tensor of ones with shape [dim]
        # https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.ones
        # Hint: This is the gamma parameter in layer normalization
        self.weight = None

        # TODO: Create self.bias as a Tensor of zeros with shape [dim]
        # https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.zeros
        # Hint: This is the beta parameter in layer normalization
        self.bias = None

    def __call__(self, x: Tensor) -> Tensor:
        """Apply layer normalization.

        Args:
            x: Input tensor.

        Returns:
            Normalized tensor.
        """
        # 3: Apply layer normalization and return the result
        # TODO: Use F.layer_norm() with x, gamma=self.weight, beta=self.bias, epsilon=self.eps
        # https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.layer_norm
        # Hint: Layer normalization normalizes across the last dimension
        return None

Validation

Run pixi run s03 to verify your implementation.

Show solution
"""
Solution for Step 03: Layer Normalization

This module implements layer normalization that normalizes activations
across the embedding dimension for training stability.
"""

from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import DimLike
from max.nn.module_v3 import Module


class LayerNorm(Module):
    """Layer normalization module.

    Args:
        dim: Dimension to normalize over.
        eps: Epsilon for numerical stability.
    """

    def __init__(self, dim: DimLike, *, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = Tensor.ones([dim])
        self.bias = Tensor.zeros([dim])

    def __call__(self, x: Tensor) -> Tensor:
        """Apply layer normalization.

        Args:
            x: Input tensor.

        Returns:
            Normalized tensor.
        """
        return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)

Next: In Step 04, you’ll implement the feed-forward network (MLP) with GELU activation used in each transformer block.

Step 04: Feed-forward network

Learn to build the feed-forward network (MLP) that processes information after attention in each transformer block.

Building the MLP

In this step, you’ll create the GPT2MLP class: a two-layer feed-forward network that appears after attention in every transformer block. The MLP expands the embedding dimension by 4× (768 → 3,072), applies GELU activation for non-linearity, then projects back to the original dimension.

While attention lets tokens communicate with each other, the MLP processes each position independently. Attention aggregates information through weighted sums (linear operations), but the MLP adds non-linearity through GELU activation. This combination allows the model to learn complex patterns beyond what linear transformations alone can capture.

GPT-2 uses a 4× expansion ratio (768 to 3,072 dimensions) because this was found to work well in the original Transformer paper and has been validated across many architectures since.

Understanding the components

The MLP has three steps applied in sequence:

Expansion layer (c_fc): Projects from 768 to 3,072 dimensions using a linear layer. This expansion gives the network more capacity to process information.

GELU activation: Applies Gaussian Error Linear Unit, a smooth non-linear function. GPT-2 uses approximate="tanh" for the tanh-based approximation instead of the exact computation. This approximation was faster when GPT-2 was implemented, but while exact GELU is fast enough now, we use the approximation to match the original weights.

Projection layer (c_proj): Projects back from 3,072 to 768 dimensions using another linear layer. This returns to the embedding dimension so outputs can be added to residual connections.

The layer names c_fc (fully connected) and c_proj (projection) match Hugging Face’s GPT-2 checkpoint structure. This naming is essential for loading pretrained weights.

MAX operations

You’ll use the following MAX operations to complete this task:

Linear layers:

GELU activation:

Implementing the MLP

You’ll create the GPT2MLP class that chains two linear layers with GELU activation between them. The implementation is straightforward - three operations applied in sequence.

First, import the required modules. You’ll need functional as F for the GELU activation, Tensor for type hints, Linear for the layers, and Module as the base class.

In the __init__ method, create two linear layers:

  • Expansion layer: Linear(embed_dim, intermediate_size, bias=True) stored as self.c_fc
  • Projection layer: Linear(intermediate_size, embed_dim, bias=True) stored as self.c_proj

Both layers include bias terms (bias=True). The intermediate size is typically 4× the embedding dimension.

In the forward method, apply the three transformations:

  1. Expand: hidden_states = self.c_fc(hidden_states)
  2. Activate: hidden_states = F.gelu(hidden_states, approximate="tanh")
  3. Project: hidden_states = self.c_proj(hidden_states)

Return the final hidden_states. The input and output shapes are the same: [batch, seq_length, embed_dim].

Implementation (step_04.py):

"""
Step 04: Feed-forward Network (MLP)

Implement the MLP used in each transformer block with GELU activation.

Tasks:
1. Import functional (as F), Tensor, Linear, and Module from MAX
2. Create c_fc linear layer (embedding to intermediate dimension)
3. Create c_proj linear layer (intermediate back to embedding dimension)
4. Apply c_fc transformation in forward pass
5. Apply GELU activation function
6. Apply c_proj transformation and return result

Run: pixi run s04
"""

# 1: Import the required modules from MAX
# TODO: Import functional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional

# TODO: Import Tensor from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor

# TODO: Import Linear and Module from max.nn.module_v3
# https://docs.modular.com/max/api/python/nn/module_v3

from solutions.solution_01 import GPT2Config


class GPT2MLP(Module):
    """Feed-forward network matching HuggingFace GPT-2 structure.

    Args:
        intermediate_size: Size of the intermediate layer.
        config: GPT-2 configuration.
    """

    def __init__(self, intermediate_size: int, config: GPT2Config):
        super().__init__()
        embed_dim = config.n_embd

        # 2: Create the first linear layer (embedding to intermediate)
        # TODO: Create self.c_fc as a Linear layer from embed_dim to intermediate_size with bias=True
        # https://docs.modular.com/max/api/python/nn/module_v3#max.nn.module_v3.Linear
        # Hint: This is the expansion layer in the MLP
        self.c_fc = None

        # 3: Create the second linear layer (intermediate back to embedding)
        # TODO: Create self.c_proj as a Linear layer from intermediate_size to embed_dim with bias=True
        # https://docs.modular.com/max/api/python/nn/module_v3#max.nn.module_v3.Linear
        # Hint: This is the projection layer that brings us back to the embedding dimension
        self.c_proj = None

    def __call__(self, hidden_states: Tensor) -> Tensor:
        """Apply feed-forward network.

        Args:
            hidden_states: Input hidden states.

        Returns:
            MLP output.
        """
        # 4: Apply the first linear transformation
        # TODO: Apply self.c_fc to hidden_states
        # Hint: This expands the hidden dimension to the intermediate size
        hidden_states = None

        # 5: Apply GELU activation function
        # TODO: Use F.gelu() with hidden_states and approximate="tanh"
        # https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.gelu
        # Hint: GELU is the non-linear activation used in GPT-2's MLP
        hidden_states = None

        # 6: Apply the second linear transformation and return
        # TODO: Apply self.c_proj to hidden_states and return the result
        # Hint: This projects back to the embedding dimension
        return None

Validation

Run pixi run s04 to verify your implementation.

Show solution
"""
Solution for Step 04: Feed-forward Network (MLP)

This module implements the feed-forward network (MLP) used in each
transformer block with GELU activation.
"""

from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Linear, Module

from solutions.solution_01 import GPT2Config


class GPT2MLP(Module):
    """Feed-forward network matching HuggingFace GPT-2 structure.

    Args:
        intermediate_size: Size of the intermediate layer.
        config: GPT-2 configuration.
    """

    def __init__(self, intermediate_size: int, config: GPT2Config):
        super().__init__()
        embed_dim = config.n_embd
        self.c_fc = Linear(embed_dim, intermediate_size, bias=True)
        self.c_proj = Linear(intermediate_size, embed_dim, bias=True)

    def __call__(self, hidden_states: Tensor) -> Tensor:
        """Apply feed-forward network.

        Args:
            hidden_states: Input hidden states.

        Returns:
            MLP output.
        """
        hidden_states = self.c_fc(hidden_states)
        hidden_states = F.gelu(hidden_states, approximate="tanh")
        hidden_states = self.c_proj(hidden_states)
        return hidden_states

Next: In Step 05, you’ll implement token embeddings to convert discrete token IDs into continuous vector representations.

Step 05: Token embeddings

Learn to create token embeddings that convert discrete token IDs into continuous vector representations.

Implementing token embeddings

In this step you’ll create the Embedding class. This converts discrete token IDs (integers) into continuous vector representations that the model can process. The embedding layer is a lookup table with shape [50257, 768] where 50257 is GPT-2’s vocabulary size and 768 is the embedding dimension.

Neural networks operate on continuous values, not discrete symbols. Token embeddings convert discrete token IDs into dense vectors that can be processed by matrix operations. During training, these embeddings naturally cluster semantically similar words closer together in vector space.

Understanding embeddings

The embedding layer stores one vector per vocabulary token. When you pass in token ID 1000, it returns row 1000 as the embedding vector. The layer name wte stands for “word token embeddings” and matches the naming in the original GPT-2 code for weight loading compatibility.

Key parameters:

  • Vocabulary size: 50,257 tokens (byte-pair encoding)
  • Embedding dimension: 768 for GPT-2 base
  • Shape: [vocab_size, embedding_dim]
MAX operations

You’ll use the following MAX operations to complete this task:

Embedding layer:

Implementing the class

You’ll implement the Embedding class in several steps:

  1. Import required modules: Import Embedding and Module from MAX libraries.

  2. Create embedding layer: Use Embedding(config.vocab_size, dim=config.n_embd) and store in self.wte.

  3. Implement forward pass: Call self.wte(input_ids) to lookup embeddings. Input shape: [batch_size, seq_length]. Output shape: [batch_size, seq_length, n_embd].

Implementation (step_05.py):

"""
Step 05: Token Embeddings

Implement token embeddings that convert discrete token IDs into continuous vectors.

Tasks:
1. Import Embedding and Module from max.nn.module_v3
2. Create token embedding layer using Embedding(vocab_size, dim=n_embd)
3. Implement forward pass that looks up embeddings for input token IDs

Run: pixi run s05
"""

# TODO: Import required modules from MAX
# Hint: You'll need Embedding and Module from max.nn.module_v3

from solutions.solution_01 import GPT2Config


class GPT2Embeddings(Module):
    """Token embeddings for GPT-2."""

    def __init__(self, config: GPT2Config):
        super().__init__()

        # TODO: Create token embedding layer
        # Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
        # This creates a lookup table that converts token IDs to embedding vectors
        self.wte = None

    def __call__(self, input_ids):
        """Convert token IDs to embeddings.

        Args:
            input_ids: Tensor of token IDs, shape [batch_size, seq_length]

        Returns:
            Token embeddings, shape [batch_size, seq_length, n_embd]
        """
        # TODO: Return the embedded tokens
        # Hint: Simply call self.wte with input_ids
        return None

Validation

Run pixi run s05 to verify your implementation.

Show solution
"""
Solution for Step 05: Token Embeddings

This module implements token embeddings that convert discrete token IDs
into continuous vector representations.
"""

from max.nn.module_v3 import Embedding, Module

from solutions.solution_01 import GPT2Config


class GPT2Embeddings(Module):
    """Token embeddings for GPT-2, matching HuggingFace structure."""

    def __init__(self, config: GPT2Config):
        """Initialize token embedding layer.

        Args:
            config: GPT2Config containing vocab_size and n_embd
        """
        super().__init__()

        # Token embedding: lookup table from vocab_size to embedding dimension
        # This converts discrete token IDs (0 to vocab_size-1) into dense vectors
        self.wte = Embedding(config.vocab_size, dim=config.n_embd)

    def __call__(self, input_ids):
        """Convert token IDs to embeddings.

        Args:
            input_ids: Tensor of token IDs, shape [batch_size, seq_length]

        Returns:
            Token embeddings, shape [batch_size, seq_length, n_embd]
        """
        # Simple lookup: each token ID becomes its corresponding embedding vector
        return self.wte(input_ids)

Next: In Step 06, you’ll implement position embeddings to encode sequence order information, which will be combined with these token embeddings.

Step 06: Position embeddings

Learn to create position embeddings that encode the order of tokens in a sequence.

Implementing position embeddings

In this step you’ll create position embeddings to encode where each token appears in the sequence. While token embeddings tell the model “what” each token is, position embeddings tell it “where” the token is located. These position vectors are added to token embeddings before entering the transformer blocks.

Transformers process all positions in parallel through attention, unlike Recurrent Neural Networks (RNNs) that process sequentially. This parallelism enables faster training but loses positional information. Position embeddings restore this information so the model can distinguish “dog bites man” from “man bites dog”.

Understanding position embeddings

Position embeddings work like token embeddings: a lookup table with shape [1024, 768] where 1024 is the maximum sequence length. Position 0 gets the first row, position 1 gets the second row, and so on.

GPT-2 uses learned position embeddings, meaning these vectors are initialized randomly and trained alongside the model. This differs from the original Transformer which used fixed sinusoidal position encodings. Learned embeddings let the model discover optimal position representations for its specific task, though they cannot generalize beyond the maximum length seen during training (1024 tokens).

Key parameters:

  • Maximum sequence length: 1,024 positions
  • Embedding dimension: 768 for GPT-2 base
  • Shape: [n_positions, n_embd]
  • Layer name: wpe (word position embeddings)
MAX operations

You’ll use the following MAX operations to complete this task:

Position indices:

Embedding layer:

Implementing the class

You’ll implement the position embeddings in several steps:

  1. Import required modules: Import Tensor, Embedding, and Module from MAX libraries.

  2. Create position embedding layer: Use Embedding(config.n_positions, dim=config.n_embd) and store in self.wpe.

  3. Implement forward pass: Call self.wpe(position_ids) to lookup position embeddings. Input shape: [seq_length] or [batch, seq_length]. Output shape: [seq_length, n_embd] or [batch, seq_length, n_embd].

Implementation (step_06.py):

"""
Step 06: Position Embeddings

Implement position embeddings that encode sequence order information.

Tasks:
1. Import Tensor from max.experimental.tensor
2. Import Embedding and Module from max.nn.module_v3
3. Create position embedding layer using Embedding(n_positions, dim=n_embd)
4. Implement forward pass that looks up embeddings for position indices

Run: pixi run s06
"""

# TODO: Import required modules from MAX
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need Embedding and Module from max.nn.module_v3

from solutions.solution_01 import GPT2Config


class GPT2PositionEmbeddings(Module):
    """Position embeddings for GPT-2."""

    def __init__(self, config: GPT2Config):
        super().__init__()

        # TODO: Create position embedding layer
        # Hint: Use Embedding(config.n_positions, dim=config.n_embd)
        # This creates a lookup table for position indices (0, 1, 2, ..., n_positions-1)
        self.wpe = None

    def __call__(self, position_ids):
        """Convert position indices to embeddings.

        Args:
            position_ids: Tensor of position indices, shape [seq_length] or [batch_size, seq_length]

        Returns:
            Position embeddings, shape matching input with added embedding dimension
        """
        # TODO: Return the position embeddings
        # Hint: Simply call self.wpe with position_ids
        return None

Validation

Run pixi run s06 to verify your implementation.

Show solution
"""
Solution for Step 06: Position Embeddings

This module implements position embeddings that encode sequence order information
into the transformer model.
"""

from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding, Module

from solutions.solution_01 import GPT2Config


class GPT2PositionEmbeddings(Module):
    """Position embeddings for GPT-2, matching HuggingFace structure."""

    def __init__(self, config: GPT2Config):
        """Initialize position embedding layer.

        Args:
            config: GPT2Config containing n_positions and n_embd
        """
        super().__init__()

        # Position embedding: lookup table from position indices to embedding vectors
        # This encodes "where" information - position 0, 1, 2, etc.
        self.wpe = Embedding(config.n_positions, dim=config.n_embd)

    def __call__(self, position_ids):
        """Convert position indices to embeddings.

        Args:
            position_ids: Tensor of position indices, shape [seq_length] or [batch_size, seq_length]

        Returns:
            Position embeddings, shape matching input with added embedding dimension
        """
        # Simple lookup: each position index becomes its corresponding embedding vector
        return self.wpe(position_ids)

Next: In Step 07, you’ll implement multi-head attention.

Step 07: Multi-head attention

Learn to use multi-head attention, enabling the model to attend to different representation subspaces.

Building multi-head attention

In this step, you’ll implement the GPT2MultiHeadAttention class that runs 12 attention operations in parallel. Instead of computing attention once over the full 768-dimensional space, you split the dimensions into 12 heads of 64 dimensions each. Each head learns to focus on different patterns.

GPT-2 uses 12 heads with 768-dimensional embeddings, giving each head 768 ÷ 12 = 64 dimensions. The Q, K, V tensors are reshaped to split the embedding dimension across heads, attention is computed for all heads in parallel, then the outputs are concatenated back together. This happens in a single efficient operation using tensor reshaping and broadcasting.

Multiple heads let the model learn complementary attention strategies. Different heads can specialize in different relationships, such as one that might attend to adjacent tokens, another to syntactic patterns, and another to semantic similarity. This increases the model’s capacity without dramatically increasing computation.

Understanding the architecture

Multi-head attention splits the embedding dimension, computes attention independently for each head, then merges the results. This requires careful tensor reshaping to organize the computation efficiently.

Head splitting: Transform from [batch, seq_length, 768] to [batch, 12, seq_length, 64]. First reshape to add the head dimension: [batch, seq_length, 12, 64]. Then transpose to move heads before sequence: [batch, 12, seq_length, 64]. Now each of the 12 heads operates independently on its 64-dimensional subspace.

Parallel attention: With shape [batch, num_heads, seq_length, head_dim], you can compute attention for all heads simultaneously. The matrix multiplication Q @ K^T operates on the last two dimensions [seq_length, head_dim] @ [head_dim, seq_length], broadcasting across the batch and head dimensions. All 12 heads computed in a single efficient operation.

Head merging: Reverse the splitting to go from [batch, 12, seq_length, 64] back to [batch, seq_length, 768]. First transpose to [batch, seq_length, 12, 64], then reshape to flatten the head dimension: [batch, seq_length, 768]. This concatenates all head outputs back into the original dimension.

Output projection (c_proj): After merging heads, apply a learned linear transformation that maps [batch, seq_length, 768] to [batch, seq_length, 768]. This lets the model mix information across heads, combining the different perspectives each head learned.

The layer names c_attn (combined Q/K/V projection) and c_proj (output projection) match Hugging Face’s GPT-2 implementation. This naming is essential for loading pretrained weights.

MAX operations

You’ll use the following MAX operations to complete this task:

Linear layers:

Tensor operations:

  • tensor.reshape(new_shape): Splits or merges head dimension
  • tensor.transpose(axis1, axis2): Rearranges dimensions for parallel attention
  • F.split(tensor, split_sizes, axis): Divides Q/K/V from combined projection

Implementing multi-head attention

You’ll create the GPT2MultiHeadAttention class with helper methods for splitting and merging heads. The implementation builds on the attention mechanism from Step 02, extending it to work with multiple heads in parallel.

First, import the required modules. You’ll need math for scaling, functional as F for operations, Tensor for type hints, device and dtype utilities, and Linear and Module from MAX’s neural network module. You’ll also reuse the causal_mask function from Step 02.

In the __init__ method, create the projection layers and store configuration:

  • Combined Q/K/V projection: Linear(embed_dim, 3 * embed_dim, bias=True) stored as self.c_attn
  • Output projection: Linear(embed_dim, embed_dim, bias=True) stored as self.c_proj
  • Store self.num_heads (12) and self.head_dim (64) from config
  • Calculate self.split_size for splitting Q, K, V later

Implement _split_heads to reshape for parallel attention:

  • Calculate new shape by replacing the last dimension: tensor.shape[:-1] + [num_heads, attn_head_size]
  • Reshape to add the head dimension: tensor.reshape(new_shape)
  • Transpose to move heads to position 1: tensor.transpose(-3, -2)
  • Returns shape [batch, num_heads, seq_length, head_size]

Implement _merge_heads to concatenate head outputs:

  • Transpose to move heads back: tensor.transpose(-3, -2)
  • Calculate flattened shape: tensor.shape[:-2] + [num_heads * attn_head_size]
  • Reshape to merge heads: tensor.reshape(new_shape)
  • Returns shape [batch, seq_length, n_embd]

Implement _attn to compute scaled dot-product attention for all heads:

  • Compute attention scores: query @ key.transpose(-2, -1)
  • Scale by square root of head dimension
  • Apply causal mask to prevent attending to future positions
  • Apply softmax to get attention weights
  • Multiply weights by values: attn_weights @ value

In the forward method, orchestrate the complete multi-head attention:

  • Project to Q/K/V: qkv = self.c_attn(hidden_states)
  • Split into separate tensors: F.split(qkv, [self.split_size, self.split_size, self.split_size], axis=-1)
  • Split heads for each: query = self._split_heads(query, self.num_heads, self.head_dim) (repeat for key, value)
  • Compute attention: attn_output = self._attn(query, key, value)
  • Merge heads: attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  • Final projection: return self.c_proj(attn_output)

Implementation (step_07.py):

"""
Step 07: Multi-head Attention

Implement multi-head attention that splits Q/K/V into multiple heads,
computes attention in parallel for each head, and merges the results.

Tasks:
1. Import required modules (math, F, Tensor, Linear, Module, etc.)
2. Create c_attn and c_proj linear layers
3. Implement _split_heads: reshape and transpose to add head dimension
4. Implement _merge_heads: transpose and reshape to remove head dimension
5. Implement _attn: compute attention for all heads in parallel
6. Implement forward pass: project -> split -> attend -> merge -> project

Run: pixi run s07
"""

# TODO: Import required modules
# Hint: You'll need math for scaling
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor, Device, DType from max.experimental.tensor and max.driver
# Hint: You'll need Dim, DimLike from max.graph
# Hint: You'll also need Linear and Module from max.nn.module_v3

from solutions.solution_01 import GPT2Config


# TODO: Copy causal_mask function from solution_02.py
# This is the same function you implemented in Step 02


class GPT2MultiHeadAttention(Module):
    """Multi-head attention for GPT-2."""

    def __init__(self, config: GPT2Config):
        super().__init__()

        self.embed_dim = config.n_embd
        self.num_heads = config.n_head
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim

        # TODO: Create combined Q/K/V projection
        # Hint: Use Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
        self.c_attn = None

        # TODO: Create output projection
        # Hint: Use Linear(self.embed_dim, self.embed_dim, bias=True)
        self.c_proj = None

    def _split_heads(self, tensor, num_heads, attn_head_size):
        """Split the last dimension into (num_heads, head_size).

        Args:
            tensor: Input tensor, shape [batch, seq_length, n_embd]
            num_heads: Number of attention heads
            attn_head_size: Dimension of each head

        Returns:
            Tensor with shape [batch, num_heads, seq_length, head_size]
        """
        # TODO: Add head dimension
        # Hint: new_shape = tensor.shape[:-1] + [num_heads, attn_head_size]
        # Hint: tensor = tensor.reshape(new_shape)
        pass

        # TODO: Move heads dimension to position 1
        # Hint: return tensor.transpose(-3, -2)
        return None

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """Merge attention heads back to original shape.

        Args:
            tensor: Input tensor, shape [batch, num_heads, seq_length, head_size]
            num_heads: Number of attention heads
            attn_head_size: Dimension of each head

        Returns:
            Tensor with shape [batch, seq_length, n_embd]
        """
        # TODO: Move heads dimension back
        # Hint: tensor = tensor.transpose(-3, -2)
        pass

        # TODO: Flatten head dimensions
        # Hint: new_shape = tensor.shape[:-2] + [num_heads * attn_head_size]
        # Hint: return tensor.reshape(new_shape)
        return None

    def _attn(self, query, key, value):
        """Compute attention for all heads in parallel.

        Args:
            query: Query tensor, shape [batch, num_heads, seq_length, head_size]
            key: Key tensor, shape [batch, num_heads, seq_length, head_size]
            value: Value tensor, shape [batch, num_heads, seq_length, head_size]

        Returns:
            Attention output, shape [batch, num_heads, seq_length, head_size]
        """
        # TODO: Implement attention computation
        # The same 5-step process: scores, scale, mask, softmax, weighted sum
        # Hint: Compute attention scores: query @ key.transpose(-1, -2)
        # Hint: Scale by sqrt(head_dim): attn_weights / math.sqrt(head_dim)
        # Hint: Apply causal mask using causal_mask function
        # Hint: Apply softmax: F.softmax(attn_weights)
        # Hint: Weighted sum: attn_weights @ value
        return None

    def __call__(self, hidden_states):
        """Apply multi-head attention.

        Args:
            hidden_states: Input tensor, shape [batch, seq_length, n_embd]

        Returns:
            Attention output, shape [batch, seq_length, n_embd]
        """
        # TODO: Project to Q, K, V
        # Hint: qkv = self.c_attn(hidden_states)
        # Hint: query, key, value = F.split(qkv, [self.split_size, self.split_size, self.split_size], axis=-1)
        pass

        # TODO: Split into multiple heads
        # Hint: query = self._split_heads(query, self.num_heads, self.head_dim)
        # Hint: key = self._split_heads(key, self.num_heads, self.head_dim)
        # Hint: value = self._split_heads(value, self.num_heads, self.head_dim)
        pass

        # TODO: Apply attention
        # Hint: attn_output = self._attn(query, key, value)
        pass

        # TODO: Merge heads back
        # Hint: attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        pass

        # TODO: Output projection
        # Hint: attn_output = self.c_proj(attn_output)
        # Hint: return attn_output
        return None

Validation

Run pixi run s07 to verify your implementation.

Show solution
"""
Solution for Step 07: Multi-head Attention

This module implements multi-head attention, which allows the model to jointly
attend to information from different representation subspaces at different positions.
"""

import math

from max.driver import Device
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import Dim, DimLike
from max.nn.module_v3 import Linear, Module

from solutions.solution_01 import GPT2Config


@F.functional
def causal_mask(
    sequence_length: DimLike,
    num_tokens: DimLike,
    *,
    dtype: DType,
    device: Device,
):
    """Create a causal attention mask."""
    n = Dim(sequence_length) + num_tokens
    mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
    mask = F.broadcast_to(mask, shape=(sequence_length, n))
    return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)


class GPT2MultiHeadAttention(Module):
    """Multi-head attention for GPT-2, matching HuggingFace structure."""

    def __init__(self, config: GPT2Config):
        """Initialize multi-head attention.

        Args:
            config: GPT2Config containing n_embd and n_head
        """
        super().__init__()

        self.embed_dim = config.n_embd
        self.num_heads = config.n_head
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim

        # Combined Q/K/V projection
        self.c_attn = Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
        # Output projection
        self.c_proj = Linear(self.embed_dim, self.embed_dim, bias=True)

    def _split_heads(self, tensor, num_heads, attn_head_size):
        """Split the last dimension into (num_heads, head_size).

        Transforms shape from [batch, seq_length, n_embd]
        to [batch, num_heads, seq_length, head_size]

        Args:
            tensor: Input tensor, shape [batch, seq_length, n_embd]
            num_heads: Number of attention heads
            attn_head_size: Dimension of each head

        Returns:
            Tensor with shape [batch, num_heads, seq_length, head_size]
        """
        # Add head dimension: [batch, seq_length, n_embd] -> [batch, seq_length, num_heads, head_size]
        new_shape = tensor.shape[:-1] + [num_heads, attn_head_size]
        tensor = tensor.reshape(new_shape)
        # Move heads dimension: [batch, seq_length, num_heads, head_size] -> [batch, num_heads, seq_length, head_size]
        return tensor.transpose(-3, -2)

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """Merge attention heads back to original shape.

        Transforms shape from [batch, num_heads, seq_length, head_size]
        to [batch, seq_length, n_embd]

        Args:
            tensor: Input tensor, shape [batch, num_heads, seq_length, head_size]
            num_heads: Number of attention heads
            attn_head_size: Dimension of each head

        Returns:
            Tensor with shape [batch, seq_length, n_embd]
        """
        # Move heads dimension back: [batch, num_heads, seq_length, head_size] -> [batch, seq_length, num_heads, head_size]
        tensor = tensor.transpose(-3, -2)
        # Flatten head dimensions: [batch, seq_length, num_heads, head_size] -> [batch, seq_length, n_embd]
        new_shape = tensor.shape[:-2] + [num_heads * attn_head_size]
        return tensor.reshape(new_shape)

    def _attn(self, query, key, value):
        """Compute attention for all heads in parallel.

        Args:
            query: Query tensor, shape [batch, num_heads, seq_length, head_size]
            key: Key tensor, shape [batch, num_heads, seq_length, head_size]
            value: Value tensor, shape [batch, num_heads, seq_length, head_size]

        Returns:
            Attention output, shape [batch, num_heads, seq_length, head_size]
        """
        # Compute attention scores
        attn_weights = query @ key.transpose(-1, -2)

        # Scale attention weights
        attn_weights = attn_weights / math.sqrt(int(value.shape[-1]))

        # Apply causal mask
        seq_len = query.shape[-2]
        mask = causal_mask(seq_len, 0, dtype=query.dtype, device=query.device)
        attn_weights = attn_weights + mask

        # Softmax and weighted sum
        attn_weights = F.softmax(attn_weights)
        attn_output = attn_weights @ value

        return attn_output

    def __call__(self, hidden_states):
        """Apply multi-head attention.

        Args:
            hidden_states: Input tensor, shape [batch, seq_length, n_embd]

        Returns:
            Attention output, shape [batch, seq_length, n_embd]
        """
        # Project to Q, K, V
        qkv = self.c_attn(hidden_states)
        query, key, value = F.split(
            qkv, [self.split_size, self.split_size, self.split_size], axis=-1
        )

        # Split into multiple heads
        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        # Apply attention
        attn_output = self._attn(query, key, value)

        # Merge heads back
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)

        # Output projection
        attn_output = self.c_proj(attn_output)

        return attn_output

Next: In Step 08, you’ll implement residual connections and layer normalization to enable training deep transformer networks.

Step 08: Residual connections and layer normalization

Learn to implement residual connections and layer normalization to enable training deep transformer networks.

Building the residual pattern

In this step, you’ll combine residual connections and layer normalization into a reusable pattern for transformer blocks. Residual connections add the input directly to the output using output = input + layer(input), creating shortcuts that let gradients flow through deep networks. You’ll implement this alongside the layer normalization from Step 03.

GPT-2 uses pre-norm architecture where layer norm is applied before each sublayer (attention or MLP). The pattern is x = x + sublayer(layer_norm(x)): normalize first, process, then add the original input back. This is more stable than post-norm alternatives for deep networks.

Residual connections solve the vanishing gradient problem. During backpropagation, gradients flow through the identity path (x = x + ...) without being multiplied by layer weights. This allows training networks with 12+ layers. Without residuals, gradients would diminish exponentially as they propagate through many layers.

Layer normalization works identically during training and inference because it normalizes each example independently. No batch statistics, no running averages, just consistent normalization that keeps activation distributions stable throughout training.

Understanding the pattern

The pre-norm residual pattern combines three operations in sequence:

Layer normalization: Normalize the input with F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). This uses learnable weight (gamma) and bias (beta) parameters to scale and shift the normalized values. You already implemented this in Step 03.

Sublayer processing: Pass the normalized input through a sublayer (attention or MLP). The sublayer transforms the data while the layer norm keeps its input well-conditioned.

Residual addition: Add the original input back to the sublayer output using simple element-wise addition: x + sublayer_output. Both tensors must have identical shapes [batch, seq_length, embed_dim].

The complete pattern is x = x + sublayer(layer_norm(x)). This differs from post-norm x = layer_norm(x + sublayer(x)), as pre-norm is more stable because normalization happens before potentially unstable sublayer operations.

MAX operations

You’ll use the following MAX operations to complete this task:

Layer normalization:

Tensor initialization:

Implementing the pattern

You’ll implement three classes that demonstrate the residual pattern: LayerNorm for normalization, ResidualBlock that combines norm and residual addition, and a standalone apply_residual_connection function.

First, import the required modules. You’ll need functional as F for layer norm, Tensor for parameters, DimLike for type hints, and Module as the base class.

LayerNorm implementation:

In __init__, create the learnable parameters:

  • Weight: Tensor.ones([dim]) stored as self.weight
  • Bias: Tensor.zeros([dim]) stored as self.bias
  • Store eps for numerical stability

In forward, apply normalization with F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). Returns a normalized tensor with the same shape as input.

ResidualBlock implementation:

In __init__, create a LayerNorm instance: self.ln = LayerNorm(dim, eps=eps). This will normalize inputs before sublayers.

In forward, implement the pre-norm pattern:

  1. Normalize: normalized = self.ln(x)
  2. Process: sublayer_output = sublayer(normalized)
  3. Add residual: return x + sublayer_output

Standalone function:

Implement apply_residual_connection(input_tensor, sublayer_output) that returns input_tensor + sublayer_output. This demonstrates the residual pattern as a simple function.

Implementation (step_08.py):

"""
Step 08: Residual Connections and Layer Normalization

Implement layer normalization and residual connections, which enable
training deep transformer networks by stabilizing gradients.

Tasks:
1. Import F (functional), Tensor, DimLike, and Module
2. Create LayerNorm class with learnable weight and bias parameters
3. Implement layer norm using F.layer_norm
4. Implement residual connection (simple addition)

Run: pixi run s08
"""

# TODO: Import required modules
# Hint: You'll need F from max.experimental
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need DimLike from max.graph
# Hint: You'll need Module from max.nn.module_v3


class LayerNorm(Module):
    """Layer normalization module matching HuggingFace GPT-2."""

    def __init__(self, dim: DimLike, *, eps: float = 1e-5):
        """Initialize layer normalization.

        Args:
            dim: Dimension to normalize (embedding dimension)
            eps: Small epsilon for numerical stability
        """
        super().__init__()
        self.eps = eps

        # TODO: Create learnable scale parameter (weight)
        # Hint: Use Tensor.ones([dim])
        self.weight = None

        # TODO: Create learnable shift parameter (bias)
        # Hint: Use Tensor.zeros([dim])
        self.bias = None

    def __call__(self, x: Tensor) -> Tensor:
        """Apply layer normalization.

        Args:
            x: Input tensor, shape [..., dim]

        Returns:
            Normalized tensor, same shape as input
        """
        # TODO: Apply layer normalization
        # Hint: Use F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)
        return None


class ResidualBlock(Module):
    """Demonstrates residual connections with layer normalization."""

    def __init__(self, dim: int, eps: float = 1e-5):
        """Initialize residual block.

        Args:
            dim: Dimension of the input/output
            eps: Epsilon for layer normalization
        """
        super().__init__()

        # TODO: Create layer normalization
        # Hint: Use LayerNorm(dim, eps=eps)
        self.ln = None

    def __call__(self, x: Tensor, sublayer_output: Tensor) -> Tensor:
        """Apply residual connection.

        Args:
            x: Input tensor (the residual)
            sublayer_output: Output from sublayer applied to ln(x)

        Returns:
            x + sublayer_output
        """
        # TODO: Add input and sublayer output (residual connection)
        # Hint: return x + sublayer_output
        return None


def apply_residual_connection(input_tensor: Tensor, sublayer_output: Tensor) -> Tensor:
    """Apply a residual connection by adding input to sublayer output.

    Args:
        input_tensor: Original input (the residual)
        sublayer_output: Output from a sublayer (attention, MLP, etc.)

    Returns:
        input_tensor + sublayer_output
    """
    # TODO: Add the two tensors
    # Hint: return input_tensor + sublayer_output
    return None

Validation

Run pixi run s08 to verify your implementation.

Show solution
"""
Solution for Step 08: Residual Connections and Layer Normalization

This module implements layer normalization and demonstrates residual connections,
which are essential for training deep transformer networks.
"""

from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import DimLike
from max.nn.module_v3 import Module


class LayerNorm(Module):
    """Layer normalization module matching HuggingFace GPT-2.

    Layer norm normalizes activations across the feature dimension,
    stabilizing training and allowing deeper networks.
    """

    def __init__(self, dim: DimLike, *, eps: float = 1e-5):
        """Initialize layer normalization.

        Args:
            dim: Dimension to normalize (embedding dimension)
            eps: Small epsilon for numerical stability
        """
        super().__init__()
        self.eps = eps
        # Learnable scale parameter (gamma)
        self.weight = Tensor.ones([dim])
        # Learnable shift parameter (beta)
        self.bias = Tensor.zeros([dim])

    def __call__(self, x: Tensor) -> Tensor:
        """Apply layer normalization.

        Args:
            x: Input tensor, shape [..., dim]

        Returns:
            Normalized tensor, same shape as input
        """
        return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)


class ResidualBlock(Module):
    """Demonstrates residual connections with layer normalization.

    This shows the pre-norm architecture used in GPT-2:
    output = input + sublayer(layer_norm(input))
    """

    def __init__(self, dim: int, eps: float = 1e-5):
        """Initialize residual block.

        Args:
            dim: Dimension of the input/output
            eps: Epsilon for layer normalization
        """
        super().__init__()
        self.ln = LayerNorm(dim, eps=eps)

    def __call__(self, x: Tensor, sublayer_output: Tensor) -> Tensor:
        """Apply residual connection.

        This demonstrates the pattern:
        1. Normalize input: ln(x)
        2. Apply sublayer (passed as argument for simplicity)
        3. Add residual: x + sublayer_output

        In practice, the sublayer (attention or MLP) is applied to ln(x),
        but we receive the result as a parameter for clarity.

        Args:
            x: Input tensor (the residual)
            sublayer_output: Output from sublayer applied to ln(x)

        Returns:
            x + sublayer_output
        """
        # In a real transformer block, you would do:
        # residual = x
        # x = self.ln(x)
        # x = sublayer(x)  # e.g., attention or MLP
        # x = x + residual

        # For this demonstration, we just add
        return x + sublayer_output


def apply_residual_connection(input_tensor: Tensor, sublayer_output: Tensor) -> Tensor:
    """Apply a residual connection by adding input to sublayer output.

    Residual connections allow gradients to flow directly through the network,
    enabling training of very deep models.

    Args:
        input_tensor: Original input (the residual)
        sublayer_output: Output from a sublayer (attention, MLP, etc.)

    Returns:
        input_tensor + sublayer_output
    """
    return input_tensor + sublayer_output

Next: In Step 09, you’ll combine multi-head attention, MLP, layer norm, and residual connections into a complete transformer block.

Step 09: Transformer block

Learn to combine attention, MLP, layer normalization, and residual connections into a complete transformer block.

Building the transformer block

In this step, you’ll build the GPT2Block class. This is a fundamental repeating unit of GPT-2. Each block combines multi-head attention and a feed-forward network, with layer normalization and residual connections around each.

The block processes input through two sequential operations. First, it applies layer norm, runs multi-head attention, then adds the result back to the input (residual connection). Second, it applies another layer norm, runs the MLP, and adds that result back. This pattern is x = x + sublayer(layer_norm(x)), called pre-normalization.

GPT-2 uses pre-norm because it stabilizes training in deep networks. By normalizing before each sublayer instead of after, gradients flow more smoothly through the network’s 12 stacked blocks.

Understanding the components

The transformer block consists of four components, applied in this order:

First layer norm (ln_1): Normalizes the input before attention. Uses epsilon=1e-5 for numerical stability.

Multi-head attention (attn): The self-attention mechanism from Step 07. Lets each position attend to all previous positions.

Second layer norm (ln_2): Normalizes before the MLP. Same configuration as the first.

Feed-forward network (mlp): The position-wise MLP from Step 04. Expands to 3,072 dimensions internally (4× the embedding size), then projects back to 768.

The block maintains a constant 768-dimensional representation throughout. Input shape [batch, seq_length, 768] stays the same after each sublayer, which is essential for stacking 12 blocks together.

Understanding the flow

Each sublayer follows the pre-norm pattern:

  1. Save the input as residual
  2. Apply layer normalization to the input
  3. Process through the sublayer (attention or MLP)
  4. Add the original residual back to the output

This happens twice per block, once for attention and once for the MLP. The residual connections let gradients flow directly through the network, preventing vanishing gradients in deep models.

Component names (ln_1, attn, ln_2, mlp) match Hugging Face’s GPT-2 implementation. This matters for loading pretrained weights in later steps.

Implementing the block

You’ll create the GPT2Block class by composing the components from earlier steps. The block takes GPT2Config and creates four sublayers, then applies them in sequence with residual connections.

First, import the required modules. You’ll need Module from MAX, plus the previously implemented components: GPT2Config, GPT2MLP, GPT2MultiHeadAttention, and LayerNorm.

In the __init__ method, create the four sublayers:

  • ln_1: LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  • attn: GPT2MultiHeadAttention(config)
  • ln_2: LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  • mlp: GPT2MLP(4 * config.n_embd, config)

The MLP uses 4 * config.n_embd (3,072 dimensions) as its inner dimension, following the standard transformer ratio.

In the forward method, implement the two sublayer blocks:

Attention block:

  1. Save residual = hidden_states
  2. Normalize: hidden_states = self.ln_1(hidden_states)
  3. Apply attention: attn_output = self.attn(hidden_states)
  4. Add back: hidden_states = attn_output + residual

MLP block:

  1. Save residual = hidden_states
  2. Normalize: hidden_states = self.ln_2(hidden_states)
  3. Apply MLP: feed_forward_hidden_states = self.mlp(hidden_states)
  4. Add back: hidden_states = residual + feed_forward_hidden_states

Finally, return hidden_states.

Implementation (step_09.py):

"""
Step 09: Transformer Block

Combine multi-head attention, MLP, layer normalization, and residual
connections into a complete transformer block.

Tasks:
1. Import Module and all previous solution components
2. Create ln_1, attn, ln_2, and mlp layers
3. Implement forward pass with pre-norm residual pattern

Run: pixi run s09
"""

# TODO: Import required modules
# Hint: You'll need Module from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import GPT2MLP from solutions.solution_04
# Hint: Import GPT2MultiHeadAttention from solutions.solution_07
# Hint: Import LayerNorm from solutions.solution_08


class GPT2Block(Module):
    """Complete GPT-2 transformer block."""

    def __init__(self, config: GPT2Config):
        """Initialize transformer block.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        hidden_size = config.n_embd
        inner_dim = (
            config.n_inner
            if hasattr(config, "n_inner") and config.n_inner is not None
            else 4 * hidden_size
        )

        # TODO: Create first layer norm (before attention)
        # Hint: Use LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.ln_1 = None

        # TODO: Create multi-head attention
        # Hint: Use GPT2MultiHeadAttention(config)
        self.attn = None

        # TODO: Create second layer norm (before MLP)
        # Hint: Use LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.ln_2 = None

        # TODO: Create MLP
        # Hint: Use GPT2MLP(inner_dim, config)
        self.mlp = None

    def __call__(self, hidden_states):
        """Apply transformer block.

        Args:
            hidden_states: Input tensor, shape [batch, seq_length, n_embd]

        Returns:
            Output tensor, shape [batch, seq_length, n_embd]
        """
        # TODO: Attention block with residual connection
        # Hint: residual = hidden_states
        # Hint: hidden_states = self.ln_1(hidden_states)
        # Hint: attn_output = self.attn(hidden_states)
        # Hint: hidden_states = attn_output + residual
        pass

        # TODO: MLP block with residual connection
        # Hint: residual = hidden_states
        # Hint: hidden_states = self.ln_2(hidden_states)
        # Hint: feed_forward_hidden_states = self.mlp(hidden_states)
        # Hint: hidden_states = residual + feed_forward_hidden_states
        pass

        # TODO: Return the output
        return None

Validation

Run pixi run s09 to verify your implementation.

Show solution
"""
Solution for Step 09: Transformer Block

This module implements a complete GPT-2 transformer block, combining
multi-head attention, MLP, layer normalization, and residual connections.
"""

from max.nn.module_v3 import Module

from solutions.solution_01 import GPT2Config
from solutions.solution_04 import GPT2MLP
from solutions.solution_07 import GPT2MultiHeadAttention
from solutions.solution_08 import LayerNorm


class GPT2Block(Module):
    """Complete GPT-2 transformer block matching HuggingFace structure.

    Architecture (pre-norm):
    1. x = x + attention(layer_norm(x))
    2. x = x + mlp(layer_norm(x))
    """

    def __init__(self, config: GPT2Config):
        """Initialize transformer block.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        hidden_size = config.n_embd
        # Inner dimension for MLP (4x hidden size by default)
        inner_dim = (
            config.n_inner
            if hasattr(config, "n_inner") and config.n_inner is not None
            else 4 * hidden_size
        )

        # First layer norm (before attention)
        self.ln_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        # Multi-head attention
        self.attn = GPT2MultiHeadAttention(config)
        # Second layer norm (before MLP)
        self.ln_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        # Feed-forward MLP
        self.mlp = GPT2MLP(inner_dim, config)

    def __call__(self, hidden_states):
        """Apply transformer block.

        Args:
            hidden_states: Input tensor, shape [batch, seq_length, n_embd]

        Returns:
            Output tensor, shape [batch, seq_length, n_embd]
        """
        # Attention block with residual connection
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(hidden_states)
        hidden_states = attn_output + residual

        # MLP block with residual connection
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = residual + feed_forward_hidden_states

        return hidden_states

Next: In Step 10, you’ll stack 12 transformer blocks together to create the complete GPT-2 model architecture.

Step 10: Stacking transformer blocks

Learn to stack 12 transformer blocks with embeddings and final normalization to create the complete GPT-2 model.

Building the complete model

In this step, you’ll create the GPT2Model class - the complete transformer that takes token IDs as input and outputs contextualized representations. This class combines embeddings, 12 stacked transformer blocks, and final layer normalization.

The model processes input in four stages: convert token IDs to embeddings, add position information, pass through 12 transformer blocks sequentially, and normalize the final output. Each transformer block refines the representation, building up from surface patterns in early layers to semantic understanding in later layers.

GPT-2 uses 12 layers because this depth allows the model to learn complex patterns while remaining trainable. Fewer layers would limit the model’s capacity. More layers would increase training difficulty without proportional gains in quality for a 117M parameter model.

Understanding the components

The complete model has four main components:

Token embeddings (wte): Maps each token ID to a 768-dimensional vector using a lookup table with 50,257 entries (one per vocabulary token).

Position embeddings (wpe): Maps each position (0 to 1,023) to a 768-dimensional vector. These are added to token embeddings so the model knows token order.

Transformer blocks (h): 12 identical blocks stacked using MAX’s Sequential module. Sequential applies blocks in order, passing each block’s output to the next.

Final layer norm (ln_f): Normalizes the output after all blocks. This stabilizes the representation before the language model head (added in Step 11) projects to vocabulary logits.

Understanding the forward pass

The forward method processes token IDs through the model:

First, create position indices using Tensor.arange. Generate positions [0, 1, 2, …, seq_length-1] matching the input’s dtype and device. This ensures compatibility when adding to embeddings.

Next, look up embeddings. Get token embeddings with self.wte(input_ids) and position embeddings with self.wpe(position_indices). Add them together element-wise, as both are shape [batch, seq_length, 768].

Then, pass through the transformer blocks with self.h(x). Sequential applies all 12 blocks in order, each refining the representation.

Finally, normalize the output with self.ln_f(x) and return the result. The output shape matches the input: [batch, seq_length, 768].

MAX operations

You’ll use the following MAX operations to complete this task:

Module composition:

Embeddings:

Position generation:

Implementing the model

You’ll create the GPT2Model class by composing embedding layers, transformer blocks, and layer normalization. The class builds on all the components from previous steps.

First, import the required modules. You’ll need Tensor for position indices, Embedding, Module, and Sequential from MAX’s neural network module, plus the previously implemented GPT2Config, LayerNorm, and GPT2Block.

In the __init__ method, create the four components:

  • Token embeddings: Embedding(config.vocab_size, dim=config.n_embd) stored as self.wte
  • Position embeddings: Embedding(config.n_positions, dim=config.n_embd) stored as self.wpe
  • Transformer blocks: Sequential(*(GPT2Block(config) for _ in range(config.n_layer))) stored as self.h
  • Final layer norm: LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) stored as self.ln_f

The Sequential module takes a generator expression that creates 12 identical GPT2Block instances. The * unpacks them as arguments to Sequential.

In the forward method, implement the four-stage processing:

  1. Get the sequence length from input_ids.shape
  2. Create position indices: Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
  3. Look up embeddings and add them: x = self.wte(input_ids) + self.wpe(position_indices)
  4. Apply transformer blocks: x = self.h(x)
  5. Apply final normalization: x = self.ln_f(x)
  6. Return x

The position indices must match the input’s dtype and device to ensure the tensors are compatible for addition.

Implementation (step_10.py):

"""
Step 10: Stacking Transformer Blocks

Stack multiple transformer blocks with embeddings to create
the complete GPT-2 model architecture.

Tasks:
1. Import Tensor, Embedding, Module, Sequential, and previous components
2. Create token and position embeddings
3. Stack n_layer transformer blocks using Sequential
4. Create final layer normalization
5. Implement forward pass: embeddings -> blocks -> layer norm

Run: pixi run s10
"""

# TODO: Import required modules
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need Embedding, Module, Sequential from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import LayerNorm from solutions.solution_08
# Hint: Import GPT2Block from solutions.solution_09


class GPT2Model(Module):
    """Complete GPT-2 transformer model."""

    def __init__(self, config: GPT2Config):
        """Initialize GPT-2 model.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        # TODO: Create token embeddings
        # Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
        self.wte = None

        # TODO: Create position embeddings
        # Hint: Use Embedding(config.n_positions, dim=config.n_embd)
        self.wpe = None

        # TODO: Stack transformer blocks
        # Hint: Use Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
        # This creates config.n_layer blocks (12 for GPT-2 base)
        self.h = None

        # TODO: Create final layer normalization
        # Hint: Use LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.ln_f = None

    def __call__(self, input_ids):
        """Forward pass through the transformer.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Hidden states, shape [batch, seq_length, n_embd]
        """
        # TODO: Get batch size and sequence length
        # Hint: batch_size, seq_length = input_ids.shape
        pass

        # TODO: Get token embeddings
        # Hint: tok_embeds = self.wte(input_ids)
        pass

        # TODO: Get position embeddings
        # Hint: Create position indices with Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
        # Hint: pos_embeds = self.wpe(position_indices)
        pass

        # TODO: Combine embeddings
        # Hint: x = tok_embeds + pos_embeds
        pass

        # TODO: Apply transformer blocks
        # Hint: x = self.h(x)
        pass

        # TODO: Apply final layer norm
        # Hint: x = self.ln_f(x)
        pass

        # TODO: Return the output
        return None

Validation

Run pixi run s10 to verify your implementation.

Show solution
"""
Solution for Step 10: Stacking Transformer Blocks

This module stacks multiple transformer blocks and adds embeddings
to create the complete GPT-2 transformer architecture.
"""

from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding, Module, Sequential

from solutions.solution_01 import GPT2Config
from solutions.solution_08 import LayerNorm
from solutions.solution_09 import GPT2Block


class GPT2Model(Module):
    """Complete GPT-2 transformer model matching HuggingFace structure.

    Architecture:
    1. Token embeddings + position embeddings
    2. Stack of n_layer transformer blocks
    3. Final layer normalization
    """

    def __init__(self, config: GPT2Config):
        """Initialize GPT-2 model.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        # Token embeddings (vocabulary -> embeddings)
        self.wte = Embedding(config.vocab_size, dim=config.n_embd)
        # Position embeddings (positions -> embeddings)
        self.wpe = Embedding(config.n_positions, dim=config.n_embd)

        # Stack of transformer blocks
        self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))

        # Final layer normalization
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

    def __call__(self, input_ids):
        """Forward pass through the transformer.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Hidden states, shape [batch, seq_length, n_embd]
        """
        batch_size, seq_length = input_ids.shape

        # Get token embeddings
        tok_embeds = self.wte(input_ids)

        # Get position embeddings
        pos_embeds = self.wpe(
            Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
        )

        # Combine embeddings
        x = tok_embeds + pos_embeds

        # Apply transformer blocks
        x = self.h(x)

        # Final layer norm
        x = self.ln_f(x)

        return x

Next: In Step 11, you’ll add the language modeling head that projects hidden states to vocabulary logits for text generation.

Step 11: Language model head

Learn to add the final linear projection layer that converts hidden states to vocabulary logits for next-token prediction.

Adding the language model head

In this step, you’ll create the MaxGPT2LMHeadModel - the complete language model that can predict next tokens. This class wraps the transformer from Step 10 and adds a final linear layer that projects 768-dimensional hidden states to 50,257-dimensional vocabulary logits.

The language model head is a single linear layer without bias. For each position in the sequence, it outputs a score for every possible next token. Higher scores indicate the model thinks that token is more likely to come next.

At 768 × 50,257 = 38.6M parameters, the LM head is the single largest component in GPT-2, representing about 33% of the model’s 117M total parameters. This is larger than all 12 transformer blocks combined.

Understanding the projection

The language model head performs a simple linear projection using MAX’s Linear layer. It maps each 768-dimensional hidden state to 50,257 scores, one per vocabulary token.

The layer uses bias=False, meaning it only has weights and no bias vector. This saves 50,257 parameters (about 0.4% of model size). The bias provides little benefit because the layer normalization before the LM head already centers the activations. Adding a constant bias to all logits wouldn’t change the relative probabilities after softmax.

The output is called “logits,” which are raw scores before applying softmax. Logits can be any real number. During text generation (Step 12), you’ll convert logits to probabilities with softmax. Working with logits directly enables techniques like temperature scaling and top-k sampling.

Understanding the complete model

With the LM head added, you now have the complete GPT-2 architecture:

  1. Input: Token IDs [batch, seq_length]
  2. Embeddings: Token + position [batch, seq_length, 768]
  3. Transformer blocks: 12 blocks process the embeddings [batch, seq_length, 768]
  4. Final layer norm: Normalizes the output [batch, seq_length, 768]
  5. LM head: Projects to vocabulary [batch, seq_length, 50257]
  6. Output: Logits [batch, seq_length, 50257]

Each position gets independent logits over the vocabulary. To predict the next token after position i, you look at the logits at position i. The highest scoring token is the model’s top prediction.

MAX operations

You’ll use the following MAX operations to complete this task:

Linear layer:

Implementing the language model

You’ll create the MaxGPT2LMHeadModel class that wraps the transformer with a language modeling head. The implementation is straightforward, with just two components and a simple forward pass.

First, import the required modules. You’ll need Linear and Module from MAX, plus the previously implemented GPT2Config and GPT2Model.

In the __init__ method, create two components:

  • Transformer: GPT2Model(config) stored as self.transformer
  • LM head: Linear(config.n_embd, config.vocab_size, bias=False) stored as self.lm_head

Note the bias=False parameter, which creates a linear layer without bias terms.

In the forward method, implement a simple two-step process:

  1. Get hidden states from the transformer: hidden_states = self.transformer(input_ids)
  2. Project to vocabulary logits: logits = self.lm_head(hidden_states)
  3. Return logits

That’s it. The model takes token IDs and returns logits. In the next step, you’ll use these logits to generate text.

Implementation (step_11.py):

"""
Step 11: Language Model Head

Add the final projection layer that converts hidden states to vocabulary logits.

Tasks:
1. Import Linear, Module, and previous components
2. Create transformer and lm_head layers
3. Implement forward pass: transformer -> lm_head

Run: pixi run s11
"""

# TODO: Import required modules
# Hint: You'll need Linear and Module from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import GPT2Model from solutions.solution_10


class MaxGPT2LMHeadModel(Module):
    """Complete GPT-2 model with language modeling head."""

    def __init__(self, config: GPT2Config):
        """Initialize GPT-2 with LM head.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        self.config = config

        # TODO: Create the transformer
        # Hint: Use GPT2Model(config)
        self.transformer = None

        # TODO: Create language modeling head
        # Hint: Use Linear(config.n_embd, config.vocab_size, bias=False)
        # Projects from hidden dimension to vocabulary size
        self.lm_head = None

    def __call__(self, input_ids):
        """Forward pass through transformer and LM head.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Logits over vocabulary, shape [batch, seq_length, vocab_size]
        """
        # TODO: Get hidden states from transformer
        # Hint: hidden_states = self.transformer(input_ids)
        pass

        # TODO: Project to vocabulary logits
        # Hint: logits = self.lm_head(hidden_states)
        pass

        # TODO: Return logits
        return None

Validation

Run pixi run s11 to verify your implementation.

Show solution
"""
Solution for Step 11: Language Model Head

This module adds the final projection layer that converts hidden states
to vocabulary logits for predicting the next token.
"""

from max.nn.module_v3 import Linear, Module

from solutions.solution_01 import GPT2Config
from solutions.solution_10 import GPT2Model


class MaxGPT2LMHeadModel(Module):
    """Complete GPT-2 model with language modeling head.

    This is the full model that can be used for text generation.
    """

    def __init__(self, config: GPT2Config):
        """Initialize GPT-2 with LM head.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        self.config = config
        # The transformer (embeddings + blocks + final norm)
        self.transformer = GPT2Model(config)
        # Language modeling head (hidden states -> vocabulary logits)
        self.lm_head = Linear(config.n_embd, config.vocab_size, bias=False)

    def __call__(self, input_ids):
        """Forward pass through transformer and LM head.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Logits over vocabulary, shape [batch, seq_length, vocab_size]
        """
        # Get hidden states from transformer
        hidden_states = self.transformer(input_ids)

        # Project to vocabulary logits
        logits = self.lm_head(hidden_states)

        return logits

Next: In Step 12, you’ll implement text generation using sampling and temperature control to generate coherent text autoregressively.

Step 12: Text generation

Learn to implement autoregressive text generation with sampling and temperature control.

Generating text

In this final step, you’ll implement the generation loop that produces text one token at a time. The model predicts the next token, appends it to the sequence, and repeats until reaching the desired length.

Start with a prompt like “Hello world” (tokens [15496, 995]). The model predicts the next token, giving you [15496, 995, 318] (“Hello world is”). It predicts again, producing [15496, 995, 318, 257] (“Hello world is a”). This process continues, with each prediction feeding back as input for the next.

You’ll implement two generation strategies: greedy decoding (always pick the highest-scoring token) and sampling (randomly choose according to probabilities). You’ll also add temperature control to adjust how random or focused the generation is.

Understanding the generation loop

The generation loop is simple: run the model, extract the next token prediction, append it to the sequence, repeat. Each iteration requires a full forward pass through all 12 transformer blocks.

The model outputs logits with shape [batch, seq_length, vocab_size]. Since you only care about predicting the next token, extract the last position: logits[0, -1, :]. This gives you a vector of 50,257 scores, one per vocabulary token.

These scores are logits (unnormalized), not probabilities. To convert them to probabilities, apply softmax. Then you can either pick the highest-probability token (greedy) or sample from the distribution (random).

Understanding temperature control

Temperature scaling adjusts how random the generation is using the formula scaled_logits = logits / temperature.

With temperature 1.0, you use the original distribution. With temperature 0.7, you sharpen the distribution, and high-probability tokens become even more likely, making generation more focused and deterministic. With temperature 1.2, you flatten the distribution, and lower-probability tokens get more chances, making generation more diverse and creative.

Temperature is applied before softmax. Dividing by a value less than 1 makes large logits even larger (sharpening), while dividing by a value greater than 1 reduces the differences between logits (flattening).

Understanding sampling vs greedy

Greedy decoding always picks the highest-probability token using F.argmax. It’s fast, deterministic, and simple, but often produces repetitive text because the model keeps choosing the safest option.

Sampling randomly selects tokens according to their probabilities. Convert logits to probabilities with F.softmax, transfer to CPU, convert to NumPy with np.from_dlpack, then sample with np.random.choice. You use NumPy because MAX doesn’t have built-in sampling yet.

Most practical generation uses sampling with temperature control. This balances creativity with coherence, as the model can explore different possibilities while still favoring high-quality continuations.

MAX operations

You’ll use the following MAX operations to complete this task:

Probability operations:

Sequence building:

NumPy interop:

  • probs.to(CPU()): Transfers tensor to CPU
  • np.from_dlpack(probs): Converts MAX tensor to NumPy for sampling

Implementing text generation

You’ll create two functions: generate_next_token that predicts a single token, and generate that loops to produce full sequences.

First, import the required modules. You’ll need numpy for sampling, CPU from MAX’s driver, DType for type constants, functional as F for operations like softmax and argmax, and Tensor for creating tensors.

In generate_next_token, implement the prediction logic:

  1. Run the model to get logits: logits = model(input_ids)
  2. Extract the last position (next token prediction): next_token_logits = logits[0, -1, :]
  3. If using temperature, scale the logits by dividing by the temperature tensor
  4. For sampling: convert to probabilities with F.softmax, transfer to CPU, convert to NumPy with np.from_dlpack, sample with np.random.choice, then convert back to a MAX tensor
  5. For greedy: use F.argmax to select the highest-scoring token

The temperature must be a tensor with the same dtype and device as the logits. Create it with Tensor.constant(temperature, dtype=..., device=...).

In generate, implement the generation loop:

  1. Initialize with the input: generated_tokens = input_ids
  2. Loop max_new_tokens times
  3. Generate the next token: next_token = generate_next_token(model, generated_tokens, ...)
  4. Reshape to 2D: next_token_2d = next_token.reshape([1, -1])
  5. Concatenate to the sequence: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
  6. Return the complete sequence

The reshape is necessary because concat requires matching dimensions, and the generated token is 0D (scalar).

Implementation (step_12.py):

"""
Step 12: Text Generation

Implement autoregressive text generation with sampling and temperature control.

Tasks:
1. Import required modules (numpy, F, Tensor, etc.)
2. Implement generate_next_token: get logits, apply temperature, sample/argmax
3. Implement generate_tokens: loop to generate multiple tokens

Run: pixi run s12
"""

# TODO: Import required modules
# Hint: You'll need numpy as np
# Hint: You'll need CPU from max.driver
# Hint: You'll need DType from max.dtype
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor from max.experimental.tensor


def generate_next_token(model, input_ids, temperature=1.0, do_sample=True):
    """Generate the next token given input context.

    Args:
        model: GPT-2 model with LM head
        input_ids: Current sequence, shape [batch, seq_length]
        temperature: Sampling temperature (higher = more random)
        do_sample: If True, sample from distribution; if False, use greedy (argmax)

    Returns:
        Next token ID as a Tensor
    """
    # TODO: Get logits from model
    # Hint: logits = model(input_ids)
    pass

    # TODO: Get logits for last position
    # Hint: next_token_logits = logits[0, -1, :]
    pass

    # TODO: If sampling with temperature
    if do_sample and temperature > 0:
        # TODO: Apply temperature scaling
        # Hint: temp_tensor = Tensor.constant(temperature, dtype=next_token_logits.dtype, device=next_token_logits.device)
        # Hint: next_token_logits = next_token_logits / temp_tensor
        pass

        # TODO: Convert to probabilities
        # Hint: probs = F.softmax(next_token_logits)
        pass

        # TODO: Sample from distribution
        # Hint: probs_np = np.from_dlpack(probs.to(CPU()))
        # Hint: next_token_id = np.random.choice(len(probs_np), p=probs_np)
        # Hint: next_token_tensor = Tensor.constant(next_token_id, dtype=DType.int64, device=input_ids.device)
        pass
    else:
        # TODO: Greedy decoding (select most likely token)
        # Hint: next_token_tensor = F.argmax(next_token_logits)
        pass

    # TODO: Return the next token
    return None


def generate_tokens(
    model, input_ids, max_new_tokens=10, temperature=1.0, do_sample=True
):
    """Generate multiple tokens autoregressively.

    Args:
        model: GPT-2 model with LM head
        input_ids: Initial sequence, shape [batch, seq_length]
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature
        do_sample: Whether to sample or use greedy decoding

    Returns:
        Generated sequence including input, shape [batch, seq_length + max_new_tokens]
    """
    # TODO: Initialize generated tokens with input
    # Hint: generated_tokens = input_ids
    pass

    # TODO: Generation loop
    # Hint: for _ in range(max_new_tokens):
    pass

    # TODO: Generate next token
    # Hint: next_token = generate_next_token(model, generated_tokens, temperature=temperature, do_sample=do_sample)
    pass

    # TODO: Reshape to [1, 1] for concatenation
    # Hint: next_token_2d = next_token.reshape([1, -1])
    pass

    # TODO: Append to sequence
    # Hint: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
    pass

    # TODO: Return generated sequence
    return None

Validation

Run pixi run s12 to verify your implementation.

Show solution
"""
Solution for Step 12: Text Generation

This module implements autoregressive text generation using the GPT-2 model.
"""

import numpy as np

from max.driver import CPU
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor


def generate_next_token(model, input_ids, temperature=1.0, do_sample=True):
    """Generate the next token given input context.

    Args:
        model: GPT-2 model with LM head
        input_ids: Current sequence, shape [batch, seq_length]
        temperature: Sampling temperature (higher = more random)
        do_sample: If True, sample from distribution; if False, use greedy (argmax)

    Returns:
        Next token ID as a Tensor
    """
    # Get logits from model
    logits = model(input_ids)

    # Get logits for last position (next token prediction)
    next_token_logits = logits[0, -1, :]  # Shape: [vocab_size]

    if do_sample and temperature > 0:
        # Apply temperature scaling
        temp_tensor = Tensor.constant(
            temperature, dtype=next_token_logits.dtype, device=next_token_logits.device
        )
        next_token_logits = next_token_logits / temp_tensor

        # Convert to probabilities
        probs = F.softmax(next_token_logits)

        # Sample from distribution
        probs_np = np.from_dlpack(probs.to(CPU()))
        next_token_id = np.random.choice(len(probs_np), p=probs_np)
        next_token_tensor = Tensor.constant(
            next_token_id, dtype=DType.int64, device=input_ids.device
        )
    else:
        # Greedy decoding: select most likely token
        next_token_tensor = F.argmax(next_token_logits)

    return next_token_tensor


def generate_tokens(
    model, input_ids, max_new_tokens=10, temperature=1.0, do_sample=True
):
    """Generate multiple tokens autoregressively.

    Args:
        model: GPT-2 model with LM head
        input_ids: Initial sequence, shape [batch, seq_length]
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature
        do_sample: Whether to sample or use greedy decoding

    Returns:
        Generated sequence including input, shape [batch, seq_length + max_new_tokens]
    """
    generated_tokens = input_ids

    for _ in range(max_new_tokens):
        # Generate next token
        next_token = generate_next_token(
            model, generated_tokens, temperature=temperature, do_sample=do_sample
        )

        # Reshape to [1, 1] for concatenation
        next_token_2d = next_token.reshape([1, -1])

        # Append to sequence
        generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)

    return generated_tokens

What you’ve built

You’ve completed all 12 steps and built a complete GPT-2 model from scratch using MAX. You now have a working implementation of:

Core components:

  • Model configuration and architecture definition
  • Causal masking for autoregressive generation
  • Layer normalization for training stability
  • Feed-forward networks with GELU activation
  • Token and position embeddings
  • Multi-head self-attention
  • Residual connections and transformer blocks
  • Language model head for next-token prediction
  • Text generation with temperature and sampling

Your model loads OpenAI’s pretrained GPT-2 weights and generates text. You understand how every component works, from the low-level tensor operations to the high-level architecture decisions.

What’s next?

You now understand the architectural foundation that powers modern language models. LLaMA, Mistral, and more build on these same components with incremental refinements. You have everything you need to implement those refinements yourself.

Consider extending your implementation with:

  • Grouped-query attention (GQA): Reduce memory consumption by sharing key-value pairs across multiple query heads, as used in LLaMA 2.
  • Rotary position embeddings (RoPE): Replace learned position embeddings with rotation-based encoding, improving length extrapolation in models like LLaMA and GPT-NeoX.
  • SwiGLU activation: Swap GELU for the gated linear unit variant used in LLaMA and PaLM.
  • Mixture of experts (MoE): Add sparse expert routing to scale model capacity efficiently, as in Mixtral and GPT-4.

Each refinement builds directly on what you’ve implemented. The attention mechanism you wrote becomes grouped-query attention with a simple modification to how you reshape key-value tensors. Your position embeddings can be replaced with RoPE by changing how you encode positional information. The feed-forward network you built becomes SwiGLU by adding a gating mechanism.

Pick an architecture that interests you and start building. You’ll find the patterns are familiar because the fundamentals haven’t changed.