Stacking transformer blocks
Stack 12 transformer blocks with embeddings and final normalization to create the complete body of the GPT-2 model.
MaxGPT2Model is the body of GPT-2. It converts raw token IDs to embeddings,
adds position information, passes through all 12 transformer blocks, and
normalizes the final output.
The model processes input in four stages:
- Token embeddings: Convert each token ID to a 768-dimensional vector via a learned lookup table with 50,257 entries.
- Position embeddings: Add a learned position vector for each token’s position (0 to 1,023). These are added element-wise to the token embeddings so the model knows token order.
- Transformer blocks: Pass through 12 identical
GPT2Blocklayers sequentially. Each block refines the representation. - Final layer norm: Normalize the output before the language model head.
Why 12 layers?
GPT-2 uses 12 layers because this depth allows complex pattern learning while remaining trainable. Early layers tend to capture surface-level patterns like word shapes and punctuation; later layers capture higher-level semantic patterns. The representations from all layers contribute to the final output.
Key APIs
Sequential
chains the 12 transformer blocks in order, passing each block’s output to the
next. The * in
Sequential(*(GPT2Block(config) for _ in range(config.n_layer))) unpacks the
generator as positional arguments.
Tensor.arange
generates position indices [0, 1, ..., seq_length-1] matching the input’s
dtype and device so they’re compatible for embedding lookup.
Embedding(vocab_size, dim)
is used for both token and position embeddings.
The code
class MaxGPT2Model(Module): # type: ignore[type-arg]
def __init__(
self,
config: GPT2Config,
) -> None:
self.wte = Embedding(config.vocab_size, dim=config.n_embd)
self.wpe = Embedding(config.n_positions, dim=config.n_embd)
self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(self, input_ids: Tensor) -> Tensor:
_, seq_length = input_ids.shape
tok_embeds = self.wte(input_ids)
pos_embeds = self.wpe(
Tensor.arange(seq_length, dtype=input_ids.dtype, device=input_ids.device)
)
x = tok_embeds + pos_embeds
x = self.h(x)
x = self.ln_f(x)
return x
The _ in _, seq_length = input_ids.shape discards the batch dimension—we
only need the sequence length to generate position indices.
Next: Section 8 adds the language model head that projects these 768-dimensional hidden states to vocabulary logits.