[general] | |
name = "batch_invariant" | |
universal = false | |
# Defines the C++ files that bind to PyTorch | |
[torch] | |
src = [ | |
"torch-ext/torch_binding.cpp", | |
"torch-ext/torch_binding.h" | |
] | |
# Defines the CUDA kernels | |
[kernel.batch_invariant_matmul] | |
backend = "cuda" | |
depends = ["torch"] | |
src = [ | |
"csrc/batch_invariant.cu", | |
] |