|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
def get_activation_layer(act_type): |
|
if act_type == "gelu": |
|
return lambda: nn.GELU() |
|
elif act_type == "gelu_tanh": |
|
|
|
return lambda: nn.GELU(approximate="tanh") |
|
elif act_type == "relu": |
|
return nn.ReLU |
|
elif act_type == "silu": |
|
return nn.SiLU |
|
else: |
|
raise ValueError(f"Unknown activation type: {act_type}") |
|
|
|
class SwiGLU(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: int, |
|
out_dim: int, |
|
): |
|
""" |
|
Initialize the SwiGLU FeedForward module. |
|
|
|
Args: |
|
dim (int): Input dimension. |
|
hidden_dim (int): Hidden dimension of the feedforward layer. |
|
|
|
Attributes: |
|
w1: Linear transformation for the first layer. |
|
w2: Linear transformation for the second layer. |
|
w3: Linear transformation for the third layer. |
|
|
|
""" |
|
super().__init__() |
|
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
|
self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) |
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
|
def forward(self, x): |
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|