Step 02: Causal masking

Learn to create attention masks to prevent the model from seeing future tokens during autoregressive generation.

Implementing causal masking

In this step you’ll implement the causal_mask() function. This creates a mask matrix that prevents the model from seeing future tokens when predicting the next token. The mask sets attention scores to negative infinity (-inf) for future positions. After softmax, these -inf values become zero probability, blocking information flow from later tokens.

GPT-2 generates text one token at a time, left-to-right. During training, causal masking prevents the model from “cheating” by looking ahead at tokens it should be predicting. Without this mask, the model would have access to information it won’t have during actual text generation.

Understanding the mask pattern

The mask creates a lower triangular pattern where each token can only attend to itself and previous tokens:

  • Position 0 attends to: position 0 only
  • Position 1 attends to: positions 0-1
  • Position 2 attends to: positions 0-2
  • And so on…

The mask shape is (sequence_length, sequence_length + num_tokens). This shape is designed for KV cache compatibility during generation. The KV cache stores key and value tensors from previously generated tokens, so you only need to compute attention for new tokens while attending to both new tokens (sequence_length) and cached tokens (num_tokens). This significantly speeds up generation by avoiding recomputation.

MAX operations

You’ll use the following MAX operations to complete this task:

Functional decorator:

  • @F.functional: Converts functions to graph operations for MAX compilation

Tensor operations:

Implementing the mask

You’ll create the causal mask in several steps:

  1. Import required modules:

    • Device from max.driver - specifies hardware device (CPU/GPU)
    • DType from max.dtype - data type specification
    • functional as F from max.experimental - functional operations library
    • Tensor from max.experimental.tensor - tensor operations
    • Dim from graph.dim - dimension handling
  2. Add @F.functional decorator: This converts the function to a MAX graph operation.

  3. Calculate total sequence length: Combine sequence_length and num_tokens using Dim() to determine mask width.

  4. Create constant tensor: Use Tensor.constant(float("-inf"), dtype=dtype, device=device) to create a scalar that will be broadcast.

  5. Broadcast to target shape: Use F.broadcast_to(mask, shape=(sequence_length, n)) to expand the scalar to a 2D matrix.

  6. Apply band part: Use F.band_part(mask, num_lower=None, num_upper=0, exclude=True) to create the lower triangular pattern. This keeps 0s on and below the diagonal, -inf above.

Implementation (step_02.py):

"""
Step 02: Causal Masking

Implement causal attention masking that prevents tokens from attending to future positions.

Tasks:
1. Import functional module (as F) and Tensor from max.experimental
2. Add @F.functional decorator to the causal_mask function
3. Create a constant tensor filled with negative infinity
4. Broadcast the mask to the correct shape (sequence_length, n)
5. Apply band_part to create the lower triangular causal structure

Run: pixi run s02
"""

# 1: Import the required modules from MAX
from max.driver import Device
from max.dtype import DType
# TODO: Import necessary funcional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional

# TODO: Import Tensor object from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor

from max.graph import Dim, DimLike

# 2: Add the @F.functional decorator to make this a MAX functional operation
# TODO: Add the decorator here


def causal_mask(
    sequence_length: DimLike,
    num_tokens: DimLike,
    *,
    dtype: DType,
    device: Device,
):
    """Create a causal mask for autoregressive attention.

    Args:
        sequence_length: Length of the sequence.
        num_tokens: Number of tokens.
        dtype: Data type for the mask.
        device: Device to create the mask on.

    Returns:
        A causal mask tensor.
    """
    # Calculate total sequence length
    n = Dim(sequence_length) + num_tokens

    # 3: Create a constant tensor filled with negative infinity
    # TODO: Use Tensor.constant() with float("-inf"), dtype, and device parameters
    # https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.constant
    # Hint: This creates the base mask value that will block attention to future tokens
    mask = None

    # 4: Broadcast the mask to the correct shape
    # TODO: Use F.broadcast_to() to expand mask to shape (sequence_length, n)
    # https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.broadcast_to
    # Hint: This creates a 2D attention mask matrix
    mask = None

    # 5: Apply band_part to create the causal (lower triangular) structure and return the mask
    # TODO: Use F.band_part() with num_lower=None, num_upper=0, exclude=True
    # https://docs.modular.com/max/api/python/experimental/functional/#max.experimental.functional.band_part
    # Hint: This keeps only the lower triangle, allowing attention to past tokens only
    return None

Validation

Run pixi run s02 to verify your implementation.

Show solution
"""
Solution for Step 02: Causal Masking

This module implements causal attention masking that prevents tokens from
attending to future positions in autoregressive generation.
"""

from max.driver import Device
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import Dim, DimLike


@F.functional
def causal_mask(
    sequence_length: DimLike,
    num_tokens: DimLike,
    *,
    dtype: DType,
    device: Device,
):
    """Create a causal mask for autoregressive attention.

    Args:
        sequence_length: Length of the sequence.
        num_tokens: Number of tokens.
        dtype: Data type for the mask.
        device: Device to create the mask on.

    Returns:
        A causal mask tensor.
    """
    n = Dim(sequence_length) + num_tokens
    mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
    mask = F.broadcast_to(mask, shape=(sequence_length, n))
    return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)

Next: In Step 03, you’ll implement layer normalization to stabilize activations for effective training.