Geometric-AI Kernels
Fused CuteDSL kernels for the loss functions that dominate post-training workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL self-distillation.
Each kernel ships a single-launch fused forward +
backward path that returns (loss, grad_logprobs) directly. No torch.autograd.Function wrapper, no extra grad_output * dpolicy backward
kernel, and no host-side syncs in the hot path.
Background and benchmarks: see the release post.
- Backend: CUDA (NVIDIA CUTLASS DSL).
- Min GPU: SM80 (Ampere) - required by
nvidia-cutlass-dsl. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200). - Min CUDA: 12.8.
- Dtypes:
float32,float16,bfloat16. - Dynamic shapes: a single compile handles arbitrary batch size and sequence length, no recompiles when shapes change between calls (common in post-training rollouts).
Kernels
| Kernel family | Direct (no autograd) | Autograd-aware | Forward-only |
|---|---|---|---|
| BNPO loss | bnpo_loss |
bnpo_loss_autograd |
bnpo_loss_fwd |
| GRPO loss | grpo_loss |
grpo_loss_autograd |
grpo_loss_fwd |
| Reverse KL | reverse_kl |
reverse_kl_autograd |
reverse_kl_fwd |
Entry points
Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
<name>(...)- fused fwd+bwd, returns(loss, grad)from one@cute.jitdispatch. Lowest-overhead path; the caller chains the gradient into the upstream model withpolicy_logprobs.backward(grad). Use this in custom training loops where you control gradient flow.<name>_autograd(...)- same kernel, registered viatorch.library.custom_op+register_autograd.loss.backward()works and composes withtorch.compile(fullgraph=True). There is a noticeable per-call dispatcher overhead vs. the direct path.<name>_fwd(...)- forward-only, returns scalarlossand skips the gradient buffer entirely. Use for inference / validation / reward-model scoring.
Loading the kernels
pip install apache-tvm-ffi nvidia-cutlass-dsl
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
BNPO Loss
Batch-Normalized Policy Optimization sums per-token policy and KL terms across the entire batch and divides by the global valid-token count:
loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1)
where per_token_loss is the PPO-clipped ratio loss:
ratio = exp(policy_logprobs - old_policy_logprobs)
clipped = clip(ratio, 1−ε, 1+ε_high)
per_token = −advantages · min(ratio, clipped)
kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
The global denominator is computed entirely on-GPU via cross-CTA atomics -
no host-side mask.sum() sync. When beta=0 the KL branch is dead-coded
at compile time.
Inputs:
policy_logprobs,old_policy_logprobs,ref_logprobs:(bs, seq_len), fp32/fp16/bf16advantages:(bs,)completions_mask:(bs, seq_len), bool or int8
Returns: (loss, grad_policy_logprobs) from bnpo_loss; scalar loss from bnpo_loss_fwd.
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
bs, seq_len = 16, 1024
policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.bnpo_loss(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.bnpo_loss_autograd(
policy_logprobs.requires_grad_(),
old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()
# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.bnpo_loss_fwd(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
GRPO Loss
Group Relative Policy Optimization implements TRL's default per-response normalization variant - each response is normalized by its own valid-token count before averaging across the batch:
loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) )
per_token_loss and kl are the same clipped-ratio and KL expressions as BNPO.
completions_mask is required because the per-response denominator is
mask-derived. The kernel uses one CTA per row so the per-row mask sum is
reduced inside the block - no cross-CTA atomics on the scaling pass.
Inputs:
policy_logprobs,old_policy_logprobs,ref_logprobs:(bs, seq_len), fp32/fp16/bf16advantages:(bs,)completions_mask:(bs, seq_len), bool or int8 - required
Returns: (loss, grad_policy_logprobs) from grpo_loss; scalar loss from grpo_loss_fwd.
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
bs, seq_len = 16, 1024
policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.grpo_loss(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.grpo_loss_autograd(
policy_logprobs.requires_grad_(),
old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()
# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.grpo_loss_fwd(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
Reverse KL
Reverse-KL self-distillation computes KL(student ‖ teacher) over a
(num_tokens, vocab) slab using an online normalization algorithm that reads
each logit row exactly once on the forward-only path:
p = softmax(student_logits)
q = softmax(teacher_logits)
kl_per_row = Σ_v p_v · (log p_v − log q_v)
loss = (mask · kl_per_row).sum() / mask.sum()
The gradient through the softmax Jacobian is analytical:
grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
where scale = mask[r] · inv_n_valid.
Inputs:
student_logits,teacher_logits:(*, V)- arbitrary leading dims (typically(bs, seq_len, vocab)); both must share shape and dtypecompletions_mask: shape matchingstudent_logits.shape[:-1]
⚠️ Fully-masked batches:
inv_n_valid = 1 / mask.sum()is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
Returns: (loss, grad_student_logits) from reverse_kl; scalar loss from reverse_kl_fwd.
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
# Qwen3.5-style vocab; arbitrary leading dims supported
bs, seq_len, vocab = 4, 256, 248320
student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True)
teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
student_logits.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.reverse_kl_autograd(
student_logits.requires_grad_(), teacher_logits, completions_mask
)
loss.backward()
# 3) Forward-only - inference / KL monitoring, no gradient buffer
loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
Performance
All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology and per-shape plots in the release post.
kernels CLI benchmark
Timed with time.perf_counter + cuda.synchronize(), mean over 100 iterations.
| Kernel | vs eager | vs torch.compile |
|---|---|---|
grpo_loss_fwd |
5.68× | 2.45× |
grpo_loss |
20.79× | 1.98x |
bnpo_loss_fwd |
5.29× | 2.52× |
bnpo_loss |
16.81× | 2.27× |
reverse_kl_fwd |
6.88× | 2.45× |
reverse_kl |
7.03× | 2.61× |
Benchmark animations
BNPO Loss vs eager PyTorch
BNPO Loss vs torch.compile
GRPO Loss vs eager PyTorch
GRPO Loss vs torch.compile
Reverse KL vs eager PyTorch
Reverse KL vs torch.compile
- Downloads last month
- 10