Diffsplat / src /models /elevest.py
paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
from typing import *
from torch import Tensor
import torch
from torch import nn
import torch.nn.functional as tF
from einops import rearrange
from src.options import Options
from src.utils import IMAGENET_MEAN, IMAGENET_STD
class ElevEst(nn.Module):
def __init__(self, opt: Options):
super().__init__()
self.opt = opt
self.backbone: nn.Module = torch.hub.load("facebookresearch/dinov2", opt.elevest_backbone_name)
if opt.freeze_backbone:
self.backbone.requires_grad_(False)
else:
self.backbone.mask_token.requires_grad_(False) # not used
self.dim = dim = {
"dinov2_vits14_reg": 384,
"dinov2_vitb14_reg": 768,
"dinov2_vitl14_reg": 1024,
}[opt.elevest_backbone_name]
self.cls_head = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, opt.elevest_num_classes),
)
self.offset_head = nn.Sequential(
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, 1),
)
self.interval = (opt.ele_max - opt.ele_min) / opt.elevest_num_classes
self.register_buffer("lower_bounds", torch.linspace(opt.ele_min, opt.ele_max, opt.elevest_num_classes+1)[:-1])
def state_dict(self, **kwargs):
# Remove frozen parameters without gradients
state_dict = super().state_dict(**kwargs)
if self.opt.freeze_backbone:
for k in list(state_dict.keys()):
if "backbone" in k:
del state_dict[k]
return state_dict
def forward(self, *args, func_name="compute_loss", **kwargs):
# To support different forward functions for models wrapped by `accelerate`
return getattr(self, func_name)(*args, **kwargs)
def compute_loss(self, data: Dict[str, Tensor], dtype: torch.dtype = torch.float32):
outputs = {}
input_images = data["image"].to(dtype) # (B, V, 3, H, W)
gt_elev = data["cam_pose"].to(dtype)[:, :, 0].rad2deg() # (B, V)
input_images = rearrange(input_images, "b v c h w -> (b v) c h w")
gt_elev = rearrange(gt_elev, "b v -> (b v)") # (B*V,)
gt_class = torch.floor((gt_elev - self.opt.ele_min) / self.interval).long()
gt_offset = gt_elev - self.lower_bounds[gt_class]
assert torch.all((gt_class >= 0) & (gt_class < self.opt.elevest_num_classes))
assert torch.all((gt_offset + 1e-8 >= 0) & (gt_offset - 1e-8 < self.interval))
# ImageNet normalization
mean = torch.tensor(IMAGENET_MEAN, device=input_images.device, dtype=dtype).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD, device=input_images.device, dtype=dtype).view(3, 1, 1)
input_images = (input_images - mean) / std
# Predict
features = self.backbone(input_images.to(dtype=dtype), is_training=True)
cls_token = features["x_norm_clstoken"] # (B*V, D)
logits = self.cls_head(cls_token) # (B*V, C)
pred_offset = self.offset_head(cls_token).squeeze(-1).clamp(0., self.interval) # (B*V,)
# Loss
outputs["loss_cls"] = tF.cross_entropy(logits, gt_class)
outputs["loss_offset"] = tF.mse_loss(pred_offset, gt_offset)
outputs["loss"] = outputs["loss_cls"] + self.opt.elevest_reg_weight * outputs["loss_offset"]
with torch.no_grad():
pred_elev = self.lower_bounds[torch.argmax(logits, dim=-1)] + pred_offset # (B*V,)
outputs["err_degree"] = torch.mean(torch.abs(pred_elev - gt_elev))
return outputs
@torch.no_grad()
def predict_elev(self, input_images: Tensor, dtype: torch.dtype = torch.float32):
# Input image preprocessing
input_images = tF.interpolate(input_images, size=(224, 224), mode="bilinear", align_corners=False, antialias=True)
input_images = input_images.to(device=self.lower_bounds.device, dtype=dtype)
mean = torch.tensor(IMAGENET_MEAN, device=input_images.device, dtype=dtype).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD, device=input_images.device, dtype=dtype).view(3, 1, 1)
input_images = (input_images - mean) / std
features = self.backbone(input_images, is_training=True)
cls_token = features["x_norm_clstoken"] # (B, D)
logits = self.cls_head(cls_token) # (B, C)
pred_offset = self.offset_head(cls_token).squeeze(-1).clamp(0., self.interval) # (B,)
pred_elev = self.lower_bounds[torch.argmax(logits, dim=-1)] + pred_offset # (B,)
return pred_elev