[general] | |
name = "megablocks" | |
universal = false | |
[torch] | |
src = [ | |
"torch-ext/torch_binding.cpp", | |
"torch-ext/torch_binding.h" | |
] | |
[kernel.megablocks] | |
backend = "rocm" | |
rocm-archs = [ | |
"gfx942", | |
"gfx1030", | |
"gfx1100", | |
"gfx1101", | |
] | |
depends = ["torch"] | |
src = [ | |
"csrc/new_cumsum.h", | |
"csrc/new_cumsum.cu", | |
"csrc/new_histogram.h", | |
"csrc/new_histogram.cu", | |
"csrc/new_indices.h", | |
"csrc/new_indices.cu", | |
"csrc/new_replicate.cu", | |
"csrc/new_replicate.h", | |
"csrc/new_sort.h", | |
"csrc/new_sort.cu", | |
# vendored grouped gemm | |
#"csrc/grouped_gemm/fill_arguments.cuh", | |
#"csrc/grouped_gemm/grouped_gemm.cu", | |
#"csrc/grouped_gemm/grouped_gemm.h", | |
] |