Speculative Decoding in Practice: How EAGLE3 Makes LLMs Faster Without Changing Their Outputs
Your GPU Is Mostly Waiting
When you profile LLM inference at batch size 1, something counterintuitive appears. Your H100 — capable of nearly 2,000 TFLOPS of BF16 compute — is running at less than 1% utilization during token generation. It is not doing math. It is waiting.
The reason is that autoregressive decoding is memory-bandwidth bound, not compute-bound. Every forward pass must load the entire model's weights from HBM (high-bandwidth memory) to produce a single token. On an H100 with ~3.35 TB/s peak memory bandwidth, a 70 GB model can be read approximately 48 times per second. That becomes the hard ceiling on token generation speed — not the thousands of idle tensor cores.
The standard inference optimization playbook was written to address this, and it has mostly been applied already in production systems:
- Quantization reduces the bytes loaded per pass. It yields real speedups (1.8–2.4× at INT8), but involves accuracy tradeoffs at lower bit widths — aggressive joint quantization (e.g., W4A4) causes substantial degradation on multi-step reasoning tasks. It is also a one-time gain; you cannot quantize your way to another 2×.
- Knowledge distillation trains a smaller replacement model. It works, but gives you a different model that may behave differently under distribution shift.
- KV cache, continuous batching, and tensor parallelism are now standard in every production inference engine (vLLM, SGLang, TensorRT-LLM). If your serving stack is reasonably modern, you already have them.
Speculative decoding is different. It does not modify the model. It does not alter the output distribution. It exploits the observation that the GPU has enormous idle compute capacity during generation, and uses that capacity to propose multiple tokens simultaneously.
The output is identical to what the model would have produced without speculation — not approximately, but exactly. That guarantee comes from the algorithm, not from our benchmarking choices.
How Speculative Decoding Works
The setup requires two models: a small, fast draft model and the original target model.
At each decoding step:
- The draft model proposes candidate tokens using cheap forward passes
- The target model verifies all candidates in a single forward pass — the same cost as generating one token normally
The key insight: during verification, all draft token positions are already known, so the target can process them in parallel. In normal generation, this is impossible — each token's identity is unknown until the previous one is produced.
The Accept/Reject Rule
Simply accepting every token the draft proposes would give you the draft's output distribution, not the target's. The algorithm needs a principled rule for deciding which tokens to keep.
Let be the target model's probability for draft token at a given position, and the draft model's probability for that same token. The acceptance rule is:
Intuitively: if the draft assigned less probability to a token than the target would, we always accept — the target agrees or would have rated it even higher. If the draft was overconfident, we accept proportionally to . When a token is rejected, we resample from the residual distribution , correcting for the draft's excess probability mass.
This construction is inspired by the Metropolis-Hastings algorithm from Markov chain Monte Carlo. The mathematical result is that the marginal distribution of accepted tokens matches the target model's distribution exactly — this is a mathematical guarantee, not a heuristic. Leviathan et al. (ICML 2023) and Chen et al. (2023) proved this independently.
Where the Speedup Comes From
If the draft has per-token acceptance rate and we propose tokens per step, the expected number of accepted tokens per target forward pass is:
The speedup grows rapidly with and more slowly with : doubling draft tokens helps only if is already high enough that most of them get accepted. The figure below makes this relationship visible.
The practical speedup depends on the model pair: a well-trained draft head might achieve –\( 0.8 \), yielding 2–4 tokens per pass, while a weaker draft gives diminishing returns regardless of how many tokens it proposes. The memory bandwidth cost stays fixed at one target forward pass.
The speedup is entirely determined by — how well the draft model predicts the target's outputs — and this is exactly what the EAGLE family of methods is designed to maximize.
From EAGLE to EAGLE3
The EAGLE family trains a specialized draft head that conditions on the target model's own internal representations, rather than being an independent smaller model. This is a substantially better starting point.
EAGLE1 (ICML 2024) conditions the draft head on the target's final hidden state — the output of the last transformer block before the language modeling head. Because the draft sees what the target was processing at its output layer, it can make predictions closely aligned with the target's distribution. EAGLE1 achieved 2.7–3.5× latency speedup on LLaMA-2-Chat 70B.
EAGLE2 (EMNLP 2024) added dynamic draft trees: instead of proposing a linear sequence of tokens, the draft explores branching token paths and retains only the most confident branches. The target verifies the entire tree in a single pass, improving acceptance rates further.
EAGLE3 (NeurIPS 2025) made a more fundamental change: tri-layer feature fusion. Instead of conditioning on only the final hidden state, EAGLE3 fuses representations from three points in the target model simultaneously:
- Early layers — encode syntax, morphology, and local token context
- Middle layers — encode semantic relationships and broader discourse structure
- Late layers — encode the output probability distribution directly
By fusing all three, the draft head sees why the target would produce a particular token at every level of abstraction — not just the output distribution, but the full semantic reasoning context. EAGLE3 also switches from feature-level prediction to direct token prediction, removing a scaling ceiling that had limited EAGLE1 and EAGLE2 as training data increased.
The EAGLE3 paper reports 4.1–6.5× speedup at temperature 0 on academic benchmarks (Vicuna 13B, Llama-3.1-8B, Llama-3.3-70B).
EAGLE3 requires training a custom draft head per target model. Off-the-shelf draft heads exist for a handful of mainstream models. For anything outside that set — especially models with MoE layers, custom attention variants, or quantized formats — you train from scratch.
Validating on Llama-3.1-8B First
Before attempting EAGLE3 on GLM-4.7-Flash, a model with no published EAGLE3 baseline, we reproduced the paper's Llama-3.1-8B results on our own infrastructure. This matters: benchmark numbers in the literature are sensitive to inference engine version, batch configuration, and hardware generation. You need a reference point before trusting results on a novel target.
| Metric | Our result | Published paper |
|---|---|---|
| B=32 throughput speedup | 1.25× | 1.32× |
Reproducing 95% of the published speedup (1.25× vs 1.32×) on a newer software stack gave us confidence that our pipeline was correct. The gap is attributable to SGLang v0.5.6 (ours) versus v0.4.4 (paper) — version drift in batching behavior, not an implementation error.
During this validation work, we found three bugs in the official EAGLE3 model releases. One caused generation to silently produce truncated outputs: the server returned valid-looking responses shorter than the requested length, with no error or warning. You would not catch this without per-request token count validation against server-side metrics. More on measurement rigor in the "What We Learned" section below.
GLM-4.7-Flash With EAGLE3
GLM-4.7-Flash is a Mixture-of-Experts model with approximately 31B total parameters and 3B active parameters per forward pass. It fits on a single NVIDIA H100 80GB (TP=1), making it a practical single-node production deployment target.
We are releasing thoughtworks/GLM-4.7-Flash-Eagle3 — the first publicly available EAGLE3 draft head for the GLM-4.7 architecture. The checkpoint is 277 MB, small enough to co-deploy on the same GPU as the target model. Training took 1 hour 26 minutes on a single H100.
Inference Results
Benchmarked on MT-Bench (154 prompts), single NVIDIA H100 80GB, SGLang v0.5.6, FlashInfer backend. All metrics from server-side Prometheus.
Single user (batch size 1)
| Metric | Baseline | EAGLE3 | Improvement |
|---|---|---|---|
| Time per output token | 8.18 ms | 5.89 ms | 1.39× faster |
| Throughput | 120 tok/s | 168 tok/s | 1.39× |
Acceptance rate: 40% of proposed draft tokens accepted, averaging 2.4 accepted tokens per verification step (with draft tokens proposed per step).
Concurrent serving (batch size 32)
| Metric | Baseline | EAGLE3 | Improvement |
|---|---|---|---|
| Time per output token | 22.6 ms | 17.3 ms | 1.30× faster |
| Throughput | 259 tok/s | 440 tok/s | 1.70× |
The per-request latency improvement (1.30×) is lower than at B=1 — expected, since the GPU has less idle capacity at higher concurrency. The system-level throughput gain (1.70×) is higher because speculative decoding lets requests complete faster, freeing slots in the continuous batching scheduler for new requests.
Hardware Cost Implications
To make the throughput numbers concrete: if you currently run 10 H100s to sustain your SLA:
| Scenario | H100s required | Monthly saving (~$3/hr per GPU) |
|---|---|---|
| Baseline | 10 | — |
| EAGLE3 at 1.70× throughput | ~6 | ~$8,600/month |
Estimates based on cloud H100 pricing at approximately $3/hour. On-premises deployments show similar proportional savings against amortized hardware costs.
Training Details
| Parameter | Value |
|---|---|
| Framework | SpecForge (PyTorch) |
| Hardware | 1× NVIDIA H100 80GB |
| Dataset | 54K mixed (ShareGPT 45% / UltraChat 35% / PerfectBlend 20%) |
| Epochs | 3 |
| Learning rate | 1e-4 |
| Training time | 1h 26min |
| Checkpoint size | 277 MB |
What We Learned (and What Went Wrong)
Training data quality beats quantity
Curating the data mix — balancing conversational, instruction-following, and code domains — had more impact on draft quality than simply scaling dataset size. We trained on 54K examples from ShareGPT (45%), UltraChat (35%), and PerfectBlend (20%). A planned follow-up will use on-distribution regenerated data: prompts drawn from the target model's actual deployment traffic, with the target's own outputs as training labels.
Batch size dynamics matter
At , even a modest acceptance rate yields meaningful latency gains because the GPU has substantial idle capacity. At higher batch sizes, that idle capacity shrinks — the GPU is better utilized, the effective amortized cost of a verification pass rises, and the break-even acceptance rate increases. This is why architecture matters: large-MoE models with expensive expert dispatch require higher acceptance rates to justify speculation at high concurrency.
Measurement is its own problem
We found a bug in SGLang's token counting for EAGLE3 requests that deflated measured throughput by approximately 35%. The server was computing output token counts incorrectly for speculative batches, making throughput numbers look worse than they were. Without server-side Prometheus histogram metrics as a cross-check, you cannot distinguish a real performance regression from a counting artifact. We strongly recommend validating any speculative decoding benchmark with server-side metrics rather than relying solely on client-side timing.
Bugs in official releases
Three bugs in the official EAGLE3 model releases surfaced during our Llama-3.1-8B validation work. One caused generation to silently terminate early — the server returned truncated outputs with no error. Another caused incorrect behavior on the first token of each batch. These are fixable issues and we submitted patches, but they illustrate that correct implementation requires detailed knowledge of the inference stack, and that reproducing published results before claiming new ones is not a formality.
When it doesn't pay off
Speculative decoding has a break-even acceptance rate below which it hurts throughput rather than helping. For GLM-4.7-Flash at B=1, our 40% acceptance rate is comfortably above break-even — the GPU has enough idle capacity that even modest draft quality pays off. At B=32, the break-even is higher: the GPU is better utilized in the baseline, leaving less headroom. In general: the leaner the baseline utilization, the lower the threshold for speculative decoding to be worth it.
A Note on MoE Models: Not Always a Free Lunch
On dense models like Llama-3.1-8B, speculative decoding is close to a free lunch — consistent speedups at modest acceptance rates. Mixture-of-Experts architectures complicate this picture.
The core issue is expert routing. When the target model verifies a batch of speculative tokens, each token may activate a different subset of experts. The routing overhead — determining which experts handle which tokens, dispatching the computation, and gathering results — grows with both the number of speculative tokens being verified simultaneously and the number of experts per layer.
For moderate-MoE models like GLM-4.7-Flash (~3B active out of 31B total), the gains still outweigh the overhead, as our numbers show. But as you move toward larger MoE architectures with more experts per layer, the break-even rises to a point where speculative decoding can hurt throughput rather than help it. The required draft accuracy becomes difficult to achieve in practice.
This is an active research area — not just for us, but for the community. Approaches like smarter verification batching, architecture-aware draft models, and overlapped expert dispatch all seem promising. We do not have all the answers yet. If you are working on this problem, we would be glad to compare notes.
How to Use
GLM-4.7-Flash support is not yet merged into the official SGLang or SpecForge releases. To use the model and draft head published here, please use our maintained forks — they include the patches needed for GLM-4.7 architecture support:
- SGLang fork: github.com/tails-mpt/sglang — for inference
- SpecForge fork: github.com/tails-mpt/SpecForge — for training your own draft head
We intend to keep these forks up to date. Once the patches land in the upstream projects, standard installation will work without the forks.
SGLang (recommended)
pip install git+https://github.com/tails-mpt/sglang.git
python -m sglang.launch_server \
--model-path zai-org/GLM-4.7-Flash \
--speculative-algorithm EAGLE3 \
--speculative-draft-model-path thoughtworks/GLM-4.7-Flash-Eagle3 \
--speculative-num-steps 3 \
--speculative-num-draft-tokens 6 \
--speculative-eagle-topk 4 \
--tp 1 \
--trust-remote-code \
--port 30000
Then query with any OpenAI-compatible client:
import requests
response = requests.post(
"http://localhost:30000/v1/chat/completions",
json={
"model": "default",
"messages": [{"role": "user", "content": "Explain speculative decoding in 3 sentences."}],
"max_tokens": 256,
}
)
print(response.json()["choices"][0]["message"]["content"])
The Bigger Picture
This work is part of a broader Thoughtworks initiative in inference optimization. We have also built SpecJAX, a pure-JAX framework for training EAGLE3 draft models on Google Cloud TPUs, covering additional architectures including Llama, Qwen, and DeepSeek.
Faster inference per GPU means smaller deployments for the same traffic, which lowers the infrastructure bar for running capable models in production.
Links
- Draft model: thoughtworks/GLM-4.7-Flash-Eagle3
- SGLang fork (GLM-4.7 support): github.com/tails-mpt/sglang
- SpecForge fork (GPU draft head training): github.com/tails-mpt/SpecForge
- SpecJAX (TPU draft head training): github.com/tails-mpt/SpecJAX
- EAGLE3 paper: arXiv:2503.01840
- Original speculative decoding: Leviathan et al., ICML 2023
Citation
@inproceedings{li2025eagle3,
title={{EAGLE-3}: Scaling up Inference Acceleration of Large Language Models via Training-Time Test},
author={Li, Yuhui and Wei, Fangyun and Zhang, Chao and Zhang, Hongyang},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2025}
}



