FooocusEnhanced / custom_mesh_graphormer /tools /run_gphmer_bodymesh_inference.py
JasonSmithSO's picture
Upload 777 files
0034848 verified
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
End-to-end inference codes for
3D human body mesh reconstruction from an image
"""
from __future__ import absolute_import, division, print_function
import argparse
import os
import os.path as op
import code
import json
import time
import datetime
import torch
import torchvision.models as models
from torchvision.utils import make_grid
import gc
import numpy as np
import cv2
from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer
from custom_mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network
from custom_mesh_graphormer.modeling._smpl import SMPL, Mesh
from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config
from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config
import custom_mesh_graphormer.modeling.data.config as cfg
from custom_mesh_graphormer.datasets.build import make_data_loader
from custom_mesh_graphormer.utils.logger import setup_logger
from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed
from custom_mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger
from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error
from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection
from PIL import Image
from torchvision import transforms
from comfy.model_management import get_torch_device
device = get_torch_device()
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
transform_visualize = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor()])
def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler):
# switch to evaluate mode
Graphormer_model.eval()
smpl.eval()
with torch.no_grad():
for image_file in image_list:
if 'pred' not in image_file:
att_all = []
img = Image.open(image_file)
img_tensor = transform(img)
img_visual = transform_visualize(img)
batch_imgs = torch.unsqueeze(img_tensor, 0).to(device)
batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device)
# forward-pass
pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler)
# obtain 3d joints from full mesh
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :]
pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
# save attantion
att_max_value = att[-1]
att_cpu = np.asarray(att_max_value.cpu().detach())
att_all.append(att_cpu)
# obtain 3d joints, which are regressed from the full mesh
pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices)
pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:]
# obtain 2d joints, which are projected from 3d joints of smpl mesh
pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera)
pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera)
visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
pred_vertices[0].detach(),
pred_camera.detach())
# visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
# pred_vertices[0].detach(),
# pred_vertices_sub2[0].detach(),
# pred_2d_431_vertices_from_smpl[0].detach(),
# pred_2d_joints_from_smpl[0].detach(),
# pred_camera.detach(),
# att[-1][0].detach())
visual_imgs = visual_imgs_output.transpose(1,2,0)
visual_imgs = np.asarray(visual_imgs)
temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
print('save to ', temp_fname)
cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
return
def visualize_mesh( renderer, images,
pred_vertices_full,
pred_camera):
img = images.cpu().numpy().transpose(1,2,0)
# Get predict vertices for the particular example
vertices_full = pred_vertices_full.cpu().numpy()
cam = pred_camera.cpu().numpy()
# Visualize only mesh reconstruction
rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
rend_img = rend_img.transpose(2,0,1)
return rend_img
def visualize_mesh_and_attention( renderer, images,
pred_vertices_full,
pred_vertices,
pred_2d_vertices,
pred_2d_joints,
pred_camera,
attention):
img = images.cpu().numpy().transpose(1,2,0)
# Get predict vertices for the particular example
vertices_full = pred_vertices_full.cpu().numpy()
vertices = pred_vertices.cpu().numpy()
vertices_2d = pred_2d_vertices.cpu().numpy()
joints_2d = pred_2d_joints.cpu().numpy()
cam = pred_camera.cpu().numpy()
att = attention.cpu().numpy()
# Visualize reconstruction and attention
rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
rend_img = rend_img.transpose(2,0,1)
return rend_img
def parse_args():
parser = argparse.ArgumentParser()
#########################################################
# Data related arguments
#########################################################
parser.add_argument("--num_workers", default=4, type=int,
help="Workers in dataloader.")
parser.add_argument("--img_scale_factor", default=1, type=int,
help="adjust image resolution.")
parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str,
help="test data")
#########################################################
# Loading/saving checkpoints
#########################################################
parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
help="Path to pre-trained transformer model or model type.")
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
help="Path to specific checkpoint for resume training.")
parser.add_argument("--output_dir", default='output/', type=str, required=False,
help="The output directory to save checkpoint and test results.")
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name.")
#########################################################
# Model architectures
#########################################################
parser.add_argument('-a', '--arch', default='hrnet-w64',
help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
help="Update model config if given")
parser.add_argument("--hidden_size", default=-1, type=int, required=False,
help="Update model config if given")
parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
help="Update model config if given. Note that the division of "
"hidden_size / num_attention_heads should be in integer.")
parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
help="Update model config if given.")
parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
help="The Image Feature Dimension.")
parser.add_argument("--which_gcn", default='0,0,1', type=str,
help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
parser.add_argument("--mesh_type", default='body', type=str, help="body or hand")
parser.add_argument("--interm_size_scale", default=2, type=int)
#########################################################
# Others
#########################################################
parser.add_argument("--run_eval_only", default=True, action='store_true',)
parser.add_argument("--device", type=str, default='cuda',
help="cuda or cpu")
parser.add_argument('--seed', type=int, default=88,
help="random seed for initialization.")
args = parser.parse_args()
return args
def main(args):
global logger
# Setup CUDA, GPU & distributed training
args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
args.distributed = args.num_gpus > 1
args.device = torch.device(args.device)
mkdir(args.output_dir)
logger = setup_logger("Graphormer", args.output_dir, get_rank())
set_seed(args.seed, args.num_gpus)
logger.info("Using {} GPUs".format(args.num_gpus))
# Mesh and SMPL utils
smpl = SMPL().to(args.device)
mesh_sampler = Mesh()
# Renderer for visualization
renderer = Renderer(faces=smpl.faces.cpu().numpy())
# Load model
trans_encoder = []
input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
output_feat_dim = input_feat_dim[1:] + [3]
# which encoder block to have graph convs
which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
# if only run eval, load checkpoint
logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
_model = torch.load(args.resume_checkpoint)
else:
# init three transformer-encoder blocks in a loop
for i in range(len(output_feat_dim)):
config_class, model_class = BertConfig, Graphormer
config = config_class.from_pretrained(args.config_name if args.config_name \
else args.model_name_or_path)
config.output_attentions = False
config.img_feature_dim = input_feat_dim[i]
config.output_feature_dim = output_feat_dim[i]
args.hidden_size = hidden_feat_dim[i]
args.intermediate_size = int(args.hidden_size*args.interm_size_scale)
if which_blk_graph[i]==1:
config.graph_conv = True
logger.info("Add Graph Conv")
else:
config.graph_conv = False
config.mesh_type = args.mesh_type
# update model structure if specified in arguments
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
for idx, param in enumerate(update_params):
arg_param = getattr(args, param)
config_param = getattr(config, param)
if arg_param > 0 and arg_param != config_param:
logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
setattr(config, param, arg_param)
# init a transformer encoder and append it to a list
assert config.hidden_size % config.num_attention_heads == 0
model = model_class(config=config)
logger.info("Init model from scratch.")
trans_encoder.append(model)
# init ImageNet pre-trained backbone model
if args.arch=='hrnet':
hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w40 model')
elif args.arch=='hrnet-w64':
hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
hrnet_update_config(hrnet_config, hrnet_yaml)
backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
logger.info('=> loading hrnet-v2-w64 model')
else:
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
# remove the last fc layer
backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
trans_encoder = torch.nn.Sequential(*trans_encoder)
total_params = sum(p.numel() for p in trans_encoder.parameters())
logger.info('Graphormer encoders total parameters: {}'.format(total_params))
backbone_total_params = sum(p.numel() for p in backbone.parameters())
logger.info('Backbone total parameters: {}'.format(backbone_total_params))
# build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder)
_model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler)
if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
# for fine-tuning or resume training or inference, load weights from checkpoint
logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
# workaround approach to load sparse tensor in graph conv.
states = torch.load(args.resume_checkpoint)
# states = checkpoint_loaded.state_dict()
for k, v in states.items():
states[k] = v.cpu()
# del checkpoint_loaded
_model.load_state_dict(states, strict=False)
del states
gc.collect()
torch.cuda.empty_cache()
# update configs to enable attention outputs
setattr(_model.trans_encoder[-1].config,'output_attentions', True)
setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
_model.trans_encoder[-1].bert.encoder.output_attentions = True
_model.trans_encoder[-1].bert.encoder.output_hidden_states = True
for iter_layer in range(4):
_model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
for inter_block in range(3):
setattr(_model.trans_encoder[-1].config,'device', args.device)
_model.to(args.device)
logger.info("Run inference")
image_list = []
if not args.image_file_or_path:
raise ValueError("image_file_or_path not specified")
if op.isfile(args.image_file_or_path):
image_list = [args.image_file_or_path]
elif op.isdir(args.image_file_or_path):
# should be a path with images only
for filename in os.listdir(args.image_file_or_path):
if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
image_list.append(args.image_file_or_path+'/'+filename)
else:
raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
run_inference(args, image_list, _model, smpl, renderer, mesh_sampler)
if __name__ == "__main__":
args = parse_args()
main(args)