""" Copyright (c) Microsoft Corporation. Licensed under the MIT license. End-to-end inference codes for 3D hand mesh reconstruction from an image """ from __future__ import absolute_import, division, print_function import argparse import os import os.path as op import code import json import time import datetime import torch import torchvision.models as models from torchvision.utils import make_grid import gc import numpy as np import cv2 from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer from custom_mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network from custom_mesh_graphormer.modeling._mano import MANO, Mesh from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config import custom_mesh_graphormer.modeling.data.config as cfg from custom_mesh_graphormer.datasets.build import make_hand_data_loader from custom_mesh_graphormer.utils.logger import setup_logger from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed from custom_mesh_graphormer.utils.metric_logger import AverageMeter from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection from PIL import Image from torchvision import transforms from comfy.model_management import get_torch_device device = get_torch_device() transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) transform_visualize = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()]) def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler): # switch to evaluate mode Graphormer_model.eval() mano.eval() with torch.no_grad(): for image_file in image_list: if 'pred' not in image_file: att_all = [] print(image_file) img = Image.open(image_file) img_tensor = transform(img) img_visual = transform_visualize(img) batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) # forward-pass pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) # obtain 3d joints from full mesh pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] # save attantion att_max_value = att[-1] att_cpu = np.asarray(att_max_value.cpu().detach()) att_all.append(att_cpu) # obtain 3d joints, which are regressed from the full mesh pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) # obtain 2d joints, which are projected from 3d joints of mesh pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], pred_vertices[0].detach(), pred_camera.detach()) # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], # pred_vertices[0].detach(), # pred_vertices_sub[0].detach(), # pred_2d_coarse_vertices_from_mesh[0].detach(), # pred_2d_joints_from_mesh[0].detach(), # pred_camera.detach(), # att[-1][0].detach()) visual_imgs = visual_imgs_output.transpose(1,2,0) visual_imgs = np.asarray(visual_imgs) temp_fname = image_file[:-4] + '_graphormer_pred.jpg' print('save to ', temp_fname) cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) return def visualize_mesh( renderer, images, pred_vertices_full, pred_camera): img = images.cpu().numpy().transpose(1,2,0) # Get predict vertices for the particular example vertices_full = pred_vertices_full.cpu().numpy() cam = pred_camera.cpu().numpy() # Visualize only mesh reconstruction rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') rend_img = rend_img.transpose(2,0,1) return rend_img def visualize_mesh_and_attention( renderer, images, pred_vertices_full, pred_vertices, pred_2d_vertices, pred_2d_joints, pred_camera, attention): img = images.cpu().numpy().transpose(1,2,0) # Get predict vertices for the particular example vertices_full = pred_vertices_full.cpu().numpy() vertices = pred_vertices.cpu().numpy() vertices_2d = pred_2d_vertices.cpu().numpy() joints_2d = pred_2d_joints.cpu().numpy() cam = pred_camera.cpu().numpy() att = attention.cpu().numpy() # Visualize reconstruction and attention rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') rend_img = rend_img.transpose(2,0,1) return rend_img def parse_args(): parser = argparse.ArgumentParser() ######################################################### # Data related arguments ######################################################### parser.add_argument("--num_workers", default=4, type=int, help="Workers in dataloader.") parser.add_argument("--img_scale_factor", default=1, type=int, help="adjust image resolution.") parser.add_argument("--image_file_or_path", default='./samples/hand', type=str, help="test data") ######################################################### # Loading/saving checkpoints ######################################################### parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, help="Path to pre-trained transformer model or model type.") parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, help="Path to specific checkpoint for resume training.") parser.add_argument("--output_dir", default='output/', type=str, required=False, help="The output directory to save checkpoint and test results.") parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name.") parser.add_argument('-a', '--arch', default='hrnet-w64', help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') ######################################################### # Model architectures ######################################################### parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, help="Update model config if given") parser.add_argument("--hidden_size", default=-1, type=int, required=False, help="Update model config if given") parser.add_argument("--num_attention_heads", default=4, type=int, required=False, help="Update model config if given. Note that the division of " "hidden_size / num_attention_heads should be in integer.") parser.add_argument("--intermediate_size", default=-1, type=int, required=False, help="Update model config if given.") parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, help="The Image Feature Dimension.") parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, help="The Image Feature Dimension.") parser.add_argument("--which_gcn", default='0,0,1', type=str, help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") ######################################################### # Others ######################################################### parser.add_argument("--run_eval_only", default=True, action='store_true',) parser.add_argument("--device", type=str, default='cuda', help="cuda or cpu") parser.add_argument('--seed', type=int, default=88, help="random seed for initialization.") args = parser.parse_args() return args def main(args): global logger # Setup CUDA, GPU & distributed training args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 os.environ['OMP_NUM_THREADS'] = str(args.num_workers) print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) mkdir(args.output_dir) logger = setup_logger("Graphormer", args.output_dir, get_rank()) set_seed(args.seed, args.num_gpus) logger.info("Using {} GPUs".format(args.num_gpus)) # Mesh and MANO utils mano_model = MANO().to(args.device) mano_model.layer = mano_model.layer.to(device) mesh_sampler = Mesh() # Renderer for visualization renderer = Renderer(faces=mano_model.face) # Load pretrained model trans_encoder = [] input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] output_feat_dim = input_feat_dim[1:] + [3] # which encoder block to have graph convs which_blk_graph = [int(item) for item in args.which_gcn.split(',')] if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: # if only run eval, load checkpoint logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) _model = torch.load(args.resume_checkpoint) else: # init three transformer-encoder blocks in a loop for i in range(len(output_feat_dim)): config_class, model_class = BertConfig, Graphormer config = config_class.from_pretrained(args.config_name if args.config_name \ else args.model_name_or_path) config.output_attentions = False config.img_feature_dim = input_feat_dim[i] config.output_feature_dim = output_feat_dim[i] args.hidden_size = hidden_feat_dim[i] args.intermediate_size = int(args.hidden_size*2) if which_blk_graph[i]==1: config.graph_conv = True logger.info("Add Graph Conv") else: config.graph_conv = False config.mesh_type = args.mesh_type # update model structure if specified in arguments update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] for idx, param in enumerate(update_params): arg_param = getattr(args, param) config_param = getattr(config, param) if arg_param > 0 and arg_param != config_param: logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) setattr(config, param, arg_param) # init a transformer encoder and append it to a list assert config.hidden_size % config.num_attention_heads == 0 model = model_class(config=config) logger.info("Init model from scratch.") trans_encoder.append(model) # create backbone model if args.arch=='hrnet': hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' hrnet_update_config(hrnet_config, hrnet_yaml) backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) logger.info('=> loading hrnet-v2-w40 model') elif args.arch=='hrnet-w64': hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' hrnet_update_config(hrnet_config, hrnet_yaml) backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) logger.info('=> loading hrnet-v2-w64 model') else: print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) # remove the last fc layer backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) trans_encoder = torch.nn.Sequential(*trans_encoder) total_params = sum(p.numel() for p in trans_encoder.parameters()) logger.info('Graphormer encoders total parameters: {}'.format(total_params)) backbone_total_params = sum(p.numel() for p in backbone.parameters()) logger.info('Backbone total parameters: {}'.format(backbone_total_params)) # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) _model = Graphormer_Network(args, config, backbone, trans_encoder) if args.resume_checkpoint!=None and args.resume_checkpoint!='None': # for fine-tuning or resume training or inference, load weights from checkpoint logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) # workaround approach to load sparse tensor in graph conv. state_dict = torch.load(args.resume_checkpoint) _model.load_state_dict(state_dict, strict=False) del state_dict gc.collect() torch.cuda.empty_cache() # update configs to enable attention outputs setattr(_model.trans_encoder[-1].config,'output_attentions', True) setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) _model.trans_encoder[-1].bert.encoder.output_attentions = True _model.trans_encoder[-1].bert.encoder.output_hidden_states = True for iter_layer in range(4): _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True for inter_block in range(3): setattr(_model.trans_encoder[-1].config,'device', args.device) _model.to(args.device) logger.info("Run inference") image_list = [] if not args.image_file_or_path: raise ValueError("image_file_or_path not specified") if op.isfile(args.image_file_or_path): image_list = [args.image_file_or_path] elif op.isdir(args.image_file_or_path): # should be a path with images only for filename in os.listdir(args.image_file_or_path): if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: image_list.append(args.image_file_or_path+'/'+filename) else: raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler) if __name__ == "__main__": args = parse_args() main(args)