from typing import Optional import torch from ._ops import ops def mha_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, alibi_slopes: torch.Tensor, p_dropout: float, softmax_scale: float, is_causal: bool, window_size_left: int, window_size_right: int, softcap: float, return_softmax: bool, gen: Optional[torch.Generator], ) -> torch.Tensor: ops.mha_fwd( q, k, v, out, alibi_slopes, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, softcap, return_softmax, gen, ) return out