Build an LLM from scratch in MAX
Experimental APIs: The APIs in the experimental package are subject to change. Share feedback on the MAX LLMs forum.
Transformer models power today’s most impactful AI applications, from language models like ChatGPT to code generation tools like GitHub Copilot. Maybe you’ve been asked to adapt one of these models for your team, or you want to understand what’s actually happening when you call an inference API. Either way, building a transformer from scratch is one of the best ways to truly understand how they work.
This guide walks you through implementing GPT-2 using Modular’s MAX framework experimental API. You’ll build each component yourself: embeddings, attention mechanisms, and feed-forward layers. You’ll see how they fit together into a complete language model by completing the sequential coding challenges in the tutorial GitHub repository.
This API is unstable: This tutorial is built on the MAX Experimental API, which we expect to change over time and expand to include new features and functionality. As it evolves, we plan to update the tutorial accordingly. When this API comes out of experimental development, the tutorial content will also enter a more stable state. While in development, this tutorial will be pinned to a major release version.
Why GPT-2?
It’s the architectural foundation for modern language models. LLaMA, Mistral, GPT-4; they’re all built on the same core components you’ll implement here:
- multi-head attention
- feed-forward layers
- layer normalization
- residual connections
Modern variants add refinements like grouped-query attention or mixture of experts, but the fundamentals remain the same. GPT-2 is complex enough to teach real transformer architecture but simple enough to implement completely and understand deeply. When you grasp how its pieces fit together, you understand how to build any transformer-based model.
Learning by building: This tutorial follows a format popularized by Andrej Karpathy’s educational work and Sebastian Raschka’s hands-on approach. Rather than abstract theory, you’ll implement each component yourself, building intuition through practice.
Why MAX?
Traditional ML development often feels like stitching together tools that weren’t designed to work together. Maybe you write your model in PyTorch, optimize in CUDA, convert to ONNX for deployment, then use separate serving tools. Each handoff introduces complexity.
MAX Framework takes a different approach: everything happens in one unified system. You write code to define your model, load weights, and run inference, all in MAX’s Python API. The MAX Platform handles optimization automatically and you can even use MAX Serve to manage your deployment.
When you build GPT-2 in this guide, you’ll load pretrained weights from Hugging Face, implement the architecture, and run text generation, all in the same environment.
Why coding challenges?
This tutorial emphasizes active problem-solving over passive reading. Each step presents a focused implementation task with:
- Clear context: What you’re building and why it matters
- Guided implementation: Code structure with specific tasks to complete
- Immediate validation: Tests that verify correctness before moving forward
- Conceptual grounding: Explanations that connect code to architecture
Rather than presenting complete solutions, this approach helps you develop intuition for when and why to use specific patterns. The skills you build extend beyond GPT-2 to model development more broadly.
You can work through the tutorial sequentially for comprehensive understanding, or skip directly to topics you need. Each step is self-contained enough to be useful independently while building toward a complete implementation.
What you’ll build
This tutorial guides you through building GPT-2 in manageable steps:
| Step | Component | What you’ll learn |
|---|---|---|
| 1 | Model configuration | Define architecture hyperparameters matching HuggingFace GPT-2. |
| 2 | Causal masking | Create attention masks to prevent looking at future tokens. |
| 3 | Layer normalization | Stabilize activations for effective training. |
| 4 | GPT-2 MLP (feed-forward network) | Build the position-wise feed-forward network with GELU activation. |
| 5 | Token embeddings | Convert token IDs to continuous vector representations. |
| 6 | Position embeddings | Encode sequence order information. |
| 7 | Multi-head attention | Extend to multiple parallel attention heads. |
| 8 | Residual connections & layer norm | Enable training deep networks with skip connections. |
| 9 | Transformer block | Combine attention and MLP into the core building block. |
| 10 | Stacking transformer blocks | Create the complete 12-layer GPT-2 model. |
| 11 | Language model head | Project hidden states to vocabulary logits. |
| 12 | Text generation | Generate text autoregressively with temperature sampling. |
| 00 | Serve your model | Run GPT-2 as an endpoint with MAX Serve…. coming soon. |
By the end, you’ll have a complete GPT-2 implementation and practical experience with MAX’s Python API. These are skills you can immediately apply to your own projects.
To install the puzzles, follow the steps in Setup.
How this book works
Each step includes automated tests that verify your implementation before moving forward. This immediate feedback helps you catch issues early and build confidence.
You’ll first need to clone the GitHub repository and navigate to the repository:
git clone https://github.com/modular/max-llm-book
cd max-llm-book
Then download and install pixi:
curl -fsSL https://pixi.sh/install.sh | sh
To validate a step, use the corresponding test command. For example, to test Step 01:
pixi run s01
Initially, tests will fail because the implementation isn’t complete:
✨ Pixi task (s01): python tests/test.step_01.py
Running tests for Step 01: Create Model Configuration...
Results:
❌ dataclass is not imported from dataclasses
❌ GPT2Config does not have the @dataclass decorator
❌ vocab_size is incorrect: expected match with Hugging Face model configuration, got None
# ...
Each failure tells you exactly what to implement.
When your implementation is correct, you’ll see:
✨ Pixi task (s01): python tests/test.step_01.py
Running tests for Step 01: Create Model Configuration...
Results:
✅ dataclass is correctly imported from dataclasses
✅ GPT2Config has the @dataclass decorator
✅ vocab_size is correct
# ...
The test output tells you exactly what needs to be fixed, making it easy to iterate until your implementation is correct. Once all checks pass, you’re ready to move on to the next step.
Prerequisites
This tutorial assumes:
- Basic Python knowledge: Classes, functions, type hints
- Familiarity with neural networks: What embeddings and layers do (we’ll explain the specifics)
- Interest in understanding: Curiosity matters more than prior transformer experience
Whether you’re exploring MAX for the first time or deepening your understanding of model architecture, this tutorial provides hands-on experience you can apply to current projects and learning priorities.
Ready to build? Let’s get started with Step 01: Model configuration.
Step 01: Model configuration
Learn to define the GPT-2 model architecture parameters using configuration classes.
Defining the model architecture
Before you can implement GPT-2, you need to define its architecture: the dimensions, layer counts, and structural parameters that determine how the model processes information.
In this step, you’ll create GPT2Config, a class that holds all the architectural decisions for GPT-2. This class describes things like: embedding dimensions, number of transformer layers, and number of attention heads. These parameters define the shape and capacity of your model.
OpenAI trained the original GPT-2 model with specific parameters that you can see in the config.json file on Hugging Face. By using the exact same values, we can access OpenAI’s pretrained weights in subsequent steps.
Understanding the parameters
Looking at the config.json file file, we can see some key information about the model. Each parameter controls a different aspect of the model’s architecture:
vocab_size: Size of the token vocabulary (default: 50,257). This seemingly odd number is actually 50,000 Byte Pair Encoding (BPE) tokens + 256 byte-level tokens (fallback for rare characters) + 1 special token.n_positions: Maximum sequence length, also called the context window (default: 1,024). Longer sequences require quadratic memory in attention.n_embd: Embedding dimension, or the size of the hidden states that flow through the model (default: 768). This determines the model’s capacity to represent information.n_layer: Number of transformer blocks stacked vertically (default: 12). More layers allow the model to learn more complex patterns.n_head: Number of attention heads per layer (default: 12). Multiple heads let the model attend to different types of patterns simultaneously.n_inner: Dimension of the MLP intermediate layer (default: 3,072). This is 4x the embedding dimension, a ratio found empirically in the Attention is all you need paper to work well.layer_norm_epsilon: Small constant for numerical stability in layer normalization (default:1e-5). This prevents division by zero when variance is very small.
These values define the small GPT-2 model. OpenAI released four sizes (small, medium, large, XL), each with different configurations that scale up these parameters. For implementation purposes we will use these parameters.
Implementing the configuration
Now let’s implement this yourself. You’ll create the GPT2Config class using Python’s @dataclass decorator. Dataclasses reduce boilerplate.
Instead of writing __init__ and defining each parameter manually, you just declare the fields with type hints and default values.
First, you’ll need to import the dataclass decorator from the dataclasses module. Then you’ll add the @dataclass decorator to the GPT2Config class definition.
The actual parameter values come from Hugging Face. You can get them in two ways:
- Option 1: Run
pixi run huggingfaceto access these parameters programmatically from the Hugging Facetransformerslibrary. - Option 2: Read the values directly from the GPT-2 model card.
Once you have the values, replace each None in the GPT2Config class properties with the correct numbers from the configuration.
Implementation (step_01.py):
"""
Step 01: Model Configuration
Implement the GPT-2 configuration dataclass that stores model hyperparameters.
Tasks:
1. Import dataclass from the dataclasses module
2. Add the @dataclass decorator to the GPT2Config class
3. Fill in the configuration values from HuggingFace GPT-2 model
Run: pixi run s01
"""
# 1. Import dataclass from the dataclasses module
# 2. Add the Python @dataclass decorator to the GPT2Config class
class GPT2Config:
"""GPT-2 configuration matching HuggingFace.
Attributes:
vocab_size: Size of the vocabulary.
n_positions: Maximum sequence length.
n_embd: Embedding dimension.
n_layer: Number of transformer layers.
n_head: Number of attention heads.
n_inner: Inner dimension of feed-forward network (defaults to 4 * n_embd if None).
layer_norm_epsilon: Epsilon for layer normalization.
"""
# 3a. Run `pixi run huggingface` to access the model parameters from the Hugging Face `transformers` library
# 3b. Alternately, read the values from GPT-2 model card: https://huggingface.co/openai-community/gpt2/blob/main/config.json
# 4. Replace the None of the GPT2Config properties with the correct values
vocab_size: int = None
n_positions: int = None
n_embd: int = None
n_layer: int = None
n_head: int = None
n_inner: int = None # Equal to 4 * n_embd
layer_norm_epsilon: float = None
Validation
Run pixi run s01 to verify your implementation matches the expected configuration.
Show solution
"""
Solution for Step 01: Model Configuration
This module implements the GPT-2 configuration dataclass that stores
hyperparameters matching HuggingFace's GPT-2 model structure.
"""
from dataclasses import dataclass
@dataclass
class GPT2Config:
"""GPT-2 configuration matching HuggingFace.
Attributes:
vocab_size: Size of the vocabulary.
n_positions: Maximum sequence length.
n_embd: Embedding dimension.
n_layer: Number of transformer layers.
n_head: Number of attention heads.
n_inner: Inner dimension of feed-forward network (defaults to 4 * n_embd if None).
layer_norm_epsilon: Epsilon for layer normalization.
"""
vocab_size: int = 50257
n_positions: int = 1024
n_embd: int = 768
n_layer: int = 12
n_head: int = 12
n_inner: int = 3072
layer_norm_epsilon: float = 1e-5
Next: In Step 02, you’ll implement causal masking to prevent tokens from attending to future positions in autoregressive generation.
Step 02: Causal masking
Learn to create attention masks to prevent the model from seeing future tokens during autoregressive generation.
Implementing causal masking
In this step you’ll implement the causal_mask() function. This creates a mask matrix that prevents the model from seeing future tokens when predicting the next token. The mask sets attention scores to negative infinity (-inf) for future positions. After softmax, these -inf values become zero probability, blocking information flow from later tokens.
GPT-2 generates text one token at a time, left-to-right. During training, causal masking prevents the model from “cheating” by looking ahead at tokens it should be predicting. Without this mask, the model would have access to information it won’t have during actual text generation.
Understanding the mask pattern
The mask creates a lower triangular pattern where each token can only attend to itself and previous tokens:
- Position 0 attends to: position 0 only
- Position 1 attends to: positions 0-1
- Position 2 attends to: positions 0-2
- And so on…
The mask shape is (sequence_length, sequence_length + num_tokens). This shape is designed for KV cache compatibility during generation. The KV cache stores key and value tensors from previously generated tokens, so you only need to compute attention for new tokens while attending to both new tokens (sequence_length) and cached tokens (num_tokens). This significantly speeds up generation by avoiding recomputation.
You’ll use the following MAX operations to complete this task:
Functional decorator:
@F.functional: Converts functions to graph operations for MAX compilation
Tensor operations:
Tensor.constant(): Creates a scalar constant tensorF.broadcast_to(): Expands tensor dimensions to target shapeF.band_part(): Extracts band matrix (keeps diagonal band, zeros out rest)
Implementing the mask
You’ll create the causal mask in several steps:
-
Import required modules:
Devicefrommax.driver- specifies hardware device (CPU/GPU)DTypefrommax.dtype- data type specificationfunctionalasFfrommax.experimental- functional operations libraryTensorfrommax.experimental.tensor- tensor operationsDimfromgraph.dim- dimension handling
-
Add @F.functional decorator: This converts the function to a MAX graph operation.
-
Calculate total sequence length: Combine
sequence_lengthandnum_tokensusingDim()to determine mask width. -
Create constant tensor: Use
Tensor.constant(float("-inf"), dtype=dtype, device=device)to create a scalar that will be broadcast. -
Broadcast to target shape: Use
F.broadcast_to(mask, shape=(sequence_length, n))to expand the scalar to a 2D matrix. -
Apply band part: Use
F.band_part(mask, num_lower=None, num_upper=0, exclude=True)to create the lower triangular pattern. This keeps 0s on and below the diagonal,-infabove.
Implementation (step_02.py):
"""
Step 02: Causal Masking
Implement causal attention masking that prevents tokens from attending to future positions.
Tasks:
1. Import functional module (as F) and Tensor from max.experimental
2. Add @F.functional decorator to the causal_mask function
3. Create a constant tensor filled with negative infinity
4. Broadcast the mask to the correct shape (sequence_length, n)
5. Apply band_part to create the lower triangular causal structure
Run: pixi run s02
"""
# 1: Import the required modules from MAX
from max.driver import Device
from max.dtype import DType
# TODO: Import necessary funcional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional
# TODO: Import Tensor object from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor
from max.graph import Dim, DimLike
# 2: Add the @F.functional decorator to make this a MAX functional operation
# TODO: Add the decorator here
def causal_mask(
sequence_length: DimLike,
num_tokens: DimLike,
*,
dtype: DType,
device: Device,
):
"""Create a causal mask for autoregressive attention.
Args:
sequence_length: Length of the sequence.
num_tokens: Number of tokens.
dtype: Data type for the mask.
device: Device to create the mask on.
Returns:
A causal mask tensor.
"""
# Calculate total sequence length
n = Dim(sequence_length) + num_tokens
# 3: Create a constant tensor filled with negative infinity
# TODO: Use Tensor.constant() with float("-inf"), dtype, and device parameters
# https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.constant
# Hint: This creates the base mask value that will block attention to future tokens
mask = None
# 4: Broadcast the mask to the correct shape
# TODO: Use F.broadcast_to() to expand mask to shape (sequence_length, n)
# https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.broadcast_to
# Hint: This creates a 2D attention mask matrix
mask = None
# 5: Apply band_part to create the causal (lower triangular) structure and return the mask
# TODO: Use F.band_part() with num_lower=None, num_upper=0, exclude=True
# https://docs.modular.com/max/api/python/experimental/functional/#max.experimental.functional.band_part
# Hint: This keeps only the lower triangle, allowing attention to past tokens only
return None
Validation
Run pixi run s02 to verify your implementation.
Show solution
"""
Solution for Step 02: Causal Masking
This module implements causal attention masking that prevents tokens from
attending to future positions in autoregressive generation.
"""
from max.driver import Device
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import Dim, DimLike
@F.functional
def causal_mask(
sequence_length: DimLike,
num_tokens: DimLike,
*,
dtype: DType,
device: Device,
):
"""Create a causal mask for autoregressive attention.
Args:
sequence_length: Length of the sequence.
num_tokens: Number of tokens.
dtype: Data type for the mask.
device: Device to create the mask on.
Returns:
A causal mask tensor.
"""
n = Dim(sequence_length) + num_tokens
mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
mask = F.broadcast_to(mask, shape=(sequence_length, n))
return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)
Next: In Step 03, you’ll implement layer normalization to stabilize activations for effective training.
Step 03: Layer normalization
Learn to implement layer normalization for stabilizing neural network training.
Building layer normalization
In this step, you’ll create the LayerNorm class that normalizes activations across the feature dimension. For each input, you compute the mean and variance across all features, normalize by subtracting the mean and dividing by the standard deviation, then apply learned weight and bias parameters to scale and shift the result.
Unlike batch normalization, layer normalization works independently for each example. This makes it ideal for transformers - no dependence on batch size, no tracking running statistics during inference, and consistent behavior between training and generation.
GPT-2 applies layer normalization before the attention and MLP blocks in each of its 12 transformer layers. This pre-normalization pattern stabilizes training in deep networks by keeping activations in a consistent range.
Understanding the operation
Layer normalization normalizes across the feature dimension (the last dimension) independently for each example. It learns two parameters per feature: weight (gamma) for scaling and bias (beta) for shifting.
The normalization follows this formula:
output = weight * (x - mean) / sqrt(variance + epsilon) + bias
The mean and variance are computed across all features in each example. After normalizing to zero mean and unit variance, the learned weight scales the result and the learned bias shifts it. The epsilon value (typically 1e-5) prevents division by zero when variance is very small.
You’ll use the following MAX operations to complete this task:
Modules:
Module: The Module class used for eager tensors
Tensor initialization:
Tensor.ones(): Creates tensor filled with 1.0 valuesTensor.zeros(): Creates tensor filled with 0.0 values
Layer normalization:
F.layer_norm(): Applies layer normalization with parameters:input,gamma(weight),beta(bias), andepsilon
Implementing layer normalization
You’ll create the LayerNorm class that wraps MAX’s layer normalization function with learnable parameters. The implementation is straightforward - two parameters and a single function call.
First, import the required modules. You’ll need functional as F for the layer norm operation and Tensor for creating parameters.
In the __init__ method, create two learnable parameters:
- Weight:
Tensor.ones([dim])stored asself.weight- initialized to ones so the initial transformation is identity - Bias:
Tensor.zeros([dim])stored asself.bias- initialized to zeros so there’s no initial shift
Store the epsilon value as self.eps for numerical stability.
In the forward method, apply layer normalization with F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). This computes the normalization and applies the learned parameters in one operation.
Implementation (step_03.py):
"""
Step 03: Layer Normalization
Implement layer normalization that normalizes activations for training stability.
Tasks:
1. Import functional module (as F) and Tensor from max.experimental
2. Initialize learnable weight (gamma) and bias (beta) parameters
3. Apply layer normalization using F.layer_norm in the forward pass
Run: pixi run s03
"""
# 1: Import the required modules from MAX
# TODO: Import functional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional
# TODO: Import Tensor from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor
from max.graph import DimLike
from max.nn.module_v3 import Module
class LayerNorm(Module):
"""Layer normalization module.
Args:
dim: Dimension to normalize over.
eps: Epsilon for numerical stability.
"""
def __init__(self, dim: DimLike, *, eps: float = 1e-5):
super().__init__()
self.eps = eps
# 2: Initialize learnable weight and bias parameters
# TODO: Create self.weight as a Tensor of ones with shape [dim]
# https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.ones
# Hint: This is the gamma parameter in layer normalization
self.weight = None
# TODO: Create self.bias as a Tensor of zeros with shape [dim]
# https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.zeros
# Hint: This is the beta parameter in layer normalization
self.bias = None
def __call__(self, x: Tensor) -> Tensor:
"""Apply layer normalization.
Args:
x: Input tensor.
Returns:
Normalized tensor.
"""
# 3: Apply layer normalization and return the result
# TODO: Use F.layer_norm() with x, gamma=self.weight, beta=self.bias, epsilon=self.eps
# https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.layer_norm
# Hint: Layer normalization normalizes across the last dimension
return None
Validation
Run pixi run s03 to verify your implementation.
Show solution
"""
Solution for Step 03: Layer Normalization
This module implements layer normalization that normalizes activations
across the embedding dimension for training stability.
"""
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import DimLike
from max.nn.module_v3 import Module
class LayerNorm(Module):
"""Layer normalization module.
Args:
dim: Dimension to normalize over.
eps: Epsilon for numerical stability.
"""
def __init__(self, dim: DimLike, *, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = Tensor.ones([dim])
self.bias = Tensor.zeros([dim])
def __call__(self, x: Tensor) -> Tensor:
"""Apply layer normalization.
Args:
x: Input tensor.
Returns:
Normalized tensor.
"""
return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)
Next: In Step 04, you’ll implement the feed-forward network (MLP) with GELU activation used in each transformer block.
Step 04: Feed-forward network
Learn to build the feed-forward network (MLP) that processes information after attention in each transformer block.
Building the MLP
In this step, you’ll create the GPT2MLP class: a two-layer feed-forward
network that appears after attention in every transformer block. The MLP expands
the embedding dimension by 4× (768 → 3,072), applies GELU activation for
non-linearity, then projects back to the original dimension.
While attention lets tokens communicate with each other, the MLP processes each position independently. Attention aggregates information through weighted sums (linear operations), but the MLP adds non-linearity through GELU activation. This combination allows the model to learn complex patterns beyond what linear transformations alone can capture.
GPT-2 uses a 4× expansion ratio (768 to 3,072 dimensions) because this was found to work well in the original Transformer paper and has been validated across many architectures since.
Understanding the components
The MLP has three steps applied in sequence:
Expansion layer (c_fc): Projects from 768 to 3,072 dimensions using a linear layer. This expansion gives the network more capacity to process information.
GELU activation: Applies Gaussian Error Linear Unit, a smooth non-linear
function. GPT-2 uses approximate="tanh" for the tanh-based approximation
instead of the exact computation. This approximation was faster when GPT-2 was
implemented, but while exact GELU is fast enough now, we use the approximation
to match the original weights.
Projection layer (c_proj): Projects back from 3,072 to 768 dimensions
using another linear layer. This returns to the embedding dimension so outputs
can be added to residual connections.
The layer names c_fc (fully connected) and c_proj (projection) match Hugging
Face’s GPT-2 checkpoint structure. This naming is essential for loading
pretrained weights.
You’ll use the following MAX operations to complete this task:
Linear layers:
Linear(in_features, out_features, bias=True): Applies linear transformationy = xW^T + b
GELU activation:
F.gelu(input, approximate="tanh"): Applies GELU activation with tanh approximation for faster computation
Implementing the MLP
You’ll create the GPT2MLP class that chains two linear layers with GELU
activation between them. The implementation is straightforward - three
operations applied in sequence.
First, import the required modules. You’ll need functional as F for the GELU
activation, Tensor for type hints, Linear for the layers, and Module as
the base class.
In the __init__ method, create two linear layers:
- Expansion layer:
Linear(embed_dim, intermediate_size, bias=True)stored asself.c_fc - Projection layer:
Linear(intermediate_size, embed_dim, bias=True)stored asself.c_proj
Both layers include bias terms (bias=True). The intermediate size is typically
4× the embedding dimension.
In the forward method, apply the three transformations:
- Expand:
hidden_states = self.c_fc(hidden_states) - Activate:
hidden_states = F.gelu(hidden_states, approximate="tanh") - Project:
hidden_states = self.c_proj(hidden_states)
Return the final hidden_states. The input and output shapes are the same:
[batch, seq_length, embed_dim].
Implementation (step_04.py):
"""
Step 04: Feed-forward Network (MLP)
Implement the MLP used in each transformer block with GELU activation.
Tasks:
1. Import functional (as F), Tensor, Linear, and Module from MAX
2. Create c_fc linear layer (embedding to intermediate dimension)
3. Create c_proj linear layer (intermediate back to embedding dimension)
4. Apply c_fc transformation in forward pass
5. Apply GELU activation function
6. Apply c_proj transformation and return result
Run: pixi run s04
"""
# 1: Import the required modules from MAX
# TODO: Import functional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional
# TODO: Import Tensor from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor
# TODO: Import Linear and Module from max.nn.module_v3
# https://docs.modular.com/max/api/python/nn/module_v3
from solutions.solution_01 import GPT2Config
class GPT2MLP(Module):
"""Feed-forward network matching HuggingFace GPT-2 structure.
Args:
intermediate_size: Size of the intermediate layer.
config: GPT-2 configuration.
"""
def __init__(self, intermediate_size: int, config: GPT2Config):
super().__init__()
embed_dim = config.n_embd
# 2: Create the first linear layer (embedding to intermediate)
# TODO: Create self.c_fc as a Linear layer from embed_dim to intermediate_size with bias=True
# https://docs.modular.com/max/api/python/nn/module_v3#max.nn.module_v3.Linear
# Hint: This is the expansion layer in the MLP
self.c_fc = None
# 3: Create the second linear layer (intermediate back to embedding)
# TODO: Create self.c_proj as a Linear layer from intermediate_size to embed_dim with bias=True
# https://docs.modular.com/max/api/python/nn/module_v3#max.nn.module_v3.Linear
# Hint: This is the projection layer that brings us back to the embedding dimension
self.c_proj = None
def __call__(self, hidden_states: Tensor) -> Tensor:
"""Apply feed-forward network.
Args:
hidden_states: Input hidden states.
Returns:
MLP output.
"""
# 4: Apply the first linear transformation
# TODO: Apply self.c_fc to hidden_states
# Hint: This expands the hidden dimension to the intermediate size
hidden_states = None
# 5: Apply GELU activation function
# TODO: Use F.gelu() with hidden_states and approximate="tanh"
# https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.gelu
# Hint: GELU is the non-linear activation used in GPT-2's MLP
hidden_states = None
# 6: Apply the second linear transformation and return
# TODO: Apply self.c_proj to hidden_states and return the result
# Hint: This projects back to the embedding dimension
return None
Validation
Run pixi run s04 to verify your implementation.
Show solution
"""
Solution for Step 04: Feed-forward Network (MLP)
This module implements the feed-forward network (MLP) used in each
transformer block with GELU activation.
"""
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Linear, Module
from solutions.solution_01 import GPT2Config
class GPT2MLP(Module):
"""Feed-forward network matching HuggingFace GPT-2 structure.
Args:
intermediate_size: Size of the intermediate layer.
config: GPT-2 configuration.
"""
def __init__(self, intermediate_size: int, config: GPT2Config):
super().__init__()
embed_dim = config.n_embd
self.c_fc = Linear(embed_dim, intermediate_size, bias=True)
self.c_proj = Linear(intermediate_size, embed_dim, bias=True)
def __call__(self, hidden_states: Tensor) -> Tensor:
"""Apply feed-forward network.
Args:
hidden_states: Input hidden states.
Returns:
MLP output.
"""
hidden_states = self.c_fc(hidden_states)
hidden_states = F.gelu(hidden_states, approximate="tanh")
hidden_states = self.c_proj(hidden_states)
return hidden_states
Next: In Step 05, you’ll implement token embeddings to convert discrete token IDs into continuous vector representations.
Step 05: Token embeddings
Learn to create token embeddings that convert discrete token IDs into continuous vector representations.
Implementing token embeddings
In this step you’ll create the Embedding class. This converts discrete token
IDs (integers) into continuous vector representations that the model can
process. The embedding layer is a lookup table with shape [50257, 768] where
50257 is GPT-2’s vocabulary size and 768 is the embedding dimension.
Neural networks operate on continuous values, not discrete symbols. Token embeddings convert discrete token IDs into dense vectors that can be processed by matrix operations. During training, these embeddings naturally cluster semantically similar words closer together in vector space.
Understanding embeddings
The embedding layer stores one vector per vocabulary token. When you pass in
token ID 1000, it returns row 1000 as the embedding vector. The layer name wte
stands for “word token embeddings” and matches the naming in the original GPT-2
code for weight loading compatibility.
Key parameters:
- Vocabulary size: 50,257 tokens (byte-pair encoding)
- Embedding dimension: 768 for GPT-2 base
- Shape: [vocab_size, embedding_dim]
You’ll use the following MAX operations to complete this task:
Embedding layer:
Embedding(num_embeddings, dim): Creates embedding lookup table with automatic weight initialization
Implementing the class
You’ll implement the Embedding class in several steps:
-
Import required modules: Import
EmbeddingandModulefrom MAX libraries. -
Create embedding layer: Use
Embedding(config.vocab_size, dim=config.n_embd)and store inself.wte. -
Implement forward pass: Call
self.wte(input_ids)to lookup embeddings. Input shape: [batch_size, seq_length]. Output shape: [batch_size, seq_length, n_embd].
Implementation (step_05.py):
"""
Step 05: Token Embeddings
Implement token embeddings that convert discrete token IDs into continuous vectors.
Tasks:
1. Import Embedding and Module from max.nn.module_v3
2. Create token embedding layer using Embedding(vocab_size, dim=n_embd)
3. Implement forward pass that looks up embeddings for input token IDs
Run: pixi run s05
"""
# TODO: Import required modules from MAX
# Hint: You'll need Embedding and Module from max.nn.module_v3
from solutions.solution_01 import GPT2Config
class GPT2Embeddings(Module):
"""Token embeddings for GPT-2."""
def __init__(self, config: GPT2Config):
super().__init__()
# TODO: Create token embedding layer
# Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
# This creates a lookup table that converts token IDs to embedding vectors
self.wte = None
def __call__(self, input_ids):
"""Convert token IDs to embeddings.
Args:
input_ids: Tensor of token IDs, shape [batch_size, seq_length]
Returns:
Token embeddings, shape [batch_size, seq_length, n_embd]
"""
# TODO: Return the embedded tokens
# Hint: Simply call self.wte with input_ids
return None
Validation
Run pixi run s05 to verify your implementation.
Show solution
"""
Solution for Step 05: Token Embeddings
This module implements token embeddings that convert discrete token IDs
into continuous vector representations.
"""
from max.nn.module_v3 import Embedding, Module
from solutions.solution_01 import GPT2Config
class GPT2Embeddings(Module):
"""Token embeddings for GPT-2, matching HuggingFace structure."""
def __init__(self, config: GPT2Config):
"""Initialize token embedding layer.
Args:
config: GPT2Config containing vocab_size and n_embd
"""
super().__init__()
# Token embedding: lookup table from vocab_size to embedding dimension
# This converts discrete token IDs (0 to vocab_size-1) into dense vectors
self.wte = Embedding(config.vocab_size, dim=config.n_embd)
def __call__(self, input_ids):
"""Convert token IDs to embeddings.
Args:
input_ids: Tensor of token IDs, shape [batch_size, seq_length]
Returns:
Token embeddings, shape [batch_size, seq_length, n_embd]
"""
# Simple lookup: each token ID becomes its corresponding embedding vector
return self.wte(input_ids)
Next: In Step 06, you’ll implement position embeddings to encode sequence order information, which will be combined with these token embeddings.
Step 06: Position embeddings
Learn to create position embeddings that encode the order of tokens in a sequence.
Implementing position embeddings
In this step you’ll create position embeddings to encode where each token appears in the sequence. While token embeddings tell the model “what” each token is, position embeddings tell it “where” the token is located. These position vectors are added to token embeddings before entering the transformer blocks.
Transformers process all positions in parallel through attention, unlike Recurrent Neural Networks (RNNs) that process sequentially. This parallelism enables faster training but loses positional information. Position embeddings restore this information so the model can distinguish “dog bites man” from “man bites dog”.
Understanding position embeddings
Position embeddings work like token embeddings: a lookup table with shape [1024, 768] where 1024 is the maximum sequence length. Position 0 gets the first row, position 1 gets the second row, and so on.
GPT-2 uses learned position embeddings, meaning these vectors are initialized randomly and trained alongside the model. This differs from the original Transformer which used fixed sinusoidal position encodings. Learned embeddings let the model discover optimal position representations for its specific task, though they cannot generalize beyond the maximum length seen during training (1024 tokens).
Key parameters:
- Maximum sequence length: 1,024 positions
- Embedding dimension: 768 for GPT-2 base
- Shape: [n_positions, n_embd]
- Layer name:
wpe(word position embeddings)
You’ll use the following MAX operations to complete this task:
Position indices:
Tensor.arange(seq_length, dtype, device): Creates sequence positions [0, 1, 2, …, seq_length-1]
Embedding layer:
Embedding(num_embeddings, dim): Same class as token embeddings, but for positions
Implementing the class
You’ll implement the position embeddings in several steps:
-
Import required modules: Import
Tensor,Embedding, andModulefrom MAX libraries. -
Create position embedding layer: Use
Embedding(config.n_positions, dim=config.n_embd)and store inself.wpe. -
Implement forward pass: Call
self.wpe(position_ids)to lookup position embeddings. Input shape: [seq_length] or [batch, seq_length]. Output shape: [seq_length, n_embd] or [batch, seq_length, n_embd].
Implementation (step_06.py):
"""
Step 06: Position Embeddings
Implement position embeddings that encode sequence order information.
Tasks:
1. Import Tensor from max.experimental.tensor
2. Import Embedding and Module from max.nn.module_v3
3. Create position embedding layer using Embedding(n_positions, dim=n_embd)
4. Implement forward pass that looks up embeddings for position indices
Run: pixi run s06
"""
# TODO: Import required modules from MAX
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need Embedding and Module from max.nn.module_v3
from solutions.solution_01 import GPT2Config
class GPT2PositionEmbeddings(Module):
"""Position embeddings for GPT-2."""
def __init__(self, config: GPT2Config):
super().__init__()
# TODO: Create position embedding layer
# Hint: Use Embedding(config.n_positions, dim=config.n_embd)
# This creates a lookup table for position indices (0, 1, 2, ..., n_positions-1)
self.wpe = None
def __call__(self, position_ids):
"""Convert position indices to embeddings.
Args:
position_ids: Tensor of position indices, shape [seq_length] or [batch_size, seq_length]
Returns:
Position embeddings, shape matching input with added embedding dimension
"""
# TODO: Return the position embeddings
# Hint: Simply call self.wpe with position_ids
return None
Validation
Run pixi run s06 to verify your implementation.
Show solution
"""
Solution for Step 06: Position Embeddings
This module implements position embeddings that encode sequence order information
into the transformer model.
"""
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding, Module
from solutions.solution_01 import GPT2Config
class GPT2PositionEmbeddings(Module):
"""Position embeddings for GPT-2, matching HuggingFace structure."""
def __init__(self, config: GPT2Config):
"""Initialize position embedding layer.
Args:
config: GPT2Config containing n_positions and n_embd
"""
super().__init__()
# Position embedding: lookup table from position indices to embedding vectors
# This encodes "where" information - position 0, 1, 2, etc.
self.wpe = Embedding(config.n_positions, dim=config.n_embd)
def __call__(self, position_ids):
"""Convert position indices to embeddings.
Args:
position_ids: Tensor of position indices, shape [seq_length] or [batch_size, seq_length]
Returns:
Position embeddings, shape matching input with added embedding dimension
"""
# Simple lookup: each position index becomes its corresponding embedding vector
return self.wpe(position_ids)
Next: In Step 07, you’ll implement multi-head attention.
Step 07: Multi-head attention
Learn to use multi-head attention, enabling the model to attend to different representation subspaces.
Building multi-head attention
In this step, you’ll implement the GPT2MultiHeadAttention class that runs 12
attention operations in parallel. Instead of computing attention once over the
full 768-dimensional space, you split the dimensions into 12 heads of 64
dimensions each. Each head learns to focus on different patterns.
GPT-2 uses 12 heads with 768-dimensional embeddings, giving each head 768 ÷ 12 = 64 dimensions. The Q, K, V tensors are reshaped to split the embedding dimension across heads, attention is computed for all heads in parallel, then the outputs are concatenated back together. This happens in a single efficient operation using tensor reshaping and broadcasting.
Multiple heads let the model learn complementary attention strategies. Different heads can specialize in different relationships, such as one that might attend to adjacent tokens, another to syntactic patterns, and another to semantic similarity. This increases the model’s capacity without dramatically increasing computation.
Understanding the architecture
Multi-head attention splits the embedding dimension, computes attention independently for each head, then merges the results. This requires careful tensor reshaping to organize the computation efficiently.
Head splitting: Transform from [batch, seq_length, 768] to
[batch, 12, seq_length, 64]. First reshape to add the head dimension:
[batch, seq_length, 12, 64]. Then transpose to move heads before sequence:
[batch, 12, seq_length, 64]. Now each of the 12 heads operates independently
on its 64-dimensional subspace.
Parallel attention: With shape [batch, num_heads, seq_length, head_dim],
you can compute attention for all heads simultaneously. The matrix
multiplication Q @ K^T operates on the last two dimensions
[seq_length, head_dim] @ [head_dim, seq_length], broadcasting across the batch
and head dimensions. All 12 heads computed in a single efficient operation.
Head merging: Reverse the splitting to go from
[batch, 12, seq_length, 64] back to [batch, seq_length, 768]. First
transpose to [batch, seq_length, 12, 64], then reshape to flatten the head
dimension: [batch, seq_length, 768]. This concatenates all head outputs back
into the original dimension.
Output projection (c_proj): After merging heads, apply a learned linear
transformation that maps [batch, seq_length, 768] to
[batch, seq_length, 768]. This lets the model mix information across heads,
combining the different perspectives each head learned.
The layer names c_attn (combined Q/K/V projection) and c_proj (output
projection) match Hugging Face’s GPT-2 implementation. This naming is essential
for loading pretrained weights.
You’ll use the following MAX operations to complete this task:
Linear layers:
Linear(in_features, out_features, bias=True): Q/K/V and output projections
Tensor operations:
tensor.reshape(new_shape): Splits or merges head dimensiontensor.transpose(axis1, axis2): Rearranges dimensions for parallel attentionF.split(tensor, split_sizes, axis): Divides Q/K/V from combined projection
Implementing multi-head attention
You’ll create the GPT2MultiHeadAttention class with helper methods for splitting and merging heads. The implementation builds on the attention mechanism from Step 02, extending it to work with multiple heads in parallel.
First, import the required modules. You’ll need math for scaling, functional as F for operations, Tensor for type hints, device and dtype utilities, and Linear and Module from MAX’s neural network module. You’ll also reuse the causal_mask function from Step 02.
In the __init__ method, create the projection layers and store configuration:
- Combined Q/K/V projection:
Linear(embed_dim, 3 * embed_dim, bias=True)stored asself.c_attn - Output projection:
Linear(embed_dim, embed_dim, bias=True)stored asself.c_proj - Store
self.num_heads(12) andself.head_dim(64) from config - Calculate
self.split_sizefor splitting Q, K, V later
Implement _split_heads to reshape for parallel attention:
- Calculate new shape by replacing the last dimension:
tensor.shape[:-1] + [num_heads, attn_head_size] - Reshape to add the head dimension:
tensor.reshape(new_shape) - Transpose to move heads to position 1:
tensor.transpose(-3, -2) - Returns shape
[batch, num_heads, seq_length, head_size]
Implement _merge_heads to concatenate head outputs:
- Transpose to move heads back:
tensor.transpose(-3, -2) - Calculate flattened shape:
tensor.shape[:-2] + [num_heads * attn_head_size] - Reshape to merge heads:
tensor.reshape(new_shape) - Returns shape
[batch, seq_length, n_embd]
Implement _attn to compute scaled dot-product attention for all heads:
- Compute attention scores:
query @ key.transpose(-2, -1) - Scale by square root of head dimension
- Apply causal mask to prevent attending to future positions
- Apply softmax to get attention weights
- Multiply weights by values:
attn_weights @ value
In the forward method, orchestrate the complete multi-head attention:
- Project to Q/K/V:
qkv = self.c_attn(hidden_states) - Split into separate tensors:
F.split(qkv, [self.split_size, self.split_size, self.split_size], axis=-1) - Split heads for each:
query = self._split_heads(query, self.num_heads, self.head_dim)(repeat for key, value) - Compute attention:
attn_output = self._attn(query, key, value) - Merge heads:
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - Final projection:
return self.c_proj(attn_output)
Implementation (step_07.py):
"""
Step 07: Multi-head Attention
Implement multi-head attention that splits Q/K/V into multiple heads,
computes attention in parallel for each head, and merges the results.
Tasks:
1. Import required modules (math, F, Tensor, Linear, Module, etc.)
2. Create c_attn and c_proj linear layers
3. Implement _split_heads: reshape and transpose to add head dimension
4. Implement _merge_heads: transpose and reshape to remove head dimension
5. Implement _attn: compute attention for all heads in parallel
6. Implement forward pass: project -> split -> attend -> merge -> project
Run: pixi run s07
"""
# TODO: Import required modules
# Hint: You'll need math for scaling
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor, Device, DType from max.experimental.tensor and max.driver
# Hint: You'll need Dim, DimLike from max.graph
# Hint: You'll also need Linear and Module from max.nn.module_v3
from solutions.solution_01 import GPT2Config
# TODO: Copy causal_mask function from solution_02.py
# This is the same function you implemented in Step 02
class GPT2MultiHeadAttention(Module):
"""Multi-head attention for GPT-2."""
def __init__(self, config: GPT2Config):
super().__init__()
self.embed_dim = config.n_embd
self.num_heads = config.n_head
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
# TODO: Create combined Q/K/V projection
# Hint: Use Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
self.c_attn = None
# TODO: Create output projection
# Hint: Use Linear(self.embed_dim, self.embed_dim, bias=True)
self.c_proj = None
def _split_heads(self, tensor, num_heads, attn_head_size):
"""Split the last dimension into (num_heads, head_size).
Args:
tensor: Input tensor, shape [batch, seq_length, n_embd]
num_heads: Number of attention heads
attn_head_size: Dimension of each head
Returns:
Tensor with shape [batch, num_heads, seq_length, head_size]
"""
# TODO: Add head dimension
# Hint: new_shape = tensor.shape[:-1] + [num_heads, attn_head_size]
# Hint: tensor = tensor.reshape(new_shape)
pass
# TODO: Move heads dimension to position 1
# Hint: return tensor.transpose(-3, -2)
return None
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""Merge attention heads back to original shape.
Args:
tensor: Input tensor, shape [batch, num_heads, seq_length, head_size]
num_heads: Number of attention heads
attn_head_size: Dimension of each head
Returns:
Tensor with shape [batch, seq_length, n_embd]
"""
# TODO: Move heads dimension back
# Hint: tensor = tensor.transpose(-3, -2)
pass
# TODO: Flatten head dimensions
# Hint: new_shape = tensor.shape[:-2] + [num_heads * attn_head_size]
# Hint: return tensor.reshape(new_shape)
return None
def _attn(self, query, key, value):
"""Compute attention for all heads in parallel.
Args:
query: Query tensor, shape [batch, num_heads, seq_length, head_size]
key: Key tensor, shape [batch, num_heads, seq_length, head_size]
value: Value tensor, shape [batch, num_heads, seq_length, head_size]
Returns:
Attention output, shape [batch, num_heads, seq_length, head_size]
"""
# TODO: Implement attention computation
# The same 5-step process: scores, scale, mask, softmax, weighted sum
# Hint: Compute attention scores: query @ key.transpose(-1, -2)
# Hint: Scale by sqrt(head_dim): attn_weights / math.sqrt(head_dim)
# Hint: Apply causal mask using causal_mask function
# Hint: Apply softmax: F.softmax(attn_weights)
# Hint: Weighted sum: attn_weights @ value
return None
def __call__(self, hidden_states):
"""Apply multi-head attention.
Args:
hidden_states: Input tensor, shape [batch, seq_length, n_embd]
Returns:
Attention output, shape [batch, seq_length, n_embd]
"""
# TODO: Project to Q, K, V
# Hint: qkv = self.c_attn(hidden_states)
# Hint: query, key, value = F.split(qkv, [self.split_size, self.split_size, self.split_size], axis=-1)
pass
# TODO: Split into multiple heads
# Hint: query = self._split_heads(query, self.num_heads, self.head_dim)
# Hint: key = self._split_heads(key, self.num_heads, self.head_dim)
# Hint: value = self._split_heads(value, self.num_heads, self.head_dim)
pass
# TODO: Apply attention
# Hint: attn_output = self._attn(query, key, value)
pass
# TODO: Merge heads back
# Hint: attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
pass
# TODO: Output projection
# Hint: attn_output = self.c_proj(attn_output)
# Hint: return attn_output
return None
Validation
Run pixi run s07 to verify your implementation.
Show solution
"""
Solution for Step 07: Multi-head Attention
This module implements multi-head attention, which allows the model to jointly
attend to information from different representation subspaces at different positions.
"""
import math
from max.driver import Device
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import Dim, DimLike
from max.nn.module_v3 import Linear, Module
from solutions.solution_01 import GPT2Config
@F.functional
def causal_mask(
sequence_length: DimLike,
num_tokens: DimLike,
*,
dtype: DType,
device: Device,
):
"""Create a causal attention mask."""
n = Dim(sequence_length) + num_tokens
mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
mask = F.broadcast_to(mask, shape=(sequence_length, n))
return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)
class GPT2MultiHeadAttention(Module):
"""Multi-head attention for GPT-2, matching HuggingFace structure."""
def __init__(self, config: GPT2Config):
"""Initialize multi-head attention.
Args:
config: GPT2Config containing n_embd and n_head
"""
super().__init__()
self.embed_dim = config.n_embd
self.num_heads = config.n_head
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
# Combined Q/K/V projection
self.c_attn = Linear(self.embed_dim, 3 * self.embed_dim, bias=True)
# Output projection
self.c_proj = Linear(self.embed_dim, self.embed_dim, bias=True)
def _split_heads(self, tensor, num_heads, attn_head_size):
"""Split the last dimension into (num_heads, head_size).
Transforms shape from [batch, seq_length, n_embd]
to [batch, num_heads, seq_length, head_size]
Args:
tensor: Input tensor, shape [batch, seq_length, n_embd]
num_heads: Number of attention heads
attn_head_size: Dimension of each head
Returns:
Tensor with shape [batch, num_heads, seq_length, head_size]
"""
# Add head dimension: [batch, seq_length, n_embd] -> [batch, seq_length, num_heads, head_size]
new_shape = tensor.shape[:-1] + [num_heads, attn_head_size]
tensor = tensor.reshape(new_shape)
# Move heads dimension: [batch, seq_length, num_heads, head_size] -> [batch, num_heads, seq_length, head_size]
return tensor.transpose(-3, -2)
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""Merge attention heads back to original shape.
Transforms shape from [batch, num_heads, seq_length, head_size]
to [batch, seq_length, n_embd]
Args:
tensor: Input tensor, shape [batch, num_heads, seq_length, head_size]
num_heads: Number of attention heads
attn_head_size: Dimension of each head
Returns:
Tensor with shape [batch, seq_length, n_embd]
"""
# Move heads dimension back: [batch, num_heads, seq_length, head_size] -> [batch, seq_length, num_heads, head_size]
tensor = tensor.transpose(-3, -2)
# Flatten head dimensions: [batch, seq_length, num_heads, head_size] -> [batch, seq_length, n_embd]
new_shape = tensor.shape[:-2] + [num_heads * attn_head_size]
return tensor.reshape(new_shape)
def _attn(self, query, key, value):
"""Compute attention for all heads in parallel.
Args:
query: Query tensor, shape [batch, num_heads, seq_length, head_size]
key: Key tensor, shape [batch, num_heads, seq_length, head_size]
value: Value tensor, shape [batch, num_heads, seq_length, head_size]
Returns:
Attention output, shape [batch, num_heads, seq_length, head_size]
"""
# Compute attention scores
attn_weights = query @ key.transpose(-1, -2)
# Scale attention weights
attn_weights = attn_weights / math.sqrt(int(value.shape[-1]))
# Apply causal mask
seq_len = query.shape[-2]
mask = causal_mask(seq_len, 0, dtype=query.dtype, device=query.device)
attn_weights = attn_weights + mask
# Softmax and weighted sum
attn_weights = F.softmax(attn_weights)
attn_output = attn_weights @ value
return attn_output
def __call__(self, hidden_states):
"""Apply multi-head attention.
Args:
hidden_states: Input tensor, shape [batch, seq_length, n_embd]
Returns:
Attention output, shape [batch, seq_length, n_embd]
"""
# Project to Q, K, V
qkv = self.c_attn(hidden_states)
query, key, value = F.split(
qkv, [self.split_size, self.split_size, self.split_size], axis=-1
)
# Split into multiple heads
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# Apply attention
attn_output = self._attn(query, key, value)
# Merge heads back
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
# Output projection
attn_output = self.c_proj(attn_output)
return attn_output
Next: In Step 08, you’ll implement residual connections and layer normalization to enable training deep transformer networks.
Step 08: Residual connections and layer normalization
Learn to implement residual connections and layer normalization to enable training deep transformer networks.
Building the residual pattern
In this step, you’ll combine residual connections and layer normalization into a
reusable pattern for transformer blocks. Residual connections add the input
directly to the output using output = input + layer(input), creating shortcuts
that let gradients flow through deep networks. You’ll implement this alongside
the layer normalization from Step 03.
GPT-2 uses pre-norm architecture where layer norm is applied before each
sublayer (attention or MLP). The pattern is x = x + sublayer(layer_norm(x)):
normalize first, process, then add the original input back. This is more stable
than post-norm alternatives for deep networks.
Residual connections solve the vanishing gradient problem. During
backpropagation, gradients flow through the identity path (x = x + ...)
without being multiplied by layer weights. This allows training networks with
12+ layers. Without residuals, gradients would diminish exponentially as they
propagate through many layers.
Layer normalization works identically during training and inference because it normalizes each example independently. No batch statistics, no running averages, just consistent normalization that keeps activation distributions stable throughout training.
Understanding the pattern
The pre-norm residual pattern combines three operations in sequence:
Layer normalization: Normalize the input with
F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). This
uses learnable weight (gamma) and bias (beta) parameters to scale and shift the
normalized values. You already implemented this in Step 03.
Sublayer processing: Pass the normalized input through a sublayer (attention or MLP). The sublayer transforms the data while the layer norm keeps its input well-conditioned.
Residual addition: Add the original input back to the sublayer output using
simple element-wise addition: x + sublayer_output. Both tensors must have
identical shapes [batch, seq_length, embed_dim].
The complete pattern is x = x + sublayer(layer_norm(x)). This differs from
post-norm x = layer_norm(x + sublayer(x)), as pre-norm is more stable because
normalization happens before potentially unstable sublayer operations.
You’ll use the following MAX operations to complete this task:
Layer normalization:
F.layer_norm(x, gamma, beta, epsilon): Normalizes across feature dimension
Tensor initialization:
Tensor.ones([dim]): Creates weight parameterTensor.zeros([dim]): Creates bias parameter
Implementing the pattern
You’ll implement three classes that demonstrate the residual pattern:
LayerNorm for normalization, ResidualBlock that combines norm and residual
addition, and a standalone apply_residual_connection function.
First, import the required modules. You’ll need functional as F for layer
norm, Tensor for parameters, DimLike for type hints, and Module as the
base class.
LayerNorm implementation:
In __init__, create the learnable parameters:
- Weight:
Tensor.ones([dim])stored asself.weight - Bias:
Tensor.zeros([dim])stored asself.bias - Store
epsfor numerical stability
In forward, apply normalization with
F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). Returns
a normalized tensor with the same shape as input.
ResidualBlock implementation:
In __init__, create a LayerNorm instance:
self.ln = LayerNorm(dim, eps=eps). This will normalize inputs before
sublayers.
In forward, implement the pre-norm pattern:
- Normalize:
normalized = self.ln(x) - Process:
sublayer_output = sublayer(normalized) - Add residual:
return x + sublayer_output
Standalone function:
Implement apply_residual_connection(input_tensor, sublayer_output) that
returns input_tensor + sublayer_output. This demonstrates the residual pattern
as a simple function.
Implementation (step_08.py):
"""
Step 08: Residual Connections and Layer Normalization
Implement layer normalization and residual connections, which enable
training deep transformer networks by stabilizing gradients.
Tasks:
1. Import F (functional), Tensor, DimLike, and Module
2. Create LayerNorm class with learnable weight and bias parameters
3. Implement layer norm using F.layer_norm
4. Implement residual connection (simple addition)
Run: pixi run s08
"""
# TODO: Import required modules
# Hint: You'll need F from max.experimental
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need DimLike from max.graph
# Hint: You'll need Module from max.nn.module_v3
class LayerNorm(Module):
"""Layer normalization module matching HuggingFace GPT-2."""
def __init__(self, dim: DimLike, *, eps: float = 1e-5):
"""Initialize layer normalization.
Args:
dim: Dimension to normalize (embedding dimension)
eps: Small epsilon for numerical stability
"""
super().__init__()
self.eps = eps
# TODO: Create learnable scale parameter (weight)
# Hint: Use Tensor.ones([dim])
self.weight = None
# TODO: Create learnable shift parameter (bias)
# Hint: Use Tensor.zeros([dim])
self.bias = None
def __call__(self, x: Tensor) -> Tensor:
"""Apply layer normalization.
Args:
x: Input tensor, shape [..., dim]
Returns:
Normalized tensor, same shape as input
"""
# TODO: Apply layer normalization
# Hint: Use F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)
return None
class ResidualBlock(Module):
"""Demonstrates residual connections with layer normalization."""
def __init__(self, dim: int, eps: float = 1e-5):
"""Initialize residual block.
Args:
dim: Dimension of the input/output
eps: Epsilon for layer normalization
"""
super().__init__()
# TODO: Create layer normalization
# Hint: Use LayerNorm(dim, eps=eps)
self.ln = None
def __call__(self, x: Tensor, sublayer_output: Tensor) -> Tensor:
"""Apply residual connection.
Args:
x: Input tensor (the residual)
sublayer_output: Output from sublayer applied to ln(x)
Returns:
x + sublayer_output
"""
# TODO: Add input and sublayer output (residual connection)
# Hint: return x + sublayer_output
return None
def apply_residual_connection(input_tensor: Tensor, sublayer_output: Tensor) -> Tensor:
"""Apply a residual connection by adding input to sublayer output.
Args:
input_tensor: Original input (the residual)
sublayer_output: Output from a sublayer (attention, MLP, etc.)
Returns:
input_tensor + sublayer_output
"""
# TODO: Add the two tensors
# Hint: return input_tensor + sublayer_output
return None
Validation
Run pixi run s08 to verify your implementation.
Show solution
"""
Solution for Step 08: Residual Connections and Layer Normalization
This module implements layer normalization and demonstrates residual connections,
which are essential for training deep transformer networks.
"""
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import DimLike
from max.nn.module_v3 import Module
class LayerNorm(Module):
"""Layer normalization module matching HuggingFace GPT-2.
Layer norm normalizes activations across the feature dimension,
stabilizing training and allowing deeper networks.
"""
def __init__(self, dim: DimLike, *, eps: float = 1e-5):
"""Initialize layer normalization.
Args:
dim: Dimension to normalize (embedding dimension)
eps: Small epsilon for numerical stability
"""
super().__init__()
self.eps = eps
# Learnable scale parameter (gamma)
self.weight = Tensor.ones([dim])
# Learnable shift parameter (beta)
self.bias = Tensor.zeros([dim])
def __call__(self, x: Tensor) -> Tensor:
"""Apply layer normalization.
Args:
x: Input tensor, shape [..., dim]
Returns:
Normalized tensor, same shape as input
"""
return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)
class ResidualBlock(Module):
"""Demonstrates residual connections with layer normalization.
This shows the pre-norm architecture used in GPT-2:
output = input + sublayer(layer_norm(input))
"""
def __init__(self, dim: int, eps: float = 1e-5):
"""Initialize residual block.
Args:
dim: Dimension of the input/output
eps: Epsilon for layer normalization
"""
super().__init__()
self.ln = LayerNorm(dim, eps=eps)
def __call__(self, x: Tensor, sublayer_output: Tensor) -> Tensor:
"""Apply residual connection.
This demonstrates the pattern:
1. Normalize input: ln(x)
2. Apply sublayer (passed as argument for simplicity)
3. Add residual: x + sublayer_output
In practice, the sublayer (attention or MLP) is applied to ln(x),
but we receive the result as a parameter for clarity.
Args:
x: Input tensor (the residual)
sublayer_output: Output from sublayer applied to ln(x)
Returns:
x + sublayer_output
"""
# In a real transformer block, you would do:
# residual = x
# x = self.ln(x)
# x = sublayer(x) # e.g., attention or MLP
# x = x + residual
# For this demonstration, we just add
return x + sublayer_output
def apply_residual_connection(input_tensor: Tensor, sublayer_output: Tensor) -> Tensor:
"""Apply a residual connection by adding input to sublayer output.
Residual connections allow gradients to flow directly through the network,
enabling training of very deep models.
Args:
input_tensor: Original input (the residual)
sublayer_output: Output from a sublayer (attention, MLP, etc.)
Returns:
input_tensor + sublayer_output
"""
return input_tensor + sublayer_output
Next: In Step 09, you’ll combine multi-head attention, MLP, layer norm, and residual connections into a complete transformer block.
Step 09: Transformer block
Learn to combine attention, MLP, layer normalization, and residual connections into a complete transformer block.
Building the transformer block
In this step, you’ll build the GPT2Block class. This is a fundamental
repeating unit of GPT-2. Each block combines multi-head attention and a
feed-forward network, with layer normalization and residual connections around
each.
The block processes input through two sequential operations. First, it applies
layer norm, runs multi-head attention, then adds the result back to the input
(residual connection). Second, it applies another layer norm, runs the MLP, and
adds that result back. This pattern is x = x + sublayer(layer_norm(x)), called
pre-normalization.
GPT-2 uses pre-norm because it stabilizes training in deep networks. By normalizing before each sublayer instead of after, gradients flow more smoothly through the network’s 12 stacked blocks.
Understanding the components
The transformer block consists of four components, applied in this order:
First layer norm (ln_1): Normalizes the input before attention. Uses epsilon=1e-5 for numerical stability.
Multi-head attention (attn): The self-attention mechanism from Step 07. Lets each position attend to all previous positions.
Second layer norm (ln_2): Normalizes before the MLP. Same configuration as the first.
Feed-forward network (mlp): The position-wise MLP from Step 04. Expands to 3,072 dimensions internally (4× the embedding size), then projects back to 768.
The block maintains a constant 768-dimensional representation throughout. Input
shape [batch, seq_length, 768] stays the same after each sublayer, which is
essential for stacking 12 blocks together.
Understanding the flow
Each sublayer follows the pre-norm pattern:
- Save the input as
residual - Apply layer normalization to the input
- Process through the sublayer (attention or MLP)
- Add the original
residualback to the output
This happens twice per block, once for attention and once for the MLP. The residual connections let gradients flow directly through the network, preventing vanishing gradients in deep models.
Component names (ln_1, attn, ln_2, mlp) match Hugging Face’s GPT-2
implementation. This matters for loading pretrained weights in later steps.
Implementing the block
You’ll create the GPT2Block class by composing the components from earlier
steps. The block takes GPT2Config and creates four sublayers, then applies
them in sequence with residual connections.
First, import the required modules. You’ll need Module from MAX, plus the
previously implemented components: GPT2Config, GPT2MLP,
GPT2MultiHeadAttention, and LayerNorm.
In the __init__ method, create the four sublayers:
ln_1:LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)attn:GPT2MultiHeadAttention(config)ln_2:LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)mlp:GPT2MLP(4 * config.n_embd, config)
The MLP uses 4 * config.n_embd (3,072 dimensions) as its inner dimension, following the standard transformer ratio.
In the forward method, implement the two sublayer blocks:
Attention block:
- Save
residual = hidden_states - Normalize:
hidden_states = self.ln_1(hidden_states) - Apply attention:
attn_output = self.attn(hidden_states) - Add back:
hidden_states = attn_output + residual
MLP block:
- Save
residual = hidden_states - Normalize:
hidden_states = self.ln_2(hidden_states) - Apply MLP:
feed_forward_hidden_states = self.mlp(hidden_states) - Add back:
hidden_states = residual + feed_forward_hidden_states
Finally, return hidden_states.
Implementation (step_09.py):
"""
Step 09: Transformer Block
Combine multi-head attention, MLP, layer normalization, and residual
connections into a complete transformer block.
Tasks:
1. Import Module and all previous solution components
2. Create ln_1, attn, ln_2, and mlp layers
3. Implement forward pass with pre-norm residual pattern
Run: pixi run s09
"""
# TODO: Import required modules
# Hint: You'll need Module from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import GPT2MLP from solutions.solution_04
# Hint: Import GPT2MultiHeadAttention from solutions.solution_07
# Hint: Import LayerNorm from solutions.solution_08
class GPT2Block(Module):
"""Complete GPT-2 transformer block."""
def __init__(self, config: GPT2Config):
"""Initialize transformer block.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
hidden_size = config.n_embd
inner_dim = (
config.n_inner
if hasattr(config, "n_inner") and config.n_inner is not None
else 4 * hidden_size
)
# TODO: Create first layer norm (before attention)
# Hint: Use LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_1 = None
# TODO: Create multi-head attention
# Hint: Use GPT2MultiHeadAttention(config)
self.attn = None
# TODO: Create second layer norm (before MLP)
# Hint: Use LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_2 = None
# TODO: Create MLP
# Hint: Use GPT2MLP(inner_dim, config)
self.mlp = None
def __call__(self, hidden_states):
"""Apply transformer block.
Args:
hidden_states: Input tensor, shape [batch, seq_length, n_embd]
Returns:
Output tensor, shape [batch, seq_length, n_embd]
"""
# TODO: Attention block with residual connection
# Hint: residual = hidden_states
# Hint: hidden_states = self.ln_1(hidden_states)
# Hint: attn_output = self.attn(hidden_states)
# Hint: hidden_states = attn_output + residual
pass
# TODO: MLP block with residual connection
# Hint: residual = hidden_states
# Hint: hidden_states = self.ln_2(hidden_states)
# Hint: feed_forward_hidden_states = self.mlp(hidden_states)
# Hint: hidden_states = residual + feed_forward_hidden_states
pass
# TODO: Return the output
return None
Validation
Run pixi run s09 to verify your implementation.
Show solution
"""
Solution for Step 09: Transformer Block
This module implements a complete GPT-2 transformer block, combining
multi-head attention, MLP, layer normalization, and residual connections.
"""
from max.nn.module_v3 import Module
from solutions.solution_01 import GPT2Config
from solutions.solution_04 import GPT2MLP
from solutions.solution_07 import GPT2MultiHeadAttention
from solutions.solution_08 import LayerNorm
class GPT2Block(Module):
"""Complete GPT-2 transformer block matching HuggingFace structure.
Architecture (pre-norm):
1. x = x + attention(layer_norm(x))
2. x = x + mlp(layer_norm(x))
"""
def __init__(self, config: GPT2Config):
"""Initialize transformer block.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
hidden_size = config.n_embd
# Inner dimension for MLP (4x hidden size by default)
inner_dim = (
config.n_inner
if hasattr(config, "n_inner") and config.n_inner is not None
else 4 * hidden_size
)
# First layer norm (before attention)
self.ln_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# Multi-head attention
self.attn = GPT2MultiHeadAttention(config)
# Second layer norm (before MLP)
self.ln_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# Feed-forward MLP
self.mlp = GPT2MLP(inner_dim, config)
def __call__(self, hidden_states):
"""Apply transformer block.
Args:
hidden_states: Input tensor, shape [batch, seq_length, n_embd]
Returns:
Output tensor, shape [batch, seq_length, n_embd]
"""
# Attention block with residual connection
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(hidden_states)
hidden_states = attn_output + residual
# MLP block with residual connection
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feed_forward_hidden_states
return hidden_states
Next: In Step 10, you’ll stack 12 transformer blocks together to create the complete GPT-2 model architecture.
Step 10: Stacking transformer blocks
Learn to stack 12 transformer blocks with embeddings and final normalization to create the complete GPT-2 model.
Building the complete model
In this step, you’ll create the GPT2Model class - the complete transformer that takes token IDs as input and outputs contextualized representations. This class combines embeddings, 12 stacked transformer blocks, and final layer normalization.
The model processes input in four stages: convert token IDs to embeddings, add position information, pass through 12 transformer blocks sequentially, and normalize the final output. Each transformer block refines the representation, building up from surface patterns in early layers to semantic understanding in later layers.
GPT-2 uses 12 layers because this depth allows the model to learn complex patterns while remaining trainable. Fewer layers would limit the model’s capacity. More layers would increase training difficulty without proportional gains in quality for a 117M parameter model.
Understanding the components
The complete model has four main components:
Token embeddings (wte): Maps each token ID to a 768-dimensional vector using a lookup table with 50,257 entries (one per vocabulary token).
Position embeddings (wpe): Maps each position (0 to 1,023) to a 768-dimensional vector. These are added to token embeddings so the model knows token order.
Transformer blocks (h): 12 identical blocks stacked using MAX’s Sequential module. Sequential applies blocks in order, passing each block’s output to the next.
Final layer norm (ln_f): Normalizes the output after all blocks. This stabilizes the representation before the language model head (added in Step 11) projects to vocabulary logits.
Understanding the forward pass
The forward method processes token IDs through the model:
First, create position indices using Tensor.arange. Generate positions [0, 1, 2, …, seq_length-1] matching the input’s dtype and device. This ensures compatibility when adding to embeddings.
Next, look up embeddings. Get token embeddings with self.wte(input_ids) and position embeddings with self.wpe(position_indices). Add them together element-wise, as both are shape [batch, seq_length, 768].
Then, pass through the transformer blocks with self.h(x). Sequential applies all 12 blocks in order, each refining the representation.
Finally, normalize the output with self.ln_f(x) and return the result. The output shape matches the input: [batch, seq_length, 768].
You’ll use the following MAX operations to complete this task:
Module composition:
Sequential(*modules): Chains transformer blocks in sequence
Embeddings:
Embedding(num_embeddings, dim): Token and position embeddings
Position generation:
Tensor.arange(seq_length, dtype, device): Creates position indices
Implementing the model
You’ll create the GPT2Model class by composing embedding layers, transformer blocks, and layer normalization. The class builds on all the components from previous steps.
First, import the required modules. You’ll need Tensor for position indices, Embedding, Module, and Sequential from MAX’s neural network module, plus the previously implemented GPT2Config, LayerNorm, and GPT2Block.
In the __init__ method, create the four components:
- Token embeddings:
Embedding(config.vocab_size, dim=config.n_embd)stored asself.wte - Position embeddings:
Embedding(config.n_positions, dim=config.n_embd)stored asself.wpe - Transformer blocks:
Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))stored asself.h - Final layer norm:
LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)stored asself.ln_f
The Sequential module takes a generator expression that creates 12 identical GPT2Block instances. The * unpacks them as arguments to Sequential.
In the forward method, implement the four-stage processing:
- Get the sequence length from
input_ids.shape - Create position indices:
Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device) - Look up embeddings and add them:
x = self.wte(input_ids) + self.wpe(position_indices) - Apply transformer blocks:
x = self.h(x) - Apply final normalization:
x = self.ln_f(x) - Return
x
The position indices must match the input’s dtype and device to ensure the tensors are compatible for addition.
Implementation (step_10.py):
"""
Step 10: Stacking Transformer Blocks
Stack multiple transformer blocks with embeddings to create
the complete GPT-2 model architecture.
Tasks:
1. Import Tensor, Embedding, Module, Sequential, and previous components
2. Create token and position embeddings
3. Stack n_layer transformer blocks using Sequential
4. Create final layer normalization
5. Implement forward pass: embeddings -> blocks -> layer norm
Run: pixi run s10
"""
# TODO: Import required modules
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need Embedding, Module, Sequential from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import LayerNorm from solutions.solution_08
# Hint: Import GPT2Block from solutions.solution_09
class GPT2Model(Module):
"""Complete GPT-2 transformer model."""
def __init__(self, config: GPT2Config):
"""Initialize GPT-2 model.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
# TODO: Create token embeddings
# Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
self.wte = None
# TODO: Create position embeddings
# Hint: Use Embedding(config.n_positions, dim=config.n_embd)
self.wpe = None
# TODO: Stack transformer blocks
# Hint: Use Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
# This creates config.n_layer blocks (12 for GPT-2 base)
self.h = None
# TODO: Create final layer normalization
# Hint: Use LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.ln_f = None
def __call__(self, input_ids):
"""Forward pass through the transformer.
Args:
input_ids: Token IDs, shape [batch, seq_length]
Returns:
Hidden states, shape [batch, seq_length, n_embd]
"""
# TODO: Get batch size and sequence length
# Hint: batch_size, seq_length = input_ids.shape
pass
# TODO: Get token embeddings
# Hint: tok_embeds = self.wte(input_ids)
pass
# TODO: Get position embeddings
# Hint: Create position indices with Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
# Hint: pos_embeds = self.wpe(position_indices)
pass
# TODO: Combine embeddings
# Hint: x = tok_embeds + pos_embeds
pass
# TODO: Apply transformer blocks
# Hint: x = self.h(x)
pass
# TODO: Apply final layer norm
# Hint: x = self.ln_f(x)
pass
# TODO: Return the output
return None
Validation
Run pixi run s10 to verify your implementation.
Show solution
"""
Solution for Step 10: Stacking Transformer Blocks
This module stacks multiple transformer blocks and adds embeddings
to create the complete GPT-2 transformer architecture.
"""
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding, Module, Sequential
from solutions.solution_01 import GPT2Config
from solutions.solution_08 import LayerNorm
from solutions.solution_09 import GPT2Block
class GPT2Model(Module):
"""Complete GPT-2 transformer model matching HuggingFace structure.
Architecture:
1. Token embeddings + position embeddings
2. Stack of n_layer transformer blocks
3. Final layer normalization
"""
def __init__(self, config: GPT2Config):
"""Initialize GPT-2 model.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
# Token embeddings (vocabulary -> embeddings)
self.wte = Embedding(config.vocab_size, dim=config.n_embd)
# Position embeddings (positions -> embeddings)
self.wpe = Embedding(config.n_positions, dim=config.n_embd)
# Stack of transformer blocks
self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
# Final layer normalization
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def __call__(self, input_ids):
"""Forward pass through the transformer.
Args:
input_ids: Token IDs, shape [batch, seq_length]
Returns:
Hidden states, shape [batch, seq_length, n_embd]
"""
batch_size, seq_length = input_ids.shape
# Get token embeddings
tok_embeds = self.wte(input_ids)
# Get position embeddings
pos_embeds = self.wpe(
Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
)
# Combine embeddings
x = tok_embeds + pos_embeds
# Apply transformer blocks
x = self.h(x)
# Final layer norm
x = self.ln_f(x)
return x
Next: In Step 11, you’ll add the language modeling head that projects hidden states to vocabulary logits for text generation.
Step 11: Language model head
Learn to add the final linear projection layer that converts hidden states to vocabulary logits for next-token prediction.
Adding the language model head
In this step, you’ll create the MaxGPT2LMHeadModel - the complete language
model that can predict next tokens. This class wraps the transformer from Step
10 and adds a final linear layer that projects 768-dimensional hidden states to
50,257-dimensional vocabulary logits.
The language model head is a single linear layer without bias. For each position in the sequence, it outputs a score for every possible next token. Higher scores indicate the model thinks that token is more likely to come next.
At 768 × 50,257 = 38.6M parameters, the LM head is the single largest component in GPT-2, representing about 33% of the model’s 117M total parameters. This is larger than all 12 transformer blocks combined.
Understanding the projection
The language model head performs a simple linear projection using MAX’s
Linear
layer. It maps each 768-dimensional hidden state to 50,257 scores, one per
vocabulary token.
The layer uses bias=False, meaning it only has weights and no bias vector.
This saves 50,257 parameters (about 0.4% of model size). The bias provides
little benefit because the layer normalization before the LM head already
centers the activations. Adding a constant bias to all logits wouldn’t change
the relative probabilities after softmax.
The output is called “logits,” which are raw scores before applying softmax. Logits can be any real number. During text generation (Step 12), you’ll convert logits to probabilities with softmax. Working with logits directly enables techniques like temperature scaling and top-k sampling.
Understanding the complete model
With the LM head added, you now have the complete GPT-2 architecture:
- Input: Token IDs
[batch, seq_length] - Embeddings: Token + position
[batch, seq_length, 768] - Transformer blocks: 12 blocks process the embeddings
[batch, seq_length, 768] - Final layer norm: Normalizes the output
[batch, seq_length, 768] - LM head: Projects to vocabulary
[batch, seq_length, 50257] - Output: Logits
[batch, seq_length, 50257]
Each position gets independent logits over the vocabulary. To predict the next token after position i, you look at the logits at position i. The highest scoring token is the model’s top prediction.
You’ll use the following MAX operations to complete this task:
Linear layer:
Linear(in_features, out_features, bias=False): Projects hidden states to vocabulary logits
Implementing the language model
You’ll create the MaxGPT2LMHeadModel class that wraps the transformer with a
language modeling head. The implementation is straightforward, with just two
components and a simple forward pass.
First, import the required modules. You’ll need Linear and Module from MAX,
plus the previously implemented GPT2Config and GPT2Model.
In the __init__ method, create two components:
- Transformer:
GPT2Model(config)stored asself.transformer - LM head:
Linear(config.n_embd, config.vocab_size, bias=False)stored asself.lm_head
Note the bias=False parameter, which creates a linear layer without bias
terms.
In the forward method, implement a simple two-step process:
- Get hidden states from the transformer:
hidden_states = self.transformer(input_ids) - Project to vocabulary logits:
logits = self.lm_head(hidden_states) - Return
logits
That’s it. The model takes token IDs and returns logits. In the next step, you’ll use these logits to generate text.
Implementation (step_11.py):
"""
Step 11: Language Model Head
Add the final projection layer that converts hidden states to vocabulary logits.
Tasks:
1. Import Linear, Module, and previous components
2. Create transformer and lm_head layers
3. Implement forward pass: transformer -> lm_head
Run: pixi run s11
"""
# TODO: Import required modules
# Hint: You'll need Linear and Module from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import GPT2Model from solutions.solution_10
class MaxGPT2LMHeadModel(Module):
"""Complete GPT-2 model with language modeling head."""
def __init__(self, config: GPT2Config):
"""Initialize GPT-2 with LM head.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
self.config = config
# TODO: Create the transformer
# Hint: Use GPT2Model(config)
self.transformer = None
# TODO: Create language modeling head
# Hint: Use Linear(config.n_embd, config.vocab_size, bias=False)
# Projects from hidden dimension to vocabulary size
self.lm_head = None
def __call__(self, input_ids):
"""Forward pass through transformer and LM head.
Args:
input_ids: Token IDs, shape [batch, seq_length]
Returns:
Logits over vocabulary, shape [batch, seq_length, vocab_size]
"""
# TODO: Get hidden states from transformer
# Hint: hidden_states = self.transformer(input_ids)
pass
# TODO: Project to vocabulary logits
# Hint: logits = self.lm_head(hidden_states)
pass
# TODO: Return logits
return None
Validation
Run pixi run s11 to verify your implementation.
Show solution
"""
Solution for Step 11: Language Model Head
This module adds the final projection layer that converts hidden states
to vocabulary logits for predicting the next token.
"""
from max.nn.module_v3 import Linear, Module
from solutions.solution_01 import GPT2Config
from solutions.solution_10 import GPT2Model
class MaxGPT2LMHeadModel(Module):
"""Complete GPT-2 model with language modeling head.
This is the full model that can be used for text generation.
"""
def __init__(self, config: GPT2Config):
"""Initialize GPT-2 with LM head.
Args:
config: GPT2Config containing model hyperparameters
"""
super().__init__()
self.config = config
# The transformer (embeddings + blocks + final norm)
self.transformer = GPT2Model(config)
# Language modeling head (hidden states -> vocabulary logits)
self.lm_head = Linear(config.n_embd, config.vocab_size, bias=False)
def __call__(self, input_ids):
"""Forward pass through transformer and LM head.
Args:
input_ids: Token IDs, shape [batch, seq_length]
Returns:
Logits over vocabulary, shape [batch, seq_length, vocab_size]
"""
# Get hidden states from transformer
hidden_states = self.transformer(input_ids)
# Project to vocabulary logits
logits = self.lm_head(hidden_states)
return logits
Next: In Step 12, you’ll implement text generation using sampling and temperature control to generate coherent text autoregressively.
Step 12: Text generation
Learn to implement autoregressive text generation with sampling and temperature control.
Generating text
In this final step, you’ll implement the generation loop that produces text one token at a time. The model predicts the next token, appends it to the sequence, and repeats until reaching the desired length.
Start with a prompt like “Hello world” (tokens [15496, 995]). The model predicts the next token, giving you [15496, 995, 318] (“Hello world is”). It predicts again, producing [15496, 995, 318, 257] (“Hello world is a”). This process continues, with each prediction feeding back as input for the next.
You’ll implement two generation strategies: greedy decoding (always pick the highest-scoring token) and sampling (randomly choose according to probabilities). You’ll also add temperature control to adjust how random or focused the generation is.
Understanding the generation loop
The generation loop is simple: run the model, extract the next token prediction, append it to the sequence, repeat. Each iteration requires a full forward pass through all 12 transformer blocks.
The model outputs logits with shape [batch, seq_length, vocab_size]. Since you only care about predicting the next token, extract the last position: logits[0, -1, :]. This gives you a vector of 50,257 scores, one per vocabulary token.
These scores are logits (unnormalized), not probabilities. To convert them to probabilities, apply softmax. Then you can either pick the highest-probability token (greedy) or sample from the distribution (random).
Understanding temperature control
Temperature scaling adjusts how random the generation is using the formula scaled_logits = logits / temperature.
With temperature 1.0, you use the original distribution. With temperature 0.7, you sharpen the distribution, and high-probability tokens become even more likely, making generation more focused and deterministic. With temperature 1.2, you flatten the distribution, and lower-probability tokens get more chances, making generation more diverse and creative.
Temperature is applied before softmax. Dividing by a value less than 1 makes large logits even larger (sharpening), while dividing by a value greater than 1 reduces the differences between logits (flattening).
Understanding sampling vs greedy
Greedy decoding always picks the highest-probability token using F.argmax. It’s fast, deterministic, and simple, but often produces repetitive text because the model keeps choosing the safest option.
Sampling randomly selects tokens according to their probabilities. Convert logits to probabilities with F.softmax, transfer to CPU, convert to NumPy with np.from_dlpack, then sample with np.random.choice. You use NumPy because MAX doesn’t have built-in sampling yet.
Most practical generation uses sampling with temperature control. This balances creativity with coherence, as the model can explore different possibilities while still favoring high-quality continuations.
You’ll use the following MAX operations to complete this task:
Probability operations:
F.softmax(logits): Converts logits to probabilitiesF.argmax(logits): Selects highest-probability token (greedy)
Sequence building:
F.concat([seq, new_token], axis=1): Appends token to sequenceTensor.constant(value, dtype, device): Creates scalar tensors
NumPy interop:
probs.to(CPU()): Transfers tensor to CPUnp.from_dlpack(probs): Converts MAX tensor to NumPy for sampling
Implementing text generation
You’ll create two functions: generate_next_token that predicts a single token, and generate that loops to produce full sequences.
First, import the required modules. You’ll need numpy for sampling, CPU from MAX’s driver, DType for type constants, functional as F for operations like softmax and argmax, and Tensor for creating tensors.
In generate_next_token, implement the prediction logic:
- Run the model to get logits:
logits = model(input_ids) - Extract the last position (next token prediction):
next_token_logits = logits[0, -1, :] - If using temperature, scale the logits by dividing by the temperature tensor
- For sampling: convert to probabilities with
F.softmax, transfer to CPU, convert to NumPy withnp.from_dlpack, sample withnp.random.choice, then convert back to a MAX tensor - For greedy: use
F.argmaxto select the highest-scoring token
The temperature must be a tensor with the same dtype and device as the logits. Create it with Tensor.constant(temperature, dtype=..., device=...).
In generate, implement the generation loop:
- Initialize with the input:
generated_tokens = input_ids - Loop
max_new_tokenstimes - Generate the next token:
next_token = generate_next_token(model, generated_tokens, ...) - Reshape to 2D:
next_token_2d = next_token.reshape([1, -1]) - Concatenate to the sequence:
generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1) - Return the complete sequence
The reshape is necessary because concat requires matching dimensions, and the generated token is 0D (scalar).
Implementation (step_12.py):
"""
Step 12: Text Generation
Implement autoregressive text generation with sampling and temperature control.
Tasks:
1. Import required modules (numpy, F, Tensor, etc.)
2. Implement generate_next_token: get logits, apply temperature, sample/argmax
3. Implement generate_tokens: loop to generate multiple tokens
Run: pixi run s12
"""
# TODO: Import required modules
# Hint: You'll need numpy as np
# Hint: You'll need CPU from max.driver
# Hint: You'll need DType from max.dtype
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor from max.experimental.tensor
def generate_next_token(model, input_ids, temperature=1.0, do_sample=True):
"""Generate the next token given input context.
Args:
model: GPT-2 model with LM head
input_ids: Current sequence, shape [batch, seq_length]
temperature: Sampling temperature (higher = more random)
do_sample: If True, sample from distribution; if False, use greedy (argmax)
Returns:
Next token ID as a Tensor
"""
# TODO: Get logits from model
# Hint: logits = model(input_ids)
pass
# TODO: Get logits for last position
# Hint: next_token_logits = logits[0, -1, :]
pass
# TODO: If sampling with temperature
if do_sample and temperature > 0:
# TODO: Apply temperature scaling
# Hint: temp_tensor = Tensor.constant(temperature, dtype=next_token_logits.dtype, device=next_token_logits.device)
# Hint: next_token_logits = next_token_logits / temp_tensor
pass
# TODO: Convert to probabilities
# Hint: probs = F.softmax(next_token_logits)
pass
# TODO: Sample from distribution
# Hint: probs_np = np.from_dlpack(probs.to(CPU()))
# Hint: next_token_id = np.random.choice(len(probs_np), p=probs_np)
# Hint: next_token_tensor = Tensor.constant(next_token_id, dtype=DType.int64, device=input_ids.device)
pass
else:
# TODO: Greedy decoding (select most likely token)
# Hint: next_token_tensor = F.argmax(next_token_logits)
pass
# TODO: Return the next token
return None
def generate_tokens(
model, input_ids, max_new_tokens=10, temperature=1.0, do_sample=True
):
"""Generate multiple tokens autoregressively.
Args:
model: GPT-2 model with LM head
input_ids: Initial sequence, shape [batch, seq_length]
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature
do_sample: Whether to sample or use greedy decoding
Returns:
Generated sequence including input, shape [batch, seq_length + max_new_tokens]
"""
# TODO: Initialize generated tokens with input
# Hint: generated_tokens = input_ids
pass
# TODO: Generation loop
# Hint: for _ in range(max_new_tokens):
pass
# TODO: Generate next token
# Hint: next_token = generate_next_token(model, generated_tokens, temperature=temperature, do_sample=do_sample)
pass
# TODO: Reshape to [1, 1] for concatenation
# Hint: next_token_2d = next_token.reshape([1, -1])
pass
# TODO: Append to sequence
# Hint: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
pass
# TODO: Return generated sequence
return None
Validation
Run pixi run s12 to verify your implementation.
Show solution
"""
Solution for Step 12: Text Generation
This module implements autoregressive text generation using the GPT-2 model.
"""
import numpy as np
from max.driver import CPU
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
def generate_next_token(model, input_ids, temperature=1.0, do_sample=True):
"""Generate the next token given input context.
Args:
model: GPT-2 model with LM head
input_ids: Current sequence, shape [batch, seq_length]
temperature: Sampling temperature (higher = more random)
do_sample: If True, sample from distribution; if False, use greedy (argmax)
Returns:
Next token ID as a Tensor
"""
# Get logits from model
logits = model(input_ids)
# Get logits for last position (next token prediction)
next_token_logits = logits[0, -1, :] # Shape: [vocab_size]
if do_sample and temperature > 0:
# Apply temperature scaling
temp_tensor = Tensor.constant(
temperature, dtype=next_token_logits.dtype, device=next_token_logits.device
)
next_token_logits = next_token_logits / temp_tensor
# Convert to probabilities
probs = F.softmax(next_token_logits)
# Sample from distribution
probs_np = np.from_dlpack(probs.to(CPU()))
next_token_id = np.random.choice(len(probs_np), p=probs_np)
next_token_tensor = Tensor.constant(
next_token_id, dtype=DType.int64, device=input_ids.device
)
else:
# Greedy decoding: select most likely token
next_token_tensor = F.argmax(next_token_logits)
return next_token_tensor
def generate_tokens(
model, input_ids, max_new_tokens=10, temperature=1.0, do_sample=True
):
"""Generate multiple tokens autoregressively.
Args:
model: GPT-2 model with LM head
input_ids: Initial sequence, shape [batch, seq_length]
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature
do_sample: Whether to sample or use greedy decoding
Returns:
Generated sequence including input, shape [batch, seq_length + max_new_tokens]
"""
generated_tokens = input_ids
for _ in range(max_new_tokens):
# Generate next token
next_token = generate_next_token(
model, generated_tokens, temperature=temperature, do_sample=do_sample
)
# Reshape to [1, 1] for concatenation
next_token_2d = next_token.reshape([1, -1])
# Append to sequence
generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
return generated_tokens
What you’ve built
You’ve completed all 12 steps and built a complete GPT-2 model from scratch using MAX. You now have a working implementation of:
Core components:
- Model configuration and architecture definition
- Causal masking for autoregressive generation
- Layer normalization for training stability
- Feed-forward networks with GELU activation
- Token and position embeddings
- Multi-head self-attention
- Residual connections and transformer blocks
- Language model head for next-token prediction
- Text generation with temperature and sampling
Your model loads OpenAI’s pretrained GPT-2 weights and generates text. You understand how every component works, from the low-level tensor operations to the high-level architecture decisions.
What’s next?
You now understand the architectural foundation that powers modern language models. LLaMA, Mistral, and more build on these same components with incremental refinements. You have everything you need to implement those refinements yourself.
Consider extending your implementation with:
- Grouped-query attention (GQA): Reduce memory consumption by sharing key-value pairs across multiple query heads, as used in LLaMA 2.
- Rotary position embeddings (RoPE): Replace learned position embeddings with rotation-based encoding, improving length extrapolation in models like LLaMA and GPT-NeoX.
- SwiGLU activation: Swap GELU for the gated linear unit variant used in LLaMA and PaLM.
- Mixture of experts (MoE): Add sparse expert routing to scale model capacity efficiently, as in Mixtral and GPT-4.
Each refinement builds directly on what you’ve implemented. The attention mechanism you wrote becomes grouped-query attention with a simple modification to how you reshape key-value tensors. Your position embeddings can be replaced with RoPE by changing how you encode positional information. The feed-forward network you built becomes SwiGLU by adding a gating mechanism.
Pick an architecture that interests you and start building. You’ll find the patterns are familiar because the fundamentals haven’t changed.