""" This file contains the MANO defination and mesh sampling operations for MANO mesh Adapted from opensource projects MANOPTH (https://github.com/hassony2/manopth) Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) GraphCMR (https://github.com/nkolot/GraphCMR/) """ from __future__ import division import numpy as np import torch import torch.nn as nn import os.path as osp import json import code from custom_manopth.manolayer import ManoLayer import scipy.sparse import custom_mesh_graphormer.modeling.data.config as cfg from pathlib import Path from comfy.model_management import get_torch_device from wrapper_for_mps import sparse_to_dense device = get_torch_device() class MANO(nn.Module): def __init__(self): super(MANO, self).__init__() self.mano_dir = str(Path(__file__).parent / "data") self.layer = self.get_layer() self.vertex_num = 778 self.face = self.layer.th_faces.numpy() self.joint_regressor = self.layer.th_J_regressor.numpy() self.joint_num = 21 self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) ) self.root_joint_idx = self.joints_name.index('Wrist') # add fingertips to joint_regressor self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand) thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot)) self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:] joint_regressor_torch = torch.from_numpy(self.joint_regressor).float() self.register_buffer('joint_regressor_torch', joint_regressor_torch) def get_layer(self): return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model def get_3d_joints(self, vertices): """ This method is used to get the joint locations from the SMPL mesh Input: vertices: size = (B, 778, 3) Output: 3D joints: size = (B, 21, 3) """ joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch]) return joints class SparseMM(torch.autograd.Function): """Redefine sparse @ dense matrix multiplication to enable backpropagation. The builtin matrix multiplication operation does not support backpropagation in some cases. """ @staticmethod def forward(ctx, sparse, dense): ctx.req_grad = dense.requires_grad ctx.save_for_backward(sparse) return torch.matmul(sparse, dense) @staticmethod def backward(ctx, grad_output): grad_input = None sparse, = ctx.saved_tensors if ctx.req_grad: grad_input = torch.matmul(sparse.t(), grad_output) return None, grad_input def spmm(sparse, dense): sparse = sparse.to(device) dense = dense.to(device) return SparseMM.apply(sparse, dense) def scipy_to_pytorch(A, U, D): """Convert scipy sparse matrices to pytorch sparse matrix.""" ptU = [] ptD = [] for i in range(len(U)): u = scipy.sparse.coo_matrix(U[i]) i = torch.LongTensor(np.array([u.row, u.col])) v = torch.FloatTensor(u.data) ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape))) for i in range(len(D)): d = scipy.sparse.coo_matrix(D[i]) i = torch.LongTensor(np.array([d.row, d.col])) v = torch.FloatTensor(d.data) ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape))) return ptU, ptD def adjmat_sparse(adjmat, nsize=1): """Create row-normalized sparse graph adjacency matrix.""" adjmat = scipy.sparse.csr_matrix(adjmat) if nsize > 1: orig_adjmat = adjmat.copy() for _ in range(1, nsize): adjmat = adjmat * orig_adjmat adjmat.data = np.ones_like(adjmat.data) for i in range(adjmat.shape[0]): adjmat[i,i] = 1 num_neighbors = np.array(1 / adjmat.sum(axis=-1)) adjmat = adjmat.multiply(num_neighbors) adjmat = scipy.sparse.coo_matrix(adjmat) row = adjmat.row col = adjmat.col data = adjmat.data i = torch.LongTensor(np.array([row, col])) v = torch.from_numpy(data).float() adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape)) return adjmat def get_graph_params(filename, nsize=1): """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" data = np.load(filename, encoding='latin1', allow_pickle=True) A = data['A'] U = data['U'] D = data['D'] U, D = scipy_to_pytorch(A, U, D) A = [adjmat_sparse(a, nsize=nsize) for a in A] return A, U, D class Mesh(object): """Mesh object that is used for handling certain graph operations.""" def __init__(self, filename=cfg.MANO_sampling_matrix, num_downsampling=1, nsize=1, device=torch.device('cuda')): self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) # self._A = [a.to(device) for a in self._A] self._U = [u.to(device) for u in self._U] self._D = [d.to(device) for d in self._D] self.num_downsampling = num_downsampling def downsample(self, x, n1=0, n2=None): """Downsample mesh.""" if n2 is None: n2 = self.num_downsampling if x.ndimension() < 3: for i in range(n1, n2): x = spmm(self._D[i], x) elif x.ndimension() == 3: out = [] for i in range(x.shape[0]): y = x[i] for j in range(n1, n2): y = spmm(self._D[j], y) out.append(y) x = torch.stack(out, dim=0) return x def upsample(self, x, n1=1, n2=0): """Upsample mesh.""" if x.ndimension() < 3: for i in reversed(range(n2, n1)): x = spmm(self._U[i], x) elif x.ndimension() == 3: out = [] for i in range(x.shape[0]): y = x[i] for j in reversed(range(n2, n1)): y = spmm(self._U[j], y) out.append(y) x = torch.stack(out, dim=0) return x