FlashHead: Accelerating Language Model Inference ~ Efficient drop-in replacement for the classification head
Post-Training · Model-Agnostic · Drop-In
📄 Paper · 🤗 Model Collection · 📊 Benchmarks · 📝 Blog
Large language models have gotten dramatically more efficient over the last few years. We now have:
- aggressive weight quantization
- highly optimized attention kernels
- KV-cache optimization
- specialized inference runtimes
Everyone loves optimizing the transformer body... But in many small and edge-deployed models, the part quietly burning budget at every decode step is often the very last layer: the dense classification head.
The final projection from hidden states to the vocabulary can account for up to 60% of total model parameters 📦 and roughly half of the inference compute. Even if the transformer layers are highly optimized, the output head still has to evaluate every token in the vocabulary at each decoding step.
In this article we dive into FlashHead, a drop-in replacement for the dense classification head that makes language models significantly faster without retraining the model. Instead of treating token prediction as a dense matrix multiplication, FlashHead reframes it as a retrieval problem.
TL;DR
FlashHead is a post-training replacement for the LM head that accelerates token generation. Key points:
- In small and mid-sized models, the dense output classification head can account for a large fraction of parameters and inference latency.
- FlashHead replaces dense full-vocabulary scoring with a two-stage retrieval pipeline over clustered token embeddings.
- The method is post-training and model-agnostic.
- FlashHead complements existings optimization techniques such as quantization and speculative decoding, achieving significant speedupts on top of existing techniques.
- The same design now shows up in the latest public multimodal release,
embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead, aimed at fast image/video reasoning on Jetson-class hardware.
The Hidden Bottleneck: Vocabulary Projection
The final step in language model decoding computes logits over the entire vocabulary:
That means every token requires a full pass over the vocabulary. As vocabularies keep growing, this becomes a structural problem, not just an implementation detail. The output layer becomes an increasingly painful place to spend memory bandwidth and compute.
Looking at the number of parameters in the head for typical LLMs:
And more:
- the head is typically memory-bound rather than compute-bound
- it scales linearly with vocabulary
- it often dominates edge-device inference
For on-device deployment, especially for real-time use cases with batch size 1, the final dense head becomes a repeated per-token tax.
From Dense MatMul to Two-Stage Retrieval
The standard LM head computes a dense matrix multiplication at every decode step, scoring all vocabulary tokens regardless of relevance. FlashHead reframes this as a two-stage retrieval problem over clustered token embeddings: first identify which regions of vocabulary space are relevant, then score only those candidates.
The key tradeoff: A dense head scores 128,256 tokens per step (for a 128K vocabulary). With c=8,016 clusters and p=256 probes, FlashHead scores only 8016 + 256 × 16 = 12,112 tokens: a 10× reduction in scored tokens, while multi-probe retrieval maintains near-perfect recall of the correct next token.
Note: The offline clustering step runs once per model and adds zero overhead at inference time. Both stages are designed around contiguous memory access patterns for GPU and edge accelerator efficiency.
Rethinking Token Prediction as Retrieval
FlashHead takes a different perspective for token prediction:
Generating the next token is not a dense matrix multiplication problem; it is a retrieval problem.
Instead of scoring the whole vocabulary for every new token, FlashHead:
- Clusters token embeddings (offline - i.e., before inference)
- Selects promising clusters by scoring centroids against the hidden state
- Computes logits only for tokens inside those clusters
FlashHead: Four Key Ideas
FlashHead combines several ideas from information retrieval and efficient indexing.
1. Equal-Sized Token Clustering
Token embeddings are grouped into balanced clusters 📚. This ensures:
- predictable memory access
- efficient parallel compute
- stable latency across tokens
Unlike hierarchical softmax, cluster sizes remain uniform, which is important for GPU and edge accelerators.
2. Multi-Probe Retrieval
Instead of retrieving only a single cluster, FlashHead probes multiple clusters. After scoring the centroids against the hidden state, it probes a large number of relevant clusters rather than committing to a single coarse partition.
This dramatically improves recall of the correct token while still evaluating far fewer candidates than the full vocabulary.
Think of it as beam search for the vocabulary space.
3. Inference-Time Sampling
FlashHead supports both:
- greedy decoding
- sampling decoding
For sampling, clusters are selected proportionally to centroid probabilities before sampling tokens inside the candidate set 🎲.
This preserves compatibility with typical generation methods.
4. Selective Quantization
FlashHead uses selective quantization for parts of the head that tolerate lower precision, reducing both memory and compute. Namely, the first stage (scoring centroids) is especially well suited to low-precision execution because it operates on a static matrix and its coarse decisions are refined later in higher precision. FlashHead exploits that by quantizing the retrieval stage aggressively while preserving higher-precision where it matters.
💡 Dense heads are often the part deployment stacks avoid quantizing aggressively. FlashHead flips that from a weakness into a strength.
Why this is different from older head-acceleration work
There is a long line of work on faster softmax layers and approximate output heads, but many earlier methods run into one of four problems:
- they require retraining or fine-tuning,
- they permanently trim the vocabulary,
- they approximate probabilities in ways that are mainly safe for high-likelihood tokens,
- or they focus on top-k retrieval rather than preserving a workable decoding process.
FlashHead is interesting because it stays in the post-training, drop-in regime while still being designed around real accelerator behavior.
Results
Initially, FlashHead was evaluated on several architectures including: Llama-3.2, Gemma-3, and Qwen-3.
On Llama-3.2-1B-Instruct evaluating the Time Per Output Token (TPOT) and (Big Bench Hard) (https://arxiv.org/abs/2210.09261):
| Method | BBH | Head TPOT | Full TPOT (BF16) | Full TPOT (INT4) |
|---|---|---|---|---|
| Baseline | 0.38 | 1.94 ms | 7.69 ms | 3.60 ms |
| FlashHead | 0.38 | 0.40 ms | 6.15 ms | 2.06 ms |
Importantly, the improvement compounds with other optimizations — meaning FlashHead speeds up models that are already quantized or heavily optimized.
FlashHead in Production
One advantage of FlashHead is that it is purely post-training 🧩.
This makes it attractive for:
- proprietary models
- already-trained checkpoints
- edge deployment scenarios
FlashHead is part of a broader effort to make models practical on real devices. It complements existing inference optimization techniques like mixed precision quantization and speculative decoding and, hence, achieves speedup on top of these methods.
Multimodal Edge Inference with FlashHead
One recent FlashHead-enabled release is:
👉 embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead
This model combines several efficiency techniques:
- FlashHead LM Head
- Quantization (W4A16): INT4 weights + FP16 activations.
- Edge2 mixed exclusions: Keep sensitive layers in FP16 precision.
In other words, it attacks both sides of the inference problem:
- optimize the transformer body with mixed-precision quantization,
- optimize the head with FlashHead.
The result is a Text + Image / Video → Text multimodal reasoning model that can run on edge devices such as NVIDIA Jetson Orin Nano, while maintaining nearly the same reasoning accuracy as the original model.
Why This Matters
Model efficiency research has historically focused on the transformer stack:
- attention
- feed-forward layers
- KV caching
But the output head is often the largest remaining inefficiency.
FlashHead shows that:
Significant gains are still possible by rethinking seemingly simple components.
As models move toward edge deployment and larger vocabularies, the output layer becomes an increasingly important optimization target.
Try It
Try the latest FlashHead-enabled model on Jetson: embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead
docker run --rm -it \
--network host \
--shm-size=8g \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
--runtime=nvidia \
--name=vllm-serve \
-e HF_TOKEN=hf_*** \
-e HF_HOME=/root/.cache/huggingface \
embedl/vllm:latest-jetson-orin-flashhead \
vllm serve "embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead" \
--max-model-len 8192 \
--gpu-memory-utilization 0.75 \
--max-num-seqs 2 \
--trust-remote-code
And explore the full FlashHead model collection:
If you're interested in efficient AI systems for edge devices, I would love to hear your feedback.

