import torch import pickle import pytorch_lightning as pl from typing import Any, Dict from yacs.config import CfgNode from ..utils.geometry import aa_to_rotmat, perspective_projection from ..utils.pylogger import get_pylogger from .backbones import create_backbone from .heads import build_smal_head from . import SMAL log = get_pylogger(__name__) class AMR(pl.LightningModule): def __init__(self, cfg: CfgNode, init_renderer: bool = True): """ Setup AMR model Args: cfg (CfgNode): Config file as a yacs CfgNode """ super().__init__() # Save hyperparameters self.save_hyperparameters(logger=False, ignore=['init_renderer']) self.cfg = cfg # Create backbone feature extractor self.backbone = create_backbone(cfg) # Create SMAL head self.smal_head = build_smal_head(cfg) # Instantiate SMAL model smal_model_path = cfg.SMAL.MODEL_PATH with open(smal_model_path, 'rb') as f: smal_cfg = pickle.load(f, encoding="latin1") self.smal = SMAL(**smal_cfg) def forward_step(self, batch: Dict) -> Dict: """ Run a forward step of the network Args: batch (Dict): Dictionary containing batch data Returns: Dict: Dictionary containing the regression output """ # Use RGB image as input x = batch['img'] batch_size = x.shape[0] # Compute conditioning features using the backbone conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # [256, 192] # conditioning_feats = self.backbone.forward_features(x)['x_norm_patchtokens'] # pred_mano_params:{'betas':[batch_size, 10], 'global_orient': [batch_size, 1, 3, 3], # 'pose':[batch_size, 33, 3, 3], 'translation': [batch_size, 3]} # pred_cam:[batch_size, 3] pred_smal_params, pred_cam, _ = self.smal_head(conditioning_feats) # Store useful regression outputs to the output dict output = {} output['pred_cam'] = pred_cam output['pred_smal_params'] = {k: v.clone() for k, v in pred_smal_params.items()} # Compute camera translation focal_length = batch['focal_length'] pred_cam_t = torch.stack([pred_cam[:, 1], pred_cam[:, 2], 2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9)], dim=-1) output['pred_cam_t'] = pred_cam_t output['focal_length'] = focal_length # Compute model vertices, joints and the projected joints pred_smal_params['global_orient'] = pred_smal_params['global_orient'].reshape(batch_size, -1, 3, 3) pred_smal_params['pose'] = pred_smal_params['pose'].reshape(batch_size, -1, 3, 3) pred_smal_params['betas'] = pred_smal_params['betas'].reshape(batch_size, -1) smal_output = self.smal(**pred_smal_params, pose2rot=False) pred_keypoints_3d = smal_output.joints pred_vertices = smal_output.vertices output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3) output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3) pred_cam_t = pred_cam_t.reshape(-1, 3) focal_length = focal_length.reshape(-1, 2) pred_keypoints_2d = perspective_projection(pred_keypoints_3d, translation=pred_cam_t, focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE) output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2) return output def forward(self, batch: Dict) -> Dict: """ Run a forward step of the network in val mode Args: batch (Dict): Dictionary containing batch data Returns: Dict: Dictionary containing the regression output """ return self.forward_step(batch)