Geometric-AI Kernels

Fused CuteDSL kernels for the loss functions that dominate post-training workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL self-distillation.

Each kernel ships a single-launch fused forward + backward path that returns (loss, grad_logprobs) directly. No torch.autograd.Function wrapper, no extra grad_output * dpolicy backward kernel, and no host-side syncs in the hot path.

Background and benchmarks: see the release post.

  • Backend: CUDA (NVIDIA CUTLASS DSL).
  • Min GPU: SM80 (Ampere) - required by nvidia-cutlass-dsl. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
  • Min CUDA: 12.8.
  • Dtypes: float32, float16, bfloat16.
  • Dynamic shapes: a single compile handles arbitrary batch size and sequence length, no recompiles when shapes change between calls (common in post-training rollouts).

Kernels

Kernel family Direct (no autograd) Autograd-aware Forward-only
BNPO loss bnpo_loss bnpo_loss_autograd bnpo_loss_fwd
GRPO loss grpo_loss grpo_loss_autograd grpo_loss_fwd
Reverse KL reverse_kl reverse_kl_autograd reverse_kl_fwd

Entry points

Each kernel family exposes three entry points with the same underlying CuteDSL kernel:

  • <name>(...) - fused fwd+bwd, returns (loss, grad) from one @cute.jit dispatch. Lowest-overhead path; the caller chains the gradient into the upstream model with policy_logprobs.backward(grad). Use this in custom training loops where you control gradient flow.
  • <name>_autograd(...) - same kernel, registered via torch.library.custom_op + register_autograd. loss.backward() works and composes with torch.compile(fullgraph=True). There is a noticeable per-call dispatcher overhead vs. the direct path.
  • <name>_fwd(...) - forward-only, returns scalar loss and skips the gradient buffer entirely. Use for inference / validation / reward-model scoring.

Loading the kernels

pip install apache-tvm-ffi nvidia-cutlass-dsl
from kernels import get_kernel

km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)

BNPO Loss

Batch-Normalized Policy Optimization sums per-token policy and KL terms across the entire batch and divides by the global valid-token count:

loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1)

where per_token_loss is the PPO-clipped ratio loss:

ratio      = exp(policy_logprobs - old_policy_logprobs)
clipped    = clip(ratio, 1−ε, 1+ε_high)
per_token  = −advantages · min(ratio, clipped)
kl         = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1

The global denominator is computed entirely on-GPU via cross-CTA atomics - no host-side mask.sum() sync. When beta=0 the KL branch is dead-coded at compile time.

Inputs:

  • policy_logprobs, old_policy_logprobs, ref_logprobs: (bs, seq_len), fp32/fp16/bf16
  • advantages: (bs,)
  • completions_mask: (bs, seq_len), bool or int8

Returns: (loss, grad_policy_logprobs) from bnpo_loss; scalar loss from bnpo_loss_fwd.

import torch
from kernels import get_kernel

km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")

bs, seq_len = 16, 1024
policy_logprobs     = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs        = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages          = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask    = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)

# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.bnpo_loss(
    policy_logprobs, old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)

# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.bnpo_loss_autograd(
    policy_logprobs.requires_grad_(),
    old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()

# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.bnpo_loss_fwd(
    policy_logprobs, old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)

GRPO Loss

Group Relative Policy Optimization implements TRL's default per-response normalization variant - each response is normalized by its own valid-token count before averaging across the batch:

loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) )

per_token_loss and kl are the same clipped-ratio and KL expressions as BNPO. completions_mask is required because the per-response denominator is mask-derived. The kernel uses one CTA per row so the per-row mask sum is reduced inside the block - no cross-CTA atomics on the scaling pass.

Inputs:

  • policy_logprobs, old_policy_logprobs, ref_logprobs: (bs, seq_len), fp32/fp16/bf16
  • advantages: (bs,)
  • completions_mask: (bs, seq_len), bool or int8 - required

Returns: (loss, grad_policy_logprobs) from grpo_loss; scalar loss from grpo_loss_fwd.

import torch
from kernels import get_kernel

km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")

bs, seq_len = 16, 1024
policy_logprobs     = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs        = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages          = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask    = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)

# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.grpo_loss(
    policy_logprobs, old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)

# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.grpo_loss_autograd(
    policy_logprobs.requires_grad_(),
    old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()

# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.grpo_loss_fwd(
    policy_logprobs, old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)

Reverse KL

Reverse-KL self-distillation computes KL(student ‖ teacher) over a (num_tokens, vocab) slab using an online normalization algorithm that reads each logit row exactly once on the forward-only path:

p = softmax(student_logits)
q = softmax(teacher_logits)
kl_per_row = Σ_v  p_v · (log p_v − log q_v)
loss = (mask · kl_per_row).sum() / mask.sum()

The gradient through the softmax Jacobian is analytical:

grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)

where scale = mask[r] · inv_n_valid.

Inputs:

  • student_logits, teacher_logits: (*, V) - arbitrary leading dims (typically (bs, seq_len, vocab)); both must share shape and dtype
  • completions_mask: shape matching student_logits.shape[:-1]

⚠️ Fully-masked batches: inv_n_valid = 1 / mask.sum() is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.

Returns: (loss, grad_student_logits) from reverse_kl; scalar loss from reverse_kl_fwd.

import torch
from kernels import get_kernel

km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")

# Qwen3.5-style vocab; arbitrary leading dims supported
bs, seq_len, vocab = 4, 256, 248320
student_logits  = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True)
teacher_logits  = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)

# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
student_logits.backward(grad)

# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.reverse_kl_autograd(
    student_logits.requires_grad_(), teacher_logits, completions_mask
)
loss.backward()

# 3) Forward-only - inference / KL monitoring, no gradient buffer
loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)

Performance

All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology and per-shape plots in the release post.

kernels CLI benchmark

Timed with time.perf_counter + cuda.synchronize(), mean over 100 iterations.

Kernel vs eager vs torch.compile
grpo_loss_fwd 5.68× 2.45×
grpo_loss 20.79× 1.98x
bnpo_loss_fwd 5.29× 2.52×
bnpo_loss 16.81× 2.27×
reverse_kl_fwd 6.88× 2.45×
reverse_kl 7.03× 2.61×

Benchmark animations

BNPO Loss vs eager PyTorch

BNPO loss latency vs eager PyTorch

BNPO Loss vs torch.compile

BNPO loss latency vs torch.compile

GRPO Loss vs eager PyTorch

GRPO loss latency vs eager PyTorch

GRPO Loss vs torch.compile

GRPO loss latency vs torch.compile

Reverse KL vs eager PyTorch

Reverse KL latency vs eager PyTorch

Reverse KL vs torch.compile

Reverse KL latency vs torch.compile
Downloads last month
10
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support