Spaces:
Running
Running
File size: 6,086 Bytes
48cafca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import json
from torch import nn
import torch
import numpy as np
import pickle
import cv2
from typing import Optional, Tuple, NewType
from dataclasses import dataclass
import smplx
from smplx.lbs import vertices2joints, lbs
from smplx.utils import MANOOutput, to_tensor, ModelOutput
from smplx.vertex_ids import vertex_ids
Tensor = NewType('Tensor', torch.Tensor)
keypoint_vertices_idx = [[1068, 1080, 1029, 1226], [2660, 3030, 2675, 3038], [910], [360, 1203, 1235, 1230],
[3188, 3156, 2327, 3183], [1976, 1974, 1980, 856], [3854, 2820, 3852, 3858], [452, 1811],
[416, 235, 182], [2156, 2382, 2203], [829], [2793], [60, 114, 186, 59],
[2091, 2037, 2036, 2160], [384, 799, 1169, 431], [2351, 2763, 2397, 3127],
[221, 104], [2754, 2192], [191, 1158, 3116, 2165],
[28, 1109, 1110, 1111, 1835, 1836, 3067, 3068, 3069],
[498, 499, 500, 501, 502, 503], [2463, 2464, 2465, 2466, 2467, 2468],
[764, 915, 916, 917, 934, 935, 956], [2878, 2879, 2880, 2897, 2898, 2919, 3751],
[1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762],
[0, 464, 465, 726, 1824, 2429, 2430, 2690]]
name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1,
'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15,
'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11,
'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27,
'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar': 33, 'REar': 34, 'EndNose': 35, 'Chin': 36,
'RightEarTip': 37, 'LeftEarTip': 38, 'LeftEye': 39, 'RightEye': 40}
@dataclass
class SMALOutput(ModelOutput):
betas: Optional[Tensor] = None
pose: Optional[Tensor] = None
class SMALLayer(nn.Module):
def __init__(self, num_betas=41, **kwargs):
super().__init__()
self.num_betas = num_betas
self.register_buffer("shapedirs", torch.from_numpy(np.array(kwargs['shapedirs'], dtype=np.float32))[:, :, :num_betas]) # [3889, 3, 41]
self.register_buffer("v_template", torch.from_numpy(np.array(kwargs['v_template']).astype(np.float32))) # [3889, 3]
self.register_buffer("posedirs", torch.from_numpy(np.array(kwargs['posedirs'], dtype=np.float32)).reshape(-1,
34*9).T) # [34*9, 11667]
self.register_buffer("J_regressor", torch.from_numpy(kwargs['J_regressor'].toarray().astype(np.float32))) # [33, 3389]
self.register_buffer("lbs_weights", torch.from_numpy(np.array(kwargs['weights'], dtype=np.float32))) # [3889, 33]
self.register_buffer("faces", torch.from_numpy(np.array(kwargs['f'], dtype=np.int32))) # [7774, 3]
kintree_table = kwargs['kintree_table']
# self.register_buffer("parents", torch.from_numpy(kintree_table[0].astype(np.int32)))
id_to_col = {kintree_table[1, i]: i for i in range(kintree_table.shape[1])}
self.register_buffer("parents", torch.tensor([0] + [id_to_col[kintree_table[0, i]] for i in range(1, kintree_table.shape[1])],
dtype=torch.long))
def forward(
self,
betas: Optional[Tensor] = None,
global_orient: Optional[Tensor] = None,
pose: Optional[Tensor] = None,
transl: Optional[Tensor] = None,
return_verts: bool = True,
return_full_pose: bool = False,
**kwargs):
"""
Args:
betas: [batch_size, 10]
global_orient: [batch_size, 1, 3, 3]
pose: [batch_size, num_joints, 3, 3]
transl: [batch_size, num_joints, 3]
return_verts:
return_full_pose:
**kwargs:
Returns:
"""
device, dtype = betas.device, betas.dtype
if global_orient is None:
batch_size = 1
global_orient = torch.eye(3, device=device, dtype=dtype).view(
1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
else:
batch_size = global_orient.shape[0]
if pose is None:
pose = torch.eye(3, device=device, dtype=dtype).view(
1, 1, 3, 3).expand(batch_size, 34, -1, -1).contiguous()
if betas is None:
betas = torch.zeros(
[batch_size, self.num_betas], dtype=dtype, device=device)
if transl is None:
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
full_pose = torch.cat([global_orient, pose], dim=1)
vertices, joints = lbs(betas, full_pose, self.v_template,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.lbs_weights, pose2rot=False)
if transl is not None:
joints = joints + transl.unsqueeze(dim=1)
vertices = vertices + transl.unsqueeze(dim=1)
output = SMALOutput(
vertices=vertices if return_verts else None,
joints=joints if return_verts else None,
betas=betas,
global_orient=global_orient,
pose=pose,
transl=transl,
full_pose=full_pose if return_full_pose else None,
)
return output
class SMAL(SMALLayer):
def __init__(self, **kwargs):
super(SMAL, self).__init__(**kwargs)
def forward(self, *args, **kwargs):
smal_output = super(SMAL, self).forward(**kwargs)
keypoint = []
for kp_v in keypoint_vertices_idx:
keypoint.append(smal_output.vertices[:, kp_v, :].mean(dim=1))
smal_output.joints = torch.stack(keypoint, dim=1)
return smal_output
|