# Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py # The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema. from typing import Optional, Callable, Iterable, Union from torch.library import custom_op, CustomOpDef from torch._library.triton import set_wrap_triton_enabled def triton_op( name: str, fn: Optional[Callable] = None, /, *, mutates_args: Union[str, Iterable[str]], schema: Optional[str] = None, # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False, # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator # and so inductor can't trace inside. allow_decomposition=True, ) -> Callable: def dec(fn: Callable[..., object]) -> CustomOpDef: def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] # Optimization: we're passing regular Tensors into the triton kernel, so # no need to go through HOP dispatch with set_wrap_triton_enabled(False): return fn(*args, **kwargs) result = custom_op( name, backend_fn, mutates_args=mutates_args, # This is the only difference with the PyTorch implementation schema=schema, ) from torch._subclasses.functional_tensor import FunctionalTensorMode # We require that the user pass us a function that is make_fx traceable, # so we can just register it as the Fake/meta kernel. result.register_fake(fn) if allow_decomposition: # We decompose the operator when FunctionalTensorMode is active. # The goal is to decompose the operator in AOTDispatcher. # - With torch.compile, this means that the backend (usually Inductor) # can see a call to the triton kernel(s) and so it can directly optimize # them by inlining them into the lowering process. def functional_decomp( # type: ignore[no-untyped-def] mode, op, types, args, kwargs ): from torch.export._trace import custom_triton_ops_decomposition_disabled if custom_triton_ops_decomposition_disabled(): return mode.__torch_dispatch__(op, types, args, kwargs) else: with mode: return fn(*args, **kwargs) result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) return result if fn is None: return dec else: return dec(fn)