Weight adaptation
How MAX’s typed parameter interface works, and the three mappings that load GPT-2’s Hugging Face checkpoint into it.
compile(weights=state_dict) maps a dict[str, WeightData] to the named
parameters in your module. Each key must match a parameter name exactly, and
each value must carry the right shape for that parameter. MAX enforces this at
compile time: a mismatched name leaves a parameter uninitialized; a mismatched
shape fails the compile.
The adapter produces the dict[str, WeightData] that compile() requires.
Three things satisfy that contract for GPT-2: key renaming to match MAX’s
module hierarchy, matrix transposition to match Linear’s declared shape,
and an explicit copy of the tied embedding weight the checkpoint omits.
MAX’s typed parameter interface
MAX modules declare parameters by name through the class hierarchy.
MaxGPT2LMHeadModel contains a MaxGPT2Model called transformer, which
contains a list of MaxGPT2Block layers under h, each with a
MaxGPT2Attention called attn, and so on. When MAX loads weights, it walks
this hierarchy to construct the expected parameter names:
transformer.h.0.attn.c_attn.weight, transformer.h.0.ln_1.weight,
lm_head.weight.
WeightData.from_numpy(arr, name)
binds an array to one of those names. The adapter builds the output dict by
producing one
WeightData
per parameter, with the name MAX expects and an array in the shape MAX expects.
That’s the entire contract: name and shape.
For any model you bring up in MAX, this is the same pattern: declare your
modules, identify the checkpoint’s naming and layout conventions, and write an
adapter that produces dict[str, WeightData] with the keys and shapes your
modules declare. The adapter is the explicit boundary between what a checkpoint
provides and what MAX’s typed parameter interface requires.
Checkpoint mappings
Key naming: The Hugging Face checkpoint stores keys without the top-level
module name: h.0.ln_1.weight. MAX expects transformer.h.0.ln_1.weight. The
adapter prepends transformer. to any key that doesn’t already have the prefix.
Shape alignment: OpenAI trained GPT-2 with a custom Conv1D layer that
stores weight matrices as [in_features, out_features]. MAX’s Linear
declares its weight as [out_features, in_features]. Three layers are
affected: c_attn (the combined Q/K/V projection), c_proj (the attention
output projection), and c_fc (the MLP expansion). The adapter transposes
these before wrapping them in WeightData. All other weight matrices are
already in the right layout.
Tied weight: GPT-2’s safetensors file doesn’t include lm_head.weight.
The language model head shares its weight matrix with the token embedding
table, so the checkpoint omits it to save 38.6M parameters on disk. MAX’s
module declares lm_head.weight as a distinct named parameter, so the adapter
adds it by copying transformer.wte.weight into a new array under the
lm_head.weight key.
The adapter copies each weight into a fresh NumPy array rather than wrapping
the original buffer. GPT-2’s weights arrive as memory-mapped safetensors
buffers, read-only views into the file. compile() requires contiguous,
writable memory; _to_numpy() ensures that requirement is always met.
Two keys per transformer block are skipped entirely: .attn.bias and
.attn.masked_bias. These are pre-computed causal mask buffers, not trainable
parameters. The model computes its own causal mask at runtime from
causal_mask().
The adapter
convert_safetensor_state_dict() applies all three operations in a single
pass over the checkpoint keys:
from __future__ import annotations
import numpy as np
from max.graph.weights import WeightData, Weights
# Layer name suffixes that use Conv1D and need transposing
_CONV1D_LAYERS = ("c_attn", "c_proj", "c_fc")
# Keys in the safetensors that are causal-mask buffers, not parameters.
_SKIP_SUFFIXES = (".attn.bias", ".attn.masked_bias")
def _to_numpy(wd: WeightData) -> np.ndarray:
# np.from_dlpack() reads via DLPack; np.array() then copies into new,
# contiguous, writable memory — required by compile().
return np.array(np.from_dlpack(wd))
def convert_safetensor_state_dict(
state_dict: dict[str, Weights],
**unused_kwargs,
) -> dict[str, WeightData]:
result: dict[str, WeightData] = {}
for key, value in state_dict.items():
# Skip causal-mask buffers — they are not model parameters.
if any(key.endswith(suffix) for suffix in _SKIP_SUFFIXES):
continue
mapped_key = (
key if key.startswith("transformer.") else f"transformer.{key}"
)
arr = _to_numpy(value.data())
# Conv1D stores [in, out]; MAX Linear expects [out, in].
if any(
layer in mapped_key for layer in _CONV1D_LAYERS
) and mapped_key.endswith(".weight"):
arr = np.ascontiguousarray(arr.T)
result[mapped_key] = WeightData.from_numpy(arr, mapped_key)
# GPT-2 small: lm_head weight is tied to wte; add it explicitly.
wte_key = "transformer.wte.weight"
if "lm_head.weight" not in result and wte_key in result:
wte_arr = np.array(result[wte_key].data)
result["lm_head.weight"] = WeightData.from_numpy(
wte_arr, "lm_head.weight"
)
return result
The transpose condition checks two things: the key ends in .weight, and it
contains one of the three Conv1D layer names. Bias vectors, stored as
[out_features] in both conventions, don’t need transposing; only the weight
matrices do.
Next: KV cache configuration covers model_config.py,
which tells the serving layer how much cache to allocate before the first
token runs.