File size: 334 Bytes
e6010fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
[general]
name = "batch_invariant"
universal = false

# Defines the C++ files that bind to PyTorch
[torch]
src = [
  "torch-ext/torch_binding.cpp",
  "torch-ext/torch_binding.h"
]

# Defines the CUDA kernels
[kernel.batch_invariant_matmul]
backend = "cuda"
depends = ["torch"]
src = [
    "csrc/batch_invariant.cu",
]