Text generation
Learn to implement autoregressive text generation with sampling and temperature control.
In this final step, you’ll implement the generation loop that produces text one token at a time. The model predicts the next token, appends it to the sequence, feeds that into the model again, and repeats until reaching the desired length.
Start with a prompt like “Hello world” (tokens [15496, 995]). The model
predicts the next token, giving you [15496, 995, 318] (“Hello world is”). It
predicts again, producing [15496, 995, 318, 257] (“Hello world is a”). This
process continues, with each prediction feeding back as input for the next.
You’ll implement two generation strategies: greedy decoding (always pick the highest-scoring token) and sampling (randomly choose according to probabilities). You’ll also add temperature control to adjust how random or focused the generation is—a higher temperature produces more variety (more hallucinations).
Understanding the generation loop
The generation loop is simple: run the model, extract the next token prediction, append it to the sequence, repeat. Each iteration requires a full forward pass through all 12 transformer blocks.
The model outputs logits with shape [batch, seq_length, vocab_size]. Since you
only care about predicting the next token, extract the last position:
logits[0, -1, :]. This gives you a vector of 50,257 scores, one per vocabulary
token.
These scores are logits (unnormalized), not probabilities. To convert them to probabilities, apply softmax. Then you can either pick the highest-probability token (greedy) or sample from the distribution (random).
Understanding temperature control
Temperature scaling adjusts how random the generation is using the formula
scaled_logits = logits / temperature.
For GPT-2, setting the temperature to 1.0 uses the original distribution. With temperature 0.7, you sharpen the distribution, and high-probability tokens become even more likely, making generation more focused and deterministic. With temperature 1.2, you flatten the distribution, and lower-probability tokens get more chances, making generation more diverse and creative. GPT-2 temperature must be between 0 and 2.0.
Temperature is applied before softmax. Dividing by a value less than 1 makes large logits even larger (sharpening), while dividing by a value greater than 1 reduces the differences between logits (flattening).
Understanding sampling vs greedy
Greedy decoding always picks the highest-probability token using
F.argmax.
It’s fast, deterministic, and simple, but often produces repetitive text because
the model keeps choosing the safest option.
Sampling randomly selects tokens according to their probabilities. Convert
logits to probabilities with F.softmax, transfer to CPU, convert to NumPy with
np.from_dlpack, ensure the array is 1D and has float dtype (required by
np.random.choice), then sample with np.random.choice. You use NumPy because
MAX doesn’t have built-in sampling yet.
Most practical generation uses sampling with temperature control. This balances creativity with coherence, as the model can explore different possibilities while still favoring high-quality continuations.
You’ll use the following MAX operations to complete this task:
Probability operations:
F.softmax(logits): Converts logits to probabilitiesF.argmax(logits): Selects highest-probability token (greedy)
Sequence building:
F.concat([seq, new_token], axis=1): Appends token to sequenceTensor.constant(value, dtype, device): Creates scalar tensors
NumPy interop:
probs.to(CPU()): Transfers tensor to CPUnp.from_dlpack(probs): Converts MAX tensor to NumPy for sampling
Implementing text generation
You’ll create two functions: generate_next_token that predicts a single token,
and generate that loops to produce full sequences.
First, import the required modules. You’ll need numpy for sampling, CPU from
MAX’s driver, DType for type constants, functional as F for operations like
softmax and argmax, and Tensor for creating tensors.
In generate_next_token, implement the prediction logic:
- Run the model to get logits:
logits = model(input_ids) - Extract the last position (next token prediction):
next_token_logits = logits[0, -1, :] - If using temperature, scale the logits by dividing by the temperature tensor
- For sampling: convert to probabilities with
F.softmax, transfer to CPU, convert to NumPy with explicit type annotation (probs_np: np.ndarray), flatten if needed, convert to float64 with.astype(np.float64)(required bynp.random.choice), sample withnp.random.choice, then convert back to a MAX tensor - For greedy: use
F.argmaxto select the highest-scoring token
The temperature must be a tensor with the same dtype and device as the logits.
Create it with Tensor.constant(temperature, dtype=..., device=...).
In generate, implement the generation loop:
- Initialize with the input:
generated_tokens = input_ids - Loop
max_new_tokenstimes - Generate the next token:
next_token = generate_next_token(model, generated_tokens, ...) - Reshape to 2D:
next_token_2d = next_token.reshape([1, -1]) - Concatenate to the sequence:
generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1) - Return the complete sequence
The reshape is necessary because concat requires matching dimensions, and the
generated token is 0D (scalar).
Implementation (step_10.py):
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Step 10: Text Generation
Implement autoregressive text generation with sampling and temperature control.
Tasks:
1. Import required modules (numpy, F, Tensor, DType, CPU)
2. Implement the generate_text function with temperature scaling
3. Add sampling logic with temperature control
4. Concatenate new tokens to generate sequences
Run: pixi run s10
"""
# TODO: Import required modules
# Hint: You'll need numpy as np
# Hint: You'll need CPU from max.driver
# Hint: You'll need DType from max.dtype
# Hint: You'll need functional as F from max.nn
# Hint: You'll need Tensor from max.tensor
from max.driver import Device
from max.nn import Module
from transformers import GPT2Tokenizer
def generate_text(
model: Module,
tokenizer: GPT2Tokenizer,
device: Device,
prompt: str,
max_new_tokens: int = 50,
temperature: float = 0.8,
do_sample: bool = True,
) -> str:
"""Generate text using the Max model.
Args:
model: Compiled MAX model
tokenizer: HuggingFace tokenizer
device: Device to run on
prompt: Starting text
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature (higher = more random)
do_sample: Whether to sample or use greedy decoding
Returns:
Generated text string
"""
# TODO: Tokenize the prompt text
# Hint: Use encode_text(prompt, tokenizer, device, max_length=100)
generated_tokens = None
print(f"Starting generation from: '{prompt}'")
print(
f"Settings: max_new_tokens={max_new_tokens}, temperature={temperature}, do_sample={do_sample}"
)
print("-" * 50)
# TODO: Implement generation loop for max_new_tokens steps
# Hint: for step in range(max_new_tokens):
pass
# TODO: Get model predictions (logits) for current sequence
# Hint: logits = model(generated_tokens)
# TODO: Extract logits for next token prediction
# Hint: next_token_logits = logits[0, -1, :]
# Note: Shape is [batch, seq_len, vocab_size], we want last position
# TODO: Apply temperature scaling if sampling
# Hint: if do_sample and temperature > 0:
# Create a temperature tensor with Tensor.constant()
# Divide next_token_logits by temperature
# Apply softmax: probs = F.softmax(next_token_logits)
# Convert to numpy with explicit type annotation: probs_np: np.ndarray = np.from_dlpack(probs.to(CPU()))
# Ensure it's 1D: if probs_np.ndim > 1: probs_np = probs_np.flatten()
# Convert to float for np.random.choice: probs_np = probs_np.astype(np.float64)
# Sample: next_token_id = np.random.choice(len(probs_np), p=probs_np)
# Convert back to tensor: next_token_tensor = Tensor.constant(next_token_id, dtype=DType.int64, device=generated_tokens.device)
# Note: np.random.choice requires p to be a 1D float array
# TODO: Use greedy decoding if not sampling
# Hint: else: next_token_tensor = F.argmax(next_token_logits)
# TODO: Reshape next token to 2D for concatenation
# Hint: next_token_2d = next_token_tensor.reshape([1, -1])
# TODO: Concatenate to growing sequence
# Hint: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
# TODO: Print progress every 5 steps
# Hint: if step % 5 == 0 or step == max_new_tokens - 1:
# current_text = decode_tokens(generated_tokens, tokenizer)
# print(f"Step {step + 1:2d}: {current_text}")
# TODO: Decode final generated sequence
# Hint: final_text = decode_tokens(generated_tokens, tokenizer)
final_text = None
print("-" * 50)
print(f"Final generated text: '{final_text}'")
return final_text
Validation
Run pixi run s10 to verify your implementation.
Show solution
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Solution for Step 10: Text Generation
This module implements autoregressive text generation using the GPT-2 model.
"""
import max.functional as F
import numpy as np
from max.driver import CPU, Device
from max.dtype import DType
from max.nn import Module
from max.tensor import Tensor
from step_09 import decode_tokens, encode_text
from transformers import GPT2Tokenizer
def generate_text(
model: Module,
tokenizer: GPT2Tokenizer,
device: Device,
prompt: str,
max_new_tokens: int = 50,
temperature: float = 0.8,
do_sample: bool = True,
) -> str:
"""Generate text using the Max model."""
generated_tokens = encode_text(prompt, tokenizer, device, max_length=100)
print(f"Starting generation from: '{prompt}'")
print(
f"Settings: max_new_tokens={max_new_tokens}, temperature={temperature}, do_sample={do_sample}"
)
print("-" * 50)
for step in range(max_new_tokens):
logits = model(generated_tokens)
next_token_logits = logits[0, -1, :]
if do_sample and temperature > 0:
# Simple temperature scaling without top-k
temp_tensor = Tensor.constant(
temperature,
dtype=next_token_logits.dtype,
device=next_token_logits.device,
)
next_token_logits = next_token_logits / temp_tensor
probs = F.softmax(next_token_logits)
# Convert to numpy for actual sampling
# Explicitly convert to 1D float array for np.random.choice
probs_np: np.ndarray = np.from_dlpack(probs.to(CPU()))
if probs_np.ndim > 1:
probs_np = probs_np.flatten()
probs_np = probs_np.astype(np.float64)
next_token_id = np.random.choice(len(probs_np), p=probs_np)
next_token_tensor = Tensor.constant(
next_token_id, dtype=DType.int64, device=generated_tokens.device
)
else:
next_token_tensor = F.argmax(next_token_logits)
next_token_2d = next_token_tensor.reshape([1, -1])
generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
if step % 5 == 0 or step == max_new_tokens - 1:
current_text = decode_tokens(generated_tokens, tokenizer)
print(f"Step {step + 1:2d}: {current_text}")
final_text = decode_tokens(generated_tokens, tokenizer)
print("-" * 50)
print(f"Final generated text: '{final_text}'")
return final_text
Next: In Step 11, you’ll load pretrained weights and interact with your complete GPT-2 implementation!