JasonSmithSO's picture
Upload 777 files
0034848 verified
from __future__ import absolute_import, division, print_function
import argparse
import os
import os.path as op
import code
import json
import zipfile
import torch
import numpy as np
from custom_mesh_graphormer.utils.metric_pampjpe import get_alignMesh
def load_pred_json(filepath):
archive = zipfile.ZipFile(filepath, 'r')
jsondata = archive.read('pred.json')
reference = json.loads(jsondata.decode("utf-8"))
return reference[0], reference[1]
def multiscale_fusion(output_dir):
s = '10'
filepath = output_dir+'ckpt200-sc10_rot0-pred.zip'
ref_joints, ref_vertices = load_pred_json(filepath)
ref_joints_array = np.asarray(ref_joints)
ref_vertices_array = np.asarray(ref_vertices)
rotations = [0.0]
for i in range(1,10):
rotations.append(i*10)
rotations.append(i*-10)
scale = [0.7,0.8,0.9,1.0,1.1]
multiscale_joints = []
multiscale_vertices = []
counter = 0
for s in scale:
for r in rotations:
setting = 'sc%02d_rot%s'%(int(s*10),str(int(r)))
filepath = output_dir+'ckpt200-'+setting+'-pred.zip'
joints, vertices = load_pred_json(filepath)
joints_array = np.asarray(joints)
vertices_array = np.asarray(vertices)
pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None)
pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None)
print('--------------------------')
print('scale:', s, 'rotate', r)
print('PAMPJPE:', 1000*np.mean(pa_joint_error))
print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
multiscale_joints.append(pa_joint_array)
multiscale_vertices.append(pa_vertices_array)
counter = counter + 1
overall_joints_array = ref_joints_array.copy()
overall_vertices_array = ref_vertices_array.copy()
for i in range(counter):
overall_joints_array += multiscale_joints[i]
overall_vertices_array += multiscale_vertices[i]
overall_joints_array /= (1+counter)
overall_vertices_array /= (1+counter)
pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None)
pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None)
print('--------------------------')
print('overall:')
print('PAMPJPE:', 1000*np.mean(pa_joint_error))
print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
joint_output_save = overall_joints_array.tolist()
mesh_output_save = overall_vertices_array.tolist()
print('save results to pred.json')
with open('pred.json', 'w') as f:
json.dump([joint_output_save, mesh_output_save], f)
filepath = output_dir+'ckpt200-multisc-pred.zip'
resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json'
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
resolved_submit_cmd = 'rm pred.json'
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
def run_multiscale_inference(model_path, mode, output_dir):
if mode==True:
rotations = [0.0]
for i in range(1,10):
rotations.append(i*10)
rotations.append(i*-10)
scale = [0.7,0.8,0.9,1.0,1.1]
else:
rotations = [0.0]
scale = [1.0]
job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \
"--val_yaml freihand_v3/test.yaml " \
"--resume_checkpoint %s " \
"--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \
"--multiscale_inference " \
"--rot %f " \
"--sc %s " \
"--arch hrnet-w64 " \
"--num_hidden_layers 4 " \
"--num_attention_heads 4 " \
"--input_feat_dim 2051,512,128 " \
"--hidden_feat_dim 1024,256,64 " \
"--output_dir %s"
for s in scale:
for r in rotations:
resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir)
print(resolved_submit_cmd)
os.system(resolved_submit_cmd)
def main(args):
model_path = args.model_path
mode = args.multiscale_inference
output_dir = args.output_dir
run_multiscale_inference(model_path, mode, output_dir)
if mode==True:
multiscale_fusion(output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder")
parser.add_argument("--model_path")
parser.add_argument("--multiscale_inference", default=False, action='store_true',)
parser.add_argument("--output_dir", default='output/', type=str, required=False,
help="The output directory to save checkpoint and test results.")
args = parser.parse_args()
main(args)