Step 05: Token embeddings

Learn to create token embeddings that convert discrete token IDs into continuous vector representations.

Implementing token embeddings

In this step you’ll create the Embedding class. This converts discrete token IDs (integers) into continuous vector representations that the model can process. The embedding layer is a lookup table with shape [50257, 768] where 50257 is GPT-2’s vocabulary size and 768 is the embedding dimension.

Neural networks operate on continuous values, not discrete symbols. Token embeddings convert discrete token IDs into dense vectors that can be processed by matrix operations. During training, these embeddings naturally cluster semantically similar words closer together in vector space.

Understanding embeddings

The embedding layer stores one vector per vocabulary token. When you pass in token ID 1000, it returns row 1000 as the embedding vector. The layer name wte stands for “word token embeddings” and matches the naming in the original GPT-2 code for weight loading compatibility.

Key parameters:

  • Vocabulary size: 50,257 tokens (byte-pair encoding)
  • Embedding dimension: 768 for GPT-2 base
  • Shape: [vocab_size, embedding_dim]
MAX operations

You’ll use the following MAX operations to complete this task:

Embedding layer:

Implementing the class

You’ll implement the Embedding class in several steps:

  1. Import required modules: Import Embedding and Module from MAX libraries.

  2. Create embedding layer: Use Embedding(config.vocab_size, dim=config.n_embd) and store in self.wte.

  3. Implement forward pass: Call self.wte(input_ids) to lookup embeddings. Input shape: [batch_size, seq_length]. Output shape: [batch_size, seq_length, n_embd].

Implementation (step_05.py):

"""
Step 05: Token Embeddings

Implement token embeddings that convert discrete token IDs into continuous vectors.

Tasks:
1. Import Embedding and Module from max.nn.module_v3
2. Create token embedding layer using Embedding(vocab_size, dim=n_embd)
3. Implement forward pass that looks up embeddings for input token IDs

Run: pixi run s05
"""

# TODO: Import required modules from MAX
# Hint: You'll need Embedding and Module from max.nn.module_v3

from solutions.solution_01 import GPT2Config


class GPT2Embeddings(Module):
    """Token embeddings for GPT-2."""

    def __init__(self, config: GPT2Config):
        super().__init__()

        # TODO: Create token embedding layer
        # Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
        # This creates a lookup table that converts token IDs to embedding vectors
        self.wte = None

    def __call__(self, input_ids):
        """Convert token IDs to embeddings.

        Args:
            input_ids: Tensor of token IDs, shape [batch_size, seq_length]

        Returns:
            Token embeddings, shape [batch_size, seq_length, n_embd]
        """
        # TODO: Return the embedded tokens
        # Hint: Simply call self.wte with input_ids
        return None

Validation

Run pixi run s05 to verify your implementation.

Show solution
"""
Solution for Step 05: Token Embeddings

This module implements token embeddings that convert discrete token IDs
into continuous vector representations.
"""

from max.nn.module_v3 import Embedding, Module

from solutions.solution_01 import GPT2Config


class GPT2Embeddings(Module):
    """Token embeddings for GPT-2, matching HuggingFace structure."""

    def __init__(self, config: GPT2Config):
        """Initialize token embedding layer.

        Args:
            config: GPT2Config containing vocab_size and n_embd
        """
        super().__init__()

        # Token embedding: lookup table from vocab_size to embedding dimension
        # This converts discrete token IDs (0 to vocab_size-1) into dense vectors
        self.wte = Embedding(config.vocab_size, dim=config.n_embd)

    def __call__(self, input_ids):
        """Convert token IDs to embeddings.

        Args:
            input_ids: Tensor of token IDs, shape [batch_size, seq_length]

        Returns:
            Token embeddings, shape [batch_size, seq_length, n_embd]
        """
        # Simple lookup: each token ID becomes its corresponding embedding vector
        return self.wte(input_ids)

Next: In Step 06, you’ll implement position embeddings to encode sequence order information, which will be combined with these token embeddings.