PartCrafter / src /datasets /objaverse_part.py
alexnasa's picture
Upload 85 files
bef5729 verified
from src.utils.typing_utils import *
import json
import os
import random
import accelerate
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from tqdm import tqdm
from src.utils.data_utils import load_surface, load_surfaces
class ObjaversePartDataset(torch.utils.data.Dataset):
def __init__(
self,
configs: DictConfig,
training: bool = True,
):
super().__init__()
self.configs = configs
self.training = training
self.min_num_parts = configs['dataset']['min_num_parts']
self.max_num_parts = configs['dataset']['max_num_parts']
self.val_min_num_parts = configs['val']['min_num_parts']
self.val_max_num_parts = configs['val']['max_num_parts']
self.max_iou_mean = configs['dataset'].get('max_iou_mean', None)
self.max_iou_max = configs['dataset'].get('max_iou_max', None)
self.shuffle_parts = configs['dataset']['shuffle_parts']
self.training_ratio = configs['dataset']['training_ratio']
self.balance_object_and_parts = configs['dataset'].get('balance_object_and_parts', False)
self.rotating_ratio = configs['dataset'].get('rotating_ratio', 0.0)
self.rotating_degree = configs['dataset'].get('rotating_degree', 10.0)
self.transform = transforms.Compose([
transforms.RandomRotation(degrees=(-self.rotating_degree, self.rotating_degree), fill=(255, 255, 255)),
])
if isinstance(configs['dataset']['config'], ListConfig):
data_configs = []
for config in configs['dataset']['config']:
local_data_configs = json.load(open(config))
if self.balance_object_and_parts:
if self.training:
local_data_configs = local_data_configs[:int(len(local_data_configs) * self.training_ratio)]
else:
local_data_configs = local_data_configs[int(len(local_data_configs) * self.training_ratio):]
local_data_configs = [config for config in local_data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts]
data_configs += local_data_configs
else:
data_configs = json.load(open(configs['dataset']['config']))
data_configs = [config for config in data_configs if config['valid']]
data_configs = [config for config in data_configs if self.min_num_parts <= config['num_parts'] <= self.max_num_parts]
if self.max_iou_mean is not None and self.max_iou_max is not None:
data_configs = [config for config in data_configs if config['iou_mean'] <= self.max_iou_mean]
data_configs = [config for config in data_configs if config['iou_max'] <= self.max_iou_max]
if not self.balance_object_and_parts:
if self.training:
data_configs = data_configs[:int(len(data_configs) * self.training_ratio)]
else:
data_configs = data_configs[int(len(data_configs) * self.training_ratio):]
data_configs = [config for config in data_configs if self.val_min_num_parts <= config['num_parts'] <= self.val_max_num_parts]
self.data_configs = data_configs
self.image_size = (512, 512)
def __len__(self) -> int:
return len(self.data_configs)
def _get_data_by_config(self, data_config):
if 'surface_path' in data_config:
surface_path = data_config['surface_path']
surface_data = np.load(surface_path, allow_pickle=True).item()
# If parts is empty, the object is the only part
part_surfaces = surface_data['parts'] if len(surface_data['parts']) > 0 else [surface_data['object']]
if self.shuffle_parts:
random.shuffle(part_surfaces)
part_surfaces = load_surfaces(part_surfaces) # [N, P, 6]
else:
part_surfaces = []
for surface_path in data_config['surface_paths']:
surface_data = np.load(surface_path, allow_pickle=True).item()
part_surfaces.append(load_surface(surface_data))
part_surfaces = torch.stack(part_surfaces, dim=0) # [N, P, 6]
image_path = data_config['image_path']
image = Image.open(image_path).resize(self.image_size)
if random.random() < self.rotating_ratio:
image = self.transform(image)
image = np.array(image)
image = torch.from_numpy(image).to(torch.uint8) # [H, W, 3]
images = torch.stack([image] * part_surfaces.shape[0], dim=0) # [N, H, W, 3]
return {
"images": images,
"part_surfaces": part_surfaces,
}
def __getitem__(self, idx: int):
# The dataset can only support batchsize == 1 training.
# Because the number of parts is not fixed.
# Please see BatchedObjaversePartDataset for batched training.
data_config = self.data_configs[idx]
data = self._get_data_by_config(data_config)
return data
class BatchedObjaversePartDataset(ObjaversePartDataset):
def __init__(
self,
configs: DictConfig,
batch_size: int,
is_main_process: bool = False,
shuffle: bool = True,
training: bool = True,
):
assert training
assert batch_size > 1
super().__init__(configs, training)
self.batch_size = batch_size
self.is_main_process = is_main_process
if batch_size < self.max_num_parts:
self.data_configs = [config for config in self.data_configs if config['num_parts'] <= batch_size]
if shuffle:
random.shuffle(self.data_configs)
self.object_configs = [config for config in self.data_configs if config['num_parts'] == 1]
self.parts_configs = [config for config in self.data_configs if config['num_parts'] > 1]
self.object_ratio = configs['dataset']['object_ratio']
# Here we keep the ratio of object to parts
self.object_configs = self.object_configs[:int(len(self.parts_configs) * self.object_ratio)]
dropped_data_configs = self.parts_configs + self.object_configs
if shuffle:
random.shuffle(dropped_data_configs)
self.data_configs = self._get_batched_configs(dropped_data_configs, batch_size)
def _get_batched_configs(self, data_configs, batch_size):
batched_data_configs = []
num_data_configs = len(data_configs)
progress_bar = tqdm(
range(len(data_configs)),
desc="Batching Dataset",
ncols=125,
disable=not self.is_main_process,
)
while len(data_configs) > 0:
temp_batch = []
temp_num_parts = 0
unchosen_configs = []
while temp_num_parts < batch_size and len(data_configs) > 0:
config = data_configs.pop() # pop the last config
num_parts = config['num_parts']
if temp_num_parts + num_parts <= batch_size:
temp_batch.append(config)
temp_num_parts += num_parts
progress_bar.update(1)
else:
unchosen_configs.append(config) # add back to the end
data_configs = data_configs + unchosen_configs # concat the unchosen configs
if temp_num_parts == batch_size:
# Successfully get a batch
if len(temp_batch) < batch_size:
# pad the batch
temp_batch += [{}] * (batch_size - len(temp_batch))
batched_data_configs += temp_batch
# Else, the code enters here because len(data_configs) == 0
# which means in the left data_configs, there are no enough
# "suitable" configs to form a batch.
# Thus, drop the uncompleted batch.
progress_bar.close()
return batched_data_configs
def __getitem__(self, idx: int):
data_config = self.data_configs[idx]
if len(data_config) == 0:
# placeholder
return {}
data = self._get_data_by_config(data_config)
return data
def collate_fn(self, batch):
batch = [data for data in batch if len(data) > 0]
images = torch.cat([data['images'] for data in batch], dim=0) # [N, H, W, 3]
surfaces = torch.cat([data['part_surfaces'] for data in batch], dim=0) # [N, P, 6]
num_parts = torch.LongTensor([data['part_surfaces'].shape[0] for data in batch])
assert images.shape[0] == surfaces.shape[0] == num_parts.sum() == self.batch_size
batch = {
"images": images,
"part_surfaces": surfaces,
"num_parts": num_parts,
}
return batch