danieldk HF Staff commited on
Commit
be5bedb
·
1 Parent(s): 6c9920d

Sync with upstream

Browse files
activation/activation_kernels.cu CHANGED
@@ -9,8 +9,16 @@
9
 
10
  namespace vllm {
11
 
 
 
 
 
 
 
12
  // Activation and gating kernel template.
13
- template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
 
 
14
  __global__ void act_and_mul_kernel(
15
  scalar_t* __restrict__ out, // [..., d]
16
  const scalar_t* __restrict__ input, // [..., 2, d]
@@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
19
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
20
  const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
21
  const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
22
- out[token_idx * d + idx] = ACT_FN(x) * y;
23
  }
24
  }
25
 
@@ -55,16 +63,21 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
55
  } // namespace vllm
56
 
57
  // Launch activation and gating kernel.
58
- #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
 
 
59
  int d = input.size(-1) / 2; \
60
  int64_t num_tokens = input.numel() / input.size(-1); \
61
  dim3 grid(num_tokens); \
62
  dim3 block(std::min(d, 1024)); \
 
 
 
63
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
64
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
65
  VLLM_DISPATCH_FLOATING_TYPES( \
66
  input.scalar_type(), "act_and_mul_kernel", [&] { \
67
- vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
68
  <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
69
  input.data_ptr<scalar_t>(), d); \
70
  });
@@ -72,19 +85,27 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
72
  void silu_and_mul(torch::Tensor& out, // [..., d]
73
  torch::Tensor& input) // [..., 2 * d]
74
  {
75
- LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
 
 
 
 
 
 
 
 
76
  }
77
 
78
  void gelu_and_mul(torch::Tensor& out, // [..., d]
79
  torch::Tensor& input) // [..., 2 * d]
80
  {
81
- LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
82
  }
83
 
84
  void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
85
  torch::Tensor& input) // [..., 2 * d]
86
  {
87
- LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
88
  }
89
 
90
  namespace vllm {
 
9
 
10
  namespace vllm {
11
 
12
+ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
13
+ bool act_first>
14
+ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
15
+ const scalar_t& y) {
16
+ return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
17
+ }
18
  // Activation and gating kernel template.
19
+
20
+ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
21
+ bool act_first>
22
  __global__ void act_and_mul_kernel(
23
  scalar_t* __restrict__ out, // [..., d]
24
  const scalar_t* __restrict__ input, // [..., 2, d]
 
27
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
28
  const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
29
  const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
30
+ out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
31
  }
32
  }
33
 
 
63
  } // namespace vllm
64
 
65
  // Launch activation and gating kernel.
66
+ // Use ACT_FIRST (bool) indicating whether to apply the activation function
67
+ // first.
68
+ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
69
  int d = input.size(-1) / 2; \
70
  int64_t num_tokens = input.numel() / input.size(-1); \
71
  dim3 grid(num_tokens); \
72
  dim3 block(std::min(d, 1024)); \
73
+ if (num_tokens == 0) { \
74
+ return; \
75
+ } \
76
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
77
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
78
  VLLM_DISPATCH_FLOATING_TYPES( \
79
  input.scalar_type(), "act_and_mul_kernel", [&] { \
80
+ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
81
  <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
82
  input.data_ptr<scalar_t>(), d); \
83
  });
 
85
  void silu_and_mul(torch::Tensor& out, // [..., d]
86
  torch::Tensor& input) // [..., 2 * d]
87
  {
88
+ LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
89
+ }
90
+
91
+ void mul_and_silu(torch::Tensor& out, // [..., d]
92
+ torch::Tensor& input) // [..., 2 * d]
93
+ {
94
+ // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
95
+ // applies the silu to the latter half of the input.
96
+ LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
97
  }
98
 
99
  void gelu_and_mul(torch::Tensor& out, // [..., d]
100
  torch::Tensor& input) // [..., 2 * d]
101
  {
102
+ LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
103
  }
104
 
105
  void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
106
  torch::Tensor& input) // [..., 2 * d]
107
  {
108
+ LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
109
  }
110
 
111
  namespace vllm {
activation/cuda_compat.h CHANGED
@@ -4,10 +4,10 @@
4
  #include <hip/hip_runtime.h>
5
  #endif
6
 
7
- #ifndef USE_ROCM
8
- #define WARP_SIZE 32
9
  #else
10
- #define WARP_SIZE warpSize
11
  #endif
12
 
13
  #ifndef USE_ROCM
 
4
  #include <hip/hip_runtime.h>
5
  #endif
6
 
7
+ #if defined(USE_ROCM) && defined(__GFX9__)
8
+ #define WARP_SIZE 64
9
  #else
10
+ #define WARP_SIZE 32
11
  #endif
12
 
13
  #ifndef USE_ROCM
activation/dispatch_utils.h CHANGED
@@ -6,6 +6,11 @@
6
 
7
  #include <torch/all.h>
8
 
 
 
 
 
 
9
  #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
@@ -14,6 +19,35 @@
14
  #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
18
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
19
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
@@ -31,5 +65,19 @@
31
  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
32
  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
33
 
 
 
 
 
 
 
 
 
 
 
34
  #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
35
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
 
 
 
 
 
6
 
7
  #include <torch/all.h>
8
 
9
+ // Need a special dispatch case macro since we will nest the FP8 dispatch.
10
+ // Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
11
+ #define AT_DISPATCH_FP8_CASE(enum_type, ...) \
12
+ AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
13
+
14
  #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
15
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
16
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
 
19
  #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
20
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
21
 
22
+ // ROCm devices might use either fn or fnuz, so set up dispatch table for both.
23
+ // A host-based check at runtime will create a preferred FP8 type for ROCm
24
+ // such that the correct kernel is dispatched.
25
+ #ifdef USE_ROCM
26
+ #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
27
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
28
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
29
+
30
+ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
31
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
32
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
33
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
34
+ #else
35
+ #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
36
+ AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
37
+
38
+ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
39
+ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
40
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
41
+ #endif
42
+
43
+ // When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
44
+ // See AT_DISPATCH_FP8_CASE above.
45
+ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
46
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
47
+
48
+ #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
49
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
50
+
51
  #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
52
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
53
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
 
65
  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
66
  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
67
 
68
+ #define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
69
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
70
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
71
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
72
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
73
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
74
+ AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
75
+ AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
76
+ AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
77
+
78
  #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
79
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
80
+
81
+ #define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
82
+ AT_DISPATCH_SWITCH( \
83
+ TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
build/torch26-cxx98-cu124-x86_64-linux/activation/layers.py CHANGED
@@ -5,6 +5,15 @@ from ._ops import ops
5
 
6
 
7
  class SiluAndMul(nn.Module):
 
 
 
 
 
 
 
 
 
8
  can_torch_compile: bool = True
9
 
10
  def forward(self, x: torch.Tensor):
@@ -14,8 +23,35 @@ class SiluAndMul(nn.Module):
14
  ops.silu_and_mul(out, x)
15
  return out
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class GeluAndMul(nn.Module):
 
 
 
 
 
 
 
 
 
19
  can_torch_compile: bool = True
20
 
21
  def forward(self, x: torch.Tensor):
@@ -38,6 +74,17 @@ class GeluTanhAndMul(nn.Module):
38
 
39
 
40
  class FatreluAndMul(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
41
  can_torch_compile: bool = True
42
 
43
  def __init__(self, threshold: float = 0.0):
 
5
 
6
 
7
  class SiluAndMul(nn.Module):
8
+ """An activation function for SwiGLU.
9
+
10
+ The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
11
+
12
+ Shapes:
13
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
14
+ return: (num_tokens, d) or (batch_size, seq_len, d)
15
+ """
16
+
17
  can_torch_compile: bool = True
18
 
19
  def forward(self, x: torch.Tensor):
 
23
  ops.silu_and_mul(out, x)
24
  return out
25
 
26
+ class MulAndSilu(CustomOp):
27
+ """An activation function for SwiGLU.
28
+
29
+ The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
30
+
31
+ Shapes:
32
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
33
+ return: (num_tokens, d) or (batch_size, seq_len, d)
34
+ """
35
+
36
+ can_torch_compile: bool = True
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ d = x.shape[-1] // 2
40
+ output_shape = (x.shape[:-1] + (d, ))
41
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
42
+ self.mul_and_silu(out, x)
43
+ return out
44
 
45
  class GeluAndMul(nn.Module):
46
+ """An activation function for GeGLU.
47
+
48
+ The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
49
+
50
+ Shapes:
51
+ x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
52
+ return: (batch_size, seq_len, d) or (num_tokens, d)
53
+ """
54
+
55
  can_torch_compile: bool = True
56
 
57
  def forward(self, x: torch.Tensor):
 
74
 
75
 
76
  class FatreluAndMul(nn.Module):
77
+ """An activation function for FATReLU.
78
+
79
+ The function computes x -> FATReLU(x[:d]) * x[d:] where
80
+ d = x.shape[-1] // 2.
81
+ This is used in openbmb/MiniCPM-S-1B-sft.
82
+
83
+ Shapes:
84
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
85
+ return: (num_tokens, d) or (batch_size, seq_len, d)
86
+ """
87
+
88
  can_torch_compile: bool = True
89
 
90
  def __init__(self, threshold: float = 0.0):
tests/kernels/test_activation.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import math
2
  import random
3
  from typing import Type
@@ -43,12 +46,19 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
43
  return F.silu(x[..., :d]) * x[..., d:]
44
 
45
 
 
 
 
 
 
46
  def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
47
  d = x.shape[-1] // 2
48
  return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
49
 
50
 
51
- @pytest.mark.parametrize("activation_name", ["silu", "gelu", "gelu_tanh", "fatrelu"])
 
 
52
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
53
  @pytest.mark.parametrize("d", D)
54
  @pytest.mark.parametrize("dtype", DTYPES)
@@ -67,11 +77,16 @@ def test_act_and_mul(
67
  torch.manual_seed(seed)
68
  torch.set_default_device(device)
69
  x = torch.randn(num_tokens, 2 * d, dtype=dtype)
70
- if activation_name == "silu":
71
  torch_fn = silu_and_mul
72
  fn = activation.silu_and_mul
73
  op = activation.ops.silu_and_mul
74
  layer = activation.layers.SiluAndMul()
 
 
 
 
 
75
  elif activation_name == "gelu":
76
  torch_fn = lambda x: gelu_and_mul(x, "none")
77
  fn = activation.gelu_and_mul
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
  import math
5
  import random
6
  from typing import Type
 
46
  return F.silu(x[..., :d]) * x[..., d:]
47
 
48
 
49
+ def mul_and_silu(x: torch.Tensor) -> torch.Tensor:
50
+ d = x.shape[-1] // 2
51
+ return x[..., :d] * F.silu(x[..., d:])
52
+
53
+
54
  def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
55
  d = x.shape[-1] // 2
56
  return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
57
 
58
 
59
+ @pytest.mark.parametrize(
60
+ "activation_name", ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]
61
+ )
62
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
63
  @pytest.mark.parametrize("d", D)
64
  @pytest.mark.parametrize("dtype", DTYPES)
 
77
  torch.manual_seed(seed)
78
  torch.set_default_device(device)
79
  x = torch.randn(num_tokens, 2 * d, dtype=dtype)
80
+ if activation_name == "silu_and_mul":
81
  torch_fn = silu_and_mul
82
  fn = activation.silu_and_mul
83
  op = activation.ops.silu_and_mul
84
  layer = activation.layers.SiluAndMul()
85
+ elif activation_name == "mul_and_silu":
86
+ torch_fn = mul_and_silu
87
+ fn = activation.mul_and_silu
88
+ op = activation.ops.mul_and_silu
89
+ layer = activation.layers.MulAndSilu()
90
  elif activation_name == "gelu":
91
  torch_fn = lambda x: gelu_and_mul(x, "none")
92
  fn = activation.gelu_and_mul
torch-ext/activation/__init__.py CHANGED
@@ -10,6 +10,11 @@ def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
10
  return out
11
 
12
 
 
 
 
 
 
13
  def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
14
  ops.gelu_and_mul(out, x)
15
  return out
 
10
  return out
11
 
12
 
13
+ def mul_and_silu(out: torch.Tensor, x: torch.Tensor) -> None:
14
+ ops.mul_and_silu(out, x)
15
+ return out
16
+
17
+
18
  def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
19
  ops.gelu_and_mul(out, x)
20
  return out
torch-ext/activation/layers.py CHANGED
@@ -5,6 +5,15 @@ from ._ops import ops
5
 
6
 
7
  class SiluAndMul(nn.Module):
 
 
 
 
 
 
 
 
 
8
  can_torch_compile: bool = True
9
 
10
  def forward(self, x: torch.Tensor):
@@ -15,7 +24,36 @@ class SiluAndMul(nn.Module):
15
  return out
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class GeluAndMul(nn.Module):
 
 
 
 
 
 
 
 
 
19
  can_torch_compile: bool = True
20
 
21
  def forward(self, x: torch.Tensor):
@@ -38,6 +76,17 @@ class GeluTanhAndMul(nn.Module):
38
 
39
 
40
  class FatreluAndMul(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
41
  can_torch_compile: bool = True
42
 
43
  def __init__(self, threshold: float = 0.0):
 
5
 
6
 
7
  class SiluAndMul(nn.Module):
8
+ """An activation function for SwiGLU.
9
+
10
+ The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
11
+
12
+ Shapes:
13
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
14
+ return: (num_tokens, d) or (batch_size, seq_len, d)
15
+ """
16
+
17
  can_torch_compile: bool = True
18
 
19
  def forward(self, x: torch.Tensor):
 
24
  return out
25
 
26
 
27
+ class MulAndSilu(nn.Module):
28
+ """An activation function for SwiGLU.
29
+
30
+ The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
31
+
32
+ Shapes:
33
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
34
+ return: (num_tokens, d) or (batch_size, seq_len, d)
35
+ """
36
+
37
+ can_torch_compile: bool = True
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ d = x.shape[-1] // 2
41
+ output_shape = x.shape[:-1] + (d,)
42
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
43
+ ops.mul_and_silu(out, x)
44
+ return out
45
+
46
+
47
  class GeluAndMul(nn.Module):
48
+ """An activation function for GeGLU.
49
+
50
+ The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
51
+
52
+ Shapes:
53
+ x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
54
+ return: (batch_size, seq_len, d) or (num_tokens, d)
55
+ """
56
+
57
  can_torch_compile: bool = True
58
 
59
  def forward(self, x: torch.Tensor):
 
76
 
77
 
78
  class FatreluAndMul(nn.Module):
79
+ """An activation function for FATReLU.
80
+
81
+ The function computes x -> FATReLU(x[:d]) * x[d:] where
82
+ d = x.shape[-1] // 2.
83
+ This is used in openbmb/MiniCPM-S-1B-sft.
84
+
85
+ Shapes:
86
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
87
+ return: (num_tokens, d) or (batch_size, seq_len, d)
88
+ """
89
+
90
  can_torch_compile: bool = True
91
 
92
  def __init__(self, threshold: float = 0.0):
torch-ext/torch_binding.cpp CHANGED
@@ -9,6 +9,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
9
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
10
  ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
11
 
 
 
 
12
  // Activation function used in GeGLU with `none` approximation.
13
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
14
  ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
 
9
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
10
  ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
11
 
12
+ ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
13
+ ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
14
+
15
  // Activation function used in GeGLU with `none` approximation.
16
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
17
  ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
torch-ext/torch_binding.h CHANGED
@@ -4,6 +4,8 @@
4
 
5
  void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
6
 
 
 
7
  void gelu_and_mul(torch::Tensor &out, torch::Tensor &input);
8
 
9
  void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);
 
4
 
5
  void silu_and_mul(torch::Tensor &out, torch::Tensor &input);
6
 
7
+ void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
8
+
9
  void gelu_and_mul(torch::Tensor &out, torch::Tensor &input);
10
 
11
  void gelu_tanh_and_mul(torch::Tensor &out, torch::Tensor &input);