Text generation

Learn to implement autoregressive text generation with sampling and temperature control.

In this final step, you’ll implement the generation loop that produces text one token at a time. The model predicts the next token, appends it to the sequence, feeds that into the model again, and repeats until reaching the desired length.

Start with a prompt like “Hello world” (tokens [15496, 995]). The model predicts the next token, giving you [15496, 995, 318] (“Hello world is”). It predicts again, producing [15496, 995, 318, 257] (“Hello world is a”). This process continues, with each prediction feeding back as input for the next.

You’ll implement two generation strategies: greedy decoding (always pick the highest-scoring token) and sampling (randomly choose according to probabilities). You’ll also add temperature control to adjust how random or focused the generation is—a higher temperature produces more variety (more hallucinations).

Understanding the generation loop

The generation loop is simple: run the model, extract the next token prediction, append it to the sequence, repeat. Each iteration requires a full forward pass through all 12 transformer blocks.

The model outputs logits with shape [batch, seq_length, vocab_size]. Since you only care about predicting the next token, extract the last position: logits[0, -1, :]. This gives you a vector of 50,257 scores, one per vocabulary token.

These scores are logits (unnormalized), not probabilities. To convert them to probabilities, apply softmax. Then you can either pick the highest-probability token (greedy) or sample from the distribution (random).

Understanding temperature control

Temperature scaling adjusts how random the generation is using the formula scaled_logits = logits / temperature.

For GPT-2, setting the temperature to 1.0 uses the original distribution. With temperature 0.7, you sharpen the distribution, and high-probability tokens become even more likely, making generation more focused and deterministic. With temperature 1.2, you flatten the distribution, and lower-probability tokens get more chances, making generation more diverse and creative. GPT-2 temperature must be between 0 and 2.0.

Temperature is applied before softmax. Dividing by a value less than 1 makes large logits even larger (sharpening), while dividing by a value greater than 1 reduces the differences between logits (flattening).

Understanding sampling vs greedy

Greedy decoding always picks the highest-probability token using F.argmax. It’s fast, deterministic, and simple, but often produces repetitive text because the model keeps choosing the safest option.

Sampling randomly selects tokens according to their probabilities. Convert logits to probabilities with F.softmax, transfer to CPU, convert to NumPy with np.from_dlpack, ensure the array is 1D and has float dtype (required by np.random.choice), then sample with np.random.choice. You use NumPy because MAX doesn’t have built-in sampling yet.

Most practical generation uses sampling with temperature control. This balances creativity with coherence, as the model can explore different possibilities while still favoring high-quality continuations.

MAX operations

You’ll use the following MAX operations to complete this task:

Probability operations:

Sequence building:

NumPy interop:

  • probs.to(CPU()): Transfers tensor to CPU
  • np.from_dlpack(probs): Converts MAX tensor to NumPy for sampling

Implementing text generation

You’ll create two functions: generate_next_token that predicts a single token, and generate that loops to produce full sequences.

First, import the required modules. You’ll need numpy for sampling, CPU from MAX’s driver, DType for type constants, functional as F for operations like softmax and argmax, and Tensor for creating tensors.

In generate_next_token, implement the prediction logic:

  1. Run the model to get logits: logits = model(input_ids)
  2. Extract the last position (next token prediction): next_token_logits = logits[0, -1, :]
  3. If using temperature, scale the logits by dividing by the temperature tensor
  4. For sampling: convert to probabilities with F.softmax, transfer to CPU, convert to NumPy with explicit type annotation (probs_np: np.ndarray), flatten if needed, convert to float64 with .astype(np.float64) (required by np.random.choice), sample with np.random.choice, then convert back to a MAX tensor
  5. For greedy: use F.argmax to select the highest-scoring token

The temperature must be a tensor with the same dtype and device as the logits. Create it with Tensor.constant(temperature, dtype=..., device=...).

In generate, implement the generation loop:

  1. Initialize with the input: generated_tokens = input_ids
  2. Loop max_new_tokens times
  3. Generate the next token: next_token = generate_next_token(model, generated_tokens, ...)
  4. Reshape to 2D: next_token_2d = next_token.reshape([1, -1])
  5. Concatenate to the sequence: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
  6. Return the complete sequence

The reshape is necessary because concat requires matching dimensions, and the generated token is 0D (scalar).

Implementation (step_10.py):

# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Step 10: Text Generation

Implement autoregressive text generation with sampling and temperature control.

Tasks:
1. Import required modules (numpy, F, Tensor, DType, CPU)
2. Implement the generate_text function with temperature scaling
3. Add sampling logic with temperature control
4. Concatenate new tokens to generate sequences

Run: pixi run s10
"""

# TODO: Import required modules
# Hint: You'll need numpy as np
# Hint: You'll need CPU from max.driver
# Hint: You'll need DType from max.dtype
# Hint: You'll need functional as F from max.nn
# Hint: You'll need Tensor from max.tensor

from max.driver import Device
from max.nn import Module
from transformers import GPT2Tokenizer


def generate_text(
    model: Module,
    tokenizer: GPT2Tokenizer,
    device: Device,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 0.8,
    do_sample: bool = True,
) -> str:
    """Generate text using the Max model.

    Args:
        model: Compiled MAX model
        tokenizer: HuggingFace tokenizer
        device: Device to run on
        prompt: Starting text
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        do_sample: Whether to sample or use greedy decoding

    Returns:
        Generated text string
    """
    # TODO: Tokenize the prompt text
    # Hint: Use encode_text(prompt, tokenizer, device, max_length=100)
    generated_tokens = None

    print(f"Starting generation from: '{prompt}'")
    print(
        f"Settings: max_new_tokens={max_new_tokens}, temperature={temperature}, do_sample={do_sample}"
    )
    print("-" * 50)

    # TODO: Implement generation loop for max_new_tokens steps
    # Hint: for step in range(max_new_tokens):
    pass

    # TODO: Get model predictions (logits) for current sequence
    # Hint: logits = model(generated_tokens)

    # TODO: Extract logits for next token prediction
    # Hint: next_token_logits = logits[0, -1, :]
    # Note: Shape is [batch, seq_len, vocab_size], we want last position

    # TODO: Apply temperature scaling if sampling
    # Hint: if do_sample and temperature > 0:
    #     Create a temperature tensor with Tensor.constant()
    #     Divide next_token_logits by temperature
    #     Apply softmax: probs = F.softmax(next_token_logits)
    #     Convert to numpy with explicit type annotation: probs_np: np.ndarray = np.from_dlpack(probs.to(CPU()))
    #     Ensure it's 1D: if probs_np.ndim > 1: probs_np = probs_np.flatten()
    #     Convert to float for np.random.choice: probs_np = probs_np.astype(np.float64)
    #     Sample: next_token_id = np.random.choice(len(probs_np), p=probs_np)
    #     Convert back to tensor: next_token_tensor = Tensor.constant(next_token_id, dtype=DType.int64, device=generated_tokens.device)
    # Note: np.random.choice requires p to be a 1D float array

    # TODO: Use greedy decoding if not sampling
    # Hint: else: next_token_tensor = F.argmax(next_token_logits)

    # TODO: Reshape next token to 2D for concatenation
    # Hint: next_token_2d = next_token_tensor.reshape([1, -1])

    # TODO: Concatenate to growing sequence
    # Hint: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)

    # TODO: Print progress every 5 steps
    # Hint: if step % 5 == 0 or step == max_new_tokens - 1:
    #     current_text = decode_tokens(generated_tokens, tokenizer)
    #     print(f"Step {step + 1:2d}: {current_text}")

    # TODO: Decode final generated sequence
    # Hint: final_text = decode_tokens(generated_tokens, tokenizer)
    final_text = None

    print("-" * 50)
    print(f"Final generated text: '{final_text}'")
    return final_text

Validation

Run pixi run s10 to verify your implementation.

Show solution
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Solution for Step 10: Text Generation

This module implements autoregressive text generation using the GPT-2 model.
"""

import max.functional as F
import numpy as np
from max.driver import CPU, Device
from max.dtype import DType
from max.nn import Module
from max.tensor import Tensor
from step_09 import decode_tokens, encode_text
from transformers import GPT2Tokenizer


def generate_text(
    model: Module,
    tokenizer: GPT2Tokenizer,
    device: Device,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 0.8,
    do_sample: bool = True,
) -> str:
    """Generate text using the Max model."""
    generated_tokens = encode_text(prompt, tokenizer, device, max_length=100)

    print(f"Starting generation from: '{prompt}'")
    print(
        f"Settings: max_new_tokens={max_new_tokens}, temperature={temperature}, do_sample={do_sample}"
    )
    print("-" * 50)

    for step in range(max_new_tokens):
        logits = model(generated_tokens)
        next_token_logits = logits[0, -1, :]

        if do_sample and temperature > 0:
            # Simple temperature scaling without top-k
            temp_tensor = Tensor.constant(
                temperature,
                dtype=next_token_logits.dtype,
                device=next_token_logits.device,
            )
            next_token_logits = next_token_logits / temp_tensor
            probs = F.softmax(next_token_logits)

            # Convert to numpy for actual sampling
            # Explicitly convert to 1D float array for np.random.choice
            probs_np: np.ndarray = np.from_dlpack(probs.to(CPU()))
            if probs_np.ndim > 1:
                probs_np = probs_np.flatten()
            probs_np = probs_np.astype(np.float64)
            next_token_id = np.random.choice(len(probs_np), p=probs_np)
            next_token_tensor = Tensor.constant(
                next_token_id, dtype=DType.int64, device=generated_tokens.device
            )
        else:
            next_token_tensor = F.argmax(next_token_logits)

        next_token_2d = next_token_tensor.reshape([1, -1])
        generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)

        if step % 5 == 0 or step == max_new_tokens - 1:
            current_text = decode_tokens(generated_tokens, tokenizer)
            print(f"Step {step + 1:2d}: {current_text}")

    final_text = decode_tokens(generated_tokens, tokenizer)
    print("-" * 50)
    print(f"Final generated text: '{final_text}'")
    return final_text

Next: In Step 11, you’ll load pretrained weights and interact with your complete GPT-2 implementation!