Layer normalization
Learn to implement layer normalization for stabilizing neural network training.
In this step, you’ll create the LayerNorm class that normalizes activations
across the feature dimension. For each input, you compute the mean and variance
across all features, normalize by subtracting the mean and dividing by the
standard deviation, then apply learned weight and bias parameters to scale and
shift the result.
Unlike batch normalization, layer normalization works independently for each example. This makes it ideal for transformers - no dependence on batch size, no tracking running statistics during inference, and consistent behavior between training and generation.
GPT-2 applies layer normalization before the attention and MLP blocks in each of its 12 transformer layers. This pre-normalization pattern stabilizes training in deep networks by keeping activations in a consistent range.
While layer normalization is most critical during training to stabilize gradients and prevent activations from exploding or vanishing, it’s still required during inference. The pretrained GPT-2 model we’re loading was trained with layer normalization - its learned weights and biases expect normalized inputs. Skipping layer normalization during inference would cause activations to be in completely different ranges than what the model learned during training, leading to poor or nonsensical outputs.
Understanding the operation
Layer normalization normalizes across the feature dimension (the last dimension) independently for each example. It learns two parameters per feature: weight (gamma) for scaling and bias (beta) for shifting.
The normalization follows this formula:
output = weight * (x - mean) / sqrt(variance + epsilon) + bias
The mean and variance are computed across all features in each example. After normalizing to zero mean and unit variance, the learned weight scales the result and the learned bias shifts it. The epsilon value (typically 1e-5) prevents division by zero when variance is very small.
You’ll use the following MAX operations to complete this task:
Modules:
Module: The Module class used for eager tensors
Tensor initialization:
Tensor.ones(): Creates tensor filled with 1.0 valuesTensor.zeros(): Creates tensor filled with 0.0 values
Layer normalization:
F.layer_norm(): Applies layer normalization with parameters:input,gamma(weight),beta(bias), andepsilon
Implementing layer normalization
You’ll create the LayerNorm class that wraps MAX’s layer normalization function
with learnable parameters. The implementation is straightforward - two
parameters and a single function call.
First, import the required modules. You’ll need functional as F for the layer
norm operation and Tensor for creating parameters.
In the __init__ method, create two learnable parameters:
- Weight:
Tensor.ones([dim])stored asself.weight- initialized to ones so the initial transformation is identity - Bias:
Tensor.zeros([dim])stored asself.bias- initialized to zeros so there’s no initial shift
Store the epsilon value as self.eps for numerical stability.
In the forward method, apply layer normalization with
F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). This
computes the normalization and applies the learned parameters in one operation.
Implementation (step_05.py):
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Step 05: Layer Normalization
Implement layer normalization that normalizes activations for training stability.
Tasks:
1. Import functional module (as F) and Tensor from max.nn
2. Initialize learnable weight (gamma) and bias (beta) parameters
3. Apply layer normalization using F.layer_norm in the forward pass
Run: pixi run s05
"""
# 1: Import the required modules from MAX
# TODO: Import functional module from max.nn with the alias F
# https://docs.modular.com/max/api/python/nn/functional
# TODO: Import Tensor from max.tensor
# https://docs.modular.com/max/api/python/tensor.Tensor
from max.graph import DimLike
from max.nn import Module
from max.tensor import Tensor
class LayerNorm(Module):
"""Layer normalization module.
Args:
dim: Dimension to normalize over.
eps: Epsilon for numerical stability.
"""
def __init__(self, dim: DimLike, *, eps: float = 1e-5) -> None:
super().__init__()
self.eps = eps
# 2: Initialize learnable weight and bias parameters
# TODO: Create self.weight as a Tensor of ones with shape [dim]
# https://docs.modular.com/max/api/python/tensor#max.tensor.Tensor.ones
# Hint: This is the gamma parameter in layer normalization
self.weight = None
# TODO: Create self.bias as a Tensor of zeros with shape [dim]
# https://docs.modular.com/max/api/python/tensor#max.tensor.Tensor.zeros
# Hint: This is the beta parameter in layer normalization
self.bias = None
def forward(self, x: Tensor) -> Tensor:
"""Apply layer normalization.
Args:
x: Input tensor.
Returns:
Normalized tensor.
"""
# 3: Apply layer normalization and return the result
# TODO: Use F.layer_norm() with x, gamma=self.weight, beta=self.bias, epsilon=self.eps
# https://docs.modular.com/max/api/python/nn/functional#max.nn.functional.layer_norm
# Hint: Layer normalization normalizes across the last dimension
return None
Validation
Run pixi run s05 to verify your implementation.
Show solution
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Solution for Step 05: Layer Normalization
This module implements layer normalization that normalizes activations
across the embedding dimension for training stability.
"""
import max.functional as F
from max.graph import DimLike
from max.nn import Module
from max.tensor import Tensor
class LayerNorm(Module):
"""Layer normalization module.
Args:
dim: Dimension to normalize over.
eps: Epsilon for numerical stability.
"""
def __init__(self, dim: DimLike, *, eps: float = 1e-5) -> None:
super().__init__()
self.eps = eps
self.weight = Tensor.ones([dim])
self.bias = Tensor.zeros([dim])
def forward(self, x: Tensor) -> Tensor:
"""Apply layer normalization.
Args:
x: Input tensor.
Returns:
Normalized tensor.
"""
return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)
Next: In Step 06, you’ll combine multi-head attention, MLP, layer norm, and residual connections into a complete transformer block.