""" """ import torch from kernels import get_kernel _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func @torch.library.custom_op("flash::flash_attn_func", mutates_args=()) def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: outputs, lse = _flash_attn_func(q, k, v) return outputs @flash_attn_func.register_fake def _(q, k, v, **kwargs): return torch.empty_like(q).contiguous()