Prepare for Torch 2.8
Browse files- build.toml +0 -4
- flake.lock +13 -12
- flake.nix +2 -16
build.toml
CHANGED
@@ -20,7 +20,6 @@ cuda-flags = [
|
|
20 |
"--ftemplate-backtrace-limit=0", # To debug template code
|
21 |
"--use_fast_math",
|
22 |
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
23 |
-
"-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
|
24 |
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
25 |
"--expt-relaxed-constexpr",
|
26 |
"--expt-extended-lambda",
|
@@ -53,7 +52,6 @@ cuda-flags = [
|
|
53 |
"--ftemplate-backtrace-limit=0", # To debug template code
|
54 |
"--use_fast_math",
|
55 |
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
56 |
-
"-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
|
57 |
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
58 |
"--expt-relaxed-constexpr",
|
59 |
"--expt-extended-lambda",
|
@@ -202,7 +200,6 @@ cuda-flags = [
|
|
202 |
"--ftemplate-backtrace-limit=0", # To debug template code
|
203 |
"--use_fast_math",
|
204 |
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
205 |
-
"-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
|
206 |
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
207 |
"--expt-relaxed-constexpr",
|
208 |
"--expt-extended-lambda",
|
@@ -551,7 +548,6 @@ depends = ["torch", "cutlass_3_9"]
|
|
551 |
# "--ftemplate-backtrace-limit=0", # To debug template code
|
552 |
# "--use_fast_math",
|
553 |
# "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
554 |
-
# "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
|
555 |
# "-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
556 |
# "--expt-relaxed-constexpr",
|
557 |
# "--expt-extended-lambda",
|
|
|
20 |
"--ftemplate-backtrace-limit=0", # To debug template code
|
21 |
"--use_fast_math",
|
22 |
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
|
|
23 |
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
24 |
"--expt-relaxed-constexpr",
|
25 |
"--expt-extended-lambda",
|
|
|
52 |
"--ftemplate-backtrace-limit=0", # To debug template code
|
53 |
"--use_fast_math",
|
54 |
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
|
|
55 |
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
56 |
"--expt-relaxed-constexpr",
|
57 |
"--expt-extended-lambda",
|
|
|
200 |
"--ftemplate-backtrace-limit=0", # To debug template code
|
201 |
"--use_fast_math",
|
202 |
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
|
|
203 |
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
204 |
"--expt-relaxed-constexpr",
|
205 |
"--expt-extended-lambda",
|
|
|
548 |
# "--ftemplate-backtrace-limit=0", # To debug template code
|
549 |
# "--use_fast_math",
|
550 |
# "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
|
|
551 |
# "-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
552 |
# "--expt-relaxed-constexpr",
|
553 |
# "--expt-extended-lambda",
|
flake.lock
CHANGED
@@ -73,11 +73,11 @@
|
|
73 |
"nixpkgs": "nixpkgs"
|
74 |
},
|
75 |
"locked": {
|
76 |
-
"lastModified":
|
77 |
-
"narHash": "sha256-
|
78 |
"owner": "huggingface",
|
79 |
"repo": "hf-nix",
|
80 |
-
"rev": "
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
@@ -98,32 +98,33 @@
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
-
"lastModified":
|
102 |
-
"narHash": "sha256-
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
-
"rev": "
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
109 |
"owner": "huggingface",
|
|
|
110 |
"repo": "kernel-builder",
|
111 |
"type": "github"
|
112 |
}
|
113 |
},
|
114 |
"nixpkgs": {
|
115 |
"locked": {
|
116 |
-
"lastModified":
|
117 |
-
"narHash": "sha256-
|
118 |
-
"owner": "
|
119 |
"repo": "nixpkgs",
|
120 |
-
"rev": "
|
121 |
"type": "github"
|
122 |
},
|
123 |
"original": {
|
124 |
-
"owner": "
|
125 |
-
"ref": "cudatoolkit-12.9-kernel-builder",
|
126 |
"repo": "nixpkgs",
|
|
|
127 |
"type": "github"
|
128 |
}
|
129 |
},
|
|
|
73 |
"nixpkgs": "nixpkgs"
|
74 |
},
|
75 |
"locked": {
|
76 |
+
"lastModified": 1753354560,
|
77 |
+
"narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=",
|
78 |
"owner": "huggingface",
|
79 |
"repo": "hf-nix",
|
80 |
+
"rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3",
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
+
"lastModified": 1753354632,
|
102 |
+
"narHash": "sha256-31SX3Raiyx0qCuY9JSlx9ZZgxljeUxvW+JdujjxbofQ=",
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
+
"rev": "524b628fd8e58525dbd28455bffb0628092c5265",
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
109 |
"owner": "huggingface",
|
110 |
+
"ref": "torch-2.8",
|
111 |
"repo": "kernel-builder",
|
112 |
"type": "github"
|
113 |
}
|
114 |
},
|
115 |
"nixpkgs": {
|
116 |
"locked": {
|
117 |
+
"lastModified": 1752785354,
|
118 |
+
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
|
119 |
+
"owner": "nixos",
|
120 |
"repo": "nixpkgs",
|
121 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
122 |
"type": "github"
|
123 |
},
|
124 |
"original": {
|
125 |
+
"owner": "nixos",
|
|
|
126 |
"repo": "nixpkgs",
|
127 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
128 |
"type": "github"
|
129 |
}
|
130 |
},
|
flake.nix
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
description = "Flake for Hopper Flash Attention kernel";
|
3 |
|
4 |
inputs = {
|
5 |
-
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
};
|
7 |
|
8 |
outputs =
|
@@ -21,21 +21,7 @@
|
|
21 |
# by hand (which works fine thanks to backward compat).
|
22 |
torchVersions = [
|
23 |
{
|
24 |
-
torchVersion = "2.
|
25 |
-
cudaVersion = "12.4";
|
26 |
-
cxx11Abi = false;
|
27 |
-
systems = [ "x86_64-linux" ];
|
28 |
-
upstreamVariant = true;
|
29 |
-
}
|
30 |
-
{
|
31 |
-
torchVersion = "2.6";
|
32 |
-
cudaVersion = "12.4";
|
33 |
-
cxx11Abi = true;
|
34 |
-
systems = [ "x86_64-linux" ];
|
35 |
-
upstreamVariant = true;
|
36 |
-
}
|
37 |
-
{
|
38 |
-
torchVersion = "2.7";
|
39 |
cudaVersion = "12.4";
|
40 |
cxx11Abi = true;
|
41 |
systems = [
|
|
|
2 |
description = "Flake for Hopper Flash Attention kernel";
|
3 |
|
4 |
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8";
|
6 |
};
|
7 |
|
8 |
outputs =
|
|
|
21 |
# by hand (which works fine thanks to backward compat).
|
22 |
torchVersions = [
|
23 |
{
|
24 |
+
torchVersion = "2.8";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
cudaVersion = "12.4";
|
26 |
cxx11Abi = true;
|
27 |
systems = [
|