Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import os.path as osp | |
| from collections import defaultdict | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| from mmengine.dataset import BaseDataset | |
| from mmengine.utils import check_file_exist | |
| from mmdet.registry import DATASETS | |
| class ReIDDataset(BaseDataset): | |
| """Dataset for ReID. | |
| Args: | |
| triplet_sampler (dict, optional): The sampler for hard mining | |
| triplet loss. Defaults to None. | |
| keys: num_ids (int): The number of person ids. | |
| ins_per_id (int): The number of image for each person. | |
| """ | |
| def __init__(self, triplet_sampler: dict = None, *args, **kwargs): | |
| self.triplet_sampler = triplet_sampler | |
| super().__init__(*args, **kwargs) | |
| def load_data_list(self) -> List[dict]: | |
| """Load annotations from an annotation file named as ''self.ann_file''. | |
| Returns: | |
| list[dict]: A list of annotation. | |
| """ | |
| assert isinstance(self.ann_file, str) | |
| check_file_exist(self.ann_file) | |
| data_list = [] | |
| with open(self.ann_file) as f: | |
| samples = [x.strip().split(' ') for x in f.readlines()] | |
| for filename, gt_label in samples: | |
| info = dict(img_prefix=self.data_prefix) | |
| if self.data_prefix['img_path'] is not None: | |
| info['img_path'] = osp.join(self.data_prefix['img_path'], | |
| filename) | |
| else: | |
| info['img_path'] = filename | |
| info['gt_label'] = np.array(gt_label, dtype=np.int64) | |
| data_list.append(info) | |
| self._parse_ann_info(data_list) | |
| return data_list | |
| def _parse_ann_info(self, data_list: List[dict]): | |
| """Parse person id annotations.""" | |
| index_tmp_dic = defaultdict(list) # pid->[idx1,...,idxN] | |
| self.index_dic = dict() # pid->array([idx1,...,idxN]) | |
| for idx, info in enumerate(data_list): | |
| pid = info['gt_label'] | |
| index_tmp_dic[int(pid)].append(idx) | |
| for pid, idxs in index_tmp_dic.items(): | |
| self.index_dic[pid] = np.asarray(idxs, dtype=np.int64) | |
| self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64) | |
| def prepare_data(self, idx: int) -> Any: | |
| """Get data processed by ''self.pipeline''. | |
| Args: | |
| idx (int): The index of ''data_info'' | |
| Returns: | |
| Any: Depends on ''self.pipeline'' | |
| """ | |
| data_info = self.get_data_info(idx) | |
| if self.triplet_sampler is not None: | |
| img_info = self.triplet_sampling(data_info['gt_label'], | |
| **self.triplet_sampler) | |
| data_info = copy.deepcopy(img_info) # triplet -> list | |
| else: | |
| data_info = copy.deepcopy(data_info) # no triplet -> dict | |
| return self.pipeline(data_info) | |
| def triplet_sampling(self, | |
| pos_pid, | |
| num_ids: int = 8, | |
| ins_per_id: int = 4) -> Dict: | |
| """Triplet sampler for hard mining triplet loss. First, for one | |
| pos_pid, random sample ins_per_id images with same person id. | |
| Then, random sample num_ids - 1 images for each negative id. | |
| Finally, random sample ins_per_id images for each negative id. | |
| Args: | |
| pos_pid (ndarray): The person id of the anchor. | |
| num_ids (int): The number of person ids. | |
| ins_per_id (int): The number of images for each person. | |
| Returns: | |
| Dict: Annotation information of num_ids X ins_per_id images. | |
| """ | |
| assert len(self.pids) >= num_ids, \ | |
| 'The number of person ids in the training set must ' \ | |
| 'be greater than the number of person ids in the sample.' | |
| pos_idxs = self.index_dic[int( | |
| pos_pid)] # all positive idxs for pos_pid | |
| idxs_list = [] | |
| # select positive samplers | |
| idxs_list.extend(pos_idxs[np.random.choice( | |
| pos_idxs.shape[0], ins_per_id, replace=True)]) | |
| # select negative ids | |
| neg_pids = np.random.choice( | |
| [i for i, _ in enumerate(self.pids) if i != pos_pid], | |
| num_ids - 1, | |
| replace=False) | |
| # select negative samplers for each negative id | |
| for neg_pid in neg_pids: | |
| neg_idxs = self.index_dic[neg_pid] | |
| idxs_list.extend(neg_idxs[np.random.choice( | |
| neg_idxs.shape[0], ins_per_id, replace=True)]) | |
| # return the final triplet batch | |
| triplet_img_infos = [] | |
| for idx in idxs_list: | |
| triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx))) | |
| # Collect data_list scatters (list of dict -> dict of list) | |
| out = dict() | |
| for key in triplet_img_infos[0].keys(): | |
| out[key] = [_info[key] for _info in triplet_img_infos] | |
| return out | |