|
import random
|
|
import numpy as np
|
|
import torch
|
|
from .tools import parse_info_name
|
|
from ..utils.tensors import collate
|
|
from ..utils.misc import to_torch
|
|
import src.utils.rotation_conversions as geometry
|
|
|
|
POSE_REPS = ["xyz", "rotvec", "rotmat", "rotquat", "rot6d"]
|
|
UNSUPERVISED_BABEL_ACTION_CAT_LABELS_IDXS = [48, 50, 28, 38, 52, 11, 29, 19, 51, 22, 14, 21, 26, 10, 24]
|
|
from src.utils.action_label_to_idx import action_label_to_idx
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset):
|
|
def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train",
|
|
pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs):
|
|
self.num_frames = num_frames
|
|
self.sampling = sampling
|
|
self.sampling_step = sampling_step
|
|
self.split = split
|
|
self.pose_rep = pose_rep
|
|
self.translation = translation
|
|
self.glob = glob
|
|
self.max_len = max_len
|
|
self.min_len = min_len
|
|
self.num_seq_max = num_seq_max
|
|
|
|
self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False)
|
|
self.only_60_classes = kwargs.get('only_60_classes', False)
|
|
self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False)
|
|
self.use_only_15_classes = kwargs.get('use_only_15_classes', False)
|
|
|
|
if self.split not in ["train", "val", "test"]:
|
|
raise ValueError(f"{self.split} is not a valid split")
|
|
|
|
super().__init__()
|
|
|
|
|
|
self._original_train = None
|
|
self._original_test = None
|
|
|
|
def action_to_label(self, action):
|
|
return self._action_to_label[action]
|
|
|
|
def label_to_action(self, label):
|
|
import numbers
|
|
if isinstance(label, numbers.Integral):
|
|
return self._label_to_action[label]
|
|
else:
|
|
label = np.argmax(label)
|
|
return self._label_to_action[label]
|
|
|
|
def get_pose_data(self, data_index, frame_ix):
|
|
pose = self._load(data_index, frame_ix)
|
|
label = self.get_label(data_index)
|
|
return pose, label
|
|
|
|
def get_clip_image(self, ind):
|
|
clip_image = self._clip_images[ind]
|
|
return clip_image
|
|
|
|
def get_clip_path(self, ind):
|
|
clip_path = self._clip_pathes[ind]
|
|
return clip_path
|
|
|
|
def get_clip_text(self, ind, frame_ix):
|
|
clip_text = self._clip_texts[ind][frame_ix]
|
|
return clip_text
|
|
|
|
def get_clip_action_cat(self, ind, frame_ix):
|
|
actions_cat = self._actions_cat[ind][frame_ix]
|
|
return actions_cat
|
|
|
|
def get_label(self, ind):
|
|
action = self.get_action(ind)
|
|
return self.action_to_label(action)
|
|
|
|
def parse_action(self, path, return_int=True):
|
|
info = parse_info_name(path)["A"]
|
|
if return_int:
|
|
return int(info)
|
|
else:
|
|
return info
|
|
|
|
def get_action(self, ind):
|
|
return self._actions[ind]
|
|
|
|
def action_to_action_name(self, action):
|
|
return self._action_classes[action]
|
|
|
|
def label_to_action_name(self, label):
|
|
action = self.label_to_action(label)
|
|
return self.action_to_action_name(action)
|
|
|
|
def __getitem__(self, index):
|
|
if self.split == 'train':
|
|
data_index = self._train[index]
|
|
else:
|
|
data_index = self._test[index]
|
|
|
|
return self._get_item_data_index(data_index)
|
|
|
|
def _load(self, ind, frame_ix):
|
|
pose_rep = self.pose_rep
|
|
if pose_rep == "xyz" or self.translation:
|
|
if getattr(self, "_load_joints3D", None) is not None:
|
|
|
|
joints3D = self._load_joints3D(ind, frame_ix)
|
|
joints3D = joints3D - joints3D[0, 0, :]
|
|
ret = to_torch(joints3D)
|
|
if self.translation:
|
|
ret_tr = ret[:, 0, :]
|
|
else:
|
|
if pose_rep == "xyz":
|
|
raise ValueError("This representation is not possible.")
|
|
if getattr(self, "_load_translation") is None:
|
|
raise ValueError("Can't extract translations.")
|
|
ret_tr = self._load_translation(ind, frame_ix)
|
|
ret_tr = to_torch(ret_tr - ret_tr[0])
|
|
|
|
if pose_rep != "xyz":
|
|
if getattr(self, "_load_rotvec", None) is None:
|
|
raise ValueError("This representation is not possible.")
|
|
else:
|
|
pose = self._load_rotvec(ind, frame_ix)
|
|
if not self.glob:
|
|
pose = pose[:, 1:, :]
|
|
pose = to_torch(pose)
|
|
if pose_rep == "rotvec":
|
|
ret = pose
|
|
elif pose_rep == "rotmat":
|
|
ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9)
|
|
elif pose_rep == "rotquat":
|
|
ret = geometry.axis_angle_to_quaternion(pose)
|
|
elif pose_rep == "rot6d":
|
|
ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose))
|
|
if pose_rep != "xyz" and self.translation:
|
|
padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype)
|
|
padded_tr[:, :3] = ret_tr
|
|
ret = torch.cat((ret, padded_tr[:, None]), 1)
|
|
ret = ret.permute(1, 2, 0).contiguous()
|
|
return ret.float()
|
|
|
|
def _get_item_data_index(self, data_index):
|
|
nframes = self._num_frames_in_video[data_index]
|
|
|
|
if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len):
|
|
frame_ix = np.arange(nframes)
|
|
else:
|
|
if self.num_frames == -2:
|
|
if self.min_len <= 0:
|
|
raise ValueError("You should put a min_len > 0 for num_frames == -2 mode")
|
|
if self.max_len != -1:
|
|
max_frame = min(nframes, self.max_len)
|
|
else:
|
|
max_frame = nframes
|
|
|
|
num_frames = random.randint(self.min_len, max(max_frame, self.min_len))
|
|
else:
|
|
num_frames = self.num_frames if self.num_frames != -1 else self.max_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_frames > nframes:
|
|
fair = False
|
|
if fair:
|
|
|
|
choices = np.random.choice(range(nframes),
|
|
num_frames,
|
|
replace=True)
|
|
frame_ix = sorted(choices)
|
|
else:
|
|
|
|
ntoadd = max(0, num_frames - nframes)
|
|
lastframe = nframes - 1
|
|
padding = lastframe * np.ones(ntoadd, dtype=int)
|
|
frame_ix = np.concatenate((np.arange(0, nframes),
|
|
padding))
|
|
|
|
elif self.sampling in ["conseq", "random_conseq"]:
|
|
step_max = (nframes - 1) // (num_frames - 1)
|
|
if self.sampling == "conseq":
|
|
if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes:
|
|
step = step_max
|
|
else:
|
|
step = self.sampling_step
|
|
elif self.sampling == "random_conseq":
|
|
step = random.randint(1, step_max)
|
|
|
|
lastone = step * (num_frames - 1)
|
|
shift_max = nframes - lastone - 1
|
|
shift = random.randint(0, max(0, shift_max - 1))
|
|
frame_ix = shift + np.arange(0, lastone + 1, step)
|
|
|
|
elif self.sampling == "random":
|
|
choices = np.random.choice(range(nframes),
|
|
num_frames,
|
|
replace=False)
|
|
frame_ix = sorted(choices)
|
|
|
|
else:
|
|
raise ValueError("Sampling not recognized.")
|
|
|
|
inp, target = self.get_pose_data(data_index, frame_ix)
|
|
|
|
output = {'inp': inp, 'target': target}
|
|
if hasattr(self, 'db') and 'clip_images' in self.db.keys():
|
|
output['clip_image'] = self.get_clip_image(data_index)
|
|
|
|
if hasattr(self, 'db') and 'clip_pathes' in self.db.keys():
|
|
output['clip_path'] = self.get_clip_path(data_index)
|
|
|
|
if hasattr(self, 'db') and self.clip_label_text in self.db.keys():
|
|
text_labels = self.get_clip_text(data_index, frame_ix)
|
|
text_labels = " and ".join(list(np.unique(text_labels)))
|
|
output['clip_text'] = text_labels
|
|
|
|
if hasattr(self, 'db') and 'action_cat' in self.db.keys() and self.use_action_cat_as_text_labels:
|
|
categories = self.get_clip_action_cat(data_index, frame_ix)
|
|
unique_cats = np.unique(categories)
|
|
all_valid_cats = []
|
|
for multi_cats in unique_cats:
|
|
for cat in multi_cats.split(","):
|
|
if cat not in action_label_to_idx:
|
|
continue
|
|
cat_idx = action_label_to_idx[cat]
|
|
if (cat_idx >= 120) or (self.only_60_classes and cat_idx >= 60) or (self.leave_out_15_classes and cat_idx in UNSUPERVISED_BABEL_ACTION_CAT_LABELS_IDXS):
|
|
continue
|
|
if self.use_only_15_classes and (cat_idx not in UNSUPERVISED_BABEL_ACTION_CAT_LABELS_IDXS):
|
|
continue
|
|
all_valid_cats.extend([cat])
|
|
|
|
if len(all_valid_cats) == 0:
|
|
return None
|
|
|
|
choosen_cat = np.random.choice(all_valid_cats, size=1)[0]
|
|
|
|
output['clip_text'] = choosen_cat
|
|
output['y'] = action_label_to_idx[choosen_cat]
|
|
output['all_categories'] = all_valid_cats
|
|
|
|
return output
|
|
|
|
def get_label_sample(self, label, n=1, return_labels=False, return_index=False):
|
|
if self.split == 'train':
|
|
index = self._train
|
|
else:
|
|
index = self._test
|
|
|
|
action = self.label_to_action(label)
|
|
choices = np.argwhere(np.array(self._actions)[index] == action).squeeze(1)
|
|
|
|
if self.dataname == 'amass':
|
|
if n == 1:
|
|
while True:
|
|
idx = np.random.randint(0, len(self))
|
|
data = self._get_item_data_index(idx)
|
|
if data is None:
|
|
continue
|
|
x, y = data['inp'], data['target']
|
|
if y == label:
|
|
break
|
|
else:
|
|
x = []
|
|
data_index = []
|
|
while len(x) < n:
|
|
idx = np.random.randint(0, len(self))
|
|
data = self._get_item_data_index(idx)
|
|
x_inp, y = data['inp'], data['target']
|
|
if y == label:
|
|
x.append(x_inp)
|
|
data_index.append(idx)
|
|
x = np.stack(x)
|
|
y = label * np.ones(n, dtype=int)
|
|
else:
|
|
if n == 1:
|
|
data_index = index[np.random.choice(choices)]
|
|
data = self._get_item_data_index(data_index)
|
|
x, y = data['inp'], data['target']
|
|
assert (label == y)
|
|
y = label
|
|
else:
|
|
data_index = np.random.choice(choices, n)
|
|
x = np.stack([self._get_item_data_index(index[di])['inp'] for di in data_index])
|
|
y = label * np.ones(n, dtype=int)
|
|
if return_labels:
|
|
if return_index:
|
|
return x, y, data_index
|
|
return x, y
|
|
else:
|
|
if return_index:
|
|
return x, data_index
|
|
return x
|
|
|
|
def get_label_sample_batch(self, labels):
|
|
samples = [self.get_label_sample(label, n=1, return_labels=True, return_index=False) for label in labels]
|
|
samples = [{'inp': x[0], 'target': x[1]} for x in samples]
|
|
batch = collate(samples)
|
|
x = batch["x"]
|
|
mask = batch["mask"]
|
|
lengths = mask.sum(1)
|
|
return x, mask, lengths
|
|
|
|
def get_mean_length_label(self, label):
|
|
if self.num_frames != -1:
|
|
return self.num_frames
|
|
|
|
if self.split == 'train':
|
|
index = self._train
|
|
else:
|
|
index = self._test
|
|
|
|
action = self.label_to_action(label)
|
|
choices = np.argwhere(self._actions[index] == action).squeeze(1)
|
|
lengths = self._num_frames_in_video[np.array(index)[choices]]
|
|
|
|
if self.max_len == -1:
|
|
return np.mean(lengths)
|
|
else:
|
|
|
|
lengths[lengths > self.max_len] = self.max_len
|
|
return np.mean(lengths)
|
|
|
|
def get_stats(self):
|
|
if self.split == 'train':
|
|
index = self._train
|
|
else:
|
|
index = self._test
|
|
|
|
numframes = self._num_frames_in_video[index]
|
|
allmeans = np.array([self.get_mean_length_label(x) for x in range(self.num_classes)])
|
|
|
|
stats = {"name": self.dataname,
|
|
"number of classes": self.num_classes,
|
|
"number of sequences": len(self),
|
|
"duration: min": int(numframes.min()),
|
|
"duration: max": int(numframes.max()),
|
|
"duration: mean": int(numframes.mean()),
|
|
"duration mean/action: min": int(allmeans.min()),
|
|
"duration mean/action: max": int(allmeans.max()),
|
|
"duration mean/action: mean": int(allmeans.mean())}
|
|
return stats
|
|
|
|
def __len__(self):
|
|
num_seq_max = getattr(self, "num_seq_max", -1)
|
|
if num_seq_max == -1:
|
|
from math import inf
|
|
num_seq_max = inf
|
|
|
|
if self.split == 'train':
|
|
return min(len(self._train), num_seq_max)
|
|
else:
|
|
return min(len(self._test), num_seq_max)
|
|
|
|
def __repr__(self):
|
|
return f"{self.dataname} dataset: ({len(self)}, _, ..)"
|
|
|
|
def update_parameters(self, parameters):
|
|
for i in range(len(self)):
|
|
if self[i] is not None:
|
|
self.njoints, self.nfeats, _ = self[i]['inp'].shape
|
|
break
|
|
parameters["num_classes"] = self.num_classes
|
|
parameters["nfeats"] = self.nfeats
|
|
parameters["njoints"] = self.njoints
|
|
|
|
def shuffle(self):
|
|
if self.split == 'train':
|
|
random.shuffle(self._train)
|
|
else:
|
|
random.shuffle(self._test)
|
|
|
|
def reset_shuffle(self):
|
|
if self.split == 'train':
|
|
if self._original_train is None:
|
|
self._original_train = self._train
|
|
else:
|
|
self._train = self._original_train
|
|
else:
|
|
if self._original_test is None:
|
|
self._original_test = self._test
|
|
else:
|
|
self._test = self._original_test
|
|
|