Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from torch import Tensor | |
from lpips import LPIPS | |
from skimage.metrics import structural_similarity as calculate_ssim | |
import numpy as np | |
from torch import nn | |
import torch.nn.functional as tF | |
from einops import rearrange | |
from src.models.networks.attention import * | |
from src.models.gs_render import GaussianRenderer | |
from src.options import Options | |
from src.utils import plucker_ray, patchify, unpatchify | |
class GSRecon(nn.Module): | |
def __init__(self, opt: Options): | |
super().__init__() | |
self.opt = opt | |
# Image tokenizer | |
in_channels = 3 + 6 # RGB + plucker | |
if opt.input_normal: | |
in_channels += 3 | |
if opt.input_coord: | |
in_channels += 3 | |
if opt.input_mr: | |
in_channels += 2 | |
self.x_embedder = nn.Linear(in_channels * (opt.patch_size**2), opt.dim) | |
# Transformer backbone | |
self.transformer = Transformer(opt.num_blocks, opt.dim, opt.num_heads, llama_style=opt.llama_style) | |
self.ln_out = nn.LayerNorm(opt.dim) | |
if opt.grad_checkpoint: | |
self.transformer.set_grad_checkpointing() | |
# Output heads | |
self.inter_res = opt.input_res // opt.patch_size | |
self.out_depth = nn.Linear(opt.dim, 1 * (opt.patch_size**2), bias=False) | |
self.out_rgb = nn.Linear(opt.dim, 3 * (opt.patch_size**2), bias=False) | |
self.out_scale = nn.Linear(opt.dim, 3 * (opt.patch_size**2), bias=False) | |
self.out_rotation = nn.Linear(opt.dim, 4 * (opt.patch_size**2), bias=False) | |
self.out_opacity = nn.Linear(opt.dim, 1 * (opt.patch_size**2), bias=False) | |
# Rendering | |
self.gs_renderer = GaussianRenderer(opt) | |
# Initialize weights | |
nn.init.xavier_uniform_(self.x_embedder.weight) | |
nn.init.zeros_(self.x_embedder.bias) | |
nn.init.zeros_(self.out_depth.weight) # zero init. | |
nn.init.xavier_uniform_(self.out_rgb.weight) | |
nn.init.zeros_(self.out_scale.weight) # zero init. | |
nn.init.xavier_uniform_(self.out_rotation.weight) | |
nn.init.zeros_(self.out_opacity.weight) # zero init. | |
def forward(self, *args, func_name="compute_loss", **kwargs): | |
# To support different forward functions for models wrapped by `accelerate` | |
return getattr(self, func_name)(*args, **kwargs) | |
def compute_loss(self, data: Dict[str, Tensor], lpips_loss: LPIPS, step: int, dtype: torch.dtype = torch.float32): | |
outputs = {} | |
color_name = "albedo" if self.opt.input_albedo else "image" | |
images = data[color_name].to(dtype) # (B, V, 3, H, W) | |
masks = data["mask"].to(dtype) # (B, V, 1, H, W) | |
C2W = data["C2W"].to(dtype) # (B, V, 4, 4) | |
fxfycxcy = data["fxfycxcy"].to(dtype) # (B, V, 4) | |
# Input views | |
V_in = self.opt.num_input_views | |
input_images = images[:, :V_in, ...] | |
input_C2W = C2W[:, :V_in, ...] | |
input_fxfycxcy = fxfycxcy[:, :V_in, ...] | |
if self.opt.input_normal: | |
input_images = torch.cat([input_images, data["normal"][:, :V_in, ...]], dim=2) | |
if self.opt.input_coord: | |
input_images = torch.cat([input_images, data["coord"][:, :V_in, ...]], dim=2) | |
if self.opt.input_mr: | |
input_images = torch.cat([input_images, data["mr"][:, :V_in, :2]], dim=2) | |
model_outputs = self.forward_gaussians(input_images, input_C2W, input_fxfycxcy) | |
render_outputs = self.gs_renderer.render(model_outputs, input_C2W, input_fxfycxcy, C2W, fxfycxcy) | |
for k in render_outputs.keys(): | |
render_outputs[k] = render_outputs[k].to(dtype) | |
render_images = render_outputs["image"] # (B, V, 3, H, W) | |
render_masks = render_outputs["alpha"] # (B, V, 1, H, W) | |
render_coords = render_outputs["coord"] # (B, V, 3, H, W) | |
render_normals = render_outputs["normal"] # (B, V, 3, H, W) | |
# For visualization | |
outputs["images_render"] = render_images | |
outputs["images_gt"] = images | |
if self.opt.vis_coords: | |
outputs["images_coord"] = render_coords | |
if self.opt.load_coord: | |
outputs["images_gt_coord"] = data["coord"] | |
if self.opt.vis_normals: | |
outputs["images_normal"] = render_normals | |
if self.opt.load_normal: | |
outputs["images_gt_normal"] = data["normal"] | |
# if self.opt.input_mr: | |
# outputs["images_mr"] = data["mr"] | |
################################ Compute reconstruction losses/metrics ################################ | |
outputs["image_mse"] = image_mse = tF.mse_loss(images, render_images) | |
outputs["mask_mse"] = mask_mse = tF.mse_loss(masks, render_masks) | |
loss = image_mse + mask_mse | |
# Coord & Normal | |
if self.opt.coord_weight > 0: | |
assert self.opt.load_coord | |
outputs["coord_mse"] = coord_mse = tF.mse_loss(data["coord"], render_coords) | |
loss += self.opt.coord_weight * coord_mse | |
if self.opt.normal_weight > 0: | |
assert self.opt.load_normal | |
outputs["normal_cosim"] = normal_cosim = tF.cosine_similarity(data["normal"], render_normals, dim=2).mean() | |
loss += self.opt.normal_weight * (1. - normal_cosim) | |
# LPIPS | |
if step < self.opt.lpips_warmup_start: | |
lpips_weight = 0. | |
elif step > self.opt.lpips_warmup_end: | |
lpips_weight = self.opt.lpips_weight | |
else: | |
lpips_weight = self.opt.lpips_weight * (step - self.opt.lpips_warmup_start) / ( | |
self.opt.lpips_warmup_end - self.opt.lpips_warmup_start) | |
if lpips_weight > 0.: | |
outputs["lpips"] = lpips = lpips_loss( | |
# Downsampled to at most 256 to reduce memory cost | |
tF.interpolate( | |
rearrange(images, "b v c h w -> (b v) c h w") * 2. - 1., | |
(self.opt.lpips_resize, self.opt.lpips_resize), mode="bilinear", align_corners=False | |
) if self.opt.lpips_resize > 0 else rearrange(images, "b v c h w -> (b v) c h w") * 2. - 1., | |
tF.interpolate( | |
rearrange(render_images, "b v c h w -> (b v) c h w") * 2. - 1., | |
(self.opt.lpips_resize, self.opt.lpips_resize), mode="bilinear", align_corners=False | |
) if self.opt.lpips_resize > 0 else rearrange(render_images, "b v c h w -> (b v) c h w") * 2. - 1., | |
).mean() | |
loss += lpips_weight * lpips | |
outputs["loss"] = loss | |
# Metric: PSNR, SSIM and LPIPS | |
with torch.no_grad(): | |
outputs["psnr"] = -10 * torch.log10(torch.mean((images - render_images.detach()) ** 2)) | |
outputs["ssim"] = torch.tensor(calculate_ssim( | |
(rearrange(images, "b v c h w -> (b v c) h w") | |
.cpu().float().numpy() * 255.).astype(np.uint8), | |
(rearrange(render_images.detach(), "b v c h w -> (b v c) h w") | |
.cpu().float().numpy() * 255.).astype(np.uint8), | |
channel_axis=0, | |
), device=images.device) | |
if lpips_weight <= 0.: | |
outputs["lpips"] = lpips = lpips_loss( | |
# Downsampled to at most 256 to reduce memory cost | |
tF.interpolate( | |
rearrange(images, "b v c h w -> (b v) c h w") * 2. - 1., | |
(self.opt.lpips_resize, self.opt.lpips_resize), mode="bilinear", align_corners=False | |
) if self.opt.lpips_resize > 0 else rearrange(images, "b v c h w -> (b v) c h w") * 2. - 1., | |
tF.interpolate( | |
rearrange(render_images.detach(), "b v c h w -> (b v) c h w") * 2. - 1., | |
(256, 256), mode="bilinear", align_corners=False | |
) if self.opt.lpips_resize > 0 else rearrange(render_images.detach(), "b v c h w -> (b v) c h w") * 2. - 1., | |
).mean() | |
return outputs | |
def forward_gaussians(self, input_images: Tensor, input_C2W: Tensor, input_fxfycxcy: Tensor): | |
""" | |
Inputs: | |
- `input_images`: (B, V_in, C, H, W) | |
- `input_C2W`: (B, V_in, 4, 4) | |
- `input_fxycxcy`: (B, V_in, 4) | |
""" | |
_, V_in, _, H, W = input_images.shape | |
plucker, _ = plucker_ray(H, W, input_C2W, input_fxfycxcy) # (B, V_in, 6, H, W) | |
images_plucker = torch.cat([input_images * 2. - 1., plucker], dim=2) | |
images_plucker = rearrange(images_plucker, "b v c h w -> (b v) c h w") | |
x = patchify(images_plucker, self.opt.patch_size) # (B*V_in, N, C) | |
x = rearrange(x, "(b v) n c -> b v n c", v=V_in) | |
x = self.x_embedder(x) # (B, V_in, N, D) | |
x = rearrange(x, "b v n d -> b (v n) d") | |
x = self.transformer(x) | |
x = self.ln_out(x) | |
def _reshape_feature(features: Tensor): | |
features = rearrange(features, "b (v h w) d -> (b v) (h w) d", v=V_in, h=self.inter_res) | |
features = unpatchify(features, self.opt.patch_size, int(features.shape[1]**0.5)) | |
features = rearrange(features, "(b v) c h w -> b v c h w", v=V_in) # (B, V_in, `dim`, H, W) | |
return features | |
depth = _reshape_feature(self.out_depth(x)) | |
rgb = _reshape_feature(self.out_rgb(x)) | |
scale = _reshape_feature(self.out_scale(x)) | |
rotation = _reshape_feature(self.out_rotation(x)) | |
opacity = _reshape_feature(self.out_opacity(x)) | |
depth = torch.sigmoid(depth) * 2. - 1. # [0, 1] -> [-1, 1] | |
rgb = torch.sigmoid(rgb) * 2. - 1. # [0, 1] -> [-1, 1] | |
scale = torch.sigmoid(scale) * 2. - 1. # [0, 1] -> [-1, 1] | |
rotation = tF.normalize(rotation, p=2, dim=2) # L2 normalize [-1, 1] | |
opacity = torch.sigmoid(opacity - 2.) * 2. - 1. # [0, 1] -> [-1, 1]; `-2.` cf. GS-LRM Appendix A.4 | |
return { | |
"depth": depth, | |
"rgb": rgb, | |
"scale": scale, | |
"rotation": rotation, | |
"opacity": opacity, | |
} | |