KevinNg99's picture
Initial commit.
43c5292
from typing import Callable
import torch
import torch.nn as nn
class ModulateDiT(nn.Module):
"""Modulation layer for DiT."""
def __init__(
self,
hidden_size: int,
factor: int,
act_layer: Callable,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
# Zero-initialize the modulation
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.act(x))
def modulate(x, shift=None, scale=None):
"""modulate by shift and scale
Args:
x (torch.Tensor): input tensor.
shift (torch.Tensor, optional): shift tensor. Defaults to None.
scale (torch.Tensor, optional): scale tensor. Defaults to None.
Returns:
torch.Tensor: the output tensor after modulate.
"""
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def ckpt_wrapper(module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")