|
import os |
|
import math |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from vit_pytorch import ViT |
|
from vit_pytorch.vit import Transformer |
|
from einops import rearrange, repeat |
|
|
|
|
|
class TriplaneDecoder(nn.Module): |
|
def __init__(self, token_num, in_dim, depth, heads, mlp_dim, out_channel, out_reso, dim_head = 64, dropout=0): |
|
super().__init__() |
|
self.token_num = token_num |
|
self.out_reso = out_reso |
|
self.out_channel = out_channel |
|
|
|
self.input_net = nn.Linear(in_dim, mlp_dim) |
|
self.pos_embedding = nn.Parameter(torch.randn(1, token_num, mlp_dim)) |
|
self.dropout = nn.Dropout(dropout) |
|
self.transformer = Transformer( |
|
mlp_dim, depth, heads, dim_head, mlp_dim, dropout |
|
) |
|
|
|
assert int(token_num ** 0.5) ** 2 == token_num |
|
self.H = int(token_num ** 0.5) |
|
self.out_patch_size = out_reso // int(token_num ** 0.5) |
|
self.out_patch_dim = (self.out_patch_size ** 2) * out_channel |
|
self.output_net = nn.Sequential( |
|
nn.LayerNorm(mlp_dim), |
|
nn.Linear(mlp_dim, self.out_patch_dim), |
|
nn.LayerNorm(self.out_patch_dim), |
|
nn.Linear(self.out_patch_dim, self.out_patch_dim), |
|
) |
|
|
|
def forward(self, x): |
|
b, n, _ = x.shape |
|
assert n == self.token_num |
|
x = self.input_net(x) |
|
x += self.pos_embedding |
|
x = self.dropout(x) |
|
|
|
x = self.transformer(x) |
|
x = self.output_net(x) |
|
x = x.reshape(b, self.H, self.H, self.out_patch_size, self.out_patch_size, self.out_channel) |
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
x = x.reshape(b, 3, self.out_channel//3, self.out_reso, self.out_reso).contiguous() |
|
return x |
|
|
|
|
|
class SingleImageToTriplaneVAE(nn.Module): |
|
def __init__(self, backbone='dino_vits8', input_reso=256, out_reso=128, out_channel=18, z_dim=32, |
|
decoder_depth=16, decoder_heads=16, decoder_mlp_dim=1024, decoder_dim_head=64, dropout=0): |
|
super().__init__() |
|
self.backbone = backbone |
|
|
|
self.input_image_size = input_reso |
|
self.out_reso = out_reso |
|
self.out_channel = out_channel |
|
self.z_dim = z_dim |
|
|
|
self.decoder_depth = decoder_depth |
|
self.decoder_heads = decoder_heads |
|
self.decoder_mlp_dim = decoder_mlp_dim |
|
self.decoder_dim_head = decoder_dim_head |
|
|
|
self.dropout = dropout |
|
self.patch_size = 8 if '8' in backbone else 16 |
|
|
|
if 'dino' in backbone: |
|
self.vit = torch.hub.load('facebookresearch/dino:main', backbone) |
|
self.embed_dim = self.vit.embed_dim |
|
self.preprocess = None |
|
else: |
|
raise NotImplementedError |
|
|
|
self.fc_mu = nn.Linear(self.embed_dim, self.z_dim) |
|
self.fc_var = nn.Linear(self.embed_dim, self.z_dim) |
|
|
|
self.vit_decoder = TriplaneDecoder((self.input_image_size // self.patch_size) ** 2, self.z_dim, |
|
depth=self.decoder_depth, heads=self.decoder_heads, mlp_dim=self.decoder_mlp_dim, |
|
out_channel=self.out_channel, out_reso=self.out_reso, dim_head = self.decoder_dim_head, dropout=0) |
|
|
|
def forward(self, x, is_train): |
|
assert x.shape[-1] == self.input_image_size |
|
bs = x.shape[0] |
|
if 'dino' in self.backbone: |
|
z = self.vit.get_intermediate_layers(x, n=1)[0][:, 1:] |
|
else: |
|
raise NotImplementedError |
|
|
|
z = z.reshape(-1, z.shape[-1]) |
|
mu = self.fc_mu(z) |
|
logvar = self.fc_var(z) |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
if is_train: |
|
rep_z = eps * std + mu |
|
else: |
|
rep_z = eps |
|
rep_z = rep_z.reshape(bs, -1, self.z_dim) |
|
out = self.vit_decoder(rep_z) |
|
|
|
return out, mu, logvar |
|
|