Step 08: Residual connections and layer normalization

Learn to implement residual connections and layer normalization to enable training deep transformer networks.

Building the residual pattern

In this step, you’ll combine residual connections and layer normalization into a reusable pattern for transformer blocks. Residual connections add the input directly to the output using output = input + layer(input), creating shortcuts that let gradients flow through deep networks. You’ll implement this alongside the layer normalization from Step 03.

GPT-2 uses pre-norm architecture where layer norm is applied before each sublayer (attention or MLP). The pattern is x = x + sublayer(layer_norm(x)): normalize first, process, then add the original input back. This is more stable than post-norm alternatives for deep networks.

Residual connections solve the vanishing gradient problem. During backpropagation, gradients flow through the identity path (x = x + ...) without being multiplied by layer weights. This allows training networks with 12+ layers. Without residuals, gradients would diminish exponentially as they propagate through many layers.

Layer normalization works identically during training and inference because it normalizes each example independently. No batch statistics, no running averages, just consistent normalization that keeps activation distributions stable throughout training.

Understanding the pattern

The pre-norm residual pattern combines three operations in sequence:

Layer normalization: Normalize the input with F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). This uses learnable weight (gamma) and bias (beta) parameters to scale and shift the normalized values. You already implemented this in Step 03.

Sublayer processing: Pass the normalized input through a sublayer (attention or MLP). The sublayer transforms the data while the layer norm keeps its input well-conditioned.

Residual addition: Add the original input back to the sublayer output using simple element-wise addition: x + sublayer_output. Both tensors must have identical shapes [batch, seq_length, embed_dim].

The complete pattern is x = x + sublayer(layer_norm(x)). This differs from post-norm x = layer_norm(x + sublayer(x)), as pre-norm is more stable because normalization happens before potentially unstable sublayer operations.

MAX operations

You’ll use the following MAX operations to complete this task:

Layer normalization:

Tensor initialization:

Implementing the pattern

You’ll implement three classes that demonstrate the residual pattern: LayerNorm for normalization, ResidualBlock that combines norm and residual addition, and a standalone apply_residual_connection function.

First, import the required modules. You’ll need functional as F for layer norm, Tensor for parameters, DimLike for type hints, and Module as the base class.

LayerNorm implementation:

In __init__, create the learnable parameters:

  • Weight: Tensor.ones([dim]) stored as self.weight
  • Bias: Tensor.zeros([dim]) stored as self.bias
  • Store eps for numerical stability

In forward, apply normalization with F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps). Returns a normalized tensor with the same shape as input.

ResidualBlock implementation:

In __init__, create a LayerNorm instance: self.ln = LayerNorm(dim, eps=eps). This will normalize inputs before sublayers.

In forward, implement the pre-norm pattern:

  1. Normalize: normalized = self.ln(x)
  2. Process: sublayer_output = sublayer(normalized)
  3. Add residual: return x + sublayer_output

Standalone function:

Implement apply_residual_connection(input_tensor, sublayer_output) that returns input_tensor + sublayer_output. This demonstrates the residual pattern as a simple function.

Implementation (step_08.py):

"""
Step 08: Residual Connections and Layer Normalization

Implement layer normalization and residual connections, which enable
training deep transformer networks by stabilizing gradients.

Tasks:
1. Import F (functional), Tensor, DimLike, and Module
2. Create LayerNorm class with learnable weight and bias parameters
3. Implement layer norm using F.layer_norm
4. Implement residual connection (simple addition)

Run: pixi run s08
"""

# TODO: Import required modules
# Hint: You'll need F from max.experimental
# Hint: You'll need Tensor from max.experimental.tensor
# Hint: You'll need DimLike from max.graph
# Hint: You'll need Module from max.nn.module_v3


class LayerNorm(Module):
    """Layer normalization module matching HuggingFace GPT-2."""

    def __init__(self, dim: DimLike, *, eps: float = 1e-5):
        """Initialize layer normalization.

        Args:
            dim: Dimension to normalize (embedding dimension)
            eps: Small epsilon for numerical stability
        """
        super().__init__()
        self.eps = eps

        # TODO: Create learnable scale parameter (weight)
        # Hint: Use Tensor.ones([dim])
        self.weight = None

        # TODO: Create learnable shift parameter (bias)
        # Hint: Use Tensor.zeros([dim])
        self.bias = None

    def __call__(self, x: Tensor) -> Tensor:
        """Apply layer normalization.

        Args:
            x: Input tensor, shape [..., dim]

        Returns:
            Normalized tensor, same shape as input
        """
        # TODO: Apply layer normalization
        # Hint: Use F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)
        return None


class ResidualBlock(Module):
    """Demonstrates residual connections with layer normalization."""

    def __init__(self, dim: int, eps: float = 1e-5):
        """Initialize residual block.

        Args:
            dim: Dimension of the input/output
            eps: Epsilon for layer normalization
        """
        super().__init__()

        # TODO: Create layer normalization
        # Hint: Use LayerNorm(dim, eps=eps)
        self.ln = None

    def __call__(self, x: Tensor, sublayer_output: Tensor) -> Tensor:
        """Apply residual connection.

        Args:
            x: Input tensor (the residual)
            sublayer_output: Output from sublayer applied to ln(x)

        Returns:
            x + sublayer_output
        """
        # TODO: Add input and sublayer output (residual connection)
        # Hint: return x + sublayer_output
        return None


def apply_residual_connection(input_tensor: Tensor, sublayer_output: Tensor) -> Tensor:
    """Apply a residual connection by adding input to sublayer output.

    Args:
        input_tensor: Original input (the residual)
        sublayer_output: Output from a sublayer (attention, MLP, etc.)

    Returns:
        input_tensor + sublayer_output
    """
    # TODO: Add the two tensors
    # Hint: return input_tensor + sublayer_output
    return None

Validation

Run pixi run s08 to verify your implementation.

Show solution
"""
Solution for Step 08: Residual Connections and Layer Normalization

This module implements layer normalization and demonstrates residual connections,
which are essential for training deep transformer networks.
"""

from max.experimental import functional as F
from max.experimental.tensor import Tensor
from max.graph import DimLike
from max.nn.module_v3 import Module


class LayerNorm(Module):
    """Layer normalization module matching HuggingFace GPT-2.

    Layer norm normalizes activations across the feature dimension,
    stabilizing training and allowing deeper networks.
    """

    def __init__(self, dim: DimLike, *, eps: float = 1e-5):
        """Initialize layer normalization.

        Args:
            dim: Dimension to normalize (embedding dimension)
            eps: Small epsilon for numerical stability
        """
        super().__init__()
        self.eps = eps
        # Learnable scale parameter (gamma)
        self.weight = Tensor.ones([dim])
        # Learnable shift parameter (beta)
        self.bias = Tensor.zeros([dim])

    def __call__(self, x: Tensor) -> Tensor:
        """Apply layer normalization.

        Args:
            x: Input tensor, shape [..., dim]

        Returns:
            Normalized tensor, same shape as input
        """
        return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)


class ResidualBlock(Module):
    """Demonstrates residual connections with layer normalization.

    This shows the pre-norm architecture used in GPT-2:
    output = input + sublayer(layer_norm(input))
    """

    def __init__(self, dim: int, eps: float = 1e-5):
        """Initialize residual block.

        Args:
            dim: Dimension of the input/output
            eps: Epsilon for layer normalization
        """
        super().__init__()
        self.ln = LayerNorm(dim, eps=eps)

    def __call__(self, x: Tensor, sublayer_output: Tensor) -> Tensor:
        """Apply residual connection.

        This demonstrates the pattern:
        1. Normalize input: ln(x)
        2. Apply sublayer (passed as argument for simplicity)
        3. Add residual: x + sublayer_output

        In practice, the sublayer (attention or MLP) is applied to ln(x),
        but we receive the result as a parameter for clarity.

        Args:
            x: Input tensor (the residual)
            sublayer_output: Output from sublayer applied to ln(x)

        Returns:
            x + sublayer_output
        """
        # In a real transformer block, you would do:
        # residual = x
        # x = self.ln(x)
        # x = sublayer(x)  # e.g., attention or MLP
        # x = x + residual

        # For this demonstration, we just add
        return x + sublayer_output


def apply_residual_connection(input_tensor: Tensor, sublayer_output: Tensor) -> Tensor:
    """Apply a residual connection by adding input to sublayer output.

    Residual connections allow gradients to flow directly through the network,
    enabling training of very deep models.

    Args:
        input_tensor: Original input (the residual)
        sublayer_output: Output from a sublayer (attention, MLP, etc.)

    Returns:
        input_tensor + sublayer_output
    """
    return input_tensor + sublayer_output

Next: In Step 09, you’ll combine multi-head attention, MLP, layer norm, and residual connections into a complete transformer block.