danieldk HF Staff commited on
Commit
9c1f92e
·
1 Parent(s): ba5e096

Update build

Browse files
build/torch-universal/triton_layer_norm/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """Triton layer normalization kernels.
2
 
3
  This kernel implements layers normalization using Triton. This kernel is from
4
  the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
 
1
+ """Triton layer normalization kernels
2
 
3
  This kernel implements layers normalization using Triton. This kernel is from
4
  the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
build/torch-universal/triton_layer_norm/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._triton_layer_norm_4dc3a9b_dirty
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_triton_layer_norm_4dc3a9b_dirty::{op_name}"
build/torch-universal/triton_layer_norm/layers.py CHANGED
@@ -5,10 +5,32 @@ from .layer_norm import rms_norm_fn
5
 
6
 
7
  class LlamaRMSNorm(nn.Module):
 
 
 
 
 
 
 
 
 
 
8
  weight: torch.Tensor
9
  variance_epsilon: float
10
 
11
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
12
  return rms_norm_fn(
13
  hidden_states,
14
  self.weight,
 
5
 
6
 
7
  class LlamaRMSNorm(nn.Module):
8
+ """
9
+ RMS Layer Norm for Llama models.
10
+
11
+ Triton-optimized RMS layer norm. The interface is compatible with `LLamaRMSNorm` in
12
+ `transformers`.
13
+
14
+ Attributes:
15
+ weight (`torch.Tensor`): The learnable scaling parameter.
16
+ variance_epsilon (`float`): The epsilon value for numerical stability.
17
+ """
18
  weight: torch.Tensor
19
  variance_epsilon: float
20
 
21
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
22
+ """
23
+ Apply RMS normalization to the input hidden states.
24
+
25
+ Args:
26
+ hidden_states (`torch.Tensor`):
27
+ Input tensor of shape `(batch_size, sequence_length, hidden_size)` or any shape
28
+ where the last dimension is the feature dimension to be normalized.
29
+
30
+ Returns:
31
+ `torch.Tensor`:
32
+ The normalized tensor with the same shape as the input `hidden_states`.
33
+ """
34
  return rms_norm_fn(
35
  hidden_states,
36
  self.weight,