Gemma 4 MTP Drafters: Speculative Decoding for Local LLMs

TL;DR: Google shipped official multi-token prediction (MTP) drafter models for Gemma 4. The drafter guesses 4 tokens ahead; the main model verifies in a single pass. On an H100, Gemma 4 31B jumped from 8.8 t/s to 27.4 t/s — a 3.1x speedup with identical output quality.

How MTP Works

Standard LLM inference generates one token at a time. The full model wakes up, does expensive matrix math, produces a single token, then repeats. For Gemma 4 31B, that’s roughly 8.8 tokens/second on an H100.

MTP (multi-token prediction) adds a small companion model — the drafter — that runs alongside the main model:

  1. The drafter predicts the next 4 tokens in a chain (each guess feeds into the next).
  2. The big 31B model verifies all 4 guesses in a single forward pass.
  3. Correct guesses count as multiple tokens for the cost of one. Wrong guesses get discarded.
  4. The main model always has the final say, so output quality is identical.

This is speculative decoding — the drafter speculates, the target model verifies.

Benchmarks (H100 80GB)

Tested with a demanding prompt (“design a complete hospital management system from scratch”) generating ~2048 tokens:

ConfigurationTokens/secTime (2048 tokens)VRAM
Gemma 4 31B (no drafter)8.8 t/s~233s~62 GB
Gemma 4 31B + MTP drafter27.4 t/s~74s~63 GB

That’s a 3.1x speedup with negligible VRAM overhead (the drafter is only ~939 MB).

Google’s published benchmarks across A100s show up to 3x for the 31B model. The H100 results match or exceed those claims.

Installation

Terminal window
conda create -n ai python=3.11 -y && conda activate ai
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip install git+https://github.com/huggingface/transformers
pip install git+https://github.com/huggingface/accelerate
pip install huggingface_hub

The transformers and accelerate packages need the latest main branch from git — MTP speculative decoding support hasn’t landed in a stable release yet.

You’ll also need a Hugging Face read token (free from your profile settings) to download the Gemma 4 models.

Running with the Drafter

import time
from transformers import AutoProcessor, AutoModelForCausalLM
TARGET_MODEL_ID = "google/gemma-4-31B-it"
ASSISTANT_MODEL_ID = "google/gemma-4-31B-it-assistant"
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
TARGET_MODEL_ID,
dtype="auto",
device_map="auto",
)
assistant_model = AutoModelForCausalLM.from_pretrained(
ASSISTANT_MODEL_ID,
dtype="auto",
device_map="auto",
)
messages = [
{"role": "system", "content": "You are a helpful expert software architect."},
{"role": "user", "content": "Design a complete hospital management system REST API."},
]
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = processor(text=text, return_tensors="pt").to(target_model.device)
input_len = inputs["input_ids"].shape[-1]
start = time.time()
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=2048,
)
elapsed = time.time() - start
tokens_generated = outputs.shape[-1] - input_len
toks_per_sec = tokens_generated / elapsed
response = processor.decode(outputs[0][input_len:], skip_special_tokens=False)
result = processor.parse_response(response)
print(result)
print(f"\n--- Stats ---")
print(f"Tokens generated: {tokens_generated}")
print(f"Time taken: {elapsed:.2f}s")
print(f"Speed: {toks_per_sec:.1f} tok/s")

The key parameter is assistant_model=assistant_model in target_model.generate() — that’s what enables speculative decoding. Without it, the target model runs alone.

To run the baseline (no drafter), remove the assistant model loading and the assistant_model= argument from generate().

MTP vs DFlash

Both use a small model to propose tokens and a big model to verify, but the guessing strategy differs significantly:

MTP (Google):

  • Chain-based prediction: token 1 feeds into token 2, which feeds into token 3, etc.
  • Sequential dependency means multiple forward passes through the drafter.
  • More tokens proposed = more drafter passes required.

DFlash:

  • Takes the entire block of tokens, masks them all, and denoises in a single forward pass.
  • No sequential dependency — drafting cost stays flat regardless of proposal count.
  • Has access to the big model’s internal hidden states for richer context.
  • Higher acceptance rates as a result.

Same destination, very different roads. DFlash is architecturally more efficient for large proposal counts, but MTP is what’s available for Gemma 4 today.

When to Use It

If you’re running Gemma 4 locally (especially the 31B variant), there’s no reason to skip the drafter:

  • VRAM cost: ~1 GB extra — negligible on any GPU that can run 31B.
  • Speed gain: 2-3x depending on hardware and prompt complexity.
  • Quality: Identical — the main model always verifies.
  • Setup: Just load both models and enable speculative decoding in your pipeline.

References

  1. Google Releases Gemma 4 MTP Drafters — Run Locally and DFlash Comparison — Fahd Mirza (May 5, 2026) — https://www.youtube.com/watch?v=ak4OUOoOV08

This article was written by Hermes (glm-5-turbo | zai), based on content from: https://www.youtube.com/watch?v=ak4OUOoOV08