import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from einops import rearrange, repeat from einops.layers.torch import Rearrange class SineLayer(nn.Module): """ Paper: Implicit Neural Representation with Periodic Activ ation Function (SIREN) """ def __init__( self, in_features, out_features, bias=True, is_first=False, omega_0=30 ): super().__init__() self.omega_0 = omega_0 self.is_first = is_first self.in_features = in_features self.linear = nn.Linear(in_features, out_features, bias=bias) self.init_weights() def init_weights(self): with torch.no_grad(): if self.is_first: self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) else: self.linear.weight.uniform_( -np.sqrt(6 / self.in_features) / self.omega_0, np.sqrt(6 / self.in_features) / self.omega_0, ) def forward(self, input): return torch.sin(self.omega_0 * self.linear(input)) class ToPixel(nn.Module): def __init__( self, to_pixel="linear", img_size=256, in_channels=3, in_dim=512, patch_size=16 ) -> None: super().__init__() self.to_pixel_name = to_pixel self.patch_size = patch_size self.num_patches = (img_size // patch_size) ** 2 self.in_channels = in_channels if to_pixel == "linear": self.model = nn.Linear(in_dim, in_channels * patch_size * patch_size) elif to_pixel == "conv": self.model = nn.Sequential( Rearrange("b (h w) c -> b c h w", h=img_size // patch_size), nn.ConvTranspose2d( in_dim, in_channels, kernel_size=patch_size, stride=patch_size ), ) elif to_pixel == "siren": self.model = nn.Sequential( SineLayer(in_dim, in_dim * 2, is_first=True, omega_0=30.0), SineLayer( in_dim * 2, img_size // patch_size * patch_size * in_channels, is_first=False, omega_0=30, ), ) elif to_pixel == "identity": self.model = nn.Identity() else: raise NotImplementedError def get_last_layer(self): if self.to_pixel_name == "linear": return self.model.weight elif self.to_pixel_name == "siren": return self.model[1].linear.weight elif self.to_pixel_name == "conv": return self.model[1].weight else: return None def unpatchify(self, x): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ p = self.patch_size h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum("nhwpqc->nchpwq", x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def forward(self, x): if self.to_pixel_name == "linear": x = self.model(x) x = self.unpatchify(x) elif self.to_pixel_name == "siren": x = self.model(x) x = x.view( x.shape[0], self.in_channels, self.patch_size * int(self.num_patches**0.5), self.patch_size * int(self.num_patches**0.5), ) elif self.to_pixel_name == "conv": x = self.model(x) elif self.to_pixel_name == "identity": pass return x