Spaces:
Configuration error
Configuration error
""" | |
Copyright (c) Microsoft Corporation. | |
Licensed under the MIT license. | |
End-to-end inference codes for | |
3D human body mesh reconstruction from an image | |
""" | |
from __future__ import absolute_import, division, print_function | |
import argparse | |
import os | |
import os.path as op | |
import code | |
import json | |
import time | |
import datetime | |
import torch | |
import torchvision.models as models | |
from torchvision.utils import make_grid | |
import gc | |
import numpy as np | |
import cv2 | |
from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer | |
from custom_mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network | |
from custom_mesh_graphormer.modeling._smpl import SMPL, Mesh | |
from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat | |
from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config | |
from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config | |
import custom_mesh_graphormer.modeling.data.config as cfg | |
from custom_mesh_graphormer.datasets.build import make_data_loader | |
from custom_mesh_graphormer.utils.logger import setup_logger | |
from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather | |
from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed | |
from custom_mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger | |
from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text | |
from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error | |
from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection | |
from PIL import Image | |
from torchvision import transforms | |
from comfy.model_management import get_torch_device | |
device = get_torch_device() | |
transform = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225])]) | |
transform_visualize = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor()]) | |
def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler): | |
# switch to evaluate mode | |
Graphormer_model.eval() | |
smpl.eval() | |
with torch.no_grad(): | |
for image_file in image_list: | |
if 'pred' not in image_file: | |
att_all = [] | |
img = Image.open(image_file) | |
img_tensor = transform(img) | |
img_visual = transform_visualize(img) | |
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) | |
batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) | |
# forward-pass | |
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler) | |
# obtain 3d joints from full mesh | |
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) | |
pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:] | |
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] | |
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :] | |
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] | |
# save attantion | |
att_max_value = att[-1] | |
att_cpu = np.asarray(att_max_value.cpu().detach()) | |
att_all.append(att_cpu) | |
# obtain 3d joints, which are regressed from the full mesh | |
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) | |
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] | |
# obtain 2d joints, which are projected from 3d joints of smpl mesh | |
pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera) | |
pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera) | |
visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], | |
pred_vertices[0].detach(), | |
pred_camera.detach()) | |
# visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], | |
# pred_vertices[0].detach(), | |
# pred_vertices_sub2[0].detach(), | |
# pred_2d_431_vertices_from_smpl[0].detach(), | |
# pred_2d_joints_from_smpl[0].detach(), | |
# pred_camera.detach(), | |
# att[-1][0].detach()) | |
visual_imgs = visual_imgs_output.transpose(1,2,0) | |
visual_imgs = np.asarray(visual_imgs) | |
temp_fname = image_file[:-4] + '_graphormer_pred.jpg' | |
print('save to ', temp_fname) | |
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) | |
return | |
def visualize_mesh( renderer, images, | |
pred_vertices_full, | |
pred_camera): | |
img = images.cpu().numpy().transpose(1,2,0) | |
# Get predict vertices for the particular example | |
vertices_full = pred_vertices_full.cpu().numpy() | |
cam = pred_camera.cpu().numpy() | |
# Visualize only mesh reconstruction | |
rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') | |
rend_img = rend_img.transpose(2,0,1) | |
return rend_img | |
def visualize_mesh_and_attention( renderer, images, | |
pred_vertices_full, | |
pred_vertices, | |
pred_2d_vertices, | |
pred_2d_joints, | |
pred_camera, | |
attention): | |
img = images.cpu().numpy().transpose(1,2,0) | |
# Get predict vertices for the particular example | |
vertices_full = pred_vertices_full.cpu().numpy() | |
vertices = pred_vertices.cpu().numpy() | |
vertices_2d = pred_2d_vertices.cpu().numpy() | |
joints_2d = pred_2d_joints.cpu().numpy() | |
cam = pred_camera.cpu().numpy() | |
att = attention.cpu().numpy() | |
# Visualize reconstruction and attention | |
rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') | |
rend_img = rend_img.transpose(2,0,1) | |
return rend_img | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
######################################################### | |
# Data related arguments | |
######################################################### | |
parser.add_argument("--num_workers", default=4, type=int, | |
help="Workers in dataloader.") | |
parser.add_argument("--img_scale_factor", default=1, type=int, | |
help="adjust image resolution.") | |
parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str, | |
help="test data") | |
######################################################### | |
# Loading/saving checkpoints | |
######################################################### | |
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, | |
help="Path to pre-trained transformer model or model type.") | |
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, | |
help="Path to specific checkpoint for resume training.") | |
parser.add_argument("--output_dir", default='output/', type=str, required=False, | |
help="The output directory to save checkpoint and test results.") | |
parser.add_argument("--config_name", default="", type=str, | |
help="Pretrained config name or path if not the same as model_name.") | |
######################################################### | |
# Model architectures | |
######################################################### | |
parser.add_argument('-a', '--arch', default='hrnet-w64', | |
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') | |
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, | |
help="Update model config if given") | |
parser.add_argument("--hidden_size", default=-1, type=int, required=False, | |
help="Update model config if given") | |
parser.add_argument("--num_attention_heads", default=4, type=int, required=False, | |
help="Update model config if given. Note that the division of " | |
"hidden_size / num_attention_heads should be in integer.") | |
parser.add_argument("--intermediate_size", default=-1, type=int, required=False, | |
help="Update model config if given.") | |
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, | |
help="The Image Feature Dimension.") | |
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, | |
help="The Image Feature Dimension.") | |
parser.add_argument("--which_gcn", default='0,0,1', type=str, | |
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") | |
parser.add_argument("--mesh_type", default='body', type=str, help="body or hand") | |
parser.add_argument("--interm_size_scale", default=2, type=int) | |
######################################################### | |
# Others | |
######################################################### | |
parser.add_argument("--run_eval_only", default=True, action='store_true',) | |
parser.add_argument("--device", type=str, default='cuda', | |
help="cuda or cpu") | |
parser.add_argument('--seed', type=int, default=88, | |
help="random seed for initialization.") | |
args = parser.parse_args() | |
return args | |
def main(args): | |
global logger | |
# Setup CUDA, GPU & distributed training | |
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 | |
os.environ['OMP_NUM_THREADS'] = str(args.num_workers) | |
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) | |
args.distributed = args.num_gpus > 1 | |
args.device = torch.device(args.device) | |
mkdir(args.output_dir) | |
logger = setup_logger("Graphormer", args.output_dir, get_rank()) | |
set_seed(args.seed, args.num_gpus) | |
logger.info("Using {} GPUs".format(args.num_gpus)) | |
# Mesh and SMPL utils | |
smpl = SMPL().to(args.device) | |
mesh_sampler = Mesh() | |
# Renderer for visualization | |
renderer = Renderer(faces=smpl.faces.cpu().numpy()) | |
# Load model | |
trans_encoder = [] | |
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] | |
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] | |
output_feat_dim = input_feat_dim[1:] + [3] | |
# which encoder block to have graph convs | |
which_blk_graph = [int(item) for item in args.which_gcn.split(',')] | |
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: | |
# if only run eval, load checkpoint | |
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) | |
_model = torch.load(args.resume_checkpoint) | |
else: | |
# init three transformer-encoder blocks in a loop | |
for i in range(len(output_feat_dim)): | |
config_class, model_class = BertConfig, Graphormer | |
config = config_class.from_pretrained(args.config_name if args.config_name \ | |
else args.model_name_or_path) | |
config.output_attentions = False | |
config.img_feature_dim = input_feat_dim[i] | |
config.output_feature_dim = output_feat_dim[i] | |
args.hidden_size = hidden_feat_dim[i] | |
args.intermediate_size = int(args.hidden_size*args.interm_size_scale) | |
if which_blk_graph[i]==1: | |
config.graph_conv = True | |
logger.info("Add Graph Conv") | |
else: | |
config.graph_conv = False | |
config.mesh_type = args.mesh_type | |
# update model structure if specified in arguments | |
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] | |
for idx, param in enumerate(update_params): | |
arg_param = getattr(args, param) | |
config_param = getattr(config, param) | |
if arg_param > 0 and arg_param != config_param: | |
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) | |
setattr(config, param, arg_param) | |
# init a transformer encoder and append it to a list | |
assert config.hidden_size % config.num_attention_heads == 0 | |
model = model_class(config=config) | |
logger.info("Init model from scratch.") | |
trans_encoder.append(model) | |
# init ImageNet pre-trained backbone model | |
if args.arch=='hrnet': | |
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' | |
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' | |
hrnet_update_config(hrnet_config, hrnet_yaml) | |
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) | |
logger.info('=> loading hrnet-v2-w40 model') | |
elif args.arch=='hrnet-w64': | |
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' | |
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' | |
hrnet_update_config(hrnet_config, hrnet_yaml) | |
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) | |
logger.info('=> loading hrnet-v2-w64 model') | |
else: | |
print("=> using pre-trained model '{}'".format(args.arch)) | |
backbone = models.__dict__[args.arch](pretrained=True) | |
# remove the last fc layer | |
backbone = torch.nn.Sequential(*list(backbone.children())[:-2]) | |
trans_encoder = torch.nn.Sequential(*trans_encoder) | |
total_params = sum(p.numel() for p in trans_encoder.parameters()) | |
logger.info('Graphormer encoders total parameters: {}'.format(total_params)) | |
backbone_total_params = sum(p.numel() for p in backbone.parameters()) | |
logger.info('Backbone total parameters: {}'.format(backbone_total_params)) | |
# build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder) | |
_model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler) | |
if args.resume_checkpoint!=None and args.resume_checkpoint!='None': | |
# for fine-tuning or resume training or inference, load weights from checkpoint | |
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) | |
# workaround approach to load sparse tensor in graph conv. | |
states = torch.load(args.resume_checkpoint) | |
# states = checkpoint_loaded.state_dict() | |
for k, v in states.items(): | |
states[k] = v.cpu() | |
# del checkpoint_loaded | |
_model.load_state_dict(states, strict=False) | |
del states | |
gc.collect() | |
torch.cuda.empty_cache() | |
# update configs to enable attention outputs | |
setattr(_model.trans_encoder[-1].config,'output_attentions', True) | |
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) | |
_model.trans_encoder[-1].bert.encoder.output_attentions = True | |
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True | |
for iter_layer in range(4): | |
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True | |
for inter_block in range(3): | |
setattr(_model.trans_encoder[-1].config,'device', args.device) | |
_model.to(args.device) | |
logger.info("Run inference") | |
image_list = [] | |
if not args.image_file_or_path: | |
raise ValueError("image_file_or_path not specified") | |
if op.isfile(args.image_file_or_path): | |
image_list = [args.image_file_or_path] | |
elif op.isdir(args.image_file_or_path): | |
# should be a path with images only | |
for filename in os.listdir(args.image_file_or_path): | |
if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: | |
image_list.append(args.image_file_or_path+'/'+filename) | |
else: | |
raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) | |
run_inference(args, image_list, _model, smpl, renderer, mesh_sampler) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |