Step 11: Language model head

Learn to add the final linear projection layer that converts hidden states to vocabulary logits for next-token prediction.

Adding the language model head

In this step, you’ll create the MaxGPT2LMHeadModel - the complete language model that can predict next tokens. This class wraps the transformer from Step 10 and adds a final linear layer that projects 768-dimensional hidden states to 50,257-dimensional vocabulary logits.

The language model head is a single linear layer without bias. For each position in the sequence, it outputs a score for every possible next token. Higher scores indicate the model thinks that token is more likely to come next.

At 768 × 50,257 = 38.6M parameters, the LM head is the single largest component in GPT-2, representing about 33% of the model’s 117M total parameters. This is larger than all 12 transformer blocks combined.

Understanding the projection

The language model head performs a simple linear projection using MAX’s Linear layer. It maps each 768-dimensional hidden state to 50,257 scores, one per vocabulary token.

The layer uses bias=False, meaning it only has weights and no bias vector. This saves 50,257 parameters (about 0.4% of model size). The bias provides little benefit because the layer normalization before the LM head already centers the activations. Adding a constant bias to all logits wouldn’t change the relative probabilities after softmax.

The output is called “logits,” which are raw scores before applying softmax. Logits can be any real number. During text generation (Step 12), you’ll convert logits to probabilities with softmax. Working with logits directly enables techniques like temperature scaling and top-k sampling.

Understanding the complete model

With the LM head added, you now have the complete GPT-2 architecture:

  1. Input: Token IDs [batch, seq_length]
  2. Embeddings: Token + position [batch, seq_length, 768]
  3. Transformer blocks: 12 blocks process the embeddings [batch, seq_length, 768]
  4. Final layer norm: Normalizes the output [batch, seq_length, 768]
  5. LM head: Projects to vocabulary [batch, seq_length, 50257]
  6. Output: Logits [batch, seq_length, 50257]

Each position gets independent logits over the vocabulary. To predict the next token after position i, you look at the logits at position i. The highest scoring token is the model’s top prediction.

MAX operations

You’ll use the following MAX operations to complete this task:

Linear layer:

Implementing the language model

You’ll create the MaxGPT2LMHeadModel class that wraps the transformer with a language modeling head. The implementation is straightforward, with just two components and a simple forward pass.

First, import the required modules. You’ll need Linear and Module from MAX, plus the previously implemented GPT2Config and GPT2Model.

In the __init__ method, create two components:

  • Transformer: GPT2Model(config) stored as self.transformer
  • LM head: Linear(config.n_embd, config.vocab_size, bias=False) stored as self.lm_head

Note the bias=False parameter, which creates a linear layer without bias terms.

In the forward method, implement a simple two-step process:

  1. Get hidden states from the transformer: hidden_states = self.transformer(input_ids)
  2. Project to vocabulary logits: logits = self.lm_head(hidden_states)
  3. Return logits

That’s it. The model takes token IDs and returns logits. In the next step, you’ll use these logits to generate text.

Implementation (step_11.py):

"""
Step 11: Language Model Head

Add the final projection layer that converts hidden states to vocabulary logits.

Tasks:
1. Import Linear, Module, and previous components
2. Create transformer and lm_head layers
3. Implement forward pass: transformer -> lm_head

Run: pixi run s11
"""

# TODO: Import required modules
# Hint: You'll need Linear and Module from max.nn.module_v3
# Hint: Import GPT2Config from solutions.solution_01
# Hint: Import GPT2Model from solutions.solution_10


class MaxGPT2LMHeadModel(Module):
    """Complete GPT-2 model with language modeling head."""

    def __init__(self, config: GPT2Config):
        """Initialize GPT-2 with LM head.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        self.config = config

        # TODO: Create the transformer
        # Hint: Use GPT2Model(config)
        self.transformer = None

        # TODO: Create language modeling head
        # Hint: Use Linear(config.n_embd, config.vocab_size, bias=False)
        # Projects from hidden dimension to vocabulary size
        self.lm_head = None

    def __call__(self, input_ids):
        """Forward pass through transformer and LM head.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Logits over vocabulary, shape [batch, seq_length, vocab_size]
        """
        # TODO: Get hidden states from transformer
        # Hint: hidden_states = self.transformer(input_ids)
        pass

        # TODO: Project to vocabulary logits
        # Hint: logits = self.lm_head(hidden_states)
        pass

        # TODO: Return logits
        return None

Validation

Run pixi run s11 to verify your implementation.

Show solution
"""
Solution for Step 11: Language Model Head

This module adds the final projection layer that converts hidden states
to vocabulary logits for predicting the next token.
"""

from max.nn.module_v3 import Linear, Module

from solutions.solution_01 import GPT2Config
from solutions.solution_10 import GPT2Model


class MaxGPT2LMHeadModel(Module):
    """Complete GPT-2 model with language modeling head.

    This is the full model that can be used for text generation.
    """

    def __init__(self, config: GPT2Config):
        """Initialize GPT-2 with LM head.

        Args:
            config: GPT2Config containing model hyperparameters
        """
        super().__init__()

        self.config = config
        # The transformer (embeddings + blocks + final norm)
        self.transformer = GPT2Model(config)
        # Language modeling head (hidden states -> vocabulary logits)
        self.lm_head = Linear(config.n_embd, config.vocab_size, bias=False)

    def __call__(self, input_ids):
        """Forward pass through transformer and LM head.

        Args:
            input_ids: Token IDs, shape [batch, seq_length]

        Returns:
            Logits over vocabulary, shape [batch, seq_length, vocab_size]
        """
        # Get hidden states from transformer
        hidden_states = self.transformer(input_ids)

        # Project to vocabulary logits
        logits = self.lm_head(hidden_states)

        return logits

Next: In Step 12, you’ll implement text generation using sampling and temperature control to generate coherent text autoregressively.