Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from numpy import ndarray | |
from torch import Tensor | |
import os | |
import json | |
from collections import defaultdict | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn.functional as tF | |
from kiui.cam import orbit_camera, undo_orbit_camera | |
from src.data.utils.chunk_dataset import ChunkedDataset | |
from src.options import Options | |
from src.utils import normalize_normals, unproject_depth | |
class GObjaverseParquetDataset(ChunkedDataset): | |
def __init__(self, opt: Options, training: bool = True, *args, **kwargs): | |
self.opt = opt | |
self.training = training | |
# Default camera intrinsics | |
self.fxfycxcy = torch.tensor([opt.fxfy, opt.fxfy, 0.5, 0.5], dtype=torch.float32) # (4,) | |
if opt.prompt_embed_dir is not None: | |
try: | |
self.negative_prompt_embed = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null.npy")).float() | |
except FileNotFoundError: | |
self.negative_prompt_embed = None | |
try: | |
self.negative_pooled_prompt_embed = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null_pooled.npy")).float() | |
except FileNotFoundError: | |
self.negative_pooled_prompt_embed = None | |
try: | |
self.negative_prompt_attention_mask = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null_attention_mask.npy")).float() | |
except FileNotFoundError: | |
self.negative_prompt_attention_mask = None | |
if "xl" in opt.pretrained_model_name_or_path: # SDXL: zero out negative prompt embedding | |
if self.negative_prompt_embed is not None and self.negative_pooled_prompt_embed is not None: | |
self.negative_prompt_embed = torch.zeros_like(self.negative_prompt_embed) | |
self.negative_pooled_prompt_embed = torch.zeros_like(self.negative_pooled_prompt_embed) | |
# Backup from local disk for error data loading | |
with open(opt.backup_json_path, "r") as f: | |
self.backup_ids = json.load(f) | |
super().__init__(*args, **kwargs) | |
def __len__(self): | |
return self.opt.dataset_size | |
def get_trainable_data_from_raw_data(self, raw_data_list) -> Dict[str, Tensor]: # only `sample["__key__"]` is in str type | |
assert len(raw_data_list) == 1 | |
sample: Dict[str, bytes] = raw_data_list[0] | |
V, V_in = self.opt.num_views, self.opt.num_input_views | |
assert V >= V_in | |
if self.opt.load_even_views or not self.training: | |
_pick_func = self._pick_even_view_indices | |
else: | |
_pick_func = self._pick_random_view_indices | |
# Randomly sample `V_in` views (some objects may not appear in the dataset) | |
random_idxs = _pick_func(V_in) | |
_num_tries = 0 | |
while not self._check_views_exist(sample, random_idxs): | |
random_idxs = _pick_func(V_in) | |
_num_tries += 1 | |
if _num_tries > 100: # TODO: make `100` configurable | |
raise ValueError(f"Cannot find 4 views in {sample['__key__']}") | |
except_idxs = random_idxs + [24, 39] # filter duplicated views; hard-coded for GObjaverse | |
if self.opt.exclude_topdown_views: | |
except_idxs += [25, 26] | |
# Randomly sample `V` views (some views may not appear in the dataset) | |
for i in np.random.permutation(40): # `40` is hard-coded for GObjaverse | |
if len(random_idxs) >= V: | |
break | |
if f"{i:05d}.png" in sample and i not in except_idxs: | |
try: | |
_ = np.frombuffer(sample[f"{i:05d}.png"], np.uint8) | |
assert sample[f"{i:05d}.json"] is not None | |
random_idxs.append(i) | |
except: # TypeError: a bytes-like object is required, not 'NoneType'; KeyError: '00001.json' | |
pass | |
# Randomly repeat views if not enough views | |
while len(random_idxs) < V: | |
random_idxs.append(np.random.choice(random_idxs)) | |
return_dict = defaultdict(list) | |
init_azi = None | |
for vid in random_idxs: | |
return_dict["fxfycxcy"].append(self.fxfycxcy) # (V, 4); fixed intrinsics for GObjaverse | |
image = self._load_png(sample[f"{vid:05d}.png"]) # (4, 512, 512) | |
mask = image[3:4] # (1, 512, 512) | |
image = image[:3] * mask + (1. - mask) # (3, 512, 512), to white bg | |
return_dict["image"].append(image) # (V, 3, H, W) | |
return_dict["mask"].append(mask) # (V, 1, H, W) | |
if self.opt.load_canny: | |
gray = cv2.cvtColor(image.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2GRAY) | |
canny = cv2.Canny((gray * 255.).astype(np.uint8), 100., 200.) | |
canny = torch.from_numpy(canny).unsqueeze(0).float().repeat(3, 1, 1) / 255. # (3, 512, 512) in [0, 1] | |
canny = -canny + 1. # 0->1, 1->0, i.e., white bg | |
return_dict["canny"].append(canny) # (V, 3, H, W) | |
c2w = self._load_camera_from_json(sample[f"{vid:05d}.json"]) | |
# Blender world + OpenCV cam -> OpenGL world & cam; https://kit.kiui.moe/camera | |
c2w[1] *= -1 | |
c2w[[1, 2]] = c2w[[2, 1]] | |
c2w[:3, 1:3] *= -1 # invert up and forward direction | |
return_dict["original_C2W"].append(torch.from_numpy(c2w).float()) # (V, 4, 4); for normal normalization only | |
# Relative azimuth w.r.t. the first view | |
ele, azi, dis = undo_orbit_camera(c2w) # elevation: [-90, 90] from +y(-90) to -y(90) | |
if init_azi is None: | |
init_azi = azi | |
azi = (azi - init_azi) % 360. # azimuth: [0, 360] from +z(0) to +x(90) | |
# To avoid numerical errors for elevation +/- 90 (GObjaverse index 25 (up) & 26 (down)) | |
ele_sign = ele >= 0 | |
ele = abs(ele) - 1e-8 | |
ele = ele * (1. if ele_sign else -1.) | |
new_c2w = torch.from_numpy(orbit_camera(ele, azi, dis)).float() | |
return_dict["C2W"].append(new_c2w) # (V, 4, 4) | |
return_dict["cam_pose"].append(torch.tensor( | |
[np.deg2rad(ele), np.deg2rad(azi), dis], dtype=torch.float32)) # (V, 3) | |
# Albedo | |
if self.opt.load_albedo: | |
albedo = self._load_png(sample[f"{vid:05d}_albedo.png"]) # (3, 512, 512) | |
albedo = albedo * mask + (1. - mask) # (3, 512, 512), to white bg | |
return_dict["albedo"].append(albedo) # (V, 3, H, W) | |
# Normal & Depth | |
if self.opt.load_normal or self.opt.load_coord: | |
nd = self._load_png(sample[f"{vid:05d}_nd.png"], uint16=True) # (4, 512, 512) | |
if self.opt.load_normal: | |
normal = nd[:3] * 2. - 1. # (3, 512, 512) in [-1, 1] | |
normal[0, ...] *= -1 # to OpenGL world convention | |
return_dict["normal"].append(normal) # (V, 3, H, W) | |
if self.opt.load_coord or self.opt.load_depth: | |
depth = nd[3] * 5. # (512, 512); NOTE: depth is scaled by 1/5 in my data preprocessing | |
return_dict["depth"].append(depth) # (V, H, W) | |
# Metal & Roughness | |
if self.opt.load_mr: | |
mr = self._load_png(sample[f"{vid:05d}_mr.png"]) # (3, 512, 512); (metallic, roughness, padding) | |
mr = mr * mask + (1. - mask) # (3, 512, 512), to white bg | |
return_dict["mr"].append(mr) # (V, 3, H, W) | |
for key in return_dict.keys(): | |
return_dict[key] = torch.stack(return_dict[key], dim=0) | |
if self.opt.load_normal: | |
# Normalize normals by the first view and transform the first view to a fixed azimuth (i.e., 0) | |
# Ensure `normals` and `original_C2W` are in the same camera convention | |
normals = normalize_normals(return_dict["normal"].unsqueeze(0), return_dict["original_C2W"].unsqueeze(0), i=0).squeeze(0) | |
normals = torch.einsum("brc,bvchw->bvrhw", return_dict["C2W"][0, :3, :3].unsqueeze(0), normals.unsqueeze(0)).squeeze(0) | |
normals = normals * 0.5 + 0.5 # [0, 1] | |
normals = normals * return_dict["mask"] + (1. - return_dict["mask"]) # (V, 3, 512, 512), to white bg | |
return_dict["normal"] = normals | |
return_dict.pop("original_C2W") # original C2W is only used for normal normalization | |
# OpenGL to COLMAP camera for Gaussian renderer | |
return_dict["C2W"][:, :3, 1:3] *= -1 | |
# Whether scale the object w.r.t. the first view to a fixed size | |
if self.opt.norm_camera: | |
scale = self.opt.norm_radius / (torch.norm(return_dict["C2W"][0, :3, 3], dim=-1)) | |
else: | |
scale = 1. | |
return_dict["C2W"][:, :3, 3] *= scale | |
return_dict["cam_pose"][:, 2] *= scale | |
if self.opt.load_coord: | |
# Unproject depth map to 3D world coordinate | |
coords = unproject_depth(return_dict["depth"].unsqueeze(0) * scale, | |
return_dict["C2W"].unsqueeze(0), return_dict["fxfycxcy"].unsqueeze(0)).squeeze(0) | |
coords = coords * 0.5 + 0.5 # [0, 1] | |
coords = coords * return_dict["mask"] + (1. - return_dict["mask"]) # (V, 3, 512, 512), to white bg | |
return_dict["coord"] = coords | |
if not self.opt.load_depth: | |
return_dict.pop("depth") | |
if self.opt.load_depth: | |
depths = return_dict["depth"].unsqueeze(1) * return_dict["mask"] # (V, 1, 512, 512), to black bg | |
assert depths.min() == 0. | |
if self.opt.normalize_depth: | |
H, W = depths.shape[-2:] | |
depths = depths.reshape(V, -1) | |
depths_max = depths.max(dim=-1, keepdim=True).values | |
depths = depths / depths_max # [0, 1] | |
depths = depths.reshape(V, 1, H, W) | |
depths = -depths + 1. # 0->1, 1->0, i.e., white bg | |
return_dict["depth"] = depths.repeat(1, 3, 1, 1) | |
# Resize to the input resolution | |
for key in ["image", "mask", "albedo", "normal", "coord", "depth", "mr", "canny"]: | |
if key in return_dict.keys(): | |
return_dict[key] = tF.interpolate( | |
return_dict[key], size=(self.opt.input_res, self.opt.input_res), | |
mode="bilinear", align_corners=False, antialias=True | |
) | |
# Handle anti-aliased normal, coord and depth (GObjaverse renders anti-aliased normal & depth) | |
if self.opt.load_normal: | |
return_dict["normal"] = return_dict["normal"] * return_dict["mask"] + (1. - return_dict["mask"]) | |
if self.opt.load_coord: | |
return_dict["coord"] = return_dict["coord"] * return_dict["mask"] + (1. - return_dict["mask"]) | |
if self.opt.load_depth: | |
return_dict["depth"] = return_dict["depth"] * return_dict["mask"] + (1. - return_dict["mask"]) | |
# Load precomputed caption embeddings | |
if self.opt.prompt_embed_dir is not None: | |
uid = sample["uid"].decode("utf-8").split("/")[-1].split(".")[0] | |
return_dict["prompt_embed"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}.npy")) | |
if "xl" in self.opt.pretrained_model_name_or_path or "3" in self.opt.pretrained_model_name_or_path: # SDXL or SD3 | |
return_dict["pooled_prompt_embed"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}_pooled.npy")) | |
if "PixArt" in self.opt.pretrained_model_name_or_path: # PixArt-alpha, PixArt-Sigma | |
return_dict["prompt_attention_mask"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}_attention_mask.npy")) | |
for key in return_dict.keys(): | |
assert isinstance(return_dict[key], Tensor), f"Value of the key [{key}] is not a Tensor, but {type(return_dict[key])}." | |
return dict(return_dict) | |
def _load_png(self, png_bytes: Union[bytes, str], uint16: bool = False) -> Tensor: | |
png = np.frombuffer(png_bytes, np.uint8) | |
png = cv2.imdecode(png, cv2.IMREAD_UNCHANGED) # (H, W, C) ndarray in [0, 255] or [0, 65553] | |
png = png.astype(np.float32) / (65535. if uint16 else 255.) # (H, W, C) in [0, 1] | |
png[:, :, :3] = png[:, :, :3][..., ::-1] # BGR -> RGB | |
png_tensor = torch.from_numpy(png).nan_to_num_(0.) # there are nan in GObjaverse gt normal | |
return png_tensor.permute(2, 0, 1) # (C, H, W) in [0, 1] | |
def _load_camera_from_json(self, json_bytes: Union[bytes, str]) -> ndarray: | |
if isinstance(json_bytes, bytes): | |
json_dict = json.loads(json_bytes) | |
else: # BACKUP | |
path = os.path.join(self.opt.backup_file_dir, f"{json_bytes}.json") | |
with open(path, "r") as f: | |
json_dict = json.load(f) | |
# In OpenCV convention | |
c2w = np.eye(4) # float64 | |
c2w[:3, 0] = np.array(json_dict["x"]) | |
c2w[:3, 1] = np.array(json_dict["y"]) | |
c2w[:3, 2] = np.array(json_dict["z"]) | |
c2w[:3, 3] = np.array(json_dict["origin"]) | |
return c2w | |
def _pick_even_view_indices(self, num_views: int = 4) -> List[int]: | |
assert 12 % num_views == 0 # `12` for even-view sampling in GObjaverse | |
if np.random.rand() < 2/3: | |
index0 = np.random.choice(range(24)) # 0~23: 24 views in ele from [5, 30]; hard-coded for GObjaverse | |
return [(index0 + (24 // num_views)*i) % 24 for i in range(num_views)] | |
else: | |
index0 = np.random.choice(range(12)) # 27~38: 12 views in ele from [-5, 5]; hard-coded for GObjaverse | |
return [((index0 + (12 // num_views)*i) % 12 + 27) for i in range(num_views)] | |
def _pick_random_view_indices(self, num_views: int = 4) -> List[int]: | |
assert num_views <= 40 # `40` is hard-coded for GObjaverse | |
indices = (set(range(40)) - set([25, 26])) if self.opt.exclude_topdown_views else (set(range(40))) # `40` is hard-coded for GObjaverse | |
return np.random.choice(list(indices), num_views, replace=False).tolist() | |
def _check_views_exist(self, sample: Dict[str, Union[str, bytes]], vids: List[int]) -> bool: | |
for vid in vids: | |
if f"{vid:05d}.png" not in sample: | |
return False | |
try: | |
assert sample[f"{vid:05d}.png"] is not None and sample[f"{vid:05d}.json"] is not None | |
except: # TypeError: a bytes-like object is required, not 'NoneType'; KeyError: '00001.json' | |
return False | |
return True | |
def _check_views_exist_disk(self, uid: str, vids: List[int]) -> bool: | |
for vid in vids: | |
if not (os.path.exists(os.path.join(self.opt.backup_file_dir, f"{uid}.{vid:05d}.png")) | |
and os.path.exists(os.path.join(self.opt.backup_file_dir, f"{uid}.{vid:05d}.json"))): | |
return False | |
return True | |