FlashHead: Accelerating Language Model Inference ~ Efficient drop-in replacement for the classification head

Community Article Published March 11, 2026

Post-Training · Model-Agnostic · Drop-In

📄 Paper  ·  🤗 Model Collection  ·  📊 Benchmarks  ·  📝 Blog


Large language models have gotten dramatically more efficient over the last few years. We now have:

  • aggressive weight quantization
  • highly optimized attention kernels
  • KV-cache optimization
  • specialized inference runtimes

Everyone loves optimizing the transformer body... But in many small and edge-deployed models, the part quietly burning budget at every decode step is often the very last layer: the dense classification head.

The final projection from hidden states to the vocabulary can account for up to 60% of total model parameters 📦 and roughly half of the inference compute. Even if the transformer layers are highly optimized, the output head still has to evaluate every token in the vocabulary at each decoding step.

In this article we dive into FlashHead, a drop-in replacement for the dense classification head that makes language models significantly faster without retraining the model. Instead of treating token prediction as a dense matrix multiplication, FlashHead reframes it as a retrieval problem.

TL;DR

FlashHead is a post-training replacement for the LM head that accelerates token generation. Key points:

  • In small and mid-sized models, the dense output classification head can account for a large fraction of parameters and inference latency.
  • FlashHead replaces dense full-vocabulary scoring with a two-stage retrieval pipeline over clustered token embeddings.
  • The method is post-training and model-agnostic.
  • FlashHead complements existings optimization techniques such as quantization and speculative decoding, achieving significant speedupts on top of existing techniques.
  • The same design now shows up in the latest public multimodal release, embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead, aimed at fast image/video reasoning on Jetson-class hardware.

The Hidden Bottleneck: Vocabulary Projection

The final step in language model decoding computes logits over the entire vocabulary:

logits=ht×Wvocab \text{logits} = h_t \times W_{\text{vocab}}

where htR1×d is the hidden state at decoding step t,WvocabRd×V is the vocabulary matrix,logitsR1×V, and V is the vocabulary size. \begin{aligned} \text{where } & h_t \in \mathbb{R}^{1 \times d} \text{ is the hidden state at decoding step } t, \\ & W_{\text{vocab}} \in \mathbb{R}^{d \times V} \text{ is the vocabulary matrix}, \\ & \text{logits} \in \mathbb{R}^{1 \times V},\text{ and } V \text{ is the vocabulary size}. \end{aligned}

That means every token requires a full pass over the vocabulary. As vocabularies keep growing, this becomes a structural problem, not just an implementation detail. The output layer becomes an increasingly painful place to spend memory bandwidth and compute.

Looking at the number of parameters in the head for typical LLMs:

FlashHead reduces head parameter count across model architectures
Table: FlashHead significantly reduces the parameter footprint of the language-model classification head across several architectures. Compared to the baseline dense head, the FlashHead head is much smaller while also reducing the total model parameter count. Results are shown for Llama, Qwen, and Gemma models, demonstrating consistent reductions in head size and overall model parameters.

And more:

  • the head is typically memory-bound rather than compute-bound
  • it scales linearly with vocabulary
  • it often dominates edge-device inference

For on-device deployment, especially for real-time use cases with batch size 1, the final dense head becomes a repeated per-token tax.


From Dense MatMul to Two-Stage Retrieval

The standard LM head computes a dense matrix multiplication at every decode step, scoring all vocabulary tokens regardless of relevance. FlashHead reframes this as a two-stage retrieval problem over clustered token embeddings: first identify which regions of vocabulary space are relevant, then score only those candidates.

image

The key tradeoff: A dense head scores 128,256 tokens per step (for a 128K vocabulary). With c=8,016 clusters and p=256 probes, FlashHead scores only 8016 + 256 × 16 = 12,112 tokens: a 10× reduction in scored tokens, while multi-probe retrieval maintains near-perfect recall of the correct next token.

Note: The offline clustering step runs once per model and adds zero overhead at inference time. Both stages are designed around contiguous memory access patterns for GPU and edge accelerator efficiency.


Rethinking Token Prediction as Retrieval

FlashHead takes a different perspective for token prediction:

Generating the next token is not a dense matrix multiplication problem; it is a retrieval problem.

Instead of scoring the whole vocabulary for every new token, FlashHead:

  1. Clusters token embeddings (offline - i.e., before inference)
  2. Selects promising clusters by scoring centroids against the hidden state
  3. Computes logits only for tokens inside those clusters

image


FlashHead: Four Key Ideas

FlashHead overview
Figure 1: Illustration of the FlashHead algorithm, highlighting the difference between the greedy approach of selecting the most likely next token and a sampling approach that enables probabilistic sampling of new tokens.

FlashHead combines several ideas from information retrieval and efficient indexing.

1. Equal-Sized Token Clustering

Token embeddings are grouped into balanced clusters 📚. This ensures:

  • predictable memory access
  • efficient parallel compute
  • stable latency across tokens

Unlike hierarchical softmax, cluster sizes remain uniform, which is important for GPU and edge accelerators.

2. Multi-Probe Retrieval

Instead of retrieving only a single cluster, FlashHead probes multiple clusters. After scoring the centroids against the hidden state, it probes a large number of relevant clusters rather than committing to a single coarse partition.

This dramatically improves recall of the correct token while still evaluating far fewer candidates than the full vocabulary.

Think of it as beam search for the vocabulary space.

3. Inference-Time Sampling

FlashHead supports both:

  • greedy decoding
  • sampling decoding

For sampling, clusters are selected proportionally to centroid probabilities before sampling tokens inside the candidate set 🎲.

This preserves compatibility with typical generation methods.

4. Selective Quantization

FlashHead uses selective quantization for parts of the head that tolerate lower precision, reducing both memory and compute. Namely, the first stage (scoring centroids) is especially well suited to low-precision execution because it operates on a static matrix and its coarse decisions are refined later in higher precision. FlashHead exploits that by quantizing the retrieval stage aggressively while preserving higher-precision where it matters.

💡 Dense heads are often the part deployment stacks avoid quantizing aggressively. FlashHead flips that from a weakness into a strength.

Why this is different from older head-acceleration work

There is a long line of work on faster softmax layers and approximate output heads, but many earlier methods run into one of four problems:

  • they require retraining or fine-tuning,
  • they permanently trim the vocabulary,
  • they approximate probabilities in ways that are mainly safe for high-likelihood tokens,
  • or they focus on top-k retrieval rather than preserving a workable decoding process.

FlashHead is interesting because it stays in the post-training, drop-in regime while still being designed around real accelerator behavior.


Results

Initially, FlashHead was evaluated on several architectures including: Llama-3.2, Gemma-3, and Qwen-3.

FlashHead results for Llama Gemma Qwen

On Llama-3.2-1B-Instruct evaluating the Time Per Output Token (TPOT) and (Big Bench Hard) (https://arxiv.org/abs/2210.09261):

Method BBH Head TPOT Full TPOT (BF16) Full TPOT (INT4)
Baseline 0.38 1.94 ms 7.69 ms 3.60 ms
FlashHead 0.38 0.40 ms 6.15 ms 2.06 ms

Importantly, the improvement compounds with other optimizations — meaning FlashHead speeds up models that are already quantized or heavily optimized.


FlashHead in Production

One advantage of FlashHead is that it is purely post-training 🧩.

This makes it attractive for:

  • proprietary models
  • already-trained checkpoints
  • edge deployment scenarios

FlashHead is part of a broader effort to make models practical on real devices. It complements existing inference optimization techniques like mixed precision quantization and speculative decoding and, hence, achieves speedup on top of these methods.


Multimodal Edge Inference with FlashHead

Cosmos-Reason2-2B Benchmark Results

One recent FlashHead-enabled release is:

👉 embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead

This model combines several efficiency techniques:

  • FlashHead LM Head
  • Quantization (W4A16): INT4 weights + FP16 activations.
  • Edge2 mixed exclusions: Keep sensitive layers in FP16 precision.

In other words, it attacks both sides of the inference problem:

  • optimize the transformer body with mixed-precision quantization,
  • optimize the head with FlashHead.

The result is a Text + Image / Video → Text multimodal reasoning model that can run on edge devices such as NVIDIA Jetson Orin Nano, while maintaining nearly the same reasoning accuracy as the original model.


Why This Matters

Model efficiency research has historically focused on the transformer stack:

  • attention
  • feed-forward layers
  • KV caching

But the output head is often the largest remaining inefficiency.

FlashHead shows that:

Significant gains are still possible by rethinking seemingly simple components.

As models move toward edge deployment and larger vocabularies, the output layer becomes an increasingly important optimization target.


Try It

Try the latest FlashHead-enabled model on Jetson: embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead

docker run --rm -it \
  --network host \
  --shm-size=8g \
  --ulimit memlock=-1 \
  --ulimit stack=67108864 \
  --runtime=nvidia \
  --name=vllm-serve \
  -e HF_TOKEN=hf_*** \
  -e HF_HOME=/root/.cache/huggingface \
  embedl/vllm:latest-jetson-orin-flashhead \
  vllm serve "embedl/Cosmos-Reason2-2B-W4A16-Edge2-FlashHead" \
    --max-model-len 8192 \
    --gpu-memory-utilization 0.75 \
    --max-num-seqs 2 \
    --trust-remote-code

And explore the full FlashHead model collection:

👉 embedl/flashhead


If you're interested in efficient AI systems for edge devices, I would love to hear your feedback.


Further Reading / Resources

Community

Sign up or log in to comment