Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
def modulate(x, shift, scale): | |
return x * (1 + scale) + shift | |
class FinalLayer(nn.Module): | |
def __init__(self, hidden_size, out_channels): | |
super().__init__() | |
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.linear = nn.Linear(hidden_size, out_channels, bias=True) | |
self.adaLN_modulation = nn.Sequential( | |
nn.Linear(hidden_size, 2*hidden_size, bias=True) | |
) | |
def forward(self, x, c): | |
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) | |
x = modulate(self.norm_final(x), shift, scale) | |
x = self.linear(x) | |
return x |