Load weights and run model

Load pretrained weights from HuggingFace and prepare the model for text generation.

main() brings everything together: it loads OpenAI’s pretrained GPT-2 weights, builds the MAX model, compiles two inference heads, initializes the tokenizer, and starts an interactive session.

Loading and transposing weights

HuggingFace loads the pretrained weights with GPT2LMHeadModel.from_pretrained("gpt2"). The weights are then transferred to the MAX model via load_state_dict.

There’s one complication: HuggingFace’s GPT-2 uses Conv1D for its linear layers, which stores weights transposed relative to MAX’s Linear ([in, out] instead of [out, in]). The transposed_state loop pre-transposes the affected layers (c_attn, c_proj, c_fc) before loading, so the weights land in the correct orientation without modifying the model’s layer definitions.

Lazy initialization

The model is constructed inside F.lazy():

with F.lazy():
    max_model = MaxGPT2LMHeadModel(config)
    max_model.load_state_dict(transposed_state)

Without F.lazy(), the Embedding and Linear initializers would allocate random tensors immediately, only to discard them when load_state_dict replaces them. Inside the lazy context, those random initializations are deferred—they’re never allocated or compiled. load_state_dict installs the real HuggingFace weights directly, saving both time and memory.

Compiling two heads

The model is wrapped in GPT2SamplingHead and GPT2GreedyHead (from Section 10), then each is compiled with TensorType inputs using symbolic dimensions:

token_type = TensorType(DType.int64, ("batch", "seqlen"), device=device)
temp_type  = TensorType(dtype, [], device=device)

compiled_sampler = sampling_head.compile(token_type, temp_type)
compiled_greedy  = greedy_head.compile(token_type)

Symbolic dimensions ("batch", "seqlen") let the compiled model accept any sequence length without recompilation. Compilation takes a few seconds but only happens once per session.

The full main function

def main() -> None:
    parser = argparse.ArgumentParser(description="MAX GPT-2 text generation")
    parser.add_argument(
        "--benchmark",
        action="store_true",
        help="Run timed benchmark instead of interactive generation",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default=None,
        help="Run single generation with this prompt and exit (non-interactive)",
    )
    parser.add_argument(
        "--chat",
        action="store_true",
        help="Open a rich terminal chat session (Human vs GPT-2)",
    )
    parser.add_argument(
        "--chat-temperature",
        type=float,
        default=0.8,
        help="Sampling temperature for --chat mode (default: 0.8; lower = more focused)",
    )
    args = parser.parse_args()

    dtype, device = defaults()
    print(f"Using device: {device}, dtype: {dtype}")

    # Load HuggingFace model
    torch_dtype = torch.bfloat16 if dtype == DType.bfloat16 else torch.float32
    hf_model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch_dtype)
    print(f"Loaded HuggingFace model:\n{hf_model}")

    config = GPT2Config()
    print(
        f"Model has {config.n_layer} layers, {config.n_head} heads,"
        f" {config.n_embd} embedding dim"
    )

    # 1. Build MAX model and load weights. `defaults()` resolves `device` to
    #    GPU when one is available; input tensors and compile types both use
    #    that device so everything stays on the same device without .to().
    #    HuggingFace GPT-2 Conv1D stores weights as [in, out]; MAX Linear
    #    expects [out, in], so pre-transpose before loading.
    print("Building model and loading weights...", flush=True)
    hf_state = hf_model.state_dict()
    transposed_state: dict[str, torch.Tensor] = {}
    for name, param in hf_state.items():
        if any(k in name for k in ["c_attn", "c_proj", "c_fc"]) and name.endswith(
            ".weight"
        ):
            transposed_state[name] = param.T.contiguous()
        else:
            transposed_state[name] = param

    # F.lazy() defers all ops inside the block — random.normal in
    # Linear.__init__ / Embedding.__init__ is NEVER compiled or allocated.
    # load_state_dict replaces the lazy random tensors with the real HF
    # weights before they are ever realized.
    t0 = time.perf_counter()
    with F.lazy():
        max_model = MaxGPT2LMHeadModel(config)
        max_model.load_state_dict(transposed_state)
    print(
        f"  model init   : {(time.perf_counter() - t0) * 1e3:.0f} ms (lazy)",
        flush=True,
    )

    t0 = time.perf_counter()
    max_model.to(device)
    print(
        f"  to({device})  : {(time.perf_counter() - t0) * 1e3:.0f} ms",
        flush=True,
    )

    # 2. Wrap in compiled heads.
    sampling_head = GPT2SamplingHead(max_model)
    greedy_head = GPT2GreedyHead(max_model)

    token_type = TensorType(DType.int64, ("batch", "seqlen"), device=device)
    temp_type = TensorType(dtype, [], device=device)

    print("\nCompiling sampling model...", flush=True)
    t_compile_start = time.perf_counter()
    compiled_sampler = sampling_head.compile(token_type, temp_type)

    print("Compiling greedy model...", flush=True)
    compiled_greedy = greedy_head.compile(token_type)
    t_compile_end = time.perf_counter()
    print(f"Compile time: {t_compile_end - t_compile_start:.2f}s", flush=True)

    # Initialize tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    if args.benchmark:
        run_benchmark(compiled_sampler, tokenizer, device, dtype)
        return

    if args.prompt:
        generate_text(
            compiled_sampler,
            compiled_greedy,
            tokenizer,
            device,
            dtype,
            args.prompt,
            max_new_tokens=20,
            temperature=0.8,
            do_sample=True,
        )
        return

    if args.chat:
        chat_loop(
            compiled_sampler,
            tokenizer,
            device,
            dtype,
            temperature=args.chat_temperature,
        )
        return

    # Interactive prompt loop
    print("\n" + "=" * 50)
    print("Model ready! Enter prompts to generate text.")
    print("Press Ctrl+C or type 'quit' to exit.")
    print("=" * 50 + "\n")

    try:
        while True:
            user_input = input("Enter your prompt: ").strip()

            if user_input.lower() in ["quit", "exit", "q"]:
                print("Exiting...")
                break

            if not user_input:
                print("Please enter a non-empty prompt.\n")
                continue

            print()
            generated_text = generate_text(
                compiled_sampler,
                compiled_greedy,
                tokenizer,
                device,
                dtype,
                user_input,
                max_new_tokens=50,
                temperature=0.8,
                do_sample=True,
            )
            print(f"\nGenerated text:\n{generated_text}\n")
            print("-" * 50 + "\n")

    except KeyboardInterrupt:
        print("\n\nExiting...")


With --prompt, the model generates 20 tokens and exits. With --chat, it opens the rich terminal chat interface. Without flags, it starts an interactive prompt loop.

Running the model

pixi run gpt2
pixi run gpt2 -- --prompt "Once upon a time"
pixi run gpt2 -- --chat
pixi run gpt2 -- --benchmark

Next: Section 12 walks through the streaming chat implementation—stop sequences, BPE boundary handling, and the rich live rendering that makes the --chat mode work.