Pipeline model
How model.py loads the compiled GPT-2 graph, runs it on each decode step,
and manages the growing token sequence.
To extend a pipeline and connect it to a serving layer, you’ll need to subclass
PipelineModelWithKVCache
and tell it how to load your model, run each decode step, and manage the token
sequence.
Load the model
Every
Linear
and
Embedding
layer in MaxGPT2LMHeadModel allocates tensors when constructed. Without the
lazy context, those allocations fill with random values and are immediately
discarded when the checkpoint loads.
F.lazy()
defers all allocation inside the block: layers are declared, but nothing is
allocated until compile() runs.
default_device()
and
default_dtype()
set context variables that module construction code reads inside the lazy block,
so layers pick up the right device and numeric type without being passed them
explicitly.
compile()
runs outside the lazy block. Loading safetensors buffers inside
F.lazy()
triggers the same memory alignment error that _to_numpy() in
weight_adapters.py solves by copying into a fresh array. Passing
weights=state_dict to compile() loads and compiles in one step after the
lazy context closes:
def _load_model(
self,
weights: Weights,
adapter: WeightsAdapter | None,
) -> Any:
hf_config = self.huggingface_config
device = self.devices[0]
state_dict = parse_state_dict_from_weights(
self.pipeline_config, weights, adapter
)
with F.lazy(), default_device(device), default_dtype(self.dtype):
gpt2_module = MaxGPT2LMHeadModel(hf_config)
gpt2_module.to(device)
token_type = TensorType(
DType.int64, ("batch", "seq_len"), device=device
)
return gpt2_module.compile(token_type, weights=state_dict)
Execute a step
execute() receives a GPT2Inputs, a dataclass with one field: tokens,
a [1, seq_len] int64 Buffer containing all token IDs for the current
sequence.
Tensor.from_dlpack()
converts the driver Buffer to a MAX Tensor without copying. The compiled
model returns [1, seq_len, vocab_size]: one logit vector per position. Only
the final position’s logits are needed to sample the next token, so the output
is narrowed to [1, vocab_size] before being handed to MAX’s serving
infrastructure, which handles sampling: temperature scaling, top-p filtering,
and token selection:
def execute(self, model_inputs: ModelInputs) -> ModelOutputs:
assert isinstance(model_inputs, GPT2Inputs)
input_tensor = Tensor.from_dlpack(model_inputs.tokens).to(
self.devices[0]
)
all_logits: Tensor = self.model(input_tensor)
last_logits_np: np.ndarray = np.from_dlpack(all_logits.to(CPU()))
last_logits_np = np.ascontiguousarray(last_logits_np[0, -1:, :])
last_buf = Buffer.from_numpy(last_logits_np).to(self.devices[0])
return ModelOutputs(logits=last_buf, next_token_logits=last_buf)
Manage the token sequence
On the first step (prefill), prepare_initial_token_inputs() reads the full
prompt from ctx.tokens.all and packages it as GPT2Inputs. On each decode
step, prepare_next_token_inputs() appends the newly sampled token to the
previous token array and returns the extended sequence.
Because GPT-2 has no incremental KV cache, every decode step re-processes the full token history from position 0. Generating 30 tokens from a 10-token prompt means the 11th decode step processes 20 tokens, the 12th processes 21, and so on. The implementation stays simple at the cost of efficiency: compute grows linearly with sequence length.
def prepare_initial_token_inputs(
self,
replica_batches: Sequence[Sequence[TextContext]],
kv_cache_inputs: KVCacheInputs[Buffer, Buffer] | None = None,
return_n_logits: int = 1,
) -> GPT2Inputs:
_ = return_n_logits # PipelineModel API; last-token logits only in `execute`.
ctx = replica_batches[0][0]
token_ids = _tokens_from_context(ctx)
inputs = _make_gpt2_inputs(token_ids, self.devices[0])
inputs.kv_cache_inputs = kv_cache_inputs
return inputs
def prepare_next_token_inputs(
self,
next_tokens: Buffer,
prev_model_inputs: ModelInputs,
) -> GPT2Inputs:
assert isinstance(prev_model_inputs, GPT2Inputs)
prev_np: np.ndarray = np.from_dlpack(prev_model_inputs.tokens.to(CPU()))
new_token_np: np.ndarray = np.from_dlpack(next_tokens.to(CPU()))
new_token = int(new_token_np.ravel()[0])
extended = np.concatenate([prev_np.ravel(), [new_token]])[np.newaxis, :]
inputs = _make_gpt2_inputs(extended.ravel().tolist(), self.devices[0])
inputs.kv_cache_inputs = prev_model_inputs.kv_cache_inputs
return inputs
Next: Architecture registration covers arch.py and
__init__.py, the three-line contract that plugs the whole package into
max serve.