File size: 14,918 Bytes
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from typing import *
from numpy import ndarray
from torch import Tensor

import os
import json
from collections import defaultdict

import cv2
import numpy as np
import torch
import torch.nn.functional as tF
from kiui.cam import orbit_camera, undo_orbit_camera

from src.data.utils.chunk_dataset import ChunkedDataset
from src.options import Options
from src.utils import normalize_normals, unproject_depth


class GObjaverseParquetDataset(ChunkedDataset):
    def __init__(self, opt: Options, training: bool = True, *args, **kwargs):
        self.opt = opt
        self.training = training

        # Default camera intrinsics
        self.fxfycxcy = torch.tensor([opt.fxfy, opt.fxfy, 0.5, 0.5], dtype=torch.float32)  # (4,)

        if opt.prompt_embed_dir is not None:
            try:
                self.negative_prompt_embed = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null.npy")).float()
            except FileNotFoundError:
                self.negative_prompt_embed = None
            try:
                self.negative_pooled_prompt_embed = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null_pooled.npy")).float()
            except FileNotFoundError:
                self.negative_pooled_prompt_embed = None
            try:
                self.negative_prompt_attention_mask = torch.from_numpy(np.load(f"{opt.prompt_embed_dir}/null_attention_mask.npy")).float()
            except FileNotFoundError:
                self.negative_prompt_attention_mask = None

            if "xl" in opt.pretrained_model_name_or_path:  # SDXL: zero out negative prompt embedding
                if self.negative_prompt_embed is not None and self.negative_pooled_prompt_embed is not None:
                    self.negative_prompt_embed = torch.zeros_like(self.negative_prompt_embed)
                    self.negative_pooled_prompt_embed = torch.zeros_like(self.negative_pooled_prompt_embed)

        # Backup from local disk for error data loading
        with open(opt.backup_json_path, "r") as f:
            self.backup_ids = json.load(f)

        super().__init__(*args, **kwargs)

    def __len__(self):
        return self.opt.dataset_size

    def get_trainable_data_from_raw_data(self, raw_data_list) -> Dict[str, Tensor]:  # only `sample["__key__"]` is in str type
        assert len(raw_data_list) == 1
        sample: Dict[str, bytes] = raw_data_list[0]

        V, V_in = self.opt.num_views, self.opt.num_input_views
        assert V >= V_in

        if self.opt.load_even_views or not self.training:
            _pick_func = self._pick_even_view_indices
        else:
            _pick_func = self._pick_random_view_indices

        # Randomly sample `V_in` views (some objects may not appear in the dataset)
        random_idxs = _pick_func(V_in)
        _num_tries = 0
        while not self._check_views_exist(sample, random_idxs):
            random_idxs = _pick_func(V_in)
            _num_tries += 1
            if _num_tries > 100:  # TODO: make `100` configurable
                raise ValueError(f"Cannot find 4 views in {sample['__key__']}")

        except_idxs = random_idxs + [24, 39]  # filter duplicated views; hard-coded for GObjaverse
        if self.opt.exclude_topdown_views:
            except_idxs += [25, 26]

        # Randomly sample `V` views (some views may not appear in the dataset)
        for i in np.random.permutation(40):  # `40` is hard-coded for GObjaverse
            if len(random_idxs) >= V:
                break
            if f"{i:05d}.png" in sample and i not in except_idxs:
                try:
                    _ = np.frombuffer(sample[f"{i:05d}.png"], np.uint8)
                    assert sample[f"{i:05d}.json"] is not None
                    random_idxs.append(i)
                except:  # TypeError: a bytes-like object is required, not 'NoneType'; KeyError: '00001.json'
                    pass
        # Randomly repeat views if not enough views
        while len(random_idxs) < V:
            random_idxs.append(np.random.choice(random_idxs))

        return_dict = defaultdict(list)
        init_azi = None
        for vid in random_idxs:
            return_dict["fxfycxcy"].append(self.fxfycxcy)  # (V, 4); fixed intrinsics for GObjaverse

            image = self._load_png(sample[f"{vid:05d}.png"])  # (4, 512, 512)
            mask = image[3:4]  # (1, 512, 512)
            image = image[:3] * mask + (1. - mask)  # (3, 512, 512), to white bg
            return_dict["image"].append(image)  # (V, 3, H, W)
            return_dict["mask"].append(mask)  # (V, 1, H, W)

            if self.opt.load_canny:
                gray = cv2.cvtColor(image.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2GRAY)
                canny = cv2.Canny((gray * 255.).astype(np.uint8), 100., 200.)
                canny = torch.from_numpy(canny).unsqueeze(0).float().repeat(3, 1, 1) / 255.  # (3, 512, 512) in [0, 1]
                canny = -canny + 1.  # 0->1, 1->0, i.e., white bg
                return_dict["canny"].append(canny)  # (V, 3, H, W)

            c2w = self._load_camera_from_json(sample[f"{vid:05d}.json"])
            # Blender world + OpenCV cam -> OpenGL world & cam; https://kit.kiui.moe/camera
            c2w[1] *= -1
            c2w[[1, 2]] = c2w[[2, 1]]
            c2w[:3, 1:3] *= -1  # invert up and forward direction
            return_dict["original_C2W"].append(torch.from_numpy(c2w).float())  # (V, 4, 4); for normal normalization only

            # Relative azimuth w.r.t. the first view
            ele, azi, dis = undo_orbit_camera(c2w)  # elevation: [-90, 90] from +y(-90) to -y(90)
            if init_azi is None:
                init_azi = azi
            azi = (azi - init_azi) % 360.  # azimuth: [0, 360] from +z(0) to +x(90)
            # To avoid numerical errors for elevation +/- 90 (GObjaverse index 25 (up) & 26 (down))
            ele_sign = ele >= 0
            ele = abs(ele) - 1e-8
            ele = ele * (1. if ele_sign else -1.)

            new_c2w = torch.from_numpy(orbit_camera(ele, azi, dis)).float()
            return_dict["C2W"].append(new_c2w)  # (V, 4, 4)
            return_dict["cam_pose"].append(torch.tensor(
                [np.deg2rad(ele), np.deg2rad(azi), dis], dtype=torch.float32))  # (V, 3)

            # Albedo
            if self.opt.load_albedo:
                albedo = self._load_png(sample[f"{vid:05d}_albedo.png"])  # (3, 512, 512)
                albedo = albedo * mask + (1. - mask)  # (3, 512, 512), to white bg
                return_dict["albedo"].append(albedo)  # (V, 3, H, W)
            # Normal & Depth
            if self.opt.load_normal or self.opt.load_coord:
                nd = self._load_png(sample[f"{vid:05d}_nd.png"], uint16=True)  # (4, 512, 512)
                if self.opt.load_normal:
                    normal = nd[:3] * 2. - 1.  # (3, 512, 512) in [-1, 1]
                    normal[0, ...] *= -1  # to OpenGL world convention
                    return_dict["normal"].append(normal)  # (V, 3, H, W)
                if self.opt.load_coord or self.opt.load_depth:
                    depth = nd[3] * 5.  # (512, 512); NOTE: depth is scaled by 1/5 in my data preprocessing
                    return_dict["depth"].append(depth)  # (V, H, W)
            # Metal & Roughness
            if self.opt.load_mr:
                mr = self._load_png(sample[f"{vid:05d}_mr.png"])  # (3, 512, 512); (metallic, roughness, padding)
                mr = mr * mask + (1. - mask)  # (3, 512, 512), to white bg
                return_dict["mr"].append(mr)  # (V, 3, H, W)

        for key in return_dict.keys():
            return_dict[key] = torch.stack(return_dict[key], dim=0)

        if self.opt.load_normal:
            # Normalize normals by the first view and transform the first view to a fixed azimuth (i.e., 0)
            # Ensure `normals` and `original_C2W` are in the same camera convention
            normals = normalize_normals(return_dict["normal"].unsqueeze(0), return_dict["original_C2W"].unsqueeze(0), i=0).squeeze(0)
            normals = torch.einsum("brc,bvchw->bvrhw", return_dict["C2W"][0, :3, :3].unsqueeze(0), normals.unsqueeze(0)).squeeze(0)
            normals = normals * 0.5 + 0.5  # [0, 1]
            normals = normals * return_dict["mask"] + (1. - return_dict["mask"])  # (V, 3, 512, 512), to white bg
            return_dict["normal"] = normals
            return_dict.pop("original_C2W")  # original C2W is only used for normal normalization

        # OpenGL to COLMAP camera for Gaussian renderer
        return_dict["C2W"][:, :3, 1:3] *= -1

        # Whether scale the object w.r.t. the first view to a fixed size
        if self.opt.norm_camera:
            scale = self.opt.norm_radius / (torch.norm(return_dict["C2W"][0, :3, 3], dim=-1))
        else:
            scale = 1.
        return_dict["C2W"][:, :3, 3] *= scale
        return_dict["cam_pose"][:, 2] *= scale

        if self.opt.load_coord:
            # Unproject depth map to 3D world coordinate
            coords = unproject_depth(return_dict["depth"].unsqueeze(0) * scale,
                return_dict["C2W"].unsqueeze(0), return_dict["fxfycxcy"].unsqueeze(0)).squeeze(0)
            coords = coords * 0.5 + 0.5  # [0, 1]
            coords = coords * return_dict["mask"] + (1. - return_dict["mask"])  # (V, 3, 512, 512), to white bg
            return_dict["coord"] = coords
            if not self.opt.load_depth:
                return_dict.pop("depth")

        if self.opt.load_depth:
            depths = return_dict["depth"].unsqueeze(1) * return_dict["mask"]  # (V, 1, 512, 512), to black bg
            assert depths.min() == 0.
            if self.opt.normalize_depth:
                H, W = depths.shape[-2:]
                depths = depths.reshape(V, -1)
                depths_max = depths.max(dim=-1, keepdim=True).values
                depths = depths / depths_max  # [0, 1]
                depths = depths.reshape(V, 1, H, W)
            depths = -depths + 1.  # 0->1, 1->0, i.e., white bg
            return_dict["depth"] = depths.repeat(1, 3, 1, 1)

        # Resize to the input resolution
        for key in ["image", "mask", "albedo", "normal", "coord", "depth", "mr", "canny"]:
            if key in return_dict.keys():
                return_dict[key] = tF.interpolate(
                    return_dict[key], size=(self.opt.input_res, self.opt.input_res),
                    mode="bilinear", align_corners=False, antialias=True
                )

        # Handle anti-aliased normal, coord and depth (GObjaverse renders anti-aliased normal & depth)
        if self.opt.load_normal:
            return_dict["normal"] = return_dict["normal"] * return_dict["mask"] + (1. - return_dict["mask"])
        if self.opt.load_coord:
            return_dict["coord"] = return_dict["coord"] * return_dict["mask"] + (1. - return_dict["mask"])
        if self.opt.load_depth:
            return_dict["depth"] = return_dict["depth"] * return_dict["mask"] + (1. - return_dict["mask"])

        # Load precomputed caption embeddings
        if self.opt.prompt_embed_dir is not None:
            uid = sample["uid"].decode("utf-8").split("/")[-1].split(".")[0]
            return_dict["prompt_embed"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}.npy"))
            if "xl" in self.opt.pretrained_model_name_or_path or "3" in self.opt.pretrained_model_name_or_path:  # SDXL or SD3
                return_dict["pooled_prompt_embed"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}_pooled.npy"))
            if "PixArt" in self.opt.pretrained_model_name_or_path:  # PixArt-alpha, PixArt-Sigma
                return_dict["prompt_attention_mask"] = torch.from_numpy(np.load(f"{self.opt.prompt_embed_dir}/{uid}_attention_mask.npy"))

        for key in return_dict.keys():
            assert isinstance(return_dict[key], Tensor), f"Value of the key [{key}] is not a Tensor, but {type(return_dict[key])}."

        return dict(return_dict)

    def _load_png(self, png_bytes: Union[bytes, str], uint16: bool = False) -> Tensor:
        png = np.frombuffer(png_bytes, np.uint8)
        png = cv2.imdecode(png, cv2.IMREAD_UNCHANGED)  # (H, W, C) ndarray in [0, 255] or [0, 65553]

        png = png.astype(np.float32) / (65535. if uint16 else 255.)  # (H, W, C) in [0, 1]
        png[:, :, :3] = png[:, :, :3][..., ::-1]  # BGR -> RGB
        png_tensor = torch.from_numpy(png).nan_to_num_(0.)  # there are nan in GObjaverse gt normal
        return png_tensor.permute(2, 0, 1)  # (C, H, W) in [0, 1]

    def _load_camera_from_json(self, json_bytes: Union[bytes, str]) -> ndarray:
        if isinstance(json_bytes, bytes):
            json_dict = json.loads(json_bytes)
        else:  # BACKUP
            path = os.path.join(self.opt.backup_file_dir, f"{json_bytes}.json")
            with open(path, "r") as f:
                json_dict = json.load(f)

        # In OpenCV convention
        c2w = np.eye(4)  # float64
        c2w[:3, 0] = np.array(json_dict["x"])
        c2w[:3, 1] = np.array(json_dict["y"])
        c2w[:3, 2] = np.array(json_dict["z"])
        c2w[:3, 3] = np.array(json_dict["origin"])
        return c2w

    def _pick_even_view_indices(self, num_views: int = 4) -> List[int]:
        assert 12 % num_views == 0  # `12` for even-view sampling in GObjaverse

        if np.random.rand() < 2/3:
            index0 = np.random.choice(range(24))  # 0~23: 24 views in ele from [5, 30]; hard-coded for GObjaverse
            return [(index0 + (24 // num_views)*i) % 24 for i in range(num_views)]
        else:
            index0 = np.random.choice(range(12))  # 27~38: 12 views in ele from [-5, 5]; hard-coded for GObjaverse
            return [((index0 + (12 // num_views)*i) % 12 + 27) for i in range(num_views)]

    def _pick_random_view_indices(self, num_views: int = 4) -> List[int]:
        assert num_views <= 40  # `40` is hard-coded for GObjaverse

        indices = (set(range(40)) - set([25, 26])) if self.opt.exclude_topdown_views else (set(range(40)))  # `40` is hard-coded for GObjaverse
        return np.random.choice(list(indices), num_views, replace=False).tolist()

    def _check_views_exist(self, sample: Dict[str, Union[str, bytes]], vids: List[int]) -> bool:
        for vid in vids:
            if f"{vid:05d}.png" not in sample:
                return False
            try:
                assert sample[f"{vid:05d}.png"] is not None and sample[f"{vid:05d}.json"] is not None
            except:  # TypeError: a bytes-like object is required, not 'NoneType'; KeyError: '00001.json'
                return False
        return True

    def _check_views_exist_disk(self, uid: str, vids: List[int]) -> bool:
        for vid in vids:
            if not (os.path.exists(os.path.join(self.opt.backup_file_dir, f"{uid}.{vid:05d}.png"))
                and os.path.exists(os.path.join(self.opt.backup_file_dir, f"{uid}.{vid:05d}.json"))):
                return False
        return True