|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Modulation(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
condition_dim: int,
|
|
zero_init: bool = False,
|
|
single_layer: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.silu = nn.SiLU()
|
|
if single_layer:
|
|
self.linear1 = nn.Identity()
|
|
else:
|
|
self.linear1 = nn.Linear(condition_dim, condition_dim)
|
|
|
|
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
|
|
|
|
|
|
if zero_init:
|
|
nn.init.zeros_(self.linear2.weight)
|
|
nn.init.zeros_(self.linear2.bias)
|
|
|
|
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
|
|
emb = self.linear2(self.silu(self.linear1(condition)))
|
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
|
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
return x
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
r"""
|
|
A feed-forward layer.
|
|
|
|
Parameters:
|
|
dim (`int`): The number of channels in the input.
|
|
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
|
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
|
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
dim_out: Optional[int] = None,
|
|
mult: int = 4,
|
|
dropout: float = 0.0,
|
|
activation_fn: str = "geglu",
|
|
final_dropout: bool = False,
|
|
):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
dim_out = dim_out if dim_out is not None else dim
|
|
linear_cls = nn.Linear
|
|
|
|
if activation_fn == "gelu":
|
|
act_fn = GELU(dim, inner_dim)
|
|
if activation_fn == "gelu-approximate":
|
|
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
|
elif activation_fn == "geglu":
|
|
act_fn = GEGLU(dim, inner_dim)
|
|
elif activation_fn == "geglu-approximate":
|
|
act_fn = ApproximateGELU(dim, inner_dim)
|
|
|
|
self.net = nn.ModuleList([])
|
|
|
|
self.net.append(act_fn)
|
|
|
|
self.net.append(nn.Dropout(dropout))
|
|
|
|
self.net.append(linear_cls(inner_dim, dim_out))
|
|
|
|
if final_dropout:
|
|
self.net.append(nn.Dropout(dropout))
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
for module in self.net:
|
|
hidden_states = module(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
query_dim: int,
|
|
heads: int = 8,
|
|
dim_head: int = 64,
|
|
dropout: float = 0.0,
|
|
bias: bool = False,
|
|
out_bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.inner_dim = dim_head * heads
|
|
self.num_heads = heads
|
|
self.scale = dim_head**-0.5
|
|
self.dropout = dropout
|
|
|
|
|
|
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
|
|
|
|
|
self.to_out = nn.ModuleList(
|
|
[
|
|
nn.Linear(self.inner_dim, query_dim, bias=out_bias),
|
|
nn.Dropout(dropout),
|
|
]
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
|
|
query = self.to_q(hidden_states)
|
|
key = self.to_k(hidden_states)
|
|
value = self.to_v(hidden_states)
|
|
|
|
|
|
query = query.reshape(
|
|
batch_size, sequence_length, self.num_heads, -1
|
|
).transpose(1, 2)
|
|
key = key.reshape(batch_size, sequence_length, self.num_heads, -1).transpose(
|
|
1, 2
|
|
)
|
|
value = value.reshape(
|
|
batch_size, sequence_length, self.num_heads, -1
|
|
).transpose(1, 2)
|
|
|
|
|
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask=attention_mask,
|
|
scale=self.scale,
|
|
)
|
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
batch_size, sequence_length, self.inner_dim
|
|
)
|
|
|
|
|
|
for module in self.to_out:
|
|
hidden_states = module(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_attention_heads: int,
|
|
attention_head_dim: int,
|
|
activation_fn: str = "geglu",
|
|
attention_bias: bool = False,
|
|
norm_elementwise_affine: bool = True,
|
|
norm_eps: float = 1e-5,
|
|
):
|
|
super().__init__()
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(
|
|
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
|
)
|
|
self.attn1 = Attention(
|
|
query_dim=dim,
|
|
heads=num_attention_heads,
|
|
dim_head=attention_head_dim,
|
|
bias=attention_bias,
|
|
)
|
|
|
|
|
|
self.norm3 = nn.LayerNorm(
|
|
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
|
)
|
|
self.ff = FeedForward(
|
|
dim,
|
|
activation_fn=activation_fn,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
) -> torch.FloatTensor:
|
|
|
|
norm_hidden_states = self.norm1(hidden_states)
|
|
|
|
hidden_states = (
|
|
self.attn1(
|
|
norm_hidden_states,
|
|
attention_mask=attention_mask,
|
|
)
|
|
+ hidden_states
|
|
)
|
|
|
|
|
|
ff_output = self.ff(self.norm3(hidden_states))
|
|
|
|
hidden_states = ff_output + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class GELU(nn.Module):
|
|
r"""
|
|
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
|
|
|
Parameters:
|
|
dim_in (`int`): The number of channels in the input.
|
|
dim_out (`int`): The number of channels in the output.
|
|
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
|
"""
|
|
|
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim_in, dim_out)
|
|
self.approximate = approximate
|
|
|
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
|
if gate.device.type != "mps":
|
|
return F.gelu(gate, approximate=self.approximate)
|
|
|
|
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
|
|
dtype=gate.dtype
|
|
)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.proj(hidden_states)
|
|
hidden_states = self.gelu(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GEGLU(nn.Module):
|
|
r"""
|
|
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
|
|
|
Parameters:
|
|
dim_in (`int`): The number of channels in the input.
|
|
dim_out (`int`): The number of channels in the output.
|
|
"""
|
|
|
|
def __init__(self, dim_in: int, dim_out: int):
|
|
super().__init__()
|
|
linear_cls = nn.Linear
|
|
|
|
self.proj = linear_cls(dim_in, dim_out * 2)
|
|
|
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
|
if gate.device.type != "mps":
|
|
return F.gelu(gate)
|
|
|
|
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
|
|
|
def forward(self, hidden_states, scale: float = 1.0):
|
|
args = ()
|
|
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
|
return hidden_states * self.gelu(gate)
|
|
|
|
|
|
class ApproximateGELU(nn.Module):
|
|
r"""
|
|
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
|
|
https://arxiv.org/abs/1606.08415.
|
|
|
|
Parameters:
|
|
dim_in (`int`): The number of channels in the input.
|
|
dim_out (`int`): The number of channels in the output.
|
|
"""
|
|
|
|
def __init__(self, dim_in: int, dim_out: int):
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim_in, dim_out)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.proj(x)
|
|
return x * torch.sigmoid(1.702 * x)
|
|
|