Step 06: Position embeddings
Learn to create position embeddings that encode the order of tokens in a sequence.
Implementing position embeddings
In this step you’ll create position embeddings to encode where each token appears in the sequence. While token embeddings tell the model “what” each token is, position embeddings tell it “where” the token is located. These position vectors are added to token embeddings before entering the transformer blocks.
Transformers process all positions in parallel through attention, unlike Recurrent Neural Networks (RNNs) that process sequentially. This parallelism enables faster training but loses positional information. Position embeddings restore this information so the model can distinguish “dog bites man” from “man bites dog”.
Understanding position embeddings
Position embeddings work like token embeddings: a lookup table with shape [1024, 768] where 1024 is the maximum sequence length. Position 0 gets the first row, position 1 gets the second row, and so on.
GPT-2 uses learned position embeddings, meaning these vectors are initialized randomly and trained alongside the model. This differs from the original Transformer which used fixed sinusoidal position encodings. Learned embeddings let the model discover optimal position representations for its specific task, though they cannot generalize beyond the maximum length seen during training (1024 tokens).
Key parameters:
- Maximum sequence length: 1,024 positions
- Embedding dimension: 768 for GPT-2 base
- Shape: [n_positions, n_embd]
- Layer name:
wpe(word position embeddings)
You’ll use the following MAX operations to complete this task:
Position indices:
Tensor.arange(seq_length, dtype, device): Creates sequence positions [0, 1, 2, …, seq_length-1]
Embedding layer:
Embedding(num_embeddings, dim): Same class as token embeddings, but for positions
Implementing the class
You’ll implement the position embeddings in several steps:
-
Import required modules: Import
Tensor,Embedding, andModulefrom MAX libraries. -
Create position embedding layer: Use
Embedding(config.n_positions, dim=config.n_embd)and store inself.wpe. -
Implement forward pass: Call
self.wpe(position_ids)to lookup position embeddings. Input shape: [seq_length] or [batch, seq_length]. Output shape: [seq_length, n_embd] or [batch, seq_length, n_embd].
Implementation (step_06.py):
"""
Step 06: Position Embeddings
Implement position embeddings that encode sequence order information.
Tasks:
1. Import Tensor from max.experimental.tensor
2. Import Embedding and Module from max.nn.module_v3
3. Create position embedding layer using Embedding(n_positions, dim=n_embd)
4. Implement forward pass that looks up embeddings for position indices
Run: pixi run s06
"""
# TODO: Import required modules from MAX
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need Embedding and Module from max.nn.module_v3
from solutions.solution_01 import GPT2Config
class GPT2PositionEmbeddings(Module):
"""Position embeddings for GPT-2."""
def __init__(self, config: GPT2Config):
super().__init__()
# TODO: Create position embedding layer
# Hint: Use Embedding(config.n_positions, dim=config.n_embd)
# This creates a lookup table for position indices (0, 1, 2, ..., n_positions-1)
self.wpe = None
def __call__(self, position_ids):
"""Convert position indices to embeddings.
Args:
position_ids: Tensor of position indices, shape [seq_length] or [batch_size, seq_length]
Returns:
Position embeddings, shape matching input with added embedding dimension
"""
# TODO: Return the position embeddings
# Hint: Simply call self.wpe with position_ids
return None
Validation
Run pixi run s06 to verify your implementation.
Show solution
"""
Solution for Step 06: Position Embeddings
This module implements position embeddings that encode sequence order information
into the transformer model.
"""
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding, Module
from solutions.solution_01 import GPT2Config
class GPT2PositionEmbeddings(Module):
"""Position embeddings for GPT-2, matching HuggingFace structure."""
def __init__(self, config: GPT2Config):
"""Initialize position embedding layer.
Args:
config: GPT2Config containing n_positions and n_embd
"""
super().__init__()
# Position embedding: lookup table from position indices to embedding vectors
# This encodes "where" information - position 0, 1, 2, etc.
self.wpe = Embedding(config.n_positions, dim=config.n_embd)
def __call__(self, position_ids):
"""Convert position indices to embeddings.
Args:
position_ids: Tensor of position indices, shape [seq_length] or [batch_size, seq_length]
Returns:
Position embeddings, shape matching input with added embedding dimension
"""
# Simple lookup: each position index becomes its corresponding embedding vector
return self.wpe(position_ids)
Next: In Step 07, you’ll implement multi-head attention.