Weight adaptation

How MAX’s typed parameter interface works, and the three mappings that load GPT-2’s Hugging Face checkpoint into it.

compile(weights=state_dict) maps a dict[str, WeightData] to the named parameters in your module. Each key must match a parameter name exactly, and each value must carry the right shape for that parameter. MAX enforces this at compile time: a mismatched name leaves a parameter uninitialized; a mismatched shape fails the compile.

The adapter produces the dict[str, WeightData] that compile() requires. Three things satisfy that contract for GPT-2: key renaming to match MAX’s module hierarchy, matrix transposition to match Linear’s declared shape, and an explicit copy of the tied embedding weight the checkpoint omits.

MAX’s typed parameter interface

MAX modules declare parameters by name through the class hierarchy. MaxGPT2LMHeadModel contains a MaxGPT2Model called transformer, which contains a list of MaxGPT2Block layers under h, each with a MaxGPT2Attention called attn, and so on. When MAX loads weights, it walks this hierarchy to construct the expected parameter names: transformer.h.0.attn.c_attn.weight, transformer.h.0.ln_1.weight, lm_head.weight.

WeightData.from_numpy(arr, name) binds an array to one of those names. The adapter builds the output dict by producing one WeightData per parameter, with the name MAX expects and an array in the shape MAX expects. That’s the entire contract: name and shape.

For any model you bring up in MAX, this is the same pattern: declare your modules, identify the checkpoint’s naming and layout conventions, and write an adapter that produces dict[str, WeightData] with the keys and shapes your modules declare. The adapter is the explicit boundary between what a checkpoint provides and what MAX’s typed parameter interface requires.

Checkpoint mappings

Key naming: The Hugging Face checkpoint stores keys without the top-level module name: h.0.ln_1.weight. MAX expects transformer.h.0.ln_1.weight. The adapter prepends transformer. to any key that doesn’t already have the prefix.

Shape alignment: OpenAI trained GPT-2 with a custom Conv1D layer that stores weight matrices as [in_features, out_features]. MAX’s Linear declares its weight as [out_features, in_features]. Three layers are affected: c_attn (the combined Q/K/V projection), c_proj (the attention output projection), and c_fc (the MLP expansion). The adapter transposes these before wrapping them in WeightData. All other weight matrices are already in the right layout.

Tied weight: GPT-2’s safetensors file doesn’t include lm_head.weight. The language model head shares its weight matrix with the token embedding table, so the checkpoint omits it to save 38.6M parameters on disk. MAX’s module declares lm_head.weight as a distinct named parameter, so the adapter adds it by copying transformer.wte.weight into a new array under the lm_head.weight key.

The adapter copies each weight into a fresh NumPy array rather than wrapping the original buffer. GPT-2’s weights arrive as memory-mapped safetensors buffers, read-only views into the file. compile() requires contiguous, writable memory; _to_numpy() ensures that requirement is always met.

Two keys per transformer block are skipped entirely: .attn.bias and .attn.masked_bias. These are pre-computed causal mask buffers, not trainable parameters. The model computes its own causal mask at runtime from causal_mask().

The adapter

convert_safetensor_state_dict() applies all three operations in a single pass over the checkpoint keys:

weight_adapters.py
from __future__ import annotations

import numpy as np
from max.graph.weights import WeightData, Weights

# Layer name suffixes that use Conv1D and need transposing
_CONV1D_LAYERS = ("c_attn", "c_proj", "c_fc")

# Keys in the safetensors that are causal-mask buffers, not parameters.
_SKIP_SUFFIXES = (".attn.bias", ".attn.masked_bias")


def _to_numpy(wd: WeightData) -> np.ndarray:
    # np.from_dlpack() reads via DLPack; np.array() then copies into new,
    # contiguous, writable memory — required by compile().
    return np.array(np.from_dlpack(wd))


def convert_safetensor_state_dict(
    state_dict: dict[str, Weights],
    **unused_kwargs,
) -> dict[str, WeightData]:
    result: dict[str, WeightData] = {}

    for key, value in state_dict.items():
        # Skip causal-mask buffers — they are not model parameters.
        if any(key.endswith(suffix) for suffix in _SKIP_SUFFIXES):
            continue

        mapped_key = (
            key if key.startswith("transformer.") else f"transformer.{key}"
        )
        arr = _to_numpy(value.data())

        # Conv1D stores [in, out]; MAX Linear expects [out, in].
        if any(
            layer in mapped_key for layer in _CONV1D_LAYERS
        ) and mapped_key.endswith(".weight"):
            arr = np.ascontiguousarray(arr.T)

        result[mapped_key] = WeightData.from_numpy(arr, mapped_key)

    # GPT-2 small: lm_head weight is tied to wte; add it explicitly.
    wte_key = "transformer.wte.weight"
    if "lm_head.weight" not in result and wte_key in result:
        wte_arr = np.array(result[wte_key].data)
        result["lm_head.weight"] = WeightData.from_numpy(
            wte_arr, "lm_head.weight"
        )

    return result


The transpose condition checks two things: the key ends in .weight, and it contains one of the three Conv1D layer names. Bias vectors, stored as [out_features] in both conventions, don’t need transposing; only the weight matrices do.

Next: KV cache configuration covers model_config.py, which tells the serving layer how much cache to allocate before the first token runs.