Stacking transformer blocks

Stack 12 transformer blocks with embeddings and final normalization to create the complete body of the GPT-2 model.

MaxGPT2Model is the body of GPT-2. It converts raw token IDs to embeddings, adds position information, passes through all 12 transformer blocks, and normalizes the final output.

The model processes input in four stages:

  1. Token embeddings: Convert each token ID to a 768-dimensional vector via a learned lookup table with 50,257 entries.
  2. Position embeddings: Add a learned position vector for each token’s position (0 to 1,023). These are added element-wise to the token embeddings so the model knows token order.
  3. Transformer blocks: Pass through 12 identical GPT2Block layers sequentially. Each block refines the representation.
  4. Final layer norm: Normalize the output before the language model head.

Why 12 layers?

GPT-2 uses 12 layers because this depth allows complex pattern learning while remaining trainable. Early layers tend to capture surface-level patterns like word shapes and punctuation; later layers capture higher-level semantic patterns. The representations from all layers contribute to the final output.

Key APIs

Sequential chains the 12 transformer blocks in order, passing each block’s output to the next. The * in Sequential(*(GPT2Block(config) for _ in range(config.n_layer))) unpacks the generator as positional arguments.

Tensor.arange generates position indices [0, 1, ..., seq_length-1] matching the input’s dtype and device so they’re compatible for embedding lookup.

Embedding(vocab_size, dim) is used for both token and position embeddings.

The code

class MaxGPT2Model(Module):  # type: ignore[type-arg]
    def __init__(
        self,
        config: GPT2Config,
    ) -> None:
        self.wte = Embedding(config.vocab_size, dim=config.n_embd)
        self.wpe = Embedding(config.n_positions, dim=config.n_embd)
        self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

    def forward(self, input_ids: Tensor) -> Tensor:
        _, seq_length = input_ids.shape
        tok_embeds = self.wte(input_ids)
        pos_embeds = self.wpe(
            Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
        )
        x = tok_embeds + pos_embeds
        x = self.h(x)
        x = self.ln_f(x)
        return x


The _ in _, seq_length = input_ids.shape discards the batch dimension—we only need the sequence length to generate position indices.

Next: Section 8 adds the language model head that projects these 768-dimensional hidden states to vocabulary logits.