from src.utils.typing_utils import * import json import os import random import accelerate import torch from torchvision import transforms import numpy as np from PIL import Image from tqdm import tqdm from src.utils.data_utils import load_surface, load_surfaces class ObjaversePartDataset(torch.utils.data.Dataset): def __init__( self, configs: DictConfig, training: bool = True, ): super().__init__() self.configs = configs self.training = training self.min_num_parts = configs['dataset']['min_num_parts'] self.max_num_parts = configs['dataset']['max_num_parts'] self.val_min_num_parts = configs['val']['min_num_parts'] self.val_max_num_parts = configs['val']['max_num_parts'] self.max_iou_mean = configs['dataset'].get('max_iou_mean', None) self.max_iou_max = configs['dataset'].get('max_iou_max', None) self.shuffle_parts = configs['dataset']['shuffle_parts'] self.training_ratio = configs['dataset']['training_ratio'] self.balance_object_and_parts = configs['dataset'].get('balance_object_and_parts', False) self.rotating_ratio = configs['dataset'].get('rotating_ratio', 0.0) self.rotating_degree = configs['dataset'].get('rotating_degree', 10.0) self.transform = transforms.Compose([ transforms.RandomRotation(degrees=(-self.rotating_degree, self.rotating_degree), fill=(255, 255, 255)), ]) if isinstance(configs['dataset']['config'], ListConfig): data_configs = [] for config in configs['dataset']['config']: local_data_configs = json.load(open(config)) if self.balance_object_and_parts: if self.training: local_data_configs = local_data_configs[:int(len(local_data_configs) * self.training_ratio)] else: local_data_configs = local_data_configs[int(len(local_data_configs) * self.training_ratio):] local_data_configs = [config for config in local_data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts] data_configs += local_data_configs else: data_configs = json.load(open(configs['dataset']['config'])) data_configs = [config for config in data_configs if config['valid']] data_configs = [config for config in data_configs if self.min_num_parts <= config['num_parts'] <= self.max_num_parts] if self.max_iou_mean is not None and self.max_iou_max is not None: data_configs = [config for config in data_configs if config['iou_mean'] <= self.max_iou_mean] data_configs = [config for config in data_configs if config['iou_max'] <= self.max_iou_max] if not self.balance_object_and_parts: if self.training: data_configs = data_configs[:int(len(data_configs) * self.training_ratio)] else: data_configs = data_configs[int(len(data_configs) * self.training_ratio):] data_configs = [config for config in data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts] self.data_configs = data_configs self.image_size = (512, 512) def __len__(self) -> int: return len(self.data_configs) def _get_data_by_config(self, data_config): if 'surface_path' in data_config: surface_path = data_config['surface_path'] surface_data = np.load(surface_path, allow_pickle=True).item() # If parts is empty, the object is the only part part_surfaces = surface_data['parts'] if len(surface_data['parts']) > 0 else [surface_data['object']] if self.shuffle_parts: random.shuffle(part_surfaces) part_surfaces = load_surfaces(part_surfaces) # [N, P, 6] else: part_surfaces = [] for surface_path in data_config['surface_paths']: surface_data = np.load(surface_path, allow_pickle=True).item() part_surfaces.append(load_surface(surface_data)) part_surfaces = torch.stack(part_surfaces, dim=0) # [N, P, 6] image_path = data_config['image_path'] image = Image.open(image_path).resize(self.image_size) if random.random() < self.rotating_ratio: image = self.transform(image) image = np.array(image) image = torch.from_numpy(image).to(torch.uint8) # [H, W, 3] images = torch.stack([image] * part_surfaces.shape[0], dim=0) # [N, H, W, 3] return { "images": images, "part_surfaces": part_surfaces, } def __getitem__(self, idx: int): # The dataset can only support batchsize == 1 training. # Because the number of parts is not fixed. # Please see BatchedObjaversePartDataset for batched training. data_config = self.data_configs[idx] data = self._get_data_by_config(data_config) return data class BatchedObjaversePartDataset(ObjaversePartDataset): def __init__( self, configs: DictConfig, batch_size: int, is_main_process: bool = False, shuffle: bool = True, training: bool = True, ): assert training assert batch_size > 1 super().__init__(configs, training) self.batch_size = batch_size self.is_main_process = is_main_process if batch_size < self.max_num_parts: self.data_configs = [config for config in self.data_configs if config['num_parts'] <= batch_size] if shuffle: random.shuffle(self.data_configs) self.object_configs = [config for config in self.data_configs if config['num_parts'] == 1] self.parts_configs = [config for config in self.data_configs if config['num_parts'] > 1] self.object_ratio = configs['dataset']['object_ratio'] # Here we keep the ratio of object to parts self.object_configs = self.object_configs[:int(len(self.parts_configs) * self.object_ratio)] dropped_data_configs = self.parts_configs + self.object_configs if shuffle: random.shuffle(dropped_data_configs) self.data_configs = self._get_batched_configs(dropped_data_configs, batch_size) def _get_batched_configs(self, data_configs, batch_size): batched_data_configs = [] num_data_configs = len(data_configs) progress_bar = tqdm( range(len(data_configs)), desc="Batching Dataset", ncols=125, disable=not self.is_main_process, ) while len(data_configs) > 0: temp_batch = [] temp_num_parts = 0 unchosen_configs = [] while temp_num_parts < batch_size and len(data_configs) > 0: config = data_configs.pop() # pop the last config num_parts = config['num_parts'] if temp_num_parts + num_parts <= batch_size: temp_batch.append(config) temp_num_parts += num_parts progress_bar.update(1) else: unchosen_configs.append(config) # add back to the end data_configs = data_configs + unchosen_configs # concat the unchosen configs if temp_num_parts == batch_size: # Successfully get a batch if len(temp_batch) < batch_size: # pad the batch temp_batch += [{}] * (batch_size - len(temp_batch)) batched_data_configs += temp_batch # Else, the code enters here because len(data_configs) == 0 # which means in the left data_configs, there are no enough # "suitable" configs to form a batch. # Thus, drop the uncompleted batch. progress_bar.close() return batched_data_configs def __getitem__(self, idx: int): data_config = self.data_configs[idx] if len(data_config) == 0: # placeholder return {} data = self._get_data_by_config(data_config) return data def collate_fn(self, batch): batch = [data for data in batch if len(data) > 0] images = torch.cat([data['images'] for data in batch], dim=0) # [N, H, W, 3] surfaces = torch.cat([data['part_surfaces'] for data in batch], dim=0) # [N, P, 6] num_parts = torch.LongTensor([data['part_surfaces'].shape[0] for data in batch]) assert images.shape[0] == surfaces.shape[0] == num_parts.sum() == self.batch_size batch = { "images": images, "part_surfaces": surfaces, "num_parts": num_parts, } return batch