Spaces:
Running
on
Zero
Running
on
Zero
File size: 580 Bytes
56238f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn as nn
class _SwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
self.w12 = nn.Linear(dim, hidden_dim*2, bias=False)
self.w3 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x1, x2 = self.w12(x).chunk(2, dim=-1)
return self.w3(torch.nn.functional.silu(x1)*x2)
# try:
# from xformers.ops import SwiGLU as aa
# SwiGLU = SwiGLU
# print("use xformers swiglu")
# except:
# print("use slow swiglu")
SwiGLU = _SwiGLU |