AniMer / amr /models /amr.py
luoxue-star's picture
init commit
48cafca
import torch
import pickle
import pytorch_lightning as pl
from typing import Any, Dict
from yacs.config import CfgNode
from ..utils.geometry import aa_to_rotmat, perspective_projection
from ..utils.pylogger import get_pylogger
from .backbones import create_backbone
from .heads import build_smal_head
from . import SMAL
log = get_pylogger(__name__)
class AMR(pl.LightningModule):
def __init__(self, cfg: CfgNode, init_renderer: bool = True):
"""
Setup AMR model
Args:
cfg (CfgNode): Config file as a yacs CfgNode
"""
super().__init__()
# Save hyperparameters
self.save_hyperparameters(logger=False, ignore=['init_renderer'])
self.cfg = cfg
# Create backbone feature extractor
self.backbone = create_backbone(cfg)
# Create SMAL head
self.smal_head = build_smal_head(cfg)
# Instantiate SMAL model
smal_model_path = cfg.SMAL.MODEL_PATH
with open(smal_model_path, 'rb') as f:
smal_cfg = pickle.load(f, encoding="latin1")
self.smal = SMAL(**smal_cfg)
def forward_step(self, batch: Dict) -> Dict:
"""
Run a forward step of the network
Args:
batch (Dict): Dictionary containing batch data
Returns:
Dict: Dictionary containing the regression output
"""
# Use RGB image as input
x = batch['img']
batch_size = x.shape[0]
# Compute conditioning features using the backbone
conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # [256, 192]
# conditioning_feats = self.backbone.forward_features(x)['x_norm_patchtokens']
# pred_mano_params:{'betas':[batch_size, 10], 'global_orient': [batch_size, 1, 3, 3],
# 'pose':[batch_size, 33, 3, 3], 'translation': [batch_size, 3]}
# pred_cam:[batch_size, 3]
pred_smal_params, pred_cam, _ = self.smal_head(conditioning_feats)
# Store useful regression outputs to the output dict
output = {}
output['pred_cam'] = pred_cam
output['pred_smal_params'] = {k: v.clone() for k, v in pred_smal_params.items()}
# Compute camera translation
focal_length = batch['focal_length']
pred_cam_t = torch.stack([pred_cam[:, 1],
pred_cam[:, 2],
2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9)], dim=-1)
output['pred_cam_t'] = pred_cam_t
output['focal_length'] = focal_length
# Compute model vertices, joints and the projected joints
pred_smal_params['global_orient'] = pred_smal_params['global_orient'].reshape(batch_size, -1, 3, 3)
pred_smal_params['pose'] = pred_smal_params['pose'].reshape(batch_size, -1, 3, 3)
pred_smal_params['betas'] = pred_smal_params['betas'].reshape(batch_size, -1)
smal_output = self.smal(**pred_smal_params, pose2rot=False)
pred_keypoints_3d = smal_output.joints
pred_vertices = smal_output.vertices
output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
pred_cam_t = pred_cam_t.reshape(-1, 3)
focal_length = focal_length.reshape(-1, 2)
pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
translation=pred_cam_t,
focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
return output
def forward(self, batch: Dict) -> Dict:
"""
Run a forward step of the network in val mode
Args:
batch (Dict): Dictionary containing batch data
Returns:
Dict: Dictionary containing the regression output
"""
return self.forward_step(batch)