from typing import * from torch import Tensor import torch from torch import nn import torch.nn.functional as tF from einops import rearrange from src.options import Options from src.utils import IMAGENET_MEAN, IMAGENET_STD class ElevEst(nn.Module): def __init__(self, opt: Options): super().__init__() self.opt = opt self.backbone: nn.Module = torch.hub.load("facebookresearch/dinov2", opt.elevest_backbone_name) if opt.freeze_backbone: self.backbone.requires_grad_(False) else: self.backbone.mask_token.requires_grad_(False) # not used self.dim = dim = { "dinov2_vits14_reg": 384, "dinov2_vitb14_reg": 768, "dinov2_vitl14_reg": 1024, }[opt.elevest_backbone_name] self.cls_head = nn.Sequential( nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, opt.elevest_num_classes), ) self.offset_head = nn.Sequential( nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, 1), ) self.interval = (opt.ele_max - opt.ele_min) / opt.elevest_num_classes self.register_buffer("lower_bounds", torch.linspace(opt.ele_min, opt.ele_max, opt.elevest_num_classes+1)[:-1]) def state_dict(self, **kwargs): # Remove frozen parameters without gradients state_dict = super().state_dict(**kwargs) if self.opt.freeze_backbone: for k in list(state_dict.keys()): if "backbone" in k: del state_dict[k] return state_dict def forward(self, *args, func_name="compute_loss", **kwargs): # To support different forward functions for models wrapped by `accelerate` return getattr(self, func_name)(*args, **kwargs) def compute_loss(self, data: Dict[str, Tensor], dtype: torch.dtype = torch.float32): outputs = {} input_images = data["image"].to(dtype) # (B, V, 3, H, W) gt_elev = data["cam_pose"].to(dtype)[:, :, 0].rad2deg() # (B, V) input_images = rearrange(input_images, "b v c h w -> (b v) c h w") gt_elev = rearrange(gt_elev, "b v -> (b v)") # (B*V,) gt_class = torch.floor((gt_elev - self.opt.ele_min) / self.interval).long() gt_offset = gt_elev - self.lower_bounds[gt_class] assert torch.all((gt_class >= 0) & (gt_class < self.opt.elevest_num_classes)) assert torch.all((gt_offset + 1e-8 >= 0) & (gt_offset - 1e-8 < self.interval)) # ImageNet normalization mean = torch.tensor(IMAGENET_MEAN, device=input_images.device, dtype=dtype).view(3, 1, 1) std = torch.tensor(IMAGENET_STD, device=input_images.device, dtype=dtype).view(3, 1, 1) input_images = (input_images - mean) / std # Predict features = self.backbone(input_images.to(dtype=dtype), is_training=True) cls_token = features["x_norm_clstoken"] # (B*V, D) logits = self.cls_head(cls_token) # (B*V, C) pred_offset = self.offset_head(cls_token).squeeze(-1).clamp(0., self.interval) # (B*V,) # Loss outputs["loss_cls"] = tF.cross_entropy(logits, gt_class) outputs["loss_offset"] = tF.mse_loss(pred_offset, gt_offset) outputs["loss"] = outputs["loss_cls"] + self.opt.elevest_reg_weight * outputs["loss_offset"] with torch.no_grad(): pred_elev = self.lower_bounds[torch.argmax(logits, dim=-1)] + pred_offset # (B*V,) outputs["err_degree"] = torch.mean(torch.abs(pred_elev - gt_elev)) return outputs @torch.no_grad() def predict_elev(self, input_images: Tensor, dtype: torch.dtype = torch.float32): # Input image preprocessing input_images = tF.interpolate(input_images, size=(224, 224), mode="bilinear", align_corners=False, antialias=True) input_images = input_images.to(device=self.lower_bounds.device, dtype=dtype) mean = torch.tensor(IMAGENET_MEAN, device=input_images.device, dtype=dtype).view(3, 1, 1) std = torch.tensor(IMAGENET_STD, device=input_images.device, dtype=dtype).view(3, 1, 1) input_images = (input_images - mean) / std features = self.backbone(input_images, is_training=True) cls_token = features["x_norm_clstoken"] # (B, D) logits = self.cls_head(cls_token) # (B, C) pred_offset = self.offset_head(cls_token).squeeze(-1).clamp(0., self.interval) # (B,) pred_elev = self.lower_bounds[torch.argmax(logits, dim=-1)] + pred_offset # (B,) return pred_elev