NKI Kernel Experiments β€” Flux2-klein-4B on Neuron

Hardware: AWS Trn1.32xlarge (32 NeuronCores), TP=4, bfloat16 Model: black-forest-labs/FLUX.2-klein-4B Shapes: B=1, 512Γ—512 β†’ img_S=256 (2Γ— patchify), txt_S=512, inner_dim H=3072, n_heads=24, head_dim=128


1. RoPE kernel (nkilib.core.embeddings.rope)

Kernel constraints

  • d_head ∈ {64, 128} (Flux2-klein: 128 βœ“)
  • S ≀ 512 β€” applied to sequence-length dimension before attention
  • n_heads ≀ 16 per rank β€” after TP=4 sharding: 24/4 = 6 βœ“
  • Input layout must be [B, n_heads, S, d_head]

Flux2-klein applicability

Block type S Fits S≀512? Notes
Single-stream img_S + txt_S = 256 + 512 = 768 No RoPE is applied to the concatenated image+text sequence
Double-stream (image) img_S = 256 Yes But double-stream blocks apply RoPE inside FluxAttnProcessor after separate Q/K projections β€” hooks into NKI require custom processor
Double-stream (text) txt_S = 512 Yes (boundary)

Verdict: Not practical. Single-stream blocks (20/25 total) exceed S=512. Double-stream (5/25) would require custom processors. The XLA compiler already fuses RoPE with the surrounding matmuls in the same NEFF β€” a standalone NKI kernel would break that fusion (see Β§3).


2. Pipeline integration β€” --fused-qkv flag

Implementation: Flux2AttnProcessorFusedQKV in pipeline.py, activated by --fused-qkv. Replaces 3 separate to_q / to_k / to_v ColwisePar linear calls in double-stream blocks with a single NKI nki_qkv kernel call.

Timing (warm steps, avg of last 3/20 steps)

Mode Steps Warm avg (s/step) vs baseline
Eager, baseline 20 0.824 s/step 1Γ—
Eager, --fused-qkv 20 14.86 s/step 18Γ— slower

Output correctness: identical pixel range, mean, and std at every step β€” the kernel produces correct results.

Root cause: XLA whole-block fusion

In eager (lazy-XLA) mode, the XLA compiler traces the entire transformer block as one HLO program and compiles it into a single NEFF (neff_cache/{hash}.neff). This fuses:

  • All Q/K/V projections
  • RoPE embeddings
  • Flash attention (via custom prim decomposition)
  • Output projection + MLP
  • Layer norms

Inserting a standalone NKI kernel (@nki.jit) creates opaque tensor boundaries β€” XLA cannot inline or fuse across NKI kernel calls. The compiler sees:

[XLA subgraph] β†’ NKI qkv kernel β†’ [XLA subgraph]

instead of one monolithic NEFF. This fragmentation:

  1. Adds kernel launch overhead (PCIe round-trips for each NKI call)
  2. Prevents data reuse that XLA would achieve within the fused NEFF
  3. Defeats the cache: the fragmented graphs generate different, smaller NEFFs with no sharing benefit

The 18Γ— slowdown is consistent with this β€” the baseline fused NEFF is highly optimised; the fragmented version is not.


3. Compile mode + fused QKV (--mode compile --fused-qkv) β€” bug fix note


4. Flash attention kernel (flux2_flash_attn)

Script: examples/flux2-klein/nki_flash_attn.py Run: torchrun --nproc_per_node=4 flux2-klein/nki_flash_attn.py

Two-pass online softmax, BLOCK_Q=128, BLOCK_K=128, bidirectional (no causal mask). Uses the older NKI ISA API (sbuf.view / psum.view / hbm.view / nisa.*).

Algorithm

For each head (looped over N=6 sequentially in one kernel instance): For each Q tile (q_idx = 0..5): Pass 1 of online softmax (here collapsed into single-pass via exp-only): For each K tile (ks = 0..5): score_T = k_tile.T @ q_tile (BLOCK_K, BLOCK_Q) via nc_matmul transposed trick probs_T = exp(score_T * scale) out_psum += probs_T.T @ v_tile (BLOCK_Q, D) row_sum += probs_T.T @ ones_v (BLOCK_Q, 1) out = out_psum / row_sum β†’ bf16 β†’ HBM

Note: this is an unnormalized (non-numerically-stable) softmax β€” no row_max subtraction. Suitable for correctness test; may overflow for long sequences or large activations.

5. Compile mode full comparison

Hardware: trn2.3xlarge, TP=4, bfloat16, 512Γ—512, 4 steps, random weights, 4 runs (1 cold + 3 warm) Date: 2026-03-31 | neuronxcc: 2.0.236418.0a0+9af338ad

All four compile-mode variants measured on the same neuronxcc build for a fair apples-to-apples comparison.

Vanilla compile (no custom kernels)

Run Type step01 step02 step03 step04 total
1 COLD 533.449s 3.868s 3.868s 3.868s 545.053s
2–4 WARM 3.868s 3.869s 3.869s 3.869s 15.475s

Cold: 533.4s Β· Warm avg: 3.869 s/step Β· Throughput: 0.258 steps/s

Compile + --fused-qkv

Run Type step01 step02 step03 step04 total
1 COLD 651.147s 19.874s 3.859s 3.859s 678.740s
2–4 WARM 3.859s 3.859s 3.860s 3.860s 15.438s

Cold: 651.1s Β· Warm avg: 3.859 s/step Β· Throughput: 0.259 steps/s

Compile + --flash-attn

Run Type step01 step02 step03 step04 total
1 COLD 862.344s 19.601s 4.159s 4.159s 890.263s
2–4 WARM 4.159s 4.159s 4.159s 4.159s 16.636s

Cold: 862.3s Β· Warm avg: 4.159 s/step Β· Throughput: 0.240 steps/s

Compile + --fused-qkv --flash-attn (combined)

Run Type step01 step02 step03 step04 total
1 COLD 830.249s 19.558s 4.149s 4.149s 858.105s
2–4 WARM 4.149s 4.149s 4.149s 4.149s 16.597s

Cold: 830.2s Β· Warm avg: 4.149 s/step Β· Throughput: 0.241 steps/s

Summary table

Mode Cold (s) Warm avg/step Throughput vs vanilla compile
Eager, baseline 9.3s 0.835 s/step 1.198 steps/s 4.6Γ— faster
Compile, vanilla 533.4s 3.869 s/step 0.258 steps/s 1Γ— (baseline)
Compile, --fused-qkv 651.1s 3.859 s/step 0.259 steps/s βˆ’0.3% (noise)
Compile, --flash-attn 862.3s 4.159 s/step 0.240 steps/s +7.5% slower
Compile, --fused-qkv --flash-attn 830.2s 4.149 s/step 0.241 steps/s +7.2% slower

Interpretation

  • Fused-QKV has no measurable effect in compile mode (3.859 vs 3.869 β€” within run-to-run noise). The Dynamo+NEFF compiler already fuses QKV projections at the HLO level; the explicit NKI kernel neither helps nor hurts, but adds 118s to cold compilation.
  • Flash-attn is ~7% slower than vanilla regardless of whether fused-QKV is also enabled. The unnormalized single-pass softmax and sequential head loop are less efficient than the compiler's built-in attention decomposition (two-pass numerically stable, better SPMD utilisation).
  • Combining both kernels gives the same result as flash-attn alone (4.149 vs 4.159 β€” within noise). fused-QKV contributes nothing additional in compile mode.
  • Cold compilation time grows with NKI kernel count: vanilla (533s) β†’ fused-qkv (651s, +22%) β†’ combined (830s, +56%) β†’ flash-attn alone (862s, +62%). Each NKI kernel adds a separate KLIR compilation pass inside neuronxcc.

8. Conclusions

Kernel Correct Practical for eager Practical for compile
NKI RoPE β€” No (S > 512 for single-stream) No (same constraint)
NKI QKV Yes No β€” breaks XLA fusion (18Γ— slower) Negligible effect (within noise)
NKI Flash Attention Yes (cosine=0.9999) TBD No β€” 7% slower than vanilla, +62% compile time
NKI QKV + Flash Attention Yes No Same as flash-attn alone
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support