File size: 334 Bytes
e6010fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
[general]
name = "batch_invariant"
universal = false
# Defines the C++ files that bind to PyTorch
[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h"
]
# Defines the CUDA kernels
[kernel.batch_invariant_matmul]
backend = "cuda"
depends = ["torch"]
src = [
"csrc/batch_invariant.cu",
] |