Step 12: Text generation

Learn to implement autoregressive text generation with sampling and temperature control.

Generating text

In this final step, you’ll implement the generation loop that produces text one token at a time. The model predicts the next token, appends it to the sequence, and repeats until reaching the desired length.

Start with a prompt like “Hello world” (tokens [15496, 995]). The model predicts the next token, giving you [15496, 995, 318] (“Hello world is”). It predicts again, producing [15496, 995, 318, 257] (“Hello world is a”). This process continues, with each prediction feeding back as input for the next.

You’ll implement two generation strategies: greedy decoding (always pick the highest-scoring token) and sampling (randomly choose according to probabilities). You’ll also add temperature control to adjust how random or focused the generation is.

Understanding the generation loop

The generation loop is simple: run the model, extract the next token prediction, append it to the sequence, repeat. Each iteration requires a full forward pass through all 12 transformer blocks.

The model outputs logits with shape [batch, seq_length, vocab_size]. Since you only care about predicting the next token, extract the last position: logits[0, -1, :]. This gives you a vector of 50,257 scores, one per vocabulary token.

These scores are logits (unnormalized), not probabilities. To convert them to probabilities, apply softmax. Then you can either pick the highest-probability token (greedy) or sample from the distribution (random).

Understanding temperature control

Temperature scaling adjusts how random the generation is using the formula scaled_logits = logits / temperature.

With temperature 1.0, you use the original distribution. With temperature 0.7, you sharpen the distribution, and high-probability tokens become even more likely, making generation more focused and deterministic. With temperature 1.2, you flatten the distribution, and lower-probability tokens get more chances, making generation more diverse and creative.

Temperature is applied before softmax. Dividing by a value less than 1 makes large logits even larger (sharpening), while dividing by a value greater than 1 reduces the differences between logits (flattening).

Understanding sampling vs greedy

Greedy decoding always picks the highest-probability token using F.argmax. It’s fast, deterministic, and simple, but often produces repetitive text because the model keeps choosing the safest option.

Sampling randomly selects tokens according to their probabilities. Convert logits to probabilities with F.softmax, transfer to CPU, convert to NumPy with np.from_dlpack, then sample with np.random.choice. You use NumPy because MAX doesn’t have built-in sampling yet.

Most practical generation uses sampling with temperature control. This balances creativity with coherence, as the model can explore different possibilities while still favoring high-quality continuations.

MAX operations

You’ll use the following MAX operations to complete this task:

Probability operations:

Sequence building:

NumPy interop:

  • probs.to(CPU()): Transfers tensor to CPU
  • np.from_dlpack(probs): Converts MAX tensor to NumPy for sampling

Implementing text generation

You’ll create two functions: generate_next_token that predicts a single token, and generate that loops to produce full sequences.

First, import the required modules. You’ll need numpy for sampling, CPU from MAX’s driver, DType for type constants, functional as F for operations like softmax and argmax, and Tensor for creating tensors.

In generate_next_token, implement the prediction logic:

  1. Run the model to get logits: logits = model(input_ids)
  2. Extract the last position (next token prediction): next_token_logits = logits[0, -1, :]
  3. If using temperature, scale the logits by dividing by the temperature tensor
  4. For sampling: convert to probabilities with F.softmax, transfer to CPU, convert to NumPy with np.from_dlpack, sample with np.random.choice, then convert back to a MAX tensor
  5. For greedy: use F.argmax to select the highest-scoring token

The temperature must be a tensor with the same dtype and device as the logits. Create it with Tensor.constant(temperature, dtype=..., device=...).

In generate, implement the generation loop:

  1. Initialize with the input: generated_tokens = input_ids
  2. Loop max_new_tokens times
  3. Generate the next token: next_token = generate_next_token(model, generated_tokens, ...)
  4. Reshape to 2D: next_token_2d = next_token.reshape([1, -1])
  5. Concatenate to the sequence: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
  6. Return the complete sequence

The reshape is necessary because concat requires matching dimensions, and the generated token is 0D (scalar).

Implementation (step_12.py):

"""
Step 12: Text Generation

Implement autoregressive text generation with sampling and temperature control.

Tasks:
1. Import required modules (numpy, F, Tensor, etc.)
2. Implement generate_next_token: get logits, apply temperature, sample/argmax
3. Implement generate_tokens: loop to generate multiple tokens

Run: pixi run s12
"""

# TODO: Import required modules
# Hint: You'll need numpy as np
# Hint: You'll need CPU from max.driver
# Hint: You'll need DType from max.dtype
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor from max.experimental.tensor


def generate_next_token(model, input_ids, temperature=1.0, do_sample=True):
    """Generate the next token given input context.

    Args:
        model: GPT-2 model with LM head
        input_ids: Current sequence, shape [batch, seq_length]
        temperature: Sampling temperature (higher = more random)
        do_sample: If True, sample from distribution; if False, use greedy (argmax)

    Returns:
        Next token ID as a Tensor
    """
    # TODO: Get logits from model
    # Hint: logits = model(input_ids)
    pass

    # TODO: Get logits for last position
    # Hint: next_token_logits = logits[0, -1, :]
    pass

    # TODO: If sampling with temperature
    if do_sample and temperature > 0:
        # TODO: Apply temperature scaling
        # Hint: temp_tensor = Tensor.constant(temperature, dtype=next_token_logits.dtype, device=next_token_logits.device)
        # Hint: next_token_logits = next_token_logits / temp_tensor
        pass

        # TODO: Convert to probabilities
        # Hint: probs = F.softmax(next_token_logits)
        pass

        # TODO: Sample from distribution
        # Hint: probs_np = np.from_dlpack(probs.to(CPU()))
        # Hint: next_token_id = np.random.choice(len(probs_np), p=probs_np)
        # Hint: next_token_tensor = Tensor.constant(next_token_id, dtype=DType.int64, device=input_ids.device)
        pass
    else:
        # TODO: Greedy decoding (select most likely token)
        # Hint: next_token_tensor = F.argmax(next_token_logits)
        pass

    # TODO: Return the next token
    return None


def generate_tokens(
    model, input_ids, max_new_tokens=10, temperature=1.0, do_sample=True
):
    """Generate multiple tokens autoregressively.

    Args:
        model: GPT-2 model with LM head
        input_ids: Initial sequence, shape [batch, seq_length]
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature
        do_sample: Whether to sample or use greedy decoding

    Returns:
        Generated sequence including input, shape [batch, seq_length + max_new_tokens]
    """
    # TODO: Initialize generated tokens with input
    # Hint: generated_tokens = input_ids
    pass

    # TODO: Generation loop
    # Hint: for _ in range(max_new_tokens):
    pass

    # TODO: Generate next token
    # Hint: next_token = generate_next_token(model, generated_tokens, temperature=temperature, do_sample=do_sample)
    pass

    # TODO: Reshape to [1, 1] for concatenation
    # Hint: next_token_2d = next_token.reshape([1, -1])
    pass

    # TODO: Append to sequence
    # Hint: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
    pass

    # TODO: Return generated sequence
    return None

Validation

Run pixi run s12 to verify your implementation.

Show solution
"""
Solution for Step 12: Text Generation

This module implements autoregressive text generation using the GPT-2 model.
"""

import numpy as np

from max.driver import CPU
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor


def generate_next_token(model, input_ids, temperature=1.0, do_sample=True):
    """Generate the next token given input context.

    Args:
        model: GPT-2 model with LM head
        input_ids: Current sequence, shape [batch, seq_length]
        temperature: Sampling temperature (higher = more random)
        do_sample: If True, sample from distribution; if False, use greedy (argmax)

    Returns:
        Next token ID as a Tensor
    """
    # Get logits from model
    logits = model(input_ids)

    # Get logits for last position (next token prediction)
    next_token_logits = logits[0, -1, :]  # Shape: [vocab_size]

    if do_sample and temperature > 0:
        # Apply temperature scaling
        temp_tensor = Tensor.constant(
            temperature, dtype=next_token_logits.dtype, device=next_token_logits.device
        )
        next_token_logits = next_token_logits / temp_tensor

        # Convert to probabilities
        probs = F.softmax(next_token_logits)

        # Sample from distribution
        probs_np = np.from_dlpack(probs.to(CPU()))
        next_token_id = np.random.choice(len(probs_np), p=probs_np)
        next_token_tensor = Tensor.constant(
            next_token_id, dtype=DType.int64, device=input_ids.device
        )
    else:
        # Greedy decoding: select most likely token
        next_token_tensor = F.argmax(next_token_logits)

    return next_token_tensor


def generate_tokens(
    model, input_ids, max_new_tokens=10, temperature=1.0, do_sample=True
):
    """Generate multiple tokens autoregressively.

    Args:
        model: GPT-2 model with LM head
        input_ids: Initial sequence, shape [batch, seq_length]
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature
        do_sample: Whether to sample or use greedy decoding

    Returns:
        Generated sequence including input, shape [batch, seq_length + max_new_tokens]
    """
    generated_tokens = input_ids

    for _ in range(max_new_tokens):
        # Generate next token
        next_token = generate_next_token(
            model, generated_tokens, temperature=temperature, do_sample=do_sample
        )

        # Reshape to [1, 1] for concatenation
        next_token_2d = next_token.reshape([1, -1])

        # Append to sequence
        generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)

    return generated_tokens

What you’ve built

You’ve completed all 12 steps and built a complete GPT-2 model from scratch using MAX. You now have a working implementation of:

Core components:

  • Model configuration and architecture definition
  • Causal masking for autoregressive generation
  • Layer normalization for training stability
  • Feed-forward networks with GELU activation
  • Token and position embeddings
  • Multi-head self-attention
  • Residual connections and transformer blocks
  • Language model head for next-token prediction
  • Text generation with temperature and sampling

Your model loads OpenAI’s pretrained GPT-2 weights and generates text. You understand how every component works, from the low-level tensor operations to the high-level architecture decisions.

What’s next?

You now understand the architectural foundation that powers modern language models. LLaMA, Mistral, and more build on these same components with incremental refinements. You have everything you need to implement those refinements yourself.

Consider extending your implementation with:

  • Grouped-query attention (GQA): Reduce memory consumption by sharing key-value pairs across multiple query heads, as used in LLaMA 2.
  • Rotary position embeddings (RoPE): Replace learned position embeddings with rotation-based encoding, improving length extrapolation in models like LLaMA and GPT-NeoX.
  • SwiGLU activation: Swap GELU for the gated linear unit variant used in LLaMA and PaLM.
  • Mixture of experts (MoE): Add sparse expert routing to scale model capacity efficiently, as in Mixtral and GPT-4.

Each refinement builds directly on what you’ve implemented. The attention mechanism you wrote becomes grouped-query attention with a simple modification to how you reshape key-value tensors. Your position embeddings can be replaced with RoPE by changing how you encode positional information. The feed-forward network you built becomes SwiGLU by adding a gating mechanism.

Pick an architecture that interests you and start building. You’ll find the patterns are familiar because the fundamentals haven’t changed.