# Copyright (c) OpenMMLab. All rights reserved. import logging import mimetypes import os import time from argparse import ArgumentParser from functools import partial import cv2 import json_tricks as json import mmcv import mmengine import numpy as np from mmengine.logging import print_log from mmpose.apis import (_track_by_iou, _track_by_oks, convert_keypoint_definition, extract_pose_sequence, inference_pose_lifter_model, inference_topdown, init_model) from mmpose.models.pose_estimators import PoseLifter from mmpose.models.pose_estimators.topdown import TopdownPoseEstimator from mmpose.registry import VISUALIZERS from mmpose.structures import (PoseDataSample, merge_data_samples, split_instances) from mmpose.utils import adapt_mmdet_pipeline try: from mmdet.apis import inference_detector, init_detector has_mmdet = True except (ImportError, ModuleNotFoundError): has_mmdet = False def parse_args(): parser = ArgumentParser() parser.add_argument('det_config', help='Config file for detection') parser.add_argument('det_checkpoint', help='Checkpoint file for detection') parser.add_argument( 'pose_estimator_config', type=str, default=None, help='Config file for the 1st stage 2D pose estimator') parser.add_argument( 'pose_estimator_checkpoint', type=str, default=None, help='Checkpoint file for the 1st stage 2D pose estimator') parser.add_argument( 'pose_lifter_config', help='Config file for the 2nd stage pose lifter model') parser.add_argument( 'pose_lifter_checkpoint', help='Checkpoint file for the 2nd stage pose lifter model') parser.add_argument('--input', type=str, default='', help='Video path') parser.add_argument( '--show', action='store_true', default=False, help='Whether to show visualizations') parser.add_argument( '--disable-rebase-keypoint', action='store_true', default=False, help='Whether to disable rebasing the predicted 3D pose so its ' 'lowest keypoint has a height of 0 (landing on the ground). Rebase ' 'is useful for visualization when the model do not predict the ' 'global position of the 3D pose.') parser.add_argument( '--disable-norm-pose-2d', action='store_true', default=False, help='Whether to scale the bbox (along with the 2D pose) to the ' 'average bbox scale of the dataset, and move the bbox (along with the ' '2D pose) to the average bbox center of the dataset. This is useful ' 'when bbox is small, especially in multi-person scenarios.') parser.add_argument( '--num-instances', type=int, default=1, help='The number of 3D poses to be visualized in every frame. If ' 'less than 0, it will be set to the number of pose results in the ' 'first frame.') parser.add_argument( '--output-root', type=str, default='', help='Root of the output video file. ' 'Default not saving the visualization video.') parser.add_argument( '--save-predictions', action='store_true', default=False, help='Whether to save predicted results') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( '--det-cat-id', type=int, default=0, help='Category id for bounding box detection model') parser.add_argument( '--bbox-thr', type=float, default=0.3, help='Bounding box score threshold') parser.add_argument('--kpt-thr', type=float, default=0.3) parser.add_argument( '--use-oks-tracking', action='store_true', help='Using OKS tracking') parser.add_argument( '--tracking-thr', type=float, default=0.3, help='Tracking threshold') parser.add_argument( '--show-interval', type=int, default=0, help='Sleep seconds per frame') parser.add_argument( '--thickness', type=int, default=1, help='Link thickness for visualization') parser.add_argument( '--radius', type=int, default=3, help='Keypoint radius for visualization') parser.add_argument( '--online', action='store_true', default=False, help='Inference mode. If set to True, can not use future frame' 'information when using multi frames for inference in the 2D pose' 'detection stage. Default: False.') args = parser.parse_args() return args def process_one_image(args, detector, frame, frame_idx, pose_estimator, pose_est_results_last, pose_est_results_list, next_id, pose_lifter, visualize_frame, visualizer): """Visualize detected and predicted keypoints of one image. Pipeline of this function: frame | V +-----------------+ | detector | +-----------------+ | det_result V +-----------------+ | pose_estimator | +-----------------+ | pose_est_results V +--------------------------------------------+ | convert 2d kpts into pose-lifting format | +--------------------------------------------+ | pose_est_results_list V +-----------------------+ | extract_pose_sequence | +-----------------------+ | pose_seq_2d V +-------------+ | pose_lifter | +-------------+ | pose_lift_results V +-----------------+ | post-processing | +-----------------+ | pred_3d_data_samples V +------------+ | visualizer | +------------+ Args: args (Argument): Custom command-line arguments. detector (mmdet.BaseDetector): The mmdet detector. frame (np.ndarray): The image frame read from input image or video. frame_idx (int): The index of current frame. pose_estimator (TopdownPoseEstimator): The pose estimator for 2d pose. pose_est_results_last (list(PoseDataSample)): The results of pose estimation from the last frame for tracking instances. pose_est_results_list (list(list(PoseDataSample))): The list of all pose estimation results converted by ``convert_keypoint_definition`` from previous frames. In pose-lifting stage it is used to obtain the 2d estimation sequence. next_id (int): The next track id to be used. pose_lifter (PoseLifter): The pose-lifter for estimating 3d pose. visualize_frame (np.ndarray): The image for drawing the results on. visualizer (Visualizer): The visualizer for visualizing the 2d and 3d pose estimation results. Returns: pose_est_results (list(PoseDataSample)): The pose estimation result of the current frame. pose_est_results_list (list(list(PoseDataSample))): The list of all converted pose estimation results until the current frame. pred_3d_instances (InstanceData): The result of pose-lifting. Specifically, the predicted keypoints and scores are saved at ``pred_3d_instances.keypoints`` and ``pred_3d_instances.keypoint_scores``. next_id (int): The next track id to be used. """ pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset pose_lift_dataset_name = pose_lifter.dataset_meta['dataset_name'] # First stage: conduct 2D pose detection in a Topdown manner # use detector to obtain person bounding boxes det_result = inference_detector(detector, frame) pred_instance = det_result.pred_instances.cpu().numpy() # filter out the person instances with category and bbox threshold # e.g. 0 for person in COCO bboxes = pred_instance.bboxes bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id, pred_instance.scores > args.bbox_thr)] # estimate pose results for current image pose_est_results = inference_topdown(pose_estimator, frame, bboxes) if args.use_oks_tracking: _track = partial(_track_by_oks) else: _track = _track_by_iou pose_det_dataset_name = pose_estimator.dataset_meta['dataset_name'] pose_est_results_converted = [] # convert 2d pose estimation results into the format for pose-lifting # such as changing the keypoint order, flipping the keypoint, etc. for i, data_sample in enumerate(pose_est_results): pred_instances = data_sample.pred_instances.cpu().numpy() keypoints = pred_instances.keypoints # calculate area and bbox if 'bboxes' in pred_instances: areas = np.array([(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in pred_instances.bboxes]) pose_est_results[i].pred_instances.set_field(areas, 'areas') else: areas, bboxes = [], [] for keypoint in keypoints: xmin = np.min(keypoint[:, 0][keypoint[:, 0] > 0], initial=1e10) xmax = np.max(keypoint[:, 0]) ymin = np.min(keypoint[:, 1][keypoint[:, 1] > 0], initial=1e10) ymax = np.max(keypoint[:, 1]) areas.append((xmax - xmin) * (ymax - ymin)) bboxes.append([xmin, ymin, xmax, ymax]) pose_est_results[i].pred_instances.areas = np.array(areas) pose_est_results[i].pred_instances.bboxes = np.array(bboxes) # track id track_id, pose_est_results_last, _ = _track(data_sample, pose_est_results_last, args.tracking_thr) if track_id == -1: if np.count_nonzero(keypoints[:, :, 1]) >= 3: track_id = next_id next_id += 1 else: # If the number of keypoints detected is small, # delete that person instance. keypoints[:, :, 1] = -10 pose_est_results[i].pred_instances.set_field( keypoints, 'keypoints') pose_est_results[i].pred_instances.set_field( pred_instances.bboxes * 0, 'bboxes') pose_est_results[i].set_field(pred_instances, 'pred_instances') track_id = -1 pose_est_results[i].set_field(track_id, 'track_id') # convert keypoints for pose-lifting pose_est_result_converted = PoseDataSample() pose_est_result_converted.set_field( pose_est_results[i].pred_instances.clone(), 'pred_instances') pose_est_result_converted.set_field( pose_est_results[i].gt_instances.clone(), 'gt_instances') keypoints = convert_keypoint_definition(keypoints, pose_det_dataset_name, pose_lift_dataset_name) pose_est_result_converted.pred_instances.set_field( keypoints, 'keypoints') pose_est_result_converted.set_field(pose_est_results[i].track_id, 'track_id') pose_est_results_converted.append(pose_est_result_converted) pose_est_results_list.append(pose_est_results_converted.copy()) # Second stage: Pose lifting # extract and pad input pose2d sequence pose_seq_2d = extract_pose_sequence( pose_est_results_list, frame_idx=frame_idx, causal=pose_lift_dataset.get('causal', False), seq_len=pose_lift_dataset.get('seq_len', 1), step=pose_lift_dataset.get('seq_step', 1)) # conduct 2D-to-3D pose lifting norm_pose_2d = not args.disable_norm_pose_2d pose_lift_results = inference_pose_lifter_model( pose_lifter, pose_seq_2d, image_size=visualize_frame.shape[:2], norm_pose_2d=norm_pose_2d) # post-processing for idx, pose_lift_result in enumerate(pose_lift_results): pose_lift_result.track_id = pose_est_results[idx].get('track_id', 1e4) pred_instances = pose_lift_result.pred_instances keypoints = pred_instances.keypoints keypoint_scores = pred_instances.keypoint_scores if keypoint_scores.ndim == 3: keypoint_scores = np.squeeze(keypoint_scores, axis=1) pose_lift_results[ idx].pred_instances.keypoint_scores = keypoint_scores if keypoints.ndim == 4: keypoints = np.squeeze(keypoints, axis=1) keypoints = keypoints[..., [0, 2, 1]] keypoints[..., 0] = -keypoints[..., 0] keypoints[..., 2] = -keypoints[..., 2] # rebase height (z-axis) if not args.disable_rebase_keypoint: keypoints[..., 2] -= np.min( keypoints[..., 2], axis=-1, keepdims=True) pose_lift_results[idx].pred_instances.keypoints = keypoints pose_lift_results = sorted( pose_lift_results, key=lambda x: x.get('track_id', 1e4)) pred_3d_data_samples = merge_data_samples(pose_lift_results) det_data_sample = merge_data_samples(pose_est_results) pred_3d_instances = pred_3d_data_samples.get('pred_instances', None) if args.num_instances < 0: args.num_instances = len(pose_lift_results) # Visualization if visualizer is not None: visualizer.add_datasample( 'result', visualize_frame, data_sample=pred_3d_data_samples, det_data_sample=det_data_sample, draw_gt=False, dataset_2d=pose_det_dataset_name, dataset_3d=pose_lift_dataset_name, show=args.show, draw_bbox=True, kpt_thr=args.kpt_thr, num_instances=args.num_instances, wait_time=args.show_interval) return pose_est_results, pose_est_results_list, pred_3d_instances, next_id def main(): assert has_mmdet, 'Please install mmdet to run the demo.' args = parse_args() assert args.show or (args.output_root != '') assert args.input != '' assert args.det_config is not None assert args.det_checkpoint is not None detector = init_detector( args.det_config, args.det_checkpoint, device=args.device.lower()) detector.cfg = adapt_mmdet_pipeline(detector.cfg) pose_estimator = init_model( args.pose_estimator_config, args.pose_estimator_checkpoint, device=args.device.lower()) assert isinstance(pose_estimator, TopdownPoseEstimator), 'Only "TopDown"' \ 'model is supported for the 1st stage (2D pose detection)' det_kpt_color = pose_estimator.dataset_meta.get('keypoint_colors', None) det_dataset_skeleton = pose_estimator.dataset_meta.get( 'skeleton_links', None) det_dataset_link_color = pose_estimator.dataset_meta.get( 'skeleton_link_colors', None) pose_lifter = init_model( args.pose_lifter_config, args.pose_lifter_checkpoint, device=args.device.lower()) assert isinstance(pose_lifter, PoseLifter), \ 'Only "PoseLifter" model is supported for the 2nd stage ' \ '(2D-to-3D lifting)' pose_lifter.cfg.visualizer.radius = args.radius pose_lifter.cfg.visualizer.line_width = args.thickness pose_lifter.cfg.visualizer.det_kpt_color = det_kpt_color pose_lifter.cfg.visualizer.det_dataset_skeleton = det_dataset_skeleton pose_lifter.cfg.visualizer.det_dataset_link_color = det_dataset_link_color visualizer = VISUALIZERS.build(pose_lifter.cfg.visualizer) # the dataset_meta is loaded from the checkpoint visualizer.set_dataset_meta(pose_lifter.dataset_meta) if args.input == 'webcam': input_type = 'webcam' else: input_type = mimetypes.guess_type(args.input)[0].split('/')[0] if args.output_root == '': save_output = False else: mmengine.mkdir_or_exist(args.output_root) output_file = os.path.join(args.output_root, os.path.basename(args.input)) if args.input == 'webcam': output_file += '.mp4' save_output = True if args.save_predictions: assert args.output_root != '' args.pred_save_path = f'{args.output_root}/results_' \ f'{os.path.splitext(os.path.basename(args.input))[0]}.json' if save_output: fourcc = cv2.VideoWriter_fourcc(*'mp4v') pose_est_results_list = [] pred_instances_list = [] if input_type == 'image': frame = mmcv.imread(args.input, channel_order='rgb') _, _, pred_3d_instances, _ = process_one_image( args=args, detector=detector, frame=frame, frame_idx=0, pose_estimator=pose_estimator, pose_est_results_last=[], pose_est_results_list=pose_est_results_list, next_id=0, pose_lifter=pose_lifter, visualize_frame=frame, visualizer=visualizer) if args.save_predictions: # save prediction results pred_instances_list = split_instances(pred_3d_instances) if save_output: frame_vis = visualizer.get_image() mmcv.imwrite(mmcv.rgb2bgr(frame_vis), output_file) elif input_type in ['webcam', 'video']: next_id = 0 pose_est_results = [] if args.input == 'webcam': video = cv2.VideoCapture(0) else: video = cv2.VideoCapture(args.input) (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.') if int(major_ver) < 3: fps = video.get(cv2.cv.CV_CAP_PROP_FPS) else: fps = video.get(cv2.CAP_PROP_FPS) video_writer = None frame_idx = 0 while video.isOpened(): success, frame = video.read() frame_idx += 1 if not success: break pose_est_results_last = pose_est_results # First stage: 2D pose detection # make person results for current image (pose_est_results, pose_est_results_list, pred_3d_instances, next_id) = process_one_image( args=args, detector=detector, frame=frame, frame_idx=frame_idx, pose_estimator=pose_estimator, pose_est_results_last=pose_est_results_last, pose_est_results_list=pose_est_results_list, next_id=next_id, pose_lifter=pose_lifter, visualize_frame=mmcv.bgr2rgb(frame), visualizer=visualizer) if args.save_predictions: # save prediction results pred_instances_list.append( dict( frame_id=frame_idx, instances=split_instances(pred_3d_instances))) if save_output: frame_vis = visualizer.get_image() if video_writer is None: # the size of the image with visualization may vary # depending on the presence of heatmaps video_writer = cv2.VideoWriter(output_file, fourcc, fps, (frame_vis.shape[1], frame_vis.shape[0])) video_writer.write(mmcv.rgb2bgr(frame_vis)) if args.show: # press ESC to exit if cv2.waitKey(5) & 0xFF == 27: break time.sleep(args.show_interval) video.release() if video_writer: video_writer.release() else: args.save_predictions = False raise ValueError( f'file {os.path.basename(args.input)} has invalid format.') if args.save_predictions: with open(args.pred_save_path, 'w') as f: json.dump( dict( meta_info=pose_lifter.dataset_meta, instance_info=pred_instances_list), f, indent='\t') print(f'predictions have been saved at {args.pred_save_path}') if save_output: input_type = input_type.replace('webcam', 'video') print_log( f'the output {input_type} has been saved at {output_file}', logger='current', level=logging.INFO) if __name__ == '__main__': main()