kernel
danieldk HF Staff commited on
Commit
bc10fdc
·
1 Parent(s): 48fe103

Prepare for Torch 2.8

Browse files
Files changed (3) hide show
  1. build.toml +0 -4
  2. flake.lock +13 -12
  3. flake.nix +2 -16
build.toml CHANGED
@@ -20,7 +20,6 @@ cuda-flags = [
20
  "--ftemplate-backtrace-limit=0", # To debug template code
21
  "--use_fast_math",
22
  "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
23
- "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
24
  "-DCUTLASS_ENABLE_GDC_FOR_SM90",
25
  "--expt-relaxed-constexpr",
26
  "--expt-extended-lambda",
@@ -53,7 +52,6 @@ cuda-flags = [
53
  "--ftemplate-backtrace-limit=0", # To debug template code
54
  "--use_fast_math",
55
  "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
56
- "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
57
  "-DCUTLASS_ENABLE_GDC_FOR_SM90",
58
  "--expt-relaxed-constexpr",
59
  "--expt-extended-lambda",
@@ -202,7 +200,6 @@ cuda-flags = [
202
  "--ftemplate-backtrace-limit=0", # To debug template code
203
  "--use_fast_math",
204
  "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
205
- "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
206
  "-DCUTLASS_ENABLE_GDC_FOR_SM90",
207
  "--expt-relaxed-constexpr",
208
  "--expt-extended-lambda",
@@ -551,7 +548,6 @@ depends = ["torch", "cutlass_3_9"]
551
  # "--ftemplate-backtrace-limit=0", # To debug template code
552
  # "--use_fast_math",
553
  # "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
554
- # "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
555
  # "-DCUTLASS_ENABLE_GDC_FOR_SM90",
556
  # "--expt-relaxed-constexpr",
557
  # "--expt-extended-lambda",
 
20
  "--ftemplate-backtrace-limit=0", # To debug template code
21
  "--use_fast_math",
22
  "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
 
23
  "-DCUTLASS_ENABLE_GDC_FOR_SM90",
24
  "--expt-relaxed-constexpr",
25
  "--expt-extended-lambda",
 
52
  "--ftemplate-backtrace-limit=0", # To debug template code
53
  "--use_fast_math",
54
  "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
 
55
  "-DCUTLASS_ENABLE_GDC_FOR_SM90",
56
  "--expt-relaxed-constexpr",
57
  "--expt-extended-lambda",
 
200
  "--ftemplate-backtrace-limit=0", # To debug template code
201
  "--use_fast_math",
202
  "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
 
203
  "-DCUTLASS_ENABLE_GDC_FOR_SM90",
204
  "--expt-relaxed-constexpr",
205
  "--expt-extended-lambda",
 
548
  # "--ftemplate-backtrace-limit=0", # To debug template code
549
  # "--use_fast_math",
550
  # "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
 
551
  # "-DCUTLASS_ENABLE_GDC_FOR_SM90",
552
  # "--expt-relaxed-constexpr",
553
  # "--expt-extended-lambda",
flake.lock CHANGED
@@ -73,11 +73,11 @@
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
- "lastModified": 1750234878,
77
- "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
- "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
  "type": "github"
82
  },
83
  "original": {
@@ -98,32 +98,33 @@
98
  ]
99
  },
100
  "locked": {
101
- "lastModified": 1751014803,
102
- "narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
- "rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
 
110
  "repo": "kernel-builder",
111
  "type": "github"
112
  }
113
  },
114
  "nixpkgs": {
115
  "locked": {
116
- "lastModified": 1747820358,
117
- "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
- "owner": "danieldk",
119
  "repo": "nixpkgs",
120
- "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
  "type": "github"
122
  },
123
  "original": {
124
- "owner": "danieldk",
125
- "ref": "cudatoolkit-12.9-kernel-builder",
126
  "repo": "nixpkgs",
 
127
  "type": "github"
128
  }
129
  },
 
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
+ "lastModified": 1753354560,
77
+ "narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
+ "rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3",
81
  "type": "github"
82
  },
83
  "original": {
 
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1753354632,
102
+ "narHash": "sha256-31SX3Raiyx0qCuY9JSlx9ZZgxljeUxvW+JdujjxbofQ=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
+ "rev": "524b628fd8e58525dbd28455bffb0628092c5265",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
110
+ "ref": "torch-2.8",
111
  "repo": "kernel-builder",
112
  "type": "github"
113
  }
114
  },
115
  "nixpkgs": {
116
  "locked": {
117
+ "lastModified": 1752785354,
118
+ "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
119
+ "owner": "nixos",
120
  "repo": "nixpkgs",
121
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
122
  "type": "github"
123
  },
124
  "original": {
125
+ "owner": "nixos",
 
126
  "repo": "nixpkgs",
127
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
128
  "type": "github"
129
  }
130
  },
flake.nix CHANGED
@@ -2,7 +2,7 @@
2
  description = "Flake for Hopper Flash Attention kernel";
3
 
4
  inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
  };
7
 
8
  outputs =
@@ -21,21 +21,7 @@
21
  # by hand (which works fine thanks to backward compat).
22
  torchVersions = [
23
  {
24
- torchVersion = "2.6";
25
- cudaVersion = "12.4";
26
- cxx11Abi = false;
27
- systems = [ "x86_64-linux" ];
28
- upstreamVariant = true;
29
- }
30
- {
31
- torchVersion = "2.6";
32
- cudaVersion = "12.4";
33
- cxx11Abi = true;
34
- systems = [ "x86_64-linux" ];
35
- upstreamVariant = true;
36
- }
37
- {
38
- torchVersion = "2.7";
39
  cudaVersion = "12.4";
40
  cxx11Abi = true;
41
  systems = [
 
2
  description = "Flake for Hopper Flash Attention kernel";
3
 
4
  inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8";
6
  };
7
 
8
  outputs =
 
21
  # by hand (which works fine thanks to backward compat).
22
  torchVersions = [
23
  {
24
+ torchVersion = "2.8";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  cudaVersion = "12.4";
26
  cxx11Abi = true;
27
  systems = [