Spaces:
Running
on
Zero
Running
on
Zero
from src.utils.typing_utils import * | |
import json | |
import os | |
import random | |
import accelerate | |
import torch | |
from torchvision import transforms | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
from src.utils.data_utils import load_surface, load_surfaces | |
class ObjaversePartDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
configs: DictConfig, | |
training: bool = True, | |
): | |
super().__init__() | |
self.configs = configs | |
self.training = training | |
self.min_num_parts = configs['dataset']['min_num_parts'] | |
self.max_num_parts = configs['dataset']['max_num_parts'] | |
self.val_min_num_parts = configs['val']['min_num_parts'] | |
self.val_max_num_parts = configs['val']['max_num_parts'] | |
self.max_iou_mean = configs['dataset'].get('max_iou_mean', None) | |
self.max_iou_max = configs['dataset'].get('max_iou_max', None) | |
self.shuffle_parts = configs['dataset']['shuffle_parts'] | |
self.training_ratio = configs['dataset']['training_ratio'] | |
self.balance_object_and_parts = configs['dataset'].get('balance_object_and_parts', False) | |
self.rotating_ratio = configs['dataset'].get('rotating_ratio', 0.0) | |
self.rotating_degree = configs['dataset'].get('rotating_degree', 10.0) | |
self.transform = transforms.Compose([ | |
transforms.RandomRotation(degrees=(-self.rotating_degree, self.rotating_degree), fill=(255, 255, 255)), | |
]) | |
if isinstance(configs['dataset']['config'], ListConfig): | |
data_configs = [] | |
for config in configs['dataset']['config']: | |
local_data_configs = json.load(open(config)) | |
if self.balance_object_and_parts: | |
if self.training: | |
local_data_configs = local_data_configs[:int(len(local_data_configs) * self.training_ratio)] | |
else: | |
local_data_configs = local_data_configs[int(len(local_data_configs) * self.training_ratio):] | |
local_data_configs = [config for config in local_data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts] | |
data_configs += local_data_configs | |
else: | |
data_configs = json.load(open(configs['dataset']['config'])) | |
data_configs = [config for config in data_configs if config['valid']] | |
data_configs = [config for config in data_configs if self.min_num_parts <= config['num_parts'] <= self.max_num_parts] | |
if self.max_iou_mean is not None and self.max_iou_max is not None: | |
data_configs = [config for config in data_configs if config['iou_mean'] <= self.max_iou_mean] | |
data_configs = [config for config in data_configs if config['iou_max'] <= self.max_iou_max] | |
if not self.balance_object_and_parts: | |
if self.training: | |
data_configs = data_configs[:int(len(data_configs) * self.training_ratio)] | |
else: | |
data_configs = data_configs[int(len(data_configs) * self.training_ratio):] | |
data_configs = [config for config in data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts] | |
self.data_configs = data_configs | |
self.image_size = (512, 512) | |
def __len__(self) -> int: | |
return len(self.data_configs) | |
def _get_data_by_config(self, data_config): | |
if 'surface_path' in data_config: | |
surface_path = data_config['surface_path'] | |
surface_data = np.load(surface_path, allow_pickle=True).item() | |
# If parts is empty, the object is the only part | |
part_surfaces = surface_data['parts'] if len(surface_data['parts']) > 0 else [surface_data['object']] | |
if self.shuffle_parts: | |
random.shuffle(part_surfaces) | |
part_surfaces = load_surfaces(part_surfaces) # [N, P, 6] | |
else: | |
part_surfaces = [] | |
for surface_path in data_config['surface_paths']: | |
surface_data = np.load(surface_path, allow_pickle=True).item() | |
part_surfaces.append(load_surface(surface_data)) | |
part_surfaces = torch.stack(part_surfaces, dim=0) # [N, P, 6] | |
image_path = data_config['image_path'] | |
image = Image.open(image_path).resize(self.image_size) | |
if random.random() < self.rotating_ratio: | |
image = self.transform(image) | |
image = np.array(image) | |
image = torch.from_numpy(image).to(torch.uint8) # [H, W, 3] | |
images = torch.stack([image] * part_surfaces.shape[0], dim=0) # [N, H, W, 3] | |
return { | |
"images": images, | |
"part_surfaces": part_surfaces, | |
} | |
def __getitem__(self, idx: int): | |
# The dataset can only support batchsize == 1 training. | |
# Because the number of parts is not fixed. | |
# Please see BatchedObjaversePartDataset for batched training. | |
data_config = self.data_configs[idx] | |
data = self._get_data_by_config(data_config) | |
return data | |
class BatchedObjaversePartDataset(ObjaversePartDataset): | |
def __init__( | |
self, | |
configs: DictConfig, | |
batch_size: int, | |
is_main_process: bool = False, | |
shuffle: bool = True, | |
training: bool = True, | |
): | |
assert training | |
assert batch_size > 1 | |
super().__init__(configs, training) | |
self.batch_size = batch_size | |
self.is_main_process = is_main_process | |
if batch_size < self.max_num_parts: | |
self.data_configs = [config for config in self.data_configs if config['num_parts'] <= batch_size] | |
if shuffle: | |
random.shuffle(self.data_configs) | |
self.object_configs = [config for config in self.data_configs if config['num_parts'] == 1] | |
self.parts_configs = [config for config in self.data_configs if config['num_parts'] > 1] | |
self.object_ratio = configs['dataset']['object_ratio'] | |
# Here we keep the ratio of object to parts | |
self.object_configs = self.object_configs[:int(len(self.parts_configs) * self.object_ratio)] | |
dropped_data_configs = self.parts_configs + self.object_configs | |
if shuffle: | |
random.shuffle(dropped_data_configs) | |
self.data_configs = self._get_batched_configs(dropped_data_configs, batch_size) | |
def _get_batched_configs(self, data_configs, batch_size): | |
batched_data_configs = [] | |
num_data_configs = len(data_configs) | |
progress_bar = tqdm( | |
range(len(data_configs)), | |
desc="Batching Dataset", | |
ncols=125, | |
disable=not self.is_main_process, | |
) | |
while len(data_configs) > 0: | |
temp_batch = [] | |
temp_num_parts = 0 | |
unchosen_configs = [] | |
while temp_num_parts < batch_size and len(data_configs) > 0: | |
config = data_configs.pop() # pop the last config | |
num_parts = config['num_parts'] | |
if temp_num_parts + num_parts <= batch_size: | |
temp_batch.append(config) | |
temp_num_parts += num_parts | |
progress_bar.update(1) | |
else: | |
unchosen_configs.append(config) # add back to the end | |
data_configs = data_configs + unchosen_configs # concat the unchosen configs | |
if temp_num_parts == batch_size: | |
# Successfully get a batch | |
if len(temp_batch) < batch_size: | |
# pad the batch | |
temp_batch += [{}] * (batch_size - len(temp_batch)) | |
batched_data_configs += temp_batch | |
# Else, the code enters here because len(data_configs) == 0 | |
# which means in the left data_configs, there are no enough | |
# "suitable" configs to form a batch. | |
# Thus, drop the uncompleted batch. | |
progress_bar.close() | |
return batched_data_configs | |
def __getitem__(self, idx: int): | |
data_config = self.data_configs[idx] | |
if len(data_config) == 0: | |
# placeholder | |
return {} | |
data = self._get_data_by_config(data_config) | |
return data | |
def collate_fn(self, batch): | |
batch = [data for data in batch if len(data) > 0] | |
images = torch.cat([data['images'] for data in batch], dim=0) # [N, H, W, 3] | |
surfaces = torch.cat([data['part_surfaces'] for data in batch], dim=0) # [N, P, 6] | |
num_parts = torch.LongTensor([data['part_surfaces'].shape[0] for data in batch]) | |
assert images.shape[0] == surfaces.shape[0] == num_parts.sum() == self.batch_size | |
batch = { | |
"images": images, | |
"part_surfaces": surfaces, | |
"num_parts": num_parts, | |
} | |
return batch |