Text generation
Learn to implement autoregressive text generation with sampling and temperature control.
In this final step, you’ll implement the generation loop that produces text one token at a time. The model predicts the next token, appends it to the sequence, feeds that into the model again, and repeats until reaching the desired length.
Start with a prompt like “Hello world” (tokens [15496, 995]). The model predicts the next token, giving you [15496, 995, 318] (“Hello world is”). It predicts again, producing [15496, 995, 318, 257] (“Hello world is a”). This process continues, with each prediction feeding back as input for the next.
You’ll implement two generation strategies: greedy decoding (always pick the highest-scoring token) and sampling (randomly choose according to probabilities). You’ll also add temperature control to adjust how random or focused the generation is—a higher temperature produces more variety (more hallucinations).
Understanding the generation loop
The generation loop is simple: run the model, extract the next token prediction, append it to the sequence, repeat. Each iteration requires a full forward pass through all 12 transformer blocks.
The model outputs logits with shape [batch, seq_length, vocab_size]. Since you only care about predicting the next token, extract the last position: logits[0, -1, :]. This gives you a vector of 50,257 scores, one per vocabulary token.
These scores are logits (unnormalized), not probabilities. To convert them to probabilities, apply softmax. Then you can either pick the highest-probability token (greedy) or sample from the distribution (random).
Understanding temperature control
Temperature scaling adjusts how random the generation is using the formula scaled_logits = logits / temperature.
For GPT-2, setting the temperature to 1.0 uses the original distribution. With temperature 0.7, you sharpen the distribution, and high-probability tokens become even more likely, making generation more focused and deterministic. With temperature 1.2, you flatten the distribution, and lower-probability tokens get more chances, making generation more diverse and creative. GPT-2 temperature must be between 0 and 2.0.
Temperature is applied before softmax. Dividing by a value less than 1 makes large logits even larger (sharpening), while dividing by a value greater than 1 reduces the differences between logits (flattening).
Understanding sampling vs greedy
Greedy decoding always picks the highest-probability token using F.argmax. It’s fast, deterministic, and simple, but often produces repetitive text because the model keeps choosing the safest option.
Sampling randomly selects tokens according to their probabilities. Convert logits to probabilities with F.softmax, transfer to CPU, convert to NumPy with np.from_dlpack, then sample with np.random.choice. You use NumPy because MAX doesn’t have built-in sampling yet.
Most practical generation uses sampling with temperature control. This balances creativity with coherence, as the model can explore different possibilities while still favoring high-quality continuations.
You’ll use the following MAX operations to complete this task:
Probability operations:
F.softmax(logits): Converts logits to probabilitiesF.argmax(logits): Selects highest-probability token (greedy)
Sequence building:
F.concat([seq, new_token], axis=1): Appends token to sequenceTensor.constant(value, dtype, device): Creates scalar tensors
NumPy interop:
probs.to(CPU()): Transfers tensor to CPUnp.from_dlpack(probs): Converts MAX tensor to NumPy for sampling
Implementing text generation
You’ll create two functions: generate_next_token that predicts a single token, and generate that loops to produce full sequences.
First, import the required modules. You’ll need numpy for sampling, CPU from MAX’s driver, DType for type constants, functional as F for operations like softmax and argmax, and Tensor for creating tensors.
In generate_next_token, implement the prediction logic:
- Run the model to get logits:
logits = model(input_ids) - Extract the last position (next token prediction):
next_token_logits = logits[0, -1, :] - If using temperature, scale the logits by dividing by the temperature tensor
- For sampling: convert to probabilities with
F.softmax, transfer to CPU, convert to NumPy withnp.from_dlpack, sample withnp.random.choice, then convert back to a MAX tensor - For greedy: use
F.argmaxto select the highest-scoring token
The temperature must be a tensor with the same dtype and device as the logits. Create it with Tensor.constant(temperature, dtype=..., device=...).
In generate, implement the generation loop:
- Initialize with the input:
generated_tokens = input_ids - Loop
max_new_tokenstimes - Generate the next token:
next_token = generate_next_token(model, generated_tokens, ...) - Reshape to 2D:
next_token_2d = next_token.reshape([1, -1]) - Concatenate to the sequence:
generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1) - Return the complete sequence
The reshape is necessary because concat requires matching dimensions, and the generated token is 0D (scalar).
Implementation (step_10.py):
"""
Step 10: Text Generation
Implement autoregressive text generation with sampling and temperature control.
Tasks:
1. Import required modules (numpy, F, Tensor, DType, CPU)
2. Implement the generate_text function with temperature scaling
3. Add sampling logic with temperature control
4. Concatenate new tokens to generate sequences
Run: pixi run s10
"""
# TODO: Import required modules
# Hint: You'll need numpy as np
# Hint: You'll need CPU from max.driver
# Hint: You'll need DType from max.dtype
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor from max.experimental.tensor
from step_09 import tokenize_text, decode_tokens
def generate_text(
model,
tokenizer,
device,
prompt: str,
max_new_tokens: int = 50,
temperature: float = 0.8,
do_sample: bool = True,
):
"""Generate text using the Max model.
Args:
model: Compiled MAX model
tokenizer: HuggingFace tokenizer
device: Device to run on
prompt: Starting text
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature (higher = more random)
do_sample: Whether to sample or use greedy decoding
Returns:
Generated text string
"""
# TODO: Tokenize the prompt text
# Hint: Use tokenize_text(prompt, tokenizer, device, max_length=100)
generated_tokens = None
print(f"Starting generation from: '{prompt}'")
print(
f"Settings: max_new_tokens={max_new_tokens}, temperature={temperature}, do_sample={do_sample}"
)
print("-" * 50)
# TODO: Implement generation loop for max_new_tokens steps
# Hint: for step in range(max_new_tokens):
pass
# TODO: Get model predictions (logits) for current sequence
# Hint: logits = model(generated_tokens)
# TODO: Extract logits for next token prediction
# Hint: next_token_logits = logits[0, -1, :]
# Note: Shape is [batch, seq_len, vocab_size], we want last position
# TODO: Apply temperature scaling if sampling
# Hint: if do_sample and temperature > 0:
# Create a temperature tensor with Tensor.constant()
# Divide next_token_logits by temperature
# Apply softmax: probs = F.softmax(next_token_logits)
# Convert to numpy: probs_np = np.from_dlpack(probs.to(CPU()))
# Sample: next_token_id = np.random.choice(len(probs_np), p=probs_np)
# Convert back to tensor: next_token_tensor = Tensor.constant(next_token_id, dtype=DType.int64, device=device)
# TODO: Use greedy decoding if not sampling
# Hint: else: next_token_tensor = F.argmax(next_token_logits)
# TODO: Reshape next token to 2D for concatenation
# Hint: next_token_2d = next_token_tensor.reshape([1, -1])
# TODO: Concatenate to growing sequence
# Hint: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
# TODO: Print progress every 5 steps
# Hint: if step % 5 == 0 or step == max_new_tokens - 1:
# current_text = decode_tokens(generated_tokens, tokenizer)
# print(f"Step {step + 1:2d}: {current_text}")
# TODO: Decode final generated sequence
# Hint: final_text = decode_tokens(generated_tokens, tokenizer)
final_text = None
print("-" * 50)
print(f"Final generated text: '{final_text}'")
return final_text
Validation
Run pixi run s10 to verify your implementation.
Show solution
"""
Solution for Step 10: Text Generation
This module implements autoregressive text generation using the GPT-2 model.
"""
import numpy as np
from max.driver import CPU
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
from solution_09 import tokenize_text, decode_tokens
def generate_text(
model,
tokenizer,
device,
prompt: str,
max_new_tokens: int = 50,
temperature: float = 0.8,
do_sample: bool = True,
):
"""Generate text using the Max model."""
generated_tokens = tokenize_text(prompt, tokenizer, device, max_length=100)
print(f"Starting generation from: '{prompt}'")
print(
f"Settings: max_new_tokens={max_new_tokens}, temperature={temperature}, do_sample={do_sample}"
)
print("-" * 50)
for step in range(max_new_tokens):
logits = model(generated_tokens)
next_token_logits = logits[0, -1, :]
if do_sample and temperature > 0:
# Simple temperature scaling without top-k
temp_tensor = Tensor.constant(
temperature,
dtype=next_token_logits.dtype,
device=next_token_logits.device,
)
next_token_logits = next_token_logits / temp_tensor
probs = F.softmax(next_token_logits)
# Convert to numpy for actual sampling
probs_np = np.from_dlpack(probs.to(CPU()))
next_token_id = np.random.choice(len(probs_np), p=probs_np)
next_token_tensor = Tensor.constant(
next_token_id, dtype=DType.int64, device=generated_tokens.device
)
else:
next_token_tensor = F.argmax(next_token_logits)
next_token_2d = next_token_tensor.reshape([1, -1])
generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
if step % 5 == 0 or step == max_new_tokens - 1:
current_text = decode_tokens(generated_tokens, tokenizer)
print(f"Step {step + 1:2d}: {current_text}")
final_text = decode_tokens(generated_tokens, tokenizer)
print("-" * 50)
print(f"Final generated text: '{final_text}'")
return final_text
Next: In Step 11, you’ll load pretrained weights and interact with your complete GPT-2 implementation!