File size: 580 Bytes
56238f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn

class _SwiGLU(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
    ):
        super().__init__()
        self.w12 = nn.Linear(dim, hidden_dim*2, bias=False)
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)
    def forward(self, x):
        x1, x2 = self.w12(x).chunk(2, dim=-1)
        return self.w3(torch.nn.functional.silu(x1)*x2)


# try:
# from xformers.ops import SwiGLU as aa
#     SwiGLU = SwiGLU
#     print("use xformers swiglu")
# except:
#     print("use slow swiglu")

SwiGLU = _SwiGLU