import random import numpy as np import torch from .tools import parse_info_name from ..utils.tensors import collate from ..utils.misc import to_torch import src.utils.rotation_conversions as geometry POSE_REPS = ["xyz", "rotvec", "rotmat", "rotquat", "rot6d"] UNSUPERVISED_BABEL_ACTION_CAT_LABELS_IDXS = [48, 50, 28, 38, 52, 11, 29, 19, 51, 22, 14, 21, 26, 10, 24] from src.utils.action_label_to_idx import action_label_to_idx class Dataset(torch.utils.data.Dataset): def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train", pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs): self.num_frames = num_frames self.sampling = sampling self.sampling_step = sampling_step self.split = split self.pose_rep = pose_rep self.translation = translation self.glob = glob self.max_len = max_len self.min_len = min_len self.num_seq_max = num_seq_max self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False) self.only_60_classes = kwargs.get('only_60_classes', False) self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False) self.use_only_15_classes = kwargs.get('use_only_15_classes', False) if self.split not in ["train", "val", "test"]: raise ValueError(f"{self.split} is not a valid split") super().__init__() # to remove shuffling self._original_train = None self._original_test = None def action_to_label(self, action): return self._action_to_label[action] def label_to_action(self, label): import numbers if isinstance(label, numbers.Integral): return self._label_to_action[label] else: # if it is one hot vector label = np.argmax(label) return self._label_to_action[label] def get_pose_data(self, data_index, frame_ix): pose = self._load(data_index, frame_ix) label = self.get_label(data_index) return pose, label def get_clip_image(self, ind): clip_image = self._clip_images[ind] return clip_image def get_clip_path(self, ind): clip_path = self._clip_pathes[ind] return clip_path def get_clip_text(self, ind, frame_ix): clip_text = self._clip_texts[ind][frame_ix] return clip_text def get_clip_action_cat(self, ind, frame_ix): actions_cat = self._actions_cat[ind][frame_ix] return actions_cat def get_label(self, ind): action = self.get_action(ind) return self.action_to_label(action) def parse_action(self, path, return_int=True): info = parse_info_name(path)["A"] if return_int: return int(info) else: return info def get_action(self, ind): return self._actions[ind] def action_to_action_name(self, action): return self._action_classes[action] def label_to_action_name(self, label): action = self.label_to_action(label) return self.action_to_action_name(action) def __getitem__(self, index): if self.split == 'train': data_index = self._train[index] else: data_index = self._test[index] return self._get_item_data_index(data_index) def _load(self, ind, frame_ix): pose_rep = self.pose_rep if pose_rep == "xyz" or self.translation: if getattr(self, "_load_joints3D", None) is not None: # Locate the root joint of initial pose at origin joints3D = self._load_joints3D(ind, frame_ix) joints3D = joints3D - joints3D[0, 0, :] ret = to_torch(joints3D) if self.translation: ret_tr = ret[:, 0, :] else: if pose_rep == "xyz": raise ValueError("This representation is not possible.") if getattr(self, "_load_translation") is None: raise ValueError("Can't extract translations.") ret_tr = self._load_translation(ind, frame_ix) ret_tr = to_torch(ret_tr - ret_tr[0]) if pose_rep != "xyz": if getattr(self, "_load_rotvec", None) is None: raise ValueError("This representation is not possible.") else: pose = self._load_rotvec(ind, frame_ix) if not self.glob: pose = pose[:, 1:, :] pose = to_torch(pose) if pose_rep == "rotvec": ret = pose elif pose_rep == "rotmat": ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9) elif pose_rep == "rotquat": ret = geometry.axis_angle_to_quaternion(pose) elif pose_rep == "rot6d": ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose)) if pose_rep != "xyz" and self.translation: padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype) padded_tr[:, :3] = ret_tr ret = torch.cat((ret, padded_tr[:, None]), 1) ret = ret.permute(1, 2, 0).contiguous() return ret.float() def _get_item_data_index(self, data_index): nframes = self._num_frames_in_video[data_index] if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len): frame_ix = np.arange(nframes) else: if self.num_frames == -2: if self.min_len <= 0: raise ValueError("You should put a min_len > 0 for num_frames == -2 mode") if self.max_len != -1: max_frame = min(nframes, self.max_len) else: max_frame = nframes num_frames = random.randint(self.min_len, max(max_frame, self.min_len)) else: num_frames = self.num_frames if self.num_frames != -1 else self.max_len # sampling goal: input: ----------- 11 nframes # o--o--o--o- 4 ninputs # # step number is computed like that: [(11-1)/(4-1)] = 3 # [---][---][---][- # So step = 3, and we take 0 to step*ninputs+1 with steps # [o--][o--][o--][o-] # then we can randomly shift the vector # -[o--][o--][o--]o # If there are too much frames required if num_frames > nframes: fair = False # True if fair: # distills redundancy everywhere choices = np.random.choice(range(nframes), num_frames, replace=True) frame_ix = sorted(choices) else: # adding the last frame until done ntoadd = max(0, num_frames - nframes) lastframe = nframes - 1 padding = lastframe * np.ones(ntoadd, dtype=int) frame_ix = np.concatenate((np.arange(0, nframes), padding)) elif self.sampling in ["conseq", "random_conseq"]: step_max = (nframes - 1) // (num_frames - 1) if self.sampling == "conseq": if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes: step = step_max else: step = self.sampling_step elif self.sampling == "random_conseq": step = random.randint(1, step_max) lastone = step * (num_frames - 1) shift_max = nframes - lastone - 1 shift = random.randint(0, max(0, shift_max - 1)) frame_ix = shift + np.arange(0, lastone + 1, step) elif self.sampling == "random": choices = np.random.choice(range(nframes), num_frames, replace=False) frame_ix = sorted(choices) else: raise ValueError("Sampling not recognized.") inp, target = self.get_pose_data(data_index, frame_ix) output = {'inp': inp, 'target': target} if hasattr(self, 'db') and 'clip_images' in self.db.keys(): output['clip_image'] = self.get_clip_image(data_index) if hasattr(self, 'db') and 'clip_pathes' in self.db.keys(): output['clip_path'] = self.get_clip_path(data_index) if hasattr(self, 'db') and self.clip_label_text in self.db.keys(): text_labels = self.get_clip_text(data_index, frame_ix) text_labels = " and ".join(list(np.unique(text_labels))) output['clip_text'] = text_labels if hasattr(self, 'db') and 'action_cat' in self.db.keys() and self.use_action_cat_as_text_labels: categories = self.get_clip_action_cat(data_index, frame_ix) unique_cats = np.unique(categories) all_valid_cats = [] for multi_cats in unique_cats: for cat in multi_cats.split(","): if cat not in action_label_to_idx: continue cat_idx = action_label_to_idx[cat] if (cat_idx >= 120) or (self.only_60_classes and cat_idx >= 60) or (self.leave_out_15_classes and cat_idx in UNSUPERVISED_BABEL_ACTION_CAT_LABELS_IDXS): continue if self.use_only_15_classes and (cat_idx not in UNSUPERVISED_BABEL_ACTION_CAT_LABELS_IDXS): continue all_valid_cats.extend([cat]) if len(all_valid_cats) == 0: # No valid category available return None choosen_cat = np.random.choice(all_valid_cats, size=1)[0] # Replace clip text output['clip_text'] = choosen_cat output['y'] = action_label_to_idx[choosen_cat] output['all_categories'] = all_valid_cats return output def get_label_sample(self, label, n=1, return_labels=False, return_index=False): if self.split == 'train': index = self._train else: index = self._test action = self.label_to_action(label) choices = np.argwhere(np.array(self._actions)[index] == action).squeeze(1) if self.dataname == 'amass': if n == 1: while True: idx = np.random.randint(0, len(self)) data = self._get_item_data_index(idx) if data is None: continue x, y = data['inp'], data['target'] if y == label: break else: x = [] data_index = [] while len(x) < n: idx = np.random.randint(0, len(self)) data = self._get_item_data_index(idx) x_inp, y = data['inp'], data['target'] if y == label: x.append(x_inp) data_index.append(idx) x = np.stack(x) y = label * np.ones(n, dtype=int) else: if n == 1: data_index = index[np.random.choice(choices)] data = self._get_item_data_index(data_index) x, y = data['inp'], data['target'] assert (label == y) y = label else: data_index = np.random.choice(choices, n) x = np.stack([self._get_item_data_index(index[di])['inp'] for di in data_index]) y = label * np.ones(n, dtype=int) if return_labels: if return_index: return x, y, data_index return x, y else: if return_index: return x, data_index return x def get_label_sample_batch(self, labels): samples = [self.get_label_sample(label, n=1, return_labels=True, return_index=False) for label in labels] samples = [{'inp': x[0], 'target': x[1]} for x in samples] # Fix this to adapt new collate func batch = collate(samples) x = batch["x"] mask = batch["mask"] lengths = mask.sum(1) return x, mask, lengths def get_mean_length_label(self, label): if self.num_frames != -1: return self.num_frames if self.split == 'train': index = self._train else: index = self._test action = self.label_to_action(label) choices = np.argwhere(self._actions[index] == action).squeeze(1) lengths = self._num_frames_in_video[np.array(index)[choices]] if self.max_len == -1: return np.mean(lengths) else: # make the lengths less than max_len lengths[lengths > self.max_len] = self.max_len return np.mean(lengths) def get_stats(self): if self.split == 'train': index = self._train else: index = self._test numframes = self._num_frames_in_video[index] allmeans = np.array([self.get_mean_length_label(x) for x in range(self.num_classes)]) stats = {"name": self.dataname, "number of classes": self.num_classes, "number of sequences": len(self), "duration: min": int(numframes.min()), "duration: max": int(numframes.max()), "duration: mean": int(numframes.mean()), "duration mean/action: min": int(allmeans.min()), "duration mean/action: max": int(allmeans.max()), "duration mean/action: mean": int(allmeans.mean())} return stats def __len__(self): num_seq_max = getattr(self, "num_seq_max", -1) if num_seq_max == -1: from math import inf num_seq_max = inf if self.split == 'train': return min(len(self._train), num_seq_max) else: return min(len(self._test), num_seq_max) def __repr__(self): return f"{self.dataname} dataset: ({len(self)}, _, ..)" def update_parameters(self, parameters): for i in range(len(self)): if self[i] is not None: self.njoints, self.nfeats, _ = self[i]['inp'].shape break parameters["num_classes"] = self.num_classes parameters["nfeats"] = self.nfeats parameters["njoints"] = self.njoints def shuffle(self): if self.split == 'train': random.shuffle(self._train) else: random.shuffle(self._test) def reset_shuffle(self): if self.split == 'train': if self._original_train is None: self._original_train = self._train else: self._train = self._original_train else: if self._original_test is None: self._original_test = self._test else: self._test = self._original_test