Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Dict, List, Optional, Sequence, Union | |
import numpy as np | |
import torch | |
from mmengine.config import Config, ConfigDict | |
from mmengine.infer.infer import ModelType | |
from mmengine.structures import InstanceData | |
from rich.progress import track | |
from .base_mmpose_inferencer import BaseMMPoseInferencer | |
from .hand3d_inferencer import Hand3DInferencer | |
from .pose2d_inferencer import Pose2DInferencer | |
from .pose3d_inferencer import Pose3DInferencer | |
InstanceList = List[InstanceData] | |
InputType = Union[str, np.ndarray] | |
InputsType = Union[InputType, Sequence[InputType]] | |
PredType = Union[InstanceData, InstanceList] | |
ImgType = Union[np.ndarray, Sequence[np.ndarray]] | |
ConfigType = Union[Config, ConfigDict] | |
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] | |
class MMPoseInferencer(BaseMMPoseInferencer): | |
"""MMPose Inferencer. It's a unified inferencer interface for pose | |
estimation task, currently including: Pose2D. and it can be used to perform | |
2D keypoint detection. | |
Args: | |
pose2d (str, optional): Pretrained 2D pose estimation algorithm. | |
It's the path to the config file or the model name defined in | |
metafile. For example, it could be: | |
- model alias, e.g. ``'body'``, | |
- config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``, | |
- config path | |
Defaults to ``None``. | |
pose2d_weights (str, optional): Path to the custom checkpoint file of | |
the selected pose2d model. If it is not specified and "pose2d" is | |
a model name of metafile, the weights will be loaded from | |
metafile. Defaults to None. | |
device (str, optional): Device to run inference. If None, the | |
available device will be automatically used. Defaults to None. | |
scope (str, optional): The scope of the model. Defaults to "mmpose". | |
det_model(str, optional): Config path or alias of detection model. | |
Defaults to None. | |
det_weights(str, optional): Path to the checkpoints of detection | |
model. Defaults to None. | |
det_cat_ids(int or list[int], optional): Category id for | |
detection model. Defaults to None. | |
output_heatmaps (bool, optional): Flag to visualize predicted | |
heatmaps. If set to None, the default setting from the model | |
config will be used. Default is None. | |
""" | |
preprocess_kwargs: set = { | |
'bbox_thr', 'nms_thr', 'bboxes', 'use_oks_tracking', 'tracking_thr', | |
'disable_norm_pose_2d' | |
} | |
forward_kwargs: set = { | |
'merge_results', 'disable_rebase_keypoint', 'pose_based_nms' | |
} | |
visualize_kwargs: set = { | |
'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness', | |
'kpt_thr', 'vis_out_dir', 'skeleton_style', 'draw_heatmap', | |
'black_background', 'num_instances' | |
} | |
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} | |
def __init__(self, | |
pose2d: Optional[str] = None, | |
pose2d_weights: Optional[str] = None, | |
pose3d: Optional[str] = None, | |
pose3d_weights: Optional[str] = None, | |
device: Optional[str] = None, | |
scope: str = 'mmpose', | |
det_model: Optional[Union[ModelType, str]] = None, | |
det_weights: Optional[str] = None, | |
det_cat_ids: Optional[Union[int, List]] = None, | |
show_progress: bool = False) -> None: | |
self.visualizer = None | |
self.show_progress = show_progress | |
if pose3d is not None: | |
if 'hand3d' in pose3d: | |
self.inferencer = Hand3DInferencer(pose3d, pose3d_weights, | |
device, scope, det_model, | |
det_weights, det_cat_ids, | |
show_progress) | |
else: | |
self.inferencer = Pose3DInferencer(pose3d, pose3d_weights, | |
pose2d, pose2d_weights, | |
device, scope, det_model, | |
det_weights, det_cat_ids, | |
show_progress) | |
elif pose2d is not None: | |
self.inferencer = Pose2DInferencer(pose2d, pose2d_weights, device, | |
scope, det_model, det_weights, | |
det_cat_ids, show_progress) | |
else: | |
raise ValueError('Either 2d or 3d pose estimation algorithm ' | |
'should be provided.') | |
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): | |
"""Process the inputs into a model-feedable format. | |
Args: | |
inputs (InputsType): Inputs given by user. | |
batch_size (int): batch size. Defaults to 1. | |
Yields: | |
Any: Data processed by the ``pipeline`` and ``collate_fn``. | |
List[str or np.ndarray]: List of original inputs in the batch | |
""" | |
for data in self.inferencer.preprocess(inputs, batch_size, **kwargs): | |
yield data | |
def forward(self, inputs: InputType, **forward_kwargs) -> PredType: | |
"""Forward the inputs to the model. | |
Args: | |
inputs (InputsType): The inputs to be forwarded. | |
Returns: | |
Dict: The prediction results. Possibly with keys "pose2d". | |
""" | |
return self.inferencer.forward(inputs, **forward_kwargs) | |
def __call__( | |
self, | |
inputs: InputsType, | |
return_datasamples: bool = False, | |
batch_size: int = 1, | |
out_dir: Optional[str] = None, | |
**kwargs, | |
) -> dict: | |
"""Call the inferencer. | |
Args: | |
inputs (InputsType): Inputs for the inferencer. | |
return_datasamples (bool): Whether to return results as | |
:obj:`BaseDataElement`. Defaults to False. | |
batch_size (int): Batch size. Defaults to 1. | |
out_dir (str, optional): directory to save visualization | |
results and predictions. Will be overoden if vis_out_dir or | |
pred_out_dir are given. Defaults to None | |
**kwargs: Key words arguments passed to :meth:`preprocess`, | |
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`. | |
Each key in kwargs should be in the corresponding set of | |
``preprocess_kwargs``, ``forward_kwargs``, | |
``visualize_kwargs`` and ``postprocess_kwargs``. | |
Returns: | |
dict: Inference and visualization results. | |
""" | |
if out_dir is not None: | |
if 'vis_out_dir' not in kwargs: | |
kwargs['vis_out_dir'] = f'{out_dir}/visualizations' | |
if 'pred_out_dir' not in kwargs: | |
kwargs['pred_out_dir'] = f'{out_dir}/predictions' | |
kwargs = { | |
key: value | |
for key, value in kwargs.items() | |
if key in set.union(self.inferencer.preprocess_kwargs, | |
self.inferencer.forward_kwargs, | |
self.inferencer.visualize_kwargs, | |
self.inferencer.postprocess_kwargs) | |
} | |
( | |
preprocess_kwargs, | |
forward_kwargs, | |
visualize_kwargs, | |
postprocess_kwargs, | |
) = self._dispatch_kwargs(**kwargs) | |
self.inferencer.update_model_visualizer_settings(**kwargs) | |
# preprocessing | |
if isinstance(inputs, str) and inputs.startswith('webcam'): | |
inputs = self.inferencer._get_webcam_inputs(inputs) | |
batch_size = 1 | |
if not visualize_kwargs.get('show', False): | |
warnings.warn('The display mode is closed when using webcam ' | |
'input. It will be turned on automatically.') | |
visualize_kwargs['show'] = True | |
else: | |
inputs = self.inferencer._inputs_to_list(inputs) | |
self._video_input = self.inferencer._video_input | |
if self._video_input: | |
self.video_info = self.inferencer.video_info | |
inputs = self.preprocess( | |
inputs, batch_size=batch_size, **preprocess_kwargs) | |
# forward | |
if 'bbox_thr' in self.inferencer.forward_kwargs: | |
forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1) | |
preds = [] | |
for proc_inputs, ori_inputs in (track(inputs, description='Inference') | |
if self.show_progress else inputs): | |
preds = self.forward(proc_inputs, **forward_kwargs) | |
visualization = self.visualize(ori_inputs, preds, | |
**visualize_kwargs) | |
results = self.postprocess( | |
preds, | |
visualization, | |
return_datasamples=return_datasamples, | |
**postprocess_kwargs) | |
yield results | |
if self._video_input: | |
self._finalize_video_processing( | |
postprocess_kwargs.get('pred_out_dir', '')) | |
def visualize(self, inputs: InputsType, preds: PredType, | |
**kwargs) -> List[np.ndarray]: | |
"""Visualize predictions. | |
Args: | |
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. | |
preds (Any): Predictions of the model. | |
return_vis (bool): Whether to return images with predicted results. | |
show (bool): Whether to display the image in a popup window. | |
Defaults to False. | |
show_interval (int): The interval of show (s). Defaults to 0 | |
radius (int): Keypoint radius for visualization. Defaults to 3 | |
thickness (int): Link thickness for visualization. Defaults to 1 | |
kpt_thr (float): The threshold to visualize the keypoints. | |
Defaults to 0.3 | |
vis_out_dir (str, optional): directory to save visualization | |
results w/o predictions. If left as empty, no file will | |
be saved. Defaults to ''. | |
Returns: | |
List[np.ndarray]: Visualization results. | |
""" | |
window_name = '' | |
if self.inferencer._video_input: | |
window_name = self.inferencer.video_info['name'] | |
return self.inferencer.visualize( | |
inputs, preds, window_name=window_name, **kwargs) | |