WhisperMamba β HedgeMamba Distillation of Whisper-tiny
Unofficial implementation of "Attention to Mamba: A Recipe for Cross-Architecture Distillation" (Moudgil et al., Apple + MILA, April 2026) applied to Whisper-tiny.
The decoder's self-attention layers are replaced with HedgeMamba SSM mixers using two-stage knowledge distillation. The encoder and cross-attention remain frozen from the original Whisper-tiny weights.
Not the authors' code, not affiliated with Apple or MILA.
Code: github.com/akashicMarga/hedge-mamba-distil
What this model is
Whisper-tiny has 4 decoder layers, each with a self-attention block. This student replaces every self-attention with a HedgeMambaMixer β a selective SSM with:
- Hedgehog projection on B and C:
Ο(x) = softmax([Wx, βWx])β doubles effective state size and replaces Q/K - Selective scan with input-dependent Ξt (ZOH discretization)
- SiLU gate on the output
- Fix-B state caching for O(1) per-step autoregressive inference (no KV cache growth)
The encoder (4 Transformer layers + Conv frontend) is fully frozen. Only the decoder SSM weights are learned from scratch.
Files in this repo
| File | Description |
|---|---|
pytorch/whisper_mamba_final.pt |
Final PyTorch state dict (Stage 1 + Stage 2, 144 MB) |
pytorch/stage1_final.pt |
Stage 1 only (cosine-distilled SSM, before ASR fine-tuning) |
mlx/whisper_mamba_mlx_final.npz |
Final MLX weights (142 MB, Apple Silicon inference) |
mlx/whisper_mamba_mlx_final.json |
MLX checkpoint metadata |
The .pt files are raw state_dict OrderedDicts β load with torch.load(..., map_location="cpu"). The .npz is an MLX weight archive β load with mlx.core.load(...).
Results
WER on LibriSpeech test splits (greedy decoding, lowercase, no punctuation)
| Model | Split | WER |
|---|---|---|
| Whisper-tiny teacher | test.clean | 9.65% |
| WhisperMamba student (PyTorch) | test.clean | 8.49% |
| Whisper-tiny teacher | test.other | 20.23% |
| WhisperMamba student (PyTorch) | test.other | 18.0% |
The student outperforms the teacher on both splits. The larger gap on test.other suggests scheduled sampling gives the student better robustness to its own decoding errors.
Validation WER during Stage 2 (PyTorch, LibriSpeech train-clean-100)
| Epoch | Val WER |
|---|---|
| 3 | β |
| 5 | ~5% |
Inference latency (single utterance, 20 samples, Apple M-series)
| Model | Backend | Latency |
|---|---|---|
| Whisper-tiny teacher | PyTorch MPS | ~154 ms |
| WhisperMamba student | PyTorch MPS | ~129 ms |
| WhisperMamba student | MLX | ~41 ms |
MLX is ~3.7Γ faster than the PyTorch teacher. The O(1) SSM state means latency does not grow with sequence length (unlike the KV cache in standard Whisper).
Training
Two-stage distillation
Stage 1 β Cosine distillation (warm-up, ~3 h on M-series):
- Loss: layer-wise cosine similarity between student and teacher decoder hidden states
- Only SSM weights trained; everything else frozen
- Warm-initializes SSM from teacher attention projections (Appendix B parameter surgery:
B_proj β k_proj,C_proj β q_proj) - 2 epochs, LibriSpeech train-clean-100, batch size 8
Stage 2 β ASR fine-tuning (~5 h on M-series):
- Loss: cross-entropy on LibriSpeech transcripts
- Scheduled sampling: ground-truth token replacement ramps 0% β 50% over first half of training, closing the teacher-forcing gap
- SSM, cross-attn, FFN, and layer norms all trained
- 5 epochs, LibriSpeech train-clean-100, batch size 8
An MLX re-implementation trains both stages end-to-end in ~3.5 h.
Config
teacher: openai/whisper-tiny
state_size: 64 # Γ2 after Hedgehog = 128 effective
batch_size: 8
stage1_lr: 0.0005
stage2_lr: 0.0001
ss_max_p: 0.5 # scheduled sampling ceiling
stage1_epochs: 2
stage2_epochs: 5
data: librispeech_asr train.100 / validation
Usage
Install the source repo, then load the checkpoint:
pip install torch transformers datasets jiwer
git clone https://github.com/akashicMarga/hedge-mamba-distil
cd hedge-mamba-distil
import torch
from src.student.whisper_mamba import WhisperMambaStudent
# Load the state dict
state_dict = torch.load("pytorch/whisper_mamba_final.pt", map_location="cpu")
# Rebuild the student (requires the source repo)
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperMambaStudent.from_teacher("openai/whisper-tiny", state_size=64)
model.load_state_dict(state_dict)
model.eval()
For the MLX backend (Apple Silicon):
pip install mlx mlx-whisper
python scripts/mlx_inference.py # benchmarks student vs teacher
python scripts/mic_demo.py # live microphone
Deviations from the paper
| Paper | This repo | Reason |
|---|---|---|
| RoPE on B and C | Omitted | Whisper already has positional embeddings |
state_size = hidden_size (N = D = 384) |
state_size = 64 (128 after Hedgehog) |
N = D makes scan state (B, 768, 768) β too slow on MPS |
| Parallel associative scan | Python for-loop | No fused Metal/Triton kernel yet |
| Per-head Hedgehog (H heads Γ D/H dim) | Single virtual head of size N | Avoids the H Γ D_h = D constraint when N β D_h |
Citation
@misc{moudgil2026hedgemamba,
title = {Attention to Mamba: A Recipe for Cross-Architecture Distillation},
author = {Moudgil, Abhinav and others},
year = {2026},
url = {https://arxiv.org/abs/2604.14191}
}
Model tree for akashicmarga/whisper-tiny-hedgemamba
Base model
openai/whisper-tiny