Text generation

Generate text autoregressively using compiled sampling and greedy decoding heads with temperature control.

Text generation is autoregressive: the model predicts one token at a time, appends it to the sequence, and feeds the extended sequence back in for the next prediction.

Start with "The quick brown fox" (a few tokens). The model predicts the next token, giving you one more word. It predicts again with that extended context. This continues until you’ve generated the desired number of tokens.

Compiled sampling heads

Before the generation loop, the implementation wraps the model in two thin heads—GPT2SamplingHead and GPT2GreedyHead—and compiles each one. The compiled callables are what the generation loop actually calls:

class GPT2SamplingHead(Module):  # type: ignore[type-arg]
    """Compiled forward: last-token log-probs scaled by temperature.

    Returns a float32 [vocab_size] log-probability tensor ready for Gumbel-max
    sampling. The float32 cast happens inside the compiled graph (zero overhead)
    so the caller can use numpy DLPack directly without any eager MAX ops.
    """

    def __init__(self, lm_head: MaxGPT2LMHeadModel) -> None:
        self.lm_head = lm_head  # no super().__init__() is needed for Module

    def forward(self, input_ids: Tensor, temperature: Tensor) -> Tensor:
        logits = self.lm_head(input_ids)  # [1, seq_len, vocab_size]
        last = logits[0, -1, :]  # [vocab_size]
        log_probs = F.logsoftmax(last / temperature)
        # Cast inside compiled graph — free; avoids eager cast op outside.
        return log_probs.cast(DType.float32)  # [vocab_size] float32 log-probs


class GPT2GreedyHead(Module):  # type: ignore[type-arg]
    """Compiled forward: greedy argmax, returns scalar token ID."""

    def __init__(self, lm_head: MaxGPT2LMHeadModel) -> None:
        self.lm_head = lm_head

    def forward(self, input_ids: Tensor) -> Tensor:
        logits = self.lm_head(input_ids)  # [1, seq_len, vocab_size]
        return F.argmax(logits[0, -1, :])  # scalar int64 token id


GPT2SamplingHead.forward takes input_ids and a temperature tensor. It runs the full model, extracts the last position’s logits, divides by temperature, and returns log-probabilities as float32—all inside the compiled graph, with no eager MAX ops outside the graph boundary.

GPT2GreedyHead.forward is simpler: it runs the model and returns F.argmax of the last-position logits as a scalar token ID.

Compiling these heads (done in Section 11’s main) lets MAX optimize the full forward pass—embedding lookups, 12 transformer blocks, layer norm, and the projection—into a single efficient execution plan.

Gumbel-max sampling

For stochastic generation, the implementation uses Gumbel-max sampling rather than calling np.random.choice on a probability distribution. The two approaches are mathematically equivalent, but Gumbel-max is faster: add independent Gumbel noise to log-probabilities, then take the argmax.

One GPU→CPU transfer (via DLPack, zero-copy) and a few NumPy operations on 50,257 floats takes about 3 μs—negligible compared to the model forward pass.

The generation loop

def generate_text(
    sampler: Callable[[Tensor, Tensor], Tensor],
    greedy: Callable[[Tensor], Tensor],
    tokenizer: GPT2Tokenizer,
    device: Device,
    dtype: DType,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 0.8,
    do_sample: bool = True,
    seed: int = 0,
) -> str:
    """Generate text using compiled MAX models.

    Args:
        sampler: Compiled GPT2SamplingHead — returns log-probs for stochastic
            decoding. Called as ``sampler(input_ids, temperature_tensor)``.
        greedy: Compiled GPT2GreedyHead — returns scalar token ID for greedy
            decoding. Called as ``greedy(input_ids)``.
        tokenizer: HuggingFace GPT-2 tokenizer.
        device: Target device for input tensor construction.
        dtype: Dtype for the temperature scalar.
        prompt: Text prompt to continue.
        max_new_tokens: Maximum number of new tokens to generate.
        temperature: Sampling temperature (ignored when do_sample=False).
        do_sample: If True, use Gumbel-max stochastic sampling; else greedy.
        seed: Initial RNG seed for reproducibility.

    Returns:
        The full generated string (prompt + new tokens), decoded.
    """
    token_ids: list[int] = tokenizer.encode(prompt, max_length=100, truncation=True)
    temperature_tensor = Tensor(temperature, dtype=dtype, device=device)
    rng_state = seed

    print(f"Starting generation from: '{prompt}'")
    print(
        f"Settings: max_new_tokens={max_new_tokens}, temperature={temperature},"
        f" do_sample={do_sample}"
    )
    print("-" * 50)

    for step in range(max_new_tokens):
        input_tensor = _make_token_tensor(token_ids, device)

        if do_sample:
            # Compiled: all deterministic NN ops → [vocab_size] log-probs
            log_probs = sampler(input_tensor, temperature_tensor)
            # Gumbel-max: one GPU→CPU transfer + fast numpy (~3μs for 50K floats)
            rng = np.random.default_rng(rng_state)
            token_id = _gumbel_sample(log_probs, rng)
            rng_state += 1
        else:
            token_id = int(greedy(input_tensor).item())

        token_ids.append(int(token_id))

        if step % 5 == 0 or step == max_new_tokens - 1:
            current_text = decode_tokens(token_ids, tokenizer)
            print(f"Step {step + 1:2d}: {current_text}")

    final_text = decode_tokens(token_ids, tokenizer)
    print("-" * 50)
    print(f"Final generated text: '{final_text}'")
    return final_text


Each step:

  1. Build a [1, seq_len] int64 tensor from the current token list using np.from_dlpack (zero-copy from numpy).
  2. If sampling: call the compiled sampler, apply Gumbel noise in numpy, take argmax.
  3. If greedy: call the compiled greedy head directly.
  4. Append the new token ID to the Python list and repeat.

rng_state is incremented each step so consecutive tokens use different random seeds while still being reproducible from the initial seed.

Temperature

Temperature scales the log-probabilities before sampling: log_probs / temperature.

  • Lower temperature (e.g. 0.5): sharpens the distribution—the model strongly favors its top predictions, producing more focused text.
  • Higher temperature (e.g. 1.2): flattens the distribution—lower-ranked tokens get more chances, producing more varied or surprising text.
  • Temperature = 1.0: uses the model’s unmodified distribution.

Next: Section 11 loads the pretrained weights and wires everything together into a runnable model.