danieldk HF Staff commited on
Commit
9b61b27
·
1 Parent(s): 9c1f92e

Generate README

Browse files
Files changed (1) hide show
  1. README.md +73 -4
README.md CHANGED
@@ -1,11 +1,80 @@
1
  ---
2
  license: bsd-3-clause
3
  tags:
4
- - kernel
5
  ---
 
6
 
7
- ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/triton-layer-norm)
 
8
 
9
- ## triton-layer-norm
10
 
11
- Triton layer norm [from flash-attention](https://github.com/Dao-AILab/flash-attention).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: bsd-3-clause
3
  tags:
4
+ - kernel
5
  ---
6
+ # Triton layer normalization kernels.
7
 
8
+ This kernel implements layers normalization using Triton. This kernel is from
9
+ the [flash-attention](https://github.com/Dao-AILab/flash-attention) project.
10
 
11
+ ## Functions
12
 
13
+ ### Function `layer_norm`
14
+
15
+ `(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, residual: Optional[torch.Tensor] = None, x1: Optional[torch.Tensor] = None, weight1: Optional[torch.Tensor] = None, bias1: Optional[torch.Tensor] = None, eps: float = 1e-06, dropout_p: float = 0.0, rowscale=None, prenorm: bool = False, residual_in_fp32: bool = False, is_rms_norm: bool = False, return_dropout_mask: bool = False, out: Optional[torch.Tensor] = None, residual_out: Optional[torch.Tensor] = None)`
16
+
17
+ Apply layer normalization to the input tensor with Triton acceleration.
18
+
19
+ ### Parameters
20
+
21
+ - **x** (*torch.Tensor*) --
22
+ Input tensor to normalize.
23
+ - **weight** (*torch.Tensor*) --
24
+ Scale parameter for normalization.
25
+ - **bias** (*torch.Tensor*) --
26
+ Shift parameter for normalization.
27
+ - **residual** (*torch.Tensor*, *optional*) --
28
+ Optional residual tensor to add to the input before normalization.
29
+ - **x1** (*torch.Tensor*, *optional*) --
30
+ Optional second input tensor to combine with *x*. When provided, the function
31
+ first adds *x1* to *x* and then applies normalization.
32
+ - **weight1** (*torch.Tensor*, *optional*) --
33
+ Scale parameter for the second normalization.
34
+ - **bias1** (*torch.Tensor*, *optional*) --
35
+ Shift parameter for the second normalization.
36
+ - **eps** (*float*, *optional*, defaults to 1e-6) --
37
+ Small constant added for numerical stability in normalization.
38
+ - **dropout_p** (*float*, *optional*, defaults to 0.0) --
39
+ Dropout probability. If greater than 0, applies dropout to the input before
40
+ normalization and residual addition.
41
+ - **rowscale** (*torch.Tensor*, *optional*) --
42
+ Optional scaling factor applied to each row of the input tensor.
43
+ Not compatible with the use of *x1*.
44
+ - **prenorm** (*bool*, *optional*, defaults to False) --
45
+ If True, returns both the normalized output and the unnormalized input+residual.
46
+ - **residual_in_fp32** (*bool*, *optional*, defaults to False) --
47
+ If True, performs the residual connection in FP32 precision.
48
+ - **is_rms_norm** (*bool*, *optional*, defaults to False) --
49
+ If True, uses RMS normalization instead of layer normalization.
50
+ - **return_dropout_mask** (*bool*, *optional*, defaults to False) --
51
+ If True, returns the dropout mask used for the computation.
52
+ - **out** (*torch.Tensor*, *optional*) --
53
+ Output tensor for the normalized result. If *None*, a new tensor is allocated.
54
+ - **residual_out** (*torch.Tensor*, *optional*) --
55
+ Output tensor for the residual result when using prenorm. If *None*, a new tensor
56
+ is allocated when needed.
57
+
58
+ ### Returns
59
+
60
+ **Type**: *torch.Tensor* or tuple of *torch.Tensor*
61
+
62
+ - The normalized input.
63
+ - The second normalization of the input if *weight1* is provided.
64
+ - The residual tensor if *prenorm* is set.
65
+ - The dropout mask if *return_dropout_mask* is set.
66
+ - The dropout mask for *x1* if *x1* is provided and *return_dropout_mask* is set.
67
+
68
+ ## Layers
69
+
70
+ ### Class `LlamaRMSNorm`
71
+
72
+ No documentation available.
73
+
74
+ #### Methods
75
+
76
+ ##### Method `forward`
77
+
78
+ `(self, hidden_states: torch.Tensor) -> torch.Tensor`
79
+
80
+ No documentation available.