Load weights and run model
Load pretrained weights from HuggingFace and prepare the model for text generation.
main() brings everything together: it loads OpenAI’s pretrained GPT-2
weights, builds the MAX model, compiles two inference heads, initializes the
tokenizer, and starts an interactive session.
Loading and transposing weights
HuggingFace loads the pretrained weights with
GPT2LMHeadModel.from_pretrained("gpt2"). The weights are then transferred to
the MAX model via load_state_dict.
There’s one complication: HuggingFace’s GPT-2 uses Conv1D for its linear
layers, which stores weights transposed relative to MAX’s Linear
([in, out] instead of [out, in]). The transposed_state loop
pre-transposes the affected layers (c_attn, c_proj, c_fc) before loading,
so the weights land in the correct orientation without modifying the model’s
layer definitions.
Lazy initialization
The model is constructed inside F.lazy():
with F.lazy():
max_model = MaxGPT2LMHeadModel(config)
max_model.load_state_dict(transposed_state)
Without F.lazy(), the Embedding and Linear initializers would allocate
random tensors immediately, only to discard them when load_state_dict replaces
them. Inside the lazy context, those random initializations are deferred—they’re
never allocated or compiled. load_state_dict installs the real HuggingFace
weights directly, saving both time and memory.
Compiling two heads
The model is wrapped in GPT2SamplingHead and GPT2GreedyHead (from Section
10), then each is compiled with TensorType inputs using symbolic dimensions:
token_type = TensorType(DType.int64, ("batch", "seqlen"), device=device)
temp_type = TensorType(dtype, [], device=device)
compiled_sampler = sampling_head.compile(token_type, temp_type)
compiled_greedy = greedy_head.compile(token_type)
Symbolic dimensions ("batch", "seqlen") let the compiled model accept any
sequence length without recompilation. Compilation takes a few seconds but only
happens once per session.
The full main function
def main() -> None:
parser = argparse.ArgumentParser(description="MAX GPT-2 text generation")
parser.add_argument(
"--benchmark",
action="store_true",
help="Run timed benchmark instead of interactive generation",
)
parser.add_argument(
"--prompt",
type=str,
default=None,
help="Run single generation with this prompt and exit (non-interactive)",
)
parser.add_argument(
"--chat",
action="store_true",
help="Open a rich terminal chat session (Human vs GPT-2)",
)
parser.add_argument(
"--chat-temperature",
type=float,
default=0.8,
help="Sampling temperature for --chat mode (default: 0.8; lower = more focused)",
)
args = parser.parse_args()
dtype, device = defaults()
print(f"Using device: {device}, dtype: {dtype}")
# Load HuggingFace model
torch_dtype = torch.bfloat16 if dtype == DType.bfloat16 else torch.float32
hf_model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch_dtype)
print(f"Loaded HuggingFace model:\n{hf_model}")
config = GPT2Config()
print(
f"Model has {config.n_layer} layers, {config.n_head} heads,"
f" {config.n_embd} embedding dim"
)
# 1. Build MAX model and load weights. `defaults()` resolves `device` to
# GPU when one is available; input tensors and compile types both use
# that device so everything stays on the same device without .to().
# HuggingFace GPT-2 Conv1D stores weights as [in, out]; MAX Linear
# expects [out, in], so pre-transpose before loading.
print("Building model and loading weights...", flush=True)
hf_state = hf_model.state_dict()
transposed_state: dict[str, torch.Tensor] = {}
for name, param in hf_state.items():
if any(k in name for k in ["c_attn", "c_proj", "c_fc"]) and name.endswith(
".weight"
):
transposed_state[name] = param.T.contiguous()
else:
transposed_state[name] = param
# F.lazy() defers all ops inside the block — random.normal in
# Linear.__init__ / Embedding.__init__ is NEVER compiled or allocated.
# load_state_dict replaces the lazy random tensors with the real HF
# weights before they are ever realized.
t0 = time.perf_counter()
with F.lazy():
max_model = MaxGPT2LMHeadModel(config)
max_model.load_state_dict(transposed_state)
print(
f" model init : {(time.perf_counter() - t0) * 1e3:.0f} ms (lazy)",
flush=True,
)
t0 = time.perf_counter()
max_model.to(device)
print(
f" to({device}) : {(time.perf_counter() - t0) * 1e3:.0f} ms",
flush=True,
)
# 2. Wrap in compiled heads.
sampling_head = GPT2SamplingHead(max_model)
greedy_head = GPT2GreedyHead(max_model)
token_type = TensorType(DType.int64, ("batch", "seqlen"), device=device)
temp_type = TensorType(dtype, [], device=device)
print("\nCompiling sampling model...", flush=True)
t_compile_start = time.perf_counter()
compiled_sampler = sampling_head.compile(token_type, temp_type)
print("Compiling greedy model...", flush=True)
compiled_greedy = greedy_head.compile(token_type)
t_compile_end = time.perf_counter()
print(f"Compile time: {t_compile_end - t_compile_start:.2f}s", flush=True)
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
if args.benchmark:
run_benchmark(compiled_sampler, tokenizer, device, dtype)
return
if args.prompt:
generate_text(
compiled_sampler,
compiled_greedy,
tokenizer,
device,
dtype,
args.prompt,
max_new_tokens=20,
temperature=0.8,
do_sample=True,
)
return
if args.chat:
chat_loop(
compiled_sampler,
tokenizer,
device,
dtype,
temperature=args.chat_temperature,
)
return
# Interactive prompt loop
print("\n" + "=" * 50)
print("Model ready! Enter prompts to generate text.")
print("Press Ctrl+C or type 'quit' to exit.")
print("=" * 50 + "\n")
try:
while True:
user_input = input("Enter your prompt: ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
print("Exiting...")
break
if not user_input:
print("Please enter a non-empty prompt.\n")
continue
print()
generated_text = generate_text(
compiled_sampler,
compiled_greedy,
tokenizer,
device,
dtype,
user_input,
max_new_tokens=50,
temperature=0.8,
do_sample=True,
)
print(f"\nGenerated text:\n{generated_text}\n")
print("-" * 50 + "\n")
except KeyboardInterrupt:
print("\n\nExiting...")
With --prompt, the model generates 20 tokens and exits. With --chat, it
opens the rich terminal chat interface. Without flags, it starts an interactive
prompt loop.
Running the model
pixi run gpt2
pixi run gpt2 -- --prompt "Once upon a time"
pixi run gpt2 -- --chat
pixi run gpt2 -- --benchmark
Next: Section 12 walks through the streaming chat
implementation—stop sequences, BPE boundary handling, and the rich live
rendering that makes the --chat mode work.