Step 05: Token embeddings
Learn to create token embeddings that convert discrete token IDs into continuous vector representations.
Implementing token embeddings
In this step you’ll create the Embedding class. This converts discrete token
IDs (integers) into continuous vector representations that the model can
process. The embedding layer is a lookup table with shape [50257, 768] where
50257 is GPT-2’s vocabulary size and 768 is the embedding dimension.
Neural networks operate on continuous values, not discrete symbols. Token embeddings convert discrete token IDs into dense vectors that can be processed by matrix operations. During training, these embeddings naturally cluster semantically similar words closer together in vector space.
Understanding embeddings
The embedding layer stores one vector per vocabulary token. When you pass in
token ID 1000, it returns row 1000 as the embedding vector. The layer name wte
stands for “word token embeddings” and matches the naming in the original GPT-2
code for weight loading compatibility.
Key parameters:
- Vocabulary size: 50,257 tokens (byte-pair encoding)
- Embedding dimension: 768 for GPT-2 base
- Shape: [vocab_size, embedding_dim]
You’ll use the following MAX operations to complete this task:
Embedding layer:
Embedding(num_embeddings, dim): Creates embedding lookup table with automatic weight initialization
Implementing the class
You’ll implement the Embedding class in several steps:
-
Import required modules: Import
EmbeddingandModulefrom MAX libraries. -
Create embedding layer: Use
Embedding(config.vocab_size, dim=config.n_embd)and store inself.wte. -
Implement forward pass: Call
self.wte(input_ids)to lookup embeddings. Input shape: [batch_size, seq_length]. Output shape: [batch_size, seq_length, n_embd].
Implementation (step_05.py):
"""
Step 05: Token Embeddings
Implement token embeddings that convert discrete token IDs into continuous vectors.
Tasks:
1. Import Embedding and Module from max.nn.module_v3
2. Create token embedding layer using Embedding(vocab_size, dim=n_embd)
3. Implement forward pass that looks up embeddings for input token IDs
Run: pixi run s05
"""
# TODO: Import required modules from MAX
# Hint: You'll need Embedding and Module from max.nn.module_v3
from solutions.solution_01 import GPT2Config
class GPT2Embeddings(Module):
"""Token embeddings for GPT-2."""
def __init__(self, config: GPT2Config):
super().__init__()
# TODO: Create token embedding layer
# Hint: Use Embedding(config.vocab_size, dim=config.n_embd)
# This creates a lookup table that converts token IDs to embedding vectors
self.wte = None
def __call__(self, input_ids):
"""Convert token IDs to embeddings.
Args:
input_ids: Tensor of token IDs, shape [batch_size, seq_length]
Returns:
Token embeddings, shape [batch_size, seq_length, n_embd]
"""
# TODO: Return the embedded tokens
# Hint: Simply call self.wte with input_ids
return None
Validation
Run pixi run s05 to verify your implementation.
Show solution
"""
Solution for Step 05: Token Embeddings
This module implements token embeddings that convert discrete token IDs
into continuous vector representations.
"""
from max.nn.module_v3 import Embedding, Module
from solutions.solution_01 import GPT2Config
class GPT2Embeddings(Module):
"""Token embeddings for GPT-2, matching HuggingFace structure."""
def __init__(self, config: GPT2Config):
"""Initialize token embedding layer.
Args:
config: GPT2Config containing vocab_size and n_embd
"""
super().__init__()
# Token embedding: lookup table from vocab_size to embedding dimension
# This converts discrete token IDs (0 to vocab_size-1) into dense vectors
self.wte = Embedding(config.vocab_size, dim=config.n_embd)
def __call__(self, input_ids):
"""Convert token IDs to embeddings.
Args:
input_ids: Tensor of token IDs, shape [batch_size, seq_length]
Returns:
Token embeddings, shape [batch_size, seq_length, n_embd]
"""
# Simple lookup: each token ID becomes its corresponding embedding vector
return self.wte(input_ids)
Next: In Step 06, you’ll implement position embeddings to encode sequence order information, which will be combined with these token embeddings.