Step 12: Text generation
Learn to implement autoregressive text generation with sampling and temperature control.
Generating text
In this final step, you’ll implement the generation loop that produces text one token at a time. The model predicts the next token, appends it to the sequence, and repeats until reaching the desired length.
Start with a prompt like “Hello world” (tokens [15496, 995]). The model predicts the next token, giving you [15496, 995, 318] (“Hello world is”). It predicts again, producing [15496, 995, 318, 257] (“Hello world is a”). This process continues, with each prediction feeding back as input for the next.
You’ll implement two generation strategies: greedy decoding (always pick the highest-scoring token) and sampling (randomly choose according to probabilities). You’ll also add temperature control to adjust how random or focused the generation is.
Understanding the generation loop
The generation loop is simple: run the model, extract the next token prediction, append it to the sequence, repeat. Each iteration requires a full forward pass through all 12 transformer blocks.
The model outputs logits with shape [batch, seq_length, vocab_size]. Since you only care about predicting the next token, extract the last position: logits[0, -1, :]. This gives you a vector of 50,257 scores, one per vocabulary token.
These scores are logits (unnormalized), not probabilities. To convert them to probabilities, apply softmax. Then you can either pick the highest-probability token (greedy) or sample from the distribution (random).
Understanding temperature control
Temperature scaling adjusts how random the generation is using the formula scaled_logits = logits / temperature.
With temperature 1.0, you use the original distribution. With temperature 0.7, you sharpen the distribution, and high-probability tokens become even more likely, making generation more focused and deterministic. With temperature 1.2, you flatten the distribution, and lower-probability tokens get more chances, making generation more diverse and creative.
Temperature is applied before softmax. Dividing by a value less than 1 makes large logits even larger (sharpening), while dividing by a value greater than 1 reduces the differences between logits (flattening).
Understanding sampling vs greedy
Greedy decoding always picks the highest-probability token using F.argmax. It’s fast, deterministic, and simple, but often produces repetitive text because the model keeps choosing the safest option.
Sampling randomly selects tokens according to their probabilities. Convert logits to probabilities with F.softmax, transfer to CPU, convert to NumPy with np.from_dlpack, then sample with np.random.choice. You use NumPy because MAX doesn’t have built-in sampling yet.
Most practical generation uses sampling with temperature control. This balances creativity with coherence, as the model can explore different possibilities while still favoring high-quality continuations.
You’ll use the following MAX operations to complete this task:
Probability operations:
F.softmax(logits): Converts logits to probabilitiesF.argmax(logits): Selects highest-probability token (greedy)
Sequence building:
F.concat([seq, new_token], axis=1): Appends token to sequenceTensor.constant(value, dtype, device): Creates scalar tensors
NumPy interop:
probs.to(CPU()): Transfers tensor to CPUnp.from_dlpack(probs): Converts MAX tensor to NumPy for sampling
Implementing text generation
You’ll create two functions: generate_next_token that predicts a single token, and generate that loops to produce full sequences.
First, import the required modules. You’ll need numpy for sampling, CPU from MAX’s driver, DType for type constants, functional as F for operations like softmax and argmax, and Tensor for creating tensors.
In generate_next_token, implement the prediction logic:
- Run the model to get logits:
logits = model(input_ids) - Extract the last position (next token prediction):
next_token_logits = logits[0, -1, :] - If using temperature, scale the logits by dividing by the temperature tensor
- For sampling: convert to probabilities with
F.softmax, transfer to CPU, convert to NumPy withnp.from_dlpack, sample withnp.random.choice, then convert back to a MAX tensor - For greedy: use
F.argmaxto select the highest-scoring token
The temperature must be a tensor with the same dtype and device as the logits. Create it with Tensor.constant(temperature, dtype=..., device=...).
In generate, implement the generation loop:
- Initialize with the input:
generated_tokens = input_ids - Loop
max_new_tokenstimes - Generate the next token:
next_token = generate_next_token(model, generated_tokens, ...) - Reshape to 2D:
next_token_2d = next_token.reshape([1, -1]) - Concatenate to the sequence:
generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1) - Return the complete sequence
The reshape is necessary because concat requires matching dimensions, and the generated token is 0D (scalar).
Implementation (step_12.py):
"""
Step 12: Text Generation
Implement autoregressive text generation with sampling and temperature control.
Tasks:
1. Import required modules (numpy, F, Tensor, etc.)
2. Implement generate_next_token: get logits, apply temperature, sample/argmax
3. Implement generate_tokens: loop to generate multiple tokens
Run: pixi run s12
"""
# TODO: Import required modules
# Hint: You'll need numpy as np
# Hint: You'll need CPU from max.driver
# Hint: You'll need DType from max.dtype
# Hint: You'll need functional as F from max.experimental
# Hint: You'll need Tensor from max.experimental.tensor
def generate_next_token(model, input_ids, temperature=1.0, do_sample=True):
"""Generate the next token given input context.
Args:
model: GPT-2 model with LM head
input_ids: Current sequence, shape [batch, seq_length]
temperature: Sampling temperature (higher = more random)
do_sample: If True, sample from distribution; if False, use greedy (argmax)
Returns:
Next token ID as a Tensor
"""
# TODO: Get logits from model
# Hint: logits = model(input_ids)
pass
# TODO: Get logits for last position
# Hint: next_token_logits = logits[0, -1, :]
pass
# TODO: If sampling with temperature
if do_sample and temperature > 0:
# TODO: Apply temperature scaling
# Hint: temp_tensor = Tensor.constant(temperature, dtype=next_token_logits.dtype, device=next_token_logits.device)
# Hint: next_token_logits = next_token_logits / temp_tensor
pass
# TODO: Convert to probabilities
# Hint: probs = F.softmax(next_token_logits)
pass
# TODO: Sample from distribution
# Hint: probs_np = np.from_dlpack(probs.to(CPU()))
# Hint: next_token_id = np.random.choice(len(probs_np), p=probs_np)
# Hint: next_token_tensor = Tensor.constant(next_token_id, dtype=DType.int64, device=input_ids.device)
pass
else:
# TODO: Greedy decoding (select most likely token)
# Hint: next_token_tensor = F.argmax(next_token_logits)
pass
# TODO: Return the next token
return None
def generate_tokens(
model, input_ids, max_new_tokens=10, temperature=1.0, do_sample=True
):
"""Generate multiple tokens autoregressively.
Args:
model: GPT-2 model with LM head
input_ids: Initial sequence, shape [batch, seq_length]
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature
do_sample: Whether to sample or use greedy decoding
Returns:
Generated sequence including input, shape [batch, seq_length + max_new_tokens]
"""
# TODO: Initialize generated tokens with input
# Hint: generated_tokens = input_ids
pass
# TODO: Generation loop
# Hint: for _ in range(max_new_tokens):
pass
# TODO: Generate next token
# Hint: next_token = generate_next_token(model, generated_tokens, temperature=temperature, do_sample=do_sample)
pass
# TODO: Reshape to [1, 1] for concatenation
# Hint: next_token_2d = next_token.reshape([1, -1])
pass
# TODO: Append to sequence
# Hint: generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
pass
# TODO: Return generated sequence
return None
Validation
Run pixi run s12 to verify your implementation.
Show solution
"""
Solution for Step 12: Text Generation
This module implements autoregressive text generation using the GPT-2 model.
"""
import numpy as np
from max.driver import CPU
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor
def generate_next_token(model, input_ids, temperature=1.0, do_sample=True):
"""Generate the next token given input context.
Args:
model: GPT-2 model with LM head
input_ids: Current sequence, shape [batch, seq_length]
temperature: Sampling temperature (higher = more random)
do_sample: If True, sample from distribution; if False, use greedy (argmax)
Returns:
Next token ID as a Tensor
"""
# Get logits from model
logits = model(input_ids)
# Get logits for last position (next token prediction)
next_token_logits = logits[0, -1, :] # Shape: [vocab_size]
if do_sample and temperature > 0:
# Apply temperature scaling
temp_tensor = Tensor.constant(
temperature, dtype=next_token_logits.dtype, device=next_token_logits.device
)
next_token_logits = next_token_logits / temp_tensor
# Convert to probabilities
probs = F.softmax(next_token_logits)
# Sample from distribution
probs_np = np.from_dlpack(probs.to(CPU()))
next_token_id = np.random.choice(len(probs_np), p=probs_np)
next_token_tensor = Tensor.constant(
next_token_id, dtype=DType.int64, device=input_ids.device
)
else:
# Greedy decoding: select most likely token
next_token_tensor = F.argmax(next_token_logits)
return next_token_tensor
def generate_tokens(
model, input_ids, max_new_tokens=10, temperature=1.0, do_sample=True
):
"""Generate multiple tokens autoregressively.
Args:
model: GPT-2 model with LM head
input_ids: Initial sequence, shape [batch, seq_length]
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature
do_sample: Whether to sample or use greedy decoding
Returns:
Generated sequence including input, shape [batch, seq_length + max_new_tokens]
"""
generated_tokens = input_ids
for _ in range(max_new_tokens):
# Generate next token
next_token = generate_next_token(
model, generated_tokens, temperature=temperature, do_sample=do_sample
)
# Reshape to [1, 1] for concatenation
next_token_2d = next_token.reshape([1, -1])
# Append to sequence
generated_tokens = F.concat([generated_tokens, next_token_2d], axis=1)
return generated_tokens
What you’ve built
You’ve completed all 12 steps and built a complete GPT-2 model from scratch using MAX. You now have a working implementation of:
Core components:
- Model configuration and architecture definition
- Causal masking for autoregressive generation
- Layer normalization for training stability
- Feed-forward networks with GELU activation
- Token and position embeddings
- Multi-head self-attention
- Residual connections and transformer blocks
- Language model head for next-token prediction
- Text generation with temperature and sampling
Your model loads OpenAI’s pretrained GPT-2 weights and generates text. You understand how every component works, from the low-level tensor operations to the high-level architecture decisions.
What’s next?
You now understand the architectural foundation that powers modern language models. LLaMA, Mistral, and more build on these same components with incremental refinements. You have everything you need to implement those refinements yourself.
Consider extending your implementation with:
- Grouped-query attention (GQA): Reduce memory consumption by sharing key-value pairs across multiple query heads, as used in LLaMA 2.
- Rotary position embeddings (RoPE): Replace learned position embeddings with rotation-based encoding, improving length extrapolation in models like LLaMA and GPT-NeoX.
- SwiGLU activation: Swap GELU for the gated linear unit variant used in LLaMA and PaLM.
- Mixture of experts (MoE): Add sparse expert routing to scale model capacity efficiently, as in Mixtral and GPT-4.
Each refinement builds directly on what you’ve implemented. The attention mechanism you wrote becomes grouped-query attention with a simple modification to how you reshape key-value tensors. Your position embeddings can be replaced with RoPE by changing how you encode positional information. The feed-forward network you built becomes SwiGLU by adding a gating mechanism.
Pick an architecture that interests you and start building. You’ll find the patterns are familiar because the fundamentals haven’t changed.