Step 03: Layer normalization

Learn to implement layer normalization for stabilizing neural network training.

Building layer normalization

In this step, you’ll create the LayerNorm class that normalizes activations across the feature dimension. For each input, you compute the mean and variance across all features, normalize by subtracting the mean and dividing by the standard deviation, then apply learned weight and bias parameters to scale and shift the result.

Unlike batch normalization, layer normalization works independently for each example. This makes it ideal for transformers - no dependence on batch size, no tracking running statistics during inference, and consistent behavior between training and generation.

GPT-2 applies layer normalization before the attention and MLP blocks in each of its 12 transformer layers. This pre-normalization pattern stabilizes training in deep networks by keeping activations in a consistent range.

Understanding the operation

Layer normalization normalizes across the feature dimension (the last dimension) independently for each example. It learns two parameters per feature: weight (gamma) for scaling and bias (beta) for shifting.

The normalization follows this formula:

output = weight * (x - mean) / sqrt(variance + epsilon) + bias

The mean and variance are computed across all features in each example. After normalizing to zero mean and unit variance, the learned weight scales the result and the learned bias shifts it. The epsilon value (typically 1e-5) prevents division by zero when variance is very small.

MAX operations

You’ll use the following MAX operations to complete this task:

Modules:

  • Module: The Module class used for eager tensors

Tensor initialization:

Layer normalization:

  • F.layer_norm(): Applies layer normalization with parameters: input, gamma (weight), beta (bias), and epsilon

Implementing layer normalization

You’ll create the LayerNorm class that wraps MAX’s layer normalization function with learnable parameters. The implementation is straightforward - two parameters and a single function call.

First, import the required modules. You’ll need functional as F for the layer norm operation and Tensor for creating parameters.

In the __init__ method, create two learnable parameters:

  • Weight: Tensor.ones([dim]) stored as self.weight - initialized to ones so the initial transformation is identity
  • Bias: Tensor.zeros([dim]) stored as self.bias - initialized to zeros so there’s no initial shift

Store the epsilon value as self.eps for numerical stability.

In the forward method, apply layer normalization with F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). This computes the normalization and applies the learned parameters in one operation.

Implementation (step_03.py):

"""
Step 03: Layer Normalization

Implement layer normalization that normalizes activations for training stability.

Tasks:
1. Import functional module (as F) and Tensor from max.experimental
2. Initialize learnable weight (gamma) and bias (beta) parameters
3. Apply layer normalization using F.layer_norm in the forward pass

Run: pixi run s03
"""

# 1: Import the required modules from MAX
# TODO: Import functional module from max.experimental with the alias F
# https://docs.modular.com/max/api/python/experimental/functional

# TODO: Import Tensor from max.experimental.tensor
# https://docs.modular.com/max/api/python/experimental/tensor.Tensor

from max.graph import DimLike
from max.nn.module_v3 import Module


class LayerNorm(Module):
    """Layer normalization module.

    Args:
        dim: Dimension to normalize over.
        eps: Epsilon for numerical stability.
    """

    def __init__(self, dim: DimLike, *, eps: float = 1e-5):
        super().__init__()
        self.eps = eps

        # 2: Initialize learnable weight and bias parameters
        # TODO: Create self.weight as a Tensor of ones with shape [dim]
        # https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.ones
        # Hint: This is the gamma parameter in layer normalization
        self.weight = None

        # TODO: Create self.bias as a Tensor of zeros with shape [dim]
        # https://docs.modular.com/max/api/python/experimental/tensor#max.experimental.tensor.Tensor.zeros
        # Hint: This is the beta parameter in layer normalization
        self.bias = None

    def __call__(self, x: Tensor) -> Tensor:
        """Apply layer normalization.

        Args:
            x: Input tensor.

        Returns:
            Normalized tensor.
        """
        # 3: Apply layer normalization and return the result
        # TODO: Use F.layer_norm() with x, gamma=self.weight, beta=self.bias, epsilon=self.eps
        # https://docs.modular.com/max/api/python/experimental/functional#max.experimental.functional.layer_norm
        # Hint: Layer normalization normalizes across the last dimension
        return None

Validation

Run pixi run s03 to verify your implementation.

Show solution
"""
Solution for Step 03: Layer Normalization

This module implements layer normalization that normalizes activations
across the embedding dimension for training stability.
"""

from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import DimLike
from max.nn.module_v3 import Module


class LayerNorm(Module):
    """Layer normalization module.

    Args:
        dim: Dimension to normalize over.
        eps: Epsilon for numerical stability.
    """

    def __init__(self, dim: DimLike, *, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = Tensor.ones([dim])
        self.bias = Tensor.zeros([dim])

    def __call__(self, x: Tensor) -> Tensor:
        """Apply layer normalization.

        Args:
            x: Input tensor.

        Returns:
            Normalized tensor.
        """
        return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)

Next: In Step 04, you’ll implement the feed-forward network (MLP) with GELU activation used in each transformer block.