NKI Kernel Experiments β Flux2-klein-4B on Neuron
Hardware: AWS Trn1.32xlarge (32 NeuronCores), TP=4, bfloat16
Model: black-forest-labs/FLUX.2-klein-4B
Shapes: B=1, 512Γ512 β img_S=256 (2Γ patchify), txt_S=512, inner_dim H=3072, n_heads=24, head_dim=128
1. RoPE kernel (nkilib.core.embeddings.rope)
Kernel constraints
d_head β {64, 128}(Flux2-klein: 128 β)S β€ 512β applied to sequence-length dimension before attentionn_heads β€ 16per rank β after TP=4 sharding: 24/4 = 6 β- Input layout must be
[B, n_heads, S, d_head]
Flux2-klein applicability
| Block type | S | Fits Sβ€512? | Notes |
|---|---|---|---|
| Single-stream | img_S + txt_S = 256 + 512 = 768 | No | RoPE is applied to the concatenated image+text sequence |
| Double-stream (image) | img_S = 256 | Yes | But double-stream blocks apply RoPE inside FluxAttnProcessor after separate Q/K projections β hooks into NKI require custom processor |
| Double-stream (text) | txt_S = 512 | Yes (boundary) |
Verdict: Not practical. Single-stream blocks (20/25 total) exceed S=512. Double-stream (5/25) would require custom processors. The XLA compiler already fuses RoPE with the surrounding matmuls in the same NEFF β a standalone NKI kernel would break that fusion (see Β§3).
2. Pipeline integration β --fused-qkv flag
Implementation: Flux2AttnProcessorFusedQKV in pipeline.py, activated by --fused-qkv.
Replaces 3 separate to_q / to_k / to_v ColwisePar linear calls in double-stream blocks with a single NKI nki_qkv kernel call.
Timing (warm steps, avg of last 3/20 steps)
| Mode | Steps | Warm avg (s/step) | vs baseline |
|---|---|---|---|
| Eager, baseline | 20 | 0.824 s/step | 1Γ |
Eager, --fused-qkv |
20 | 14.86 s/step | 18Γ slower |
Output correctness: identical pixel range, mean, and std at every step β the kernel produces correct results.
Root cause: XLA whole-block fusion
In eager (lazy-XLA) mode, the XLA compiler traces the entire transformer block as one HLO program and compiles it into a single NEFF (neff_cache/{hash}.neff). This fuses:
- All Q/K/V projections
- RoPE embeddings
- Flash attention (via custom prim decomposition)
- Output projection + MLP
- Layer norms
Inserting a standalone NKI kernel (@nki.jit) creates opaque tensor boundaries β XLA cannot inline or fuse across NKI kernel calls. The compiler sees:
[XLA subgraph] β NKI qkv kernel β [XLA subgraph]
instead of one monolithic NEFF. This fragmentation:
- Adds kernel launch overhead (PCIe round-trips for each NKI call)
- Prevents data reuse that XLA would achieve within the fused NEFF
- Defeats the cache: the fragmented graphs generate different, smaller NEFFs with no sharing benefit
The 18Γ slowdown is consistent with this β the baseline fused NEFF is highly optimised; the fragmented version is not.
3. Compile mode + fused QKV (--mode compile --fused-qkv) β bug fix note
4. Flash attention kernel (flux2_flash_attn)
Script: examples/flux2-klein/nki_flash_attn.py
Run: torchrun --nproc_per_node=4 flux2-klein/nki_flash_attn.py
Two-pass online softmax, BLOCK_Q=128, BLOCK_K=128, bidirectional (no causal mask).
Uses the older NKI ISA API (sbuf.view / psum.view / hbm.view / nisa.*).
Algorithm
For each head (looped over N=6 sequentially in one kernel instance): For each Q tile (q_idx = 0..5): Pass 1 of online softmax (here collapsed into single-pass via exp-only): For each K tile (ks = 0..5): score_T = k_tile.T @ q_tile (BLOCK_K, BLOCK_Q) via nc_matmul transposed trick probs_T = exp(score_T * scale) out_psum += probs_T.T @ v_tile (BLOCK_Q, D) row_sum += probs_T.T @ ones_v (BLOCK_Q, 1) out = out_psum / row_sum β bf16 β HBM
Note: this is an unnormalized (non-numerically-stable) softmax β no row_max subtraction. Suitable for correctness test; may overflow for long sequences or large activations.
5. Compile mode full comparison
Hardware: trn2.3xlarge, TP=4, bfloat16, 512Γ512, 4 steps, random weights, 4 runs (1 cold + 3 warm)
Date: 2026-03-31 | neuronxcc: 2.0.236418.0a0+9af338ad
All four compile-mode variants measured on the same neuronxcc build for a fair apples-to-apples comparison.
Vanilla compile (no custom kernels)
| Run | Type | step01 | step02 | step03 | step04 | total |
|---|---|---|---|---|---|---|
| 1 | COLD | 533.449s | 3.868s | 3.868s | 3.868s | 545.053s |
| 2β4 | WARM | 3.868s | 3.869s | 3.869s | 3.869s | 15.475s |
Cold: 533.4s Β· Warm avg: 3.869 s/step Β· Throughput: 0.258 steps/s
Compile + --fused-qkv
| Run | Type | step01 | step02 | step03 | step04 | total |
|---|---|---|---|---|---|---|
| 1 | COLD | 651.147s | 19.874s | 3.859s | 3.859s | 678.740s |
| 2β4 | WARM | 3.859s | 3.859s | 3.860s | 3.860s | 15.438s |
Cold: 651.1s Β· Warm avg: 3.859 s/step Β· Throughput: 0.259 steps/s
Compile + --flash-attn
| Run | Type | step01 | step02 | step03 | step04 | total |
|---|---|---|---|---|---|---|
| 1 | COLD | 862.344s | 19.601s | 4.159s | 4.159s | 890.263s |
| 2β4 | WARM | 4.159s | 4.159s | 4.159s | 4.159s | 16.636s |
Cold: 862.3s Β· Warm avg: 4.159 s/step Β· Throughput: 0.240 steps/s
Compile + --fused-qkv --flash-attn (combined)
| Run | Type | step01 | step02 | step03 | step04 | total |
|---|---|---|---|---|---|---|
| 1 | COLD | 830.249s | 19.558s | 4.149s | 4.149s | 858.105s |
| 2β4 | WARM | 4.149s | 4.149s | 4.149s | 4.149s | 16.597s |
Cold: 830.2s Β· Warm avg: 4.149 s/step Β· Throughput: 0.241 steps/s
Summary table
| Mode | Cold (s) | Warm avg/step | Throughput | vs vanilla compile |
|---|---|---|---|---|
| Eager, baseline | 9.3s | 0.835 s/step | 1.198 steps/s | 4.6Γ faster |
| Compile, vanilla | 533.4s | 3.869 s/step | 0.258 steps/s | 1Γ (baseline) |
Compile, --fused-qkv |
651.1s | 3.859 s/step | 0.259 steps/s | β0.3% (noise) |
Compile, --flash-attn |
862.3s | 4.159 s/step | 0.240 steps/s | +7.5% slower |
Compile, --fused-qkv --flash-attn |
830.2s | 4.149 s/step | 0.241 steps/s | +7.2% slower |
Interpretation
- Fused-QKV has no measurable effect in compile mode (3.859 vs 3.869 β within run-to-run noise). The Dynamo+NEFF compiler already fuses QKV projections at the HLO level; the explicit NKI kernel neither helps nor hurts, but adds 118s to cold compilation.
- Flash-attn is ~7% slower than vanilla regardless of whether fused-QKV is also enabled. The unnormalized single-pass softmax and sequential head loop are less efficient than the compiler's built-in attention decomposition (two-pass numerically stable, better SPMD utilisation).
- Combining both kernels gives the same result as flash-attn alone (4.149 vs 4.159 β within noise). fused-QKV contributes nothing additional in compile mode.
- Cold compilation time grows with NKI kernel count: vanilla (533s) β fused-qkv (651s, +22%) β combined (830s, +56%) β flash-attn alone (862s, +62%). Each NKI kernel adds a separate KLIR compilation pass inside neuronxcc.
8. Conclusions
| Kernel | Correct | Practical for eager | Practical for compile |
|---|---|---|---|
| NKI RoPE | β | No (S > 512 for single-stream) | No (same constraint) |
| NKI QKV | Yes | No β breaks XLA fusion (18Γ slower) | Negligible effect (within noise) |
| NKI Flash Attention | Yes (cosine=0.9999) | TBD | No β 7% slower than vanilla, +62% compile time |
| NKI QKV + Flash Attention | Yes | No | Same as flash-attn alone |