Step 02: Causal masking
Learn to create attention masks to prevent the model from seeing future tokens during autoregressive generation.
Implementing causal masking
In this step you’ll implement the causal_mask() function. This creates a mask matrix that prevents the model from seeing future tokens when predicting the next token. The mask sets attention scores to negative infinity (-inf) for future positions. After softmax, these -inf values become zero probability, blocking information flow from later tokens.
GPT-2 generates text one token at a time, left-to-right. During training, causal masking prevents the model from “cheating” by looking ahead at tokens it should be predicting. Without this mask, the model would have access to information it won’t have during actual text generation.
Understanding the mask pattern
The mask creates a lower triangular pattern where each token can only attend to itself and previous tokens:
- Position 0 attends to: position 0 only
- Position 1 attends to: positions 0-1
- Position 2 attends to: positions 0-2
- And so on…
The mask shape is (sequence_length, sequence_length + num_tokens). This shape is designed for KV cache compatibility during generation. The KV cache stores key and value tensors from previously generated tokens, so you only need to compute attention for new tokens while attending to both new tokens (sequence_length) and cached tokens (num_tokens). This significantly speeds up generation by avoiding recomputation.
You’ll use the following MAX operations to complete this task:
Functional decorator:
@F.functional: Converts functions to graph operations for MAX compilation
Tensor operations:
Tensor.constant(): Creates a scalar constant tensorF.broadcast_to(): Expands tensor dimensions to target shapeF.band_part(): Extracts band matrix (keeps diagonal band, zeros out rest)
Implementing the mask
You’ll create the causal mask in several steps:
-
Import required modules:
Devicefrommax.driver- specifies hardware device (CPU/GPU)DTypefrommax.dtype- data type specificationfunctionalasFfrommax.experimental- functional operations libraryTensorfrommax.experimental.tensor- tensor operationsDimfromgraph.dim- dimension handling
-
Add @F.functional decorator: This converts the function to a MAX graph operation.
-
Calculate total sequence length: Combine
sequence_lengthandnum_tokensusingDim()to determine mask width. -
Create constant tensor: Use
Tensor.constant(float("-inf"), dtype=dtype, device=device)to create a scalar that will be broadcast. -
Broadcast to target shape: Use
F.broadcast_to(mask, shape=(sequence_length, n))to expand the scalar to a 2D matrix. -
Apply band part: Use
F.band_part(mask, num_lower=None, num_upper=0, exclude=True)to create the lower triangular pattern. This keeps 0s on and below the diagonal,-infabove.
Implementation (step_02.py):
"""
Step 02: Causal Masking
Implement causal attention masking that prevents tokens from attending to future positions.
Tasks:
1. Import functional module (as F) and Tensor from max.experimental
2. Add @F.functional decorator to the causal_mask function
3. Create a constant tensor filled with negative infinity
4. Broadcast the mask to the correct shape (sequence_length, n)
5. Apply band_part to create the lower triangular causal structure
Run: pixi run s02
"""
# 1: Import the required modules from MAX
from max.driver import Device
from max.dtype import DType
# TODO: Import necessary funcional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional
# TODO: Import Tensor object from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor
from max.graph import Dim, DimLike
# 2: Add the @F.functional decorator to make this a MAX functional operation
# TODO: Add the decorator here
def causal_mask(
sequence_length: DimLike,
num_tokens: DimLike,
*,
dtype: DType,
device: Device,
):
"""Create a causal mask for autoregressive attention.
Args:
sequence_length: Length of the sequence.
num_tokens: Number of tokens.
dtype: Data type for the mask.
device: Device to create the mask on.
Returns:
A causal mask tensor.
"""
# Calculate total sequence length
n = Dim(sequence_length) + num_tokens
# 3: Create a constant tensor filled with negative infinity
# TODO: Use Tensor.constant() with float("-inf"), dtype, and device parameters
# https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.constant
# Hint: This creates the base mask value that will block attention to future tokens
mask = None
# 4: Broadcast the mask to the correct shape
# TODO: Use F.broadcast_to() to expand mask to shape (sequence_length, n)
# https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.broadcast_to
# Hint: This creates a 2D attention mask matrix
mask = None
# 5: Apply band_part to create the causal (lower triangular) structure and return the mask
# TODO: Use F.band_part() with num_lower=None, num_upper=0, exclude=True
# https://docs.modular.com/max/api/python/experimental/functional/#max.experimental.functional.band_part
# Hint: This keeps only the lower triangle, allowing attention to past tokens only
return None
Validation
Run pixi run s02 to verify your implementation.
Show solution
"""
Solution for Step 02: Causal Masking
This module implements causal attention masking that prevents tokens from
attending to future positions in autoregressive generation.
"""
from max.driver import Device
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import Dim, DimLike
@F.functional
def causal_mask(
sequence_length: DimLike,
num_tokens: DimLike,
*,
dtype: DType,
device: Device,
):
"""Create a causal mask for autoregressive attention.
Args:
sequence_length: Length of the sequence.
num_tokens: Number of tokens.
dtype: Data type for the mask.
device: Device to create the mask on.
Returns:
A causal mask tensor.
"""
n = Dim(sequence_length) + num_tokens
mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
mask = F.broadcast_to(mask, shape=(sequence_length, n))
return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)
Next: In Step 03, you’ll implement layer normalization to stabilize activations for effective training.