|
import torch
|
|
import src.utils.rotation_conversions as geometry
|
|
|
|
from .smpl import SMPL, JOINTSTYPE_ROOT
|
|
from src.models.tools.jointstypes import JOINTSTYPES
|
|
|
|
|
|
class Rotation2xyz:
|
|
def __init__(self, device):
|
|
self.device = device
|
|
self.smpl_model = SMPL().eval().to(device)
|
|
|
|
def __call__(self, x, mask, pose_rep, translation, glob,
|
|
jointstype, vertstrans, betas=None, beta=0,
|
|
glob_rot=None, get_rotations_back=False, **kwargs):
|
|
if pose_rep == "xyz":
|
|
return x
|
|
|
|
if mask is None:
|
|
mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device)
|
|
|
|
if not glob and glob_rot is None:
|
|
raise TypeError("You must specify global rotation if glob is False")
|
|
|
|
if jointstype not in JOINTSTYPES:
|
|
raise NotImplementedError("This jointstype is not implemented.")
|
|
|
|
if translation:
|
|
x_translations = x[:, -1, :3]
|
|
x_rotations = x[:, :-1]
|
|
else:
|
|
x_rotations = x
|
|
|
|
x_rotations = x_rotations.permute(0, 3, 1, 2)
|
|
nsamples, time, njoints, feats = x_rotations.shape
|
|
|
|
|
|
if pose_rep == "rotvec":
|
|
rotations = geometry.axis_angle_to_matrix(x_rotations[mask])
|
|
elif pose_rep == "rotmat":
|
|
rotations = x_rotations[mask].view(-1, njoints, 3, 3)
|
|
elif pose_rep == "rotquat":
|
|
rotations = geometry.quaternion_to_matrix(x_rotations[mask])
|
|
elif pose_rep == "rot6d":
|
|
rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])
|
|
else:
|
|
raise NotImplementedError("No geometry for this one.")
|
|
|
|
if not glob:
|
|
global_orient = torch.tensor(glob_rot, device=x.device)
|
|
global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3)
|
|
global_orient = global_orient.repeat(len(rotations), 1, 1, 1)
|
|
else:
|
|
global_orient = rotations[:, 0]
|
|
rotations = rotations[:, 1:]
|
|
|
|
if betas is None:
|
|
betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas],
|
|
dtype=rotations.dtype, device=rotations.device)
|
|
betas[:, 1] = beta
|
|
|
|
out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas)
|
|
|
|
|
|
joints = out[jointstype]
|
|
|
|
x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype)
|
|
x_xyz[~mask] = 0
|
|
x_xyz[mask] = joints
|
|
|
|
x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous()
|
|
|
|
|
|
if jointstype != "vertices":
|
|
rootindex = JOINTSTYPE_ROOT[jointstype]
|
|
x_xyz = x_xyz - x_xyz[:, [rootindex], :, :]
|
|
|
|
if translation and vertstrans:
|
|
|
|
x_translations = x_translations - x_translations[:, :, [0]]
|
|
|
|
|
|
x_xyz = x_xyz + x_translations[:, None, :, :]
|
|
|
|
if get_rotations_back:
|
|
return x_xyz, rotations, global_orient
|
|
else:
|
|
return x_xyz
|
|
|