Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import defaultdict | |
from typing import Any, Optional, Sequence, Union | |
from mmengine.evaluator.evaluator import Evaluator | |
from mmengine.evaluator.metric import BaseMetric | |
from mmengine.structures import BaseDataElement | |
from mmpose.datasets.datasets.utils import parse_pose_metainfo | |
from mmpose.registry import DATASETS, EVALUATORS | |
class MultiDatasetEvaluator(Evaluator): | |
"""Wrapper class to compose multiple :class:`BaseMetric` instances. | |
Args: | |
metrics (dict or BaseMetric or Sequence): The configs of metrics. | |
datasets (Sequence[str]): The configs of datasets. | |
""" | |
def __init__( | |
self, | |
metrics: Union[dict, BaseMetric, Sequence], | |
datasets: Sequence[dict], | |
): | |
assert len(metrics) == len(datasets), 'the argument ' \ | |
'datasets should have same length as metrics' | |
super().__init__(metrics) | |
# Initialize metrics for each dataset | |
metrics_dict = dict() | |
for dataset, metric in zip(datasets, self.metrics): | |
metainfo_file = DATASETS.module_dict[dataset['type']].METAINFO | |
dataset_meta = parse_pose_metainfo(metainfo_file) | |
metric.dataset_meta = dataset_meta | |
dataset_name = dataset_meta['dataset_name'] | |
metrics_dict[dataset_name] = metric | |
self.metrics_dict = metrics_dict | |
def dataset_meta(self) -> Optional[dict]: | |
"""Optional[dict]: Meta info of the dataset.""" | |
return self._dataset_meta | |
def dataset_meta(self, dataset_meta: dict) -> None: | |
"""Set the dataset meta info to the evaluator and it's metrics.""" | |
self._dataset_meta = dataset_meta | |
def process(self, | |
data_samples: Sequence[BaseDataElement], | |
data_batch: Optional[Any] = None): | |
"""Convert ``BaseDataSample`` to dict and invoke process method of each | |
metric. | |
Args: | |
data_samples (Sequence[BaseDataElement]): predictions of the model, | |
and the ground truth of the validation set. | |
data_batch (Any, optional): A batch of data from the dataloader. | |
""" | |
_data_samples = defaultdict(list) | |
_data_batch = dict( | |
inputs=defaultdict(list), | |
data_samples=defaultdict(list), | |
) | |
for inputs, data_ds, data_sample in zip(data_batch['inputs'], | |
data_batch['data_samples'], | |
data_samples): | |
if isinstance(data_sample, BaseDataElement): | |
data_sample = data_sample.to_dict() | |
assert isinstance(data_sample, dict) | |
dataset_name = data_sample.get('dataset_name', | |
self.dataset_meta['dataset_name']) | |
_data_samples[dataset_name].append(data_sample) | |
_data_batch['inputs'][dataset_name].append(inputs) | |
_data_batch['data_samples'][dataset_name].append(data_ds) | |
for dataset_name, metric in self.metrics_dict.items(): | |
if dataset_name in _data_samples: | |
data_batch = dict( | |
inputs=_data_batch['inputs'][dataset_name], | |
data_samples=_data_batch['data_samples'][dataset_name]) | |
metric.process(data_batch, _data_samples[dataset_name]) | |
else: | |
continue | |