Spaces:
Runtime error
Runtime error
Upload 10 files
Browse files- trellis/utils/__init__.py +0 -0
- trellis/utils/data_utils.py +226 -0
- trellis/utils/dist_utils.py +93 -0
- trellis/utils/elastic_utils.py +228 -0
- trellis/utils/general_utils.py +202 -0
- trellis/utils/grad_clip_utils.py +81 -0
- trellis/utils/loss_utils.py +92 -0
- trellis/utils/postprocessing_utils.py +587 -0
- trellis/utils/random_utils.py +30 -0
- trellis/utils/render_utils.py +120 -0
trellis/utils/__init__.py
ADDED
|
File without changes
|
trellis/utils/data_utils.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def recursive_to_device(
|
| 10 |
+
data: Any,
|
| 11 |
+
device: torch.device,
|
| 12 |
+
non_blocking: bool = False,
|
| 13 |
+
) -> Any:
|
| 14 |
+
"""
|
| 15 |
+
Recursively move all tensors in a data structure to a device.
|
| 16 |
+
"""
|
| 17 |
+
if hasattr(data, "to"):
|
| 18 |
+
return data.to(device, non_blocking=non_blocking)
|
| 19 |
+
elif isinstance(data, (list, tuple)):
|
| 20 |
+
return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
|
| 21 |
+
elif isinstance(data, dict):
|
| 22 |
+
return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
|
| 23 |
+
else:
|
| 24 |
+
return data
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_balanced_group_indices(
|
| 28 |
+
load: List[int],
|
| 29 |
+
num_groups: int,
|
| 30 |
+
equal_size: bool = False,
|
| 31 |
+
) -> List[List[int]]:
|
| 32 |
+
"""
|
| 33 |
+
Split indices into groups with balanced load.
|
| 34 |
+
"""
|
| 35 |
+
if equal_size:
|
| 36 |
+
group_size = len(load) // num_groups
|
| 37 |
+
indices = np.argsort(load)[::-1]
|
| 38 |
+
groups = [[] for _ in range(num_groups)]
|
| 39 |
+
group_load = np.zeros(num_groups)
|
| 40 |
+
for idx in indices:
|
| 41 |
+
min_group_idx = np.argmin(group_load)
|
| 42 |
+
groups[min_group_idx].append(idx)
|
| 43 |
+
if equal_size and len(groups[min_group_idx]) == group_size:
|
| 44 |
+
group_load[min_group_idx] = float('inf')
|
| 45 |
+
else:
|
| 46 |
+
group_load[min_group_idx] += load[idx]
|
| 47 |
+
return groups
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def cycle(data_loader: DataLoader) -> Iterator:
|
| 51 |
+
while True:
|
| 52 |
+
for data in data_loader:
|
| 53 |
+
if isinstance(data_loader.sampler, ResumableSampler):
|
| 54 |
+
data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined]
|
| 55 |
+
yield data
|
| 56 |
+
if isinstance(data_loader.sampler, DistributedSampler):
|
| 57 |
+
data_loader.sampler.epoch += 1
|
| 58 |
+
if isinstance(data_loader.sampler, ResumableSampler):
|
| 59 |
+
data_loader.sampler.epoch += 1
|
| 60 |
+
data_loader.sampler.idx = 0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ResumableSampler(Sampler):
|
| 64 |
+
"""
|
| 65 |
+
Distributed sampler that is resumable.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
dataset: Dataset used for sampling.
|
| 69 |
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
| 70 |
+
By default, :attr:`rank` is retrieved from the current distributed
|
| 71 |
+
group.
|
| 72 |
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
| 73 |
+
indices.
|
| 74 |
+
seed (int, optional): random seed used to shuffle the sampler if
|
| 75 |
+
:attr:`shuffle=True`. This number should be identical across all
|
| 76 |
+
processes in the distributed group. Default: ``0``.
|
| 77 |
+
drop_last (bool, optional): if ``True``, then the sampler will drop the
|
| 78 |
+
tail of the data to make it evenly divisible across the number of
|
| 79 |
+
replicas. If ``False``, the sampler will add extra indices to make
|
| 80 |
+
the data evenly divisible across the replicas. Default: ``False``.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
dataset: Dataset,
|
| 86 |
+
shuffle: bool = True,
|
| 87 |
+
seed: int = 0,
|
| 88 |
+
drop_last: bool = False,
|
| 89 |
+
) -> None:
|
| 90 |
+
self.dataset = dataset
|
| 91 |
+
self.epoch = 0
|
| 92 |
+
self.idx = 0
|
| 93 |
+
self.drop_last = drop_last
|
| 94 |
+
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 95 |
+
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
| 96 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
| 97 |
+
# is no need to drop any data, since the dataset will be split equally.
|
| 98 |
+
if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
|
| 99 |
+
# Split to nearest available length that is evenly divisible.
|
| 100 |
+
# This is to ensure each rank receives the same amount of data when
|
| 101 |
+
# using this Sampler.
|
| 102 |
+
self.num_samples = math.ceil(
|
| 103 |
+
(len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type]
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
|
| 107 |
+
self.total_size = self.num_samples * self.world_size
|
| 108 |
+
self.shuffle = shuffle
|
| 109 |
+
self.seed = seed
|
| 110 |
+
|
| 111 |
+
def __iter__(self) -> Iterator:
|
| 112 |
+
if self.shuffle:
|
| 113 |
+
# deterministically shuffle based on epoch and seed
|
| 114 |
+
g = torch.Generator()
|
| 115 |
+
g.manual_seed(self.seed + self.epoch)
|
| 116 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
| 117 |
+
else:
|
| 118 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
| 119 |
+
|
| 120 |
+
if not self.drop_last:
|
| 121 |
+
# add extra samples to make it evenly divisible
|
| 122 |
+
padding_size = self.total_size - len(indices)
|
| 123 |
+
if padding_size <= len(indices):
|
| 124 |
+
indices += indices[:padding_size]
|
| 125 |
+
else:
|
| 126 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[
|
| 127 |
+
:padding_size
|
| 128 |
+
]
|
| 129 |
+
else:
|
| 130 |
+
# remove tail of data to make it evenly divisible.
|
| 131 |
+
indices = indices[: self.total_size]
|
| 132 |
+
assert len(indices) == self.total_size
|
| 133 |
+
|
| 134 |
+
# subsample
|
| 135 |
+
indices = indices[self.rank : self.total_size : self.world_size]
|
| 136 |
+
|
| 137 |
+
# resume from previous state
|
| 138 |
+
indices = indices[self.idx:]
|
| 139 |
+
|
| 140 |
+
return iter(indices)
|
| 141 |
+
|
| 142 |
+
def __len__(self) -> int:
|
| 143 |
+
return self.num_samples
|
| 144 |
+
|
| 145 |
+
def state_dict(self) -> dict[str, int]:
|
| 146 |
+
return {
|
| 147 |
+
'epoch': self.epoch,
|
| 148 |
+
'idx': self.idx,
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
def load_state_dict(self, state_dict):
|
| 152 |
+
self.epoch = state_dict['epoch']
|
| 153 |
+
self.idx = state_dict['idx']
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class BalancedResumableSampler(ResumableSampler):
|
| 157 |
+
"""
|
| 158 |
+
Distributed sampler that is resumable and balances the load among the processes.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
dataset: Dataset used for sampling.
|
| 162 |
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
| 163 |
+
By default, :attr:`rank` is retrieved from the current distributed
|
| 164 |
+
group.
|
| 165 |
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
| 166 |
+
indices.
|
| 167 |
+
seed (int, optional): random seed used to shuffle the sampler if
|
| 168 |
+
:attr:`shuffle=True`. This number should be identical across all
|
| 169 |
+
processes in the distributed group. Default: ``0``.
|
| 170 |
+
drop_last (bool, optional): if ``True``, then the sampler will drop the
|
| 171 |
+
tail of the data to make it evenly divisible across the number of
|
| 172 |
+
replicas. If ``False``, the sampler will add extra indices to make
|
| 173 |
+
the data evenly divisible across the replicas. Default: ``False``.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
dataset: Dataset,
|
| 179 |
+
shuffle: bool = True,
|
| 180 |
+
seed: int = 0,
|
| 181 |
+
drop_last: bool = False,
|
| 182 |
+
batch_size: int = 1,
|
| 183 |
+
) -> None:
|
| 184 |
+
assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
|
| 185 |
+
super().__init__(dataset, shuffle, seed, drop_last)
|
| 186 |
+
self.batch_size = batch_size
|
| 187 |
+
self.loads = dataset.loads
|
| 188 |
+
|
| 189 |
+
def __iter__(self) -> Iterator:
|
| 190 |
+
if self.shuffle:
|
| 191 |
+
# deterministically shuffle based on epoch and seed
|
| 192 |
+
g = torch.Generator()
|
| 193 |
+
g.manual_seed(self.seed + self.epoch)
|
| 194 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
| 195 |
+
else:
|
| 196 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
| 197 |
+
|
| 198 |
+
if not self.drop_last:
|
| 199 |
+
# add extra samples to make it evenly divisible
|
| 200 |
+
padding_size = self.total_size - len(indices)
|
| 201 |
+
if padding_size <= len(indices):
|
| 202 |
+
indices += indices[:padding_size]
|
| 203 |
+
else:
|
| 204 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[
|
| 205 |
+
:padding_size
|
| 206 |
+
]
|
| 207 |
+
else:
|
| 208 |
+
# remove tail of data to make it evenly divisible.
|
| 209 |
+
indices = indices[: self.total_size]
|
| 210 |
+
assert len(indices) == self.total_size
|
| 211 |
+
|
| 212 |
+
# balance load among processes
|
| 213 |
+
num_batches = len(indices) // (self.batch_size * self.world_size)
|
| 214 |
+
balanced_indices = []
|
| 215 |
+
for i in range(num_batches):
|
| 216 |
+
start_idx = i * self.batch_size * self.world_size
|
| 217 |
+
end_idx = (i + 1) * self.batch_size * self.world_size
|
| 218 |
+
batch_indices = indices[start_idx:end_idx]
|
| 219 |
+
batch_loads = [self.loads[idx] for idx in batch_indices]
|
| 220 |
+
groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
|
| 221 |
+
balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
|
| 222 |
+
|
| 223 |
+
# resume from previous state
|
| 224 |
+
indices = balanced_indices[self.idx:]
|
| 225 |
+
|
| 226 |
+
return iter(indices)
|
trellis/utils/dist_utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
|
| 10 |
+
os.environ['MASTER_ADDR'] = master_addr
|
| 11 |
+
os.environ['MASTER_PORT'] = master_port
|
| 12 |
+
os.environ['WORLD_SIZE'] = str(world_size)
|
| 13 |
+
os.environ['RANK'] = str(rank)
|
| 14 |
+
os.environ['LOCAL_RANK'] = str(local_rank)
|
| 15 |
+
torch.cuda.set_device(local_rank)
|
| 16 |
+
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def read_file_dist(path):
|
| 20 |
+
"""
|
| 21 |
+
Read the binary file distributedly.
|
| 22 |
+
File is only read once by the rank 0 process and broadcasted to other processes.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
data (io.BytesIO): The binary data read from the file.
|
| 26 |
+
"""
|
| 27 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 28 |
+
# read file
|
| 29 |
+
size = torch.LongTensor(1).cuda()
|
| 30 |
+
if dist.get_rank() == 0:
|
| 31 |
+
with open(path, 'rb') as f:
|
| 32 |
+
data = f.read()
|
| 33 |
+
data = torch.ByteTensor(
|
| 34 |
+
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
|
| 35 |
+
).cuda()
|
| 36 |
+
size[0] = data.shape[0]
|
| 37 |
+
# broadcast size
|
| 38 |
+
dist.broadcast(size, src=0)
|
| 39 |
+
if dist.get_rank() != 0:
|
| 40 |
+
data = torch.ByteTensor(size[0].item()).cuda()
|
| 41 |
+
# broadcast data
|
| 42 |
+
dist.broadcast(data, src=0)
|
| 43 |
+
# convert to io.BytesIO
|
| 44 |
+
data = data.cpu().numpy().tobytes()
|
| 45 |
+
data = io.BytesIO(data)
|
| 46 |
+
return data
|
| 47 |
+
else:
|
| 48 |
+
with open(path, 'rb') as f:
|
| 49 |
+
data = f.read()
|
| 50 |
+
data = io.BytesIO(data)
|
| 51 |
+
return data
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def unwrap_dist(model):
|
| 55 |
+
"""
|
| 56 |
+
Unwrap the model from distributed training.
|
| 57 |
+
"""
|
| 58 |
+
if isinstance(model, DDP):
|
| 59 |
+
return model.module
|
| 60 |
+
return model
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@contextmanager
|
| 64 |
+
def master_first():
|
| 65 |
+
"""
|
| 66 |
+
A context manager that ensures master process executes first.
|
| 67 |
+
"""
|
| 68 |
+
if not dist.is_initialized():
|
| 69 |
+
yield
|
| 70 |
+
else:
|
| 71 |
+
if dist.get_rank() == 0:
|
| 72 |
+
yield
|
| 73 |
+
dist.barrier()
|
| 74 |
+
else:
|
| 75 |
+
dist.barrier()
|
| 76 |
+
yield
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@contextmanager
|
| 80 |
+
def local_master_first():
|
| 81 |
+
"""
|
| 82 |
+
A context manager that ensures local master process executes first.
|
| 83 |
+
"""
|
| 84 |
+
if not dist.is_initialized():
|
| 85 |
+
yield
|
| 86 |
+
else:
|
| 87 |
+
if dist.get_rank() % torch.cuda.device_count() == 0:
|
| 88 |
+
yield
|
| 89 |
+
dist.barrier()
|
| 90 |
+
else:
|
| 91 |
+
dist.barrier()
|
| 92 |
+
yield
|
| 93 |
+
|
trellis/utils/elastic_utils.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MemoryController:
|
| 10 |
+
"""
|
| 11 |
+
Base class for memory management during training.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
_last_input_size = None
|
| 15 |
+
_last_mem_ratio = []
|
| 16 |
+
|
| 17 |
+
@contextmanager
|
| 18 |
+
def record(self):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
def update_run_states(self, input_size=None, mem_ratio=None):
|
| 22 |
+
if self._last_input_size is None:
|
| 23 |
+
self._last_input_size = input_size
|
| 24 |
+
elif self._last_input_size!= input_size:
|
| 25 |
+
raise ValueError(f'Input size should not change for different ElasticModules.')
|
| 26 |
+
self._last_mem_ratio.append(mem_ratio)
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def get_mem_ratio(self, input_size):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
@abstractmethod
|
| 33 |
+
def state_dict(self):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
@abstractmethod
|
| 37 |
+
def log(self):
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LinearMemoryController(MemoryController):
|
| 42 |
+
"""
|
| 43 |
+
A simple controller for memory management during training.
|
| 44 |
+
The memory usage is modeled as a linear function of:
|
| 45 |
+
- the number of input parameters
|
| 46 |
+
- the ratio of memory the model use compared to the maximum usage (with no checkpointing)
|
| 47 |
+
memory_usage = k * input_size * mem_ratio + b
|
| 48 |
+
The controller keeps track of the memory usage and gives the
|
| 49 |
+
expected memory ratio to keep the memory usage under a target
|
| 50 |
+
"""
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
buffer_size=1000,
|
| 54 |
+
update_every=500,
|
| 55 |
+
target_ratio=0.8,
|
| 56 |
+
available_memory=None,
|
| 57 |
+
max_mem_ratio_start=0.1,
|
| 58 |
+
params=None,
|
| 59 |
+
device=None
|
| 60 |
+
):
|
| 61 |
+
self.buffer_size = buffer_size
|
| 62 |
+
self.update_every = update_every
|
| 63 |
+
self.target_ratio = target_ratio
|
| 64 |
+
self.device = device or torch.cuda.current_device()
|
| 65 |
+
self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
|
| 66 |
+
|
| 67 |
+
self._memory = np.zeros(buffer_size, dtype=np.float32)
|
| 68 |
+
self._input_size = np.zeros(buffer_size, dtype=np.float32)
|
| 69 |
+
self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
|
| 70 |
+
self._buffer_ptr = 0
|
| 71 |
+
self._buffer_length = 0
|
| 72 |
+
self._params = tuple(params) if params is not None else (0.0, 0.0)
|
| 73 |
+
self._max_mem_ratio = max_mem_ratio_start
|
| 74 |
+
self.step = 0
|
| 75 |
+
|
| 76 |
+
def __repr__(self):
|
| 77 |
+
return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
|
| 78 |
+
|
| 79 |
+
def _add_sample(self, memory, input_size, mem_ratio):
|
| 80 |
+
self._memory[self._buffer_ptr] = memory
|
| 81 |
+
self._input_size[self._buffer_ptr] = input_size
|
| 82 |
+
self._mem_ratio[self._buffer_ptr] = mem_ratio
|
| 83 |
+
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
| 84 |
+
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
| 85 |
+
|
| 86 |
+
@contextmanager
|
| 87 |
+
def record(self):
|
| 88 |
+
torch.cuda.reset_peak_memory_stats(self.device)
|
| 89 |
+
self._last_input_size = None
|
| 90 |
+
self._last_mem_ratio = []
|
| 91 |
+
yield
|
| 92 |
+
self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
|
| 93 |
+
self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
|
| 94 |
+
self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
|
| 95 |
+
self.step += 1
|
| 96 |
+
if self.step % self.update_every == 0:
|
| 97 |
+
self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
|
| 98 |
+
self._fit_params()
|
| 99 |
+
|
| 100 |
+
def _fit_params(self):
|
| 101 |
+
memory_usage = self._memory[:self._buffer_length]
|
| 102 |
+
input_size = self._input_size[:self._buffer_length]
|
| 103 |
+
mem_ratio = self._mem_ratio[:self._buffer_length]
|
| 104 |
+
|
| 105 |
+
x = input_size * mem_ratio
|
| 106 |
+
y = memory_usage
|
| 107 |
+
k, b = np.polyfit(x, y, 1)
|
| 108 |
+
self._params = (k, b)
|
| 109 |
+
# self._visualize()
|
| 110 |
+
|
| 111 |
+
def _visualize(self):
|
| 112 |
+
import matplotlib.pyplot as plt
|
| 113 |
+
memory_usage = self._memory[:self._buffer_length]
|
| 114 |
+
input_size = self._input_size[:self._buffer_length]
|
| 115 |
+
mem_ratio = self._mem_ratio[:self._buffer_length]
|
| 116 |
+
k, b = self._params
|
| 117 |
+
|
| 118 |
+
plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
|
| 119 |
+
x = np.array([0.0, 20000.0])
|
| 120 |
+
plt.plot(x, k * x + b, c='r')
|
| 121 |
+
plt.savefig(f'linear_memory_controller_{self.step}.png')
|
| 122 |
+
plt.cla()
|
| 123 |
+
|
| 124 |
+
def get_mem_ratio(self, input_size):
|
| 125 |
+
k, b = self._params
|
| 126 |
+
if k == 0: return np.random.rand() * self._max_mem_ratio
|
| 127 |
+
pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
|
| 128 |
+
return min(self._max_mem_ratio, max(0.0, pred))
|
| 129 |
+
|
| 130 |
+
def state_dict(self):
|
| 131 |
+
return {
|
| 132 |
+
'params': self._params,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
def load_state_dict(self, state_dict):
|
| 136 |
+
self._params = tuple(state_dict['params'])
|
| 137 |
+
|
| 138 |
+
def log(self):
|
| 139 |
+
return {
|
| 140 |
+
'params/k': self._params[0],
|
| 141 |
+
'params/b': self._params[1],
|
| 142 |
+
'memory': self._last_memory,
|
| 143 |
+
'input_size': self._last_input_size,
|
| 144 |
+
'mem_ratio': self._last_mem_ratio,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class ElasticModule(nn.Module):
|
| 149 |
+
"""
|
| 150 |
+
Module for training with elastic memory management.
|
| 151 |
+
"""
|
| 152 |
+
def __init__(self):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self._memory_controller: MemoryController = None
|
| 155 |
+
|
| 156 |
+
@abstractmethod
|
| 157 |
+
def _get_input_size(self, *args, **kwargs) -> int:
|
| 158 |
+
"""
|
| 159 |
+
Get the size of the input data.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
int: The size of the input data.
|
| 163 |
+
"""
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
@abstractmethod
|
| 167 |
+
def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
|
| 168 |
+
"""
|
| 169 |
+
Forward with a given memory ratio.
|
| 170 |
+
"""
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
def register_memory_controller(self, memory_controller: MemoryController):
|
| 174 |
+
self._memory_controller = memory_controller
|
| 175 |
+
|
| 176 |
+
def forward(self, *args, **kwargs):
|
| 177 |
+
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
| 178 |
+
_, ret = self._forward_with_mem_ratio(*args, **kwargs)
|
| 179 |
+
else:
|
| 180 |
+
input_size = self._get_input_size(*args, **kwargs)
|
| 181 |
+
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
| 182 |
+
mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
|
| 183 |
+
self._memory_controller.update_run_states(input_size, mem_ratio)
|
| 184 |
+
return ret
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class ElasticModuleMixin:
|
| 188 |
+
"""
|
| 189 |
+
Mixin for training with elastic memory management.
|
| 190 |
+
"""
|
| 191 |
+
def __init__(self, *args, **kwargs):
|
| 192 |
+
super().__init__(*args, **kwargs)
|
| 193 |
+
self._memory_controller: MemoryController = None
|
| 194 |
+
|
| 195 |
+
@abstractmethod
|
| 196 |
+
def _get_input_size(self, *args, **kwargs) -> int:
|
| 197 |
+
"""
|
| 198 |
+
Get the size of the input data.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
int: The size of the input data.
|
| 202 |
+
"""
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
@abstractmethod
|
| 206 |
+
@contextmanager
|
| 207 |
+
def with_mem_ratio(self, mem_ratio=1.0) -> float:
|
| 208 |
+
"""
|
| 209 |
+
Context manager for training with a reduced memory ratio compared to the full memory usage.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
float: The exact memory ratio used during the forward pass.
|
| 213 |
+
"""
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
def register_memory_controller(self, memory_controller: MemoryController):
|
| 217 |
+
self._memory_controller = memory_controller
|
| 218 |
+
|
| 219 |
+
def forward(self, *args, **kwargs):
|
| 220 |
+
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
| 221 |
+
ret = super().forward(*args, **kwargs)
|
| 222 |
+
else:
|
| 223 |
+
input_size = self._get_input_size(*args, **kwargs)
|
| 224 |
+
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
| 225 |
+
with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
|
| 226 |
+
ret = super().forward(*args, **kwargs)
|
| 227 |
+
self._memory_controller.update_run_states(input_size, exact_mem_ratio)
|
| 228 |
+
return ret
|
trellis/utils/general_utils.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
import contextlib
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Dictionary utils
|
| 9 |
+
def _dict_merge(dicta, dictb, prefix=''):
|
| 10 |
+
"""
|
| 11 |
+
Merge two dictionaries.
|
| 12 |
+
"""
|
| 13 |
+
assert isinstance(dicta, dict), 'input must be a dictionary'
|
| 14 |
+
assert isinstance(dictb, dict), 'input must be a dictionary'
|
| 15 |
+
dict_ = {}
|
| 16 |
+
all_keys = set(dicta.keys()).union(set(dictb.keys()))
|
| 17 |
+
for key in all_keys:
|
| 18 |
+
if key in dicta.keys() and key in dictb.keys():
|
| 19 |
+
if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
|
| 20 |
+
dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
|
| 23 |
+
elif key in dicta.keys():
|
| 24 |
+
dict_[key] = dicta[key]
|
| 25 |
+
else:
|
| 26 |
+
dict_[key] = dictb[key]
|
| 27 |
+
return dict_
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def dict_merge(dicta, dictb):
|
| 31 |
+
"""
|
| 32 |
+
Merge two dictionaries.
|
| 33 |
+
"""
|
| 34 |
+
return _dict_merge(dicta, dictb, prefix='')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def dict_foreach(dic, func, special_func={}):
|
| 38 |
+
"""
|
| 39 |
+
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
| 40 |
+
"""
|
| 41 |
+
assert isinstance(dic, dict), 'input must be a dictionary'
|
| 42 |
+
for key in dic.keys():
|
| 43 |
+
if isinstance(dic[key], dict):
|
| 44 |
+
dic[key] = dict_foreach(dic[key], func)
|
| 45 |
+
else:
|
| 46 |
+
if key in special_func.keys():
|
| 47 |
+
dic[key] = special_func[key](dic[key])
|
| 48 |
+
else:
|
| 49 |
+
dic[key] = func(dic[key])
|
| 50 |
+
return dic
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def dict_reduce(dicts, func, special_func={}):
|
| 54 |
+
"""
|
| 55 |
+
Reduce a list of dictionaries. Leaf values must be scalars.
|
| 56 |
+
"""
|
| 57 |
+
assert isinstance(dicts, list), 'input must be a list of dictionaries'
|
| 58 |
+
assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
|
| 59 |
+
assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
|
| 60 |
+
all_keys = set([key for dict_ in dicts for key in dict_.keys()])
|
| 61 |
+
reduced_dict = {}
|
| 62 |
+
for key in all_keys:
|
| 63 |
+
vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
|
| 64 |
+
if isinstance(vlist[0], dict):
|
| 65 |
+
reduced_dict[key] = dict_reduce(vlist, func, special_func)
|
| 66 |
+
else:
|
| 67 |
+
if key in special_func.keys():
|
| 68 |
+
reduced_dict[key] = special_func[key](vlist)
|
| 69 |
+
else:
|
| 70 |
+
reduced_dict[key] = func(vlist)
|
| 71 |
+
return reduced_dict
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def dict_any(dic, func):
|
| 75 |
+
"""
|
| 76 |
+
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
| 77 |
+
"""
|
| 78 |
+
assert isinstance(dic, dict), 'input must be a dictionary'
|
| 79 |
+
for key in dic.keys():
|
| 80 |
+
if isinstance(dic[key], dict):
|
| 81 |
+
if dict_any(dic[key], func):
|
| 82 |
+
return True
|
| 83 |
+
else:
|
| 84 |
+
if func(dic[key]):
|
| 85 |
+
return True
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def dict_all(dic, func):
|
| 90 |
+
"""
|
| 91 |
+
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
| 92 |
+
"""
|
| 93 |
+
assert isinstance(dic, dict), 'input must be a dictionary'
|
| 94 |
+
for key in dic.keys():
|
| 95 |
+
if isinstance(dic[key], dict):
|
| 96 |
+
if not dict_all(dic[key], func):
|
| 97 |
+
return False
|
| 98 |
+
else:
|
| 99 |
+
if not func(dic[key]):
|
| 100 |
+
return False
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def dict_flatten(dic, sep='.'):
|
| 105 |
+
"""
|
| 106 |
+
Flatten a nested dictionary into a dictionary with no nested dictionaries.
|
| 107 |
+
"""
|
| 108 |
+
assert isinstance(dic, dict), 'input must be a dictionary'
|
| 109 |
+
flat_dict = {}
|
| 110 |
+
for key in dic.keys():
|
| 111 |
+
if isinstance(dic[key], dict):
|
| 112 |
+
sub_dict = dict_flatten(dic[key], sep=sep)
|
| 113 |
+
for sub_key in sub_dict.keys():
|
| 114 |
+
flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
|
| 115 |
+
else:
|
| 116 |
+
flat_dict[key] = dic[key]
|
| 117 |
+
return flat_dict
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Context utils
|
| 121 |
+
@contextlib.contextmanager
|
| 122 |
+
def nested_contexts(*contexts):
|
| 123 |
+
with contextlib.ExitStack() as stack:
|
| 124 |
+
for ctx in contexts:
|
| 125 |
+
stack.enter_context(ctx())
|
| 126 |
+
yield
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# Image utils
|
| 130 |
+
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
|
| 131 |
+
num_images = len(images)
|
| 132 |
+
if nrow is None and ncol is None:
|
| 133 |
+
if aspect_ratio is not None:
|
| 134 |
+
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
|
| 135 |
+
else:
|
| 136 |
+
nrow = int(np.sqrt(num_images))
|
| 137 |
+
ncol = (num_images + nrow - 1) // nrow
|
| 138 |
+
elif nrow is None and ncol is not None:
|
| 139 |
+
nrow = (num_images + ncol - 1) // ncol
|
| 140 |
+
elif nrow is not None and ncol is None:
|
| 141 |
+
ncol = (num_images + nrow - 1) // nrow
|
| 142 |
+
else:
|
| 143 |
+
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
|
| 144 |
+
|
| 145 |
+
if images[0].ndim == 2:
|
| 146 |
+
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
|
| 147 |
+
else:
|
| 148 |
+
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
|
| 149 |
+
for i, img in enumerate(images):
|
| 150 |
+
row = i // ncol
|
| 151 |
+
col = i % ncol
|
| 152 |
+
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
|
| 153 |
+
return grid
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def notes_on_image(img, notes=None):
|
| 157 |
+
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
|
| 158 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 159 |
+
if notes is not None:
|
| 160 |
+
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
|
| 161 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 162 |
+
return img
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def save_image_with_notes(img, path, notes=None):
|
| 166 |
+
"""
|
| 167 |
+
Save an image with notes.
|
| 168 |
+
"""
|
| 169 |
+
if isinstance(img, torch.Tensor):
|
| 170 |
+
img = img.cpu().numpy().transpose(1, 2, 0)
|
| 171 |
+
if img.dtype == np.float32 or img.dtype == np.float64:
|
| 172 |
+
img = np.clip(img * 255, 0, 255).astype(np.uint8)
|
| 173 |
+
img = notes_on_image(img, notes)
|
| 174 |
+
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# debug utils
|
| 178 |
+
|
| 179 |
+
def atol(x, y):
|
| 180 |
+
"""
|
| 181 |
+
Absolute tolerance.
|
| 182 |
+
"""
|
| 183 |
+
return torch.abs(x - y)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def rtol(x, y):
|
| 187 |
+
"""
|
| 188 |
+
Relative tolerance.
|
| 189 |
+
"""
|
| 190 |
+
return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# print utils
|
| 194 |
+
def indent(s, n=4):
|
| 195 |
+
"""
|
| 196 |
+
Indent a string.
|
| 197 |
+
"""
|
| 198 |
+
lines = s.split('\n')
|
| 199 |
+
for i in range(1, len(lines)):
|
| 200 |
+
lines[i] = ' ' * n + lines[i]
|
| 201 |
+
return '\n'.join(lines)
|
| 202 |
+
|
trellis/utils/grad_clip_utils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.utils
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AdaptiveGradClipper:
|
| 8 |
+
"""
|
| 9 |
+
Adaptive gradient clipping for training.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
max_norm=None,
|
| 14 |
+
clip_percentile=95.0,
|
| 15 |
+
buffer_size=1000,
|
| 16 |
+
):
|
| 17 |
+
self.max_norm = max_norm
|
| 18 |
+
self.clip_percentile = clip_percentile
|
| 19 |
+
self.buffer_size = buffer_size
|
| 20 |
+
|
| 21 |
+
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
|
| 22 |
+
self._max_norm = max_norm
|
| 23 |
+
self._buffer_ptr = 0
|
| 24 |
+
self._buffer_length = 0
|
| 25 |
+
|
| 26 |
+
def __repr__(self):
|
| 27 |
+
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
|
| 28 |
+
|
| 29 |
+
def state_dict(self):
|
| 30 |
+
return {
|
| 31 |
+
'grad_norm': self._grad_norm,
|
| 32 |
+
'max_norm': self._max_norm,
|
| 33 |
+
'buffer_ptr': self._buffer_ptr,
|
| 34 |
+
'buffer_length': self._buffer_length,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
def load_state_dict(self, state_dict):
|
| 38 |
+
self._grad_norm = state_dict['grad_norm']
|
| 39 |
+
self._max_norm = state_dict['max_norm']
|
| 40 |
+
self._buffer_ptr = state_dict['buffer_ptr']
|
| 41 |
+
self._buffer_length = state_dict['buffer_length']
|
| 42 |
+
|
| 43 |
+
def log(self):
|
| 44 |
+
return {
|
| 45 |
+
'max_norm': self._max_norm,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
|
| 49 |
+
"""Clip the gradient norm of an iterable of parameters.
|
| 50 |
+
|
| 51 |
+
The norm is computed over all gradients together, as if they were
|
| 52 |
+
concatenated into a single vector. Gradients are modified in-place.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
| 56 |
+
single Tensor that will have gradients normalized
|
| 57 |
+
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
| 58 |
+
infinity norm.
|
| 59 |
+
error_if_nonfinite (bool): if True, an error is thrown if the total
|
| 60 |
+
norm of the gradients from :attr:`parameters` is ``nan``,
|
| 61 |
+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
| 62 |
+
foreach (bool): use the faster foreach-based implementation.
|
| 63 |
+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
| 64 |
+
fall back to the slow implementation for other device types.
|
| 65 |
+
Default: ``None``
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
| 69 |
+
"""
|
| 70 |
+
max_norm = self._max_norm if self._max_norm is not None else float('inf')
|
| 71 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
|
| 72 |
+
|
| 73 |
+
if torch.isfinite(grad_norm):
|
| 74 |
+
self._grad_norm[self._buffer_ptr] = grad_norm
|
| 75 |
+
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
| 76 |
+
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
| 77 |
+
if self._buffer_length == self.buffer_size:
|
| 78 |
+
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
|
| 79 |
+
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
|
| 80 |
+
|
| 81 |
+
return grad_norm
|
trellis/utils/loss_utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
from math import exp
|
| 5 |
+
from lpips import LPIPS
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def smooth_l1_loss(pred, target, beta=1.0):
|
| 9 |
+
diff = torch.abs(pred - target)
|
| 10 |
+
loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
|
| 11 |
+
return loss.mean()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def l1_loss(network_output, gt):
|
| 15 |
+
return torch.abs((network_output - gt)).mean()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def l2_loss(network_output, gt):
|
| 19 |
+
return ((network_output - gt) ** 2).mean()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gaussian(window_size, sigma):
|
| 23 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
| 24 |
+
return gauss / gauss.sum()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_window(window_size, channel):
|
| 28 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 29 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 30 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
| 31 |
+
return window
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def psnr(img1, img2, max_val=1.0):
|
| 35 |
+
mse = F.mse_loss(img1, img2)
|
| 36 |
+
return 20 * torch.log10(max_val / torch.sqrt(mse))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
| 40 |
+
channel = img1.size(-3)
|
| 41 |
+
window = create_window(window_size, channel)
|
| 42 |
+
|
| 43 |
+
if img1.is_cuda:
|
| 44 |
+
window = window.cuda(img1.get_device())
|
| 45 |
+
window = window.type_as(img1)
|
| 46 |
+
|
| 47 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
| 48 |
+
|
| 49 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
| 50 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
| 51 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
| 52 |
+
|
| 53 |
+
mu1_sq = mu1.pow(2)
|
| 54 |
+
mu2_sq = mu2.pow(2)
|
| 55 |
+
mu1_mu2 = mu1 * mu2
|
| 56 |
+
|
| 57 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
| 58 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
| 59 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
| 60 |
+
|
| 61 |
+
C1 = 0.01 ** 2
|
| 62 |
+
C2 = 0.03 ** 2
|
| 63 |
+
|
| 64 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
| 65 |
+
|
| 66 |
+
if size_average:
|
| 67 |
+
return ssim_map.mean()
|
| 68 |
+
else:
|
| 69 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
loss_fn_vgg = None
|
| 73 |
+
def lpips(img1, img2, value_range=(0, 1)):
|
| 74 |
+
global loss_fn_vgg
|
| 75 |
+
if loss_fn_vgg is None:
|
| 76 |
+
loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
|
| 77 |
+
# normalize to [-1, 1]
|
| 78 |
+
img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
|
| 79 |
+
img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
|
| 80 |
+
return loss_fn_vgg(img1, img2).mean()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def normal_angle(pred, gt):
|
| 84 |
+
pred = pred * 2.0 - 1.0
|
| 85 |
+
gt = gt * 2.0 - 1.0
|
| 86 |
+
norms = pred.norm(dim=-1) * gt.norm(dim=-1)
|
| 87 |
+
cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
|
| 88 |
+
cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
|
| 89 |
+
ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
|
| 90 |
+
if ang.isnan():
|
| 91 |
+
return -1
|
| 92 |
+
return ang
|
trellis/utils/postprocessing_utils.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import utils3d
|
| 5 |
+
import nvdiffrast.torch as dr
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import trimesh
|
| 8 |
+
import trimesh.visual
|
| 9 |
+
import xatlas
|
| 10 |
+
import pyvista as pv
|
| 11 |
+
from pymeshfix import _meshfix
|
| 12 |
+
import igraph
|
| 13 |
+
import cv2
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from .random_utils import sphere_hammersley_sequence
|
| 16 |
+
from .render_utils import render_multiview
|
| 17 |
+
from ..renderers import GaussianRenderer
|
| 18 |
+
from ..representations import Strivec, Gaussian, MeshExtractResult
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def _fill_holes(
|
| 23 |
+
verts,
|
| 24 |
+
faces,
|
| 25 |
+
max_hole_size=0.04,
|
| 26 |
+
max_hole_nbe=32,
|
| 27 |
+
resolution=128,
|
| 28 |
+
num_views=500,
|
| 29 |
+
debug=False,
|
| 30 |
+
verbose=False
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Rasterize a mesh from multiple views and remove invisible faces.
|
| 34 |
+
Also includes postprocessing to:
|
| 35 |
+
1. Remove connected components that are have low visibility.
|
| 36 |
+
2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
|
| 40 |
+
faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
|
| 41 |
+
max_hole_size (float): Maximum area of a hole to fill.
|
| 42 |
+
resolution (int): Resolution of the rasterization.
|
| 43 |
+
num_views (int): Number of views to rasterize the mesh.
|
| 44 |
+
verbose (bool): Whether to print progress.
|
| 45 |
+
"""
|
| 46 |
+
# Construct cameras
|
| 47 |
+
yaws = []
|
| 48 |
+
pitchs = []
|
| 49 |
+
for i in range(num_views):
|
| 50 |
+
y, p = sphere_hammersley_sequence(i, num_views)
|
| 51 |
+
yaws.append(y)
|
| 52 |
+
pitchs.append(p)
|
| 53 |
+
yaws = torch.tensor(yaws).cuda()
|
| 54 |
+
pitchs = torch.tensor(pitchs).cuda()
|
| 55 |
+
radius = 2.0
|
| 56 |
+
fov = torch.deg2rad(torch.tensor(40)).cuda()
|
| 57 |
+
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
|
| 58 |
+
views = []
|
| 59 |
+
for (yaw, pitch) in zip(yaws, pitchs):
|
| 60 |
+
orig = torch.tensor([
|
| 61 |
+
torch.sin(yaw) * torch.cos(pitch),
|
| 62 |
+
torch.cos(yaw) * torch.cos(pitch),
|
| 63 |
+
torch.sin(pitch),
|
| 64 |
+
]).cuda().float() * radius
|
| 65 |
+
view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
| 66 |
+
views.append(view)
|
| 67 |
+
views = torch.stack(views, dim=0)
|
| 68 |
+
|
| 69 |
+
# Rasterize
|
| 70 |
+
visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
|
| 71 |
+
rastctx = utils3d.torch.RastContext(backend='cuda')
|
| 72 |
+
for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
|
| 73 |
+
view = views[i]
|
| 74 |
+
buffers = utils3d.torch.rasterize_triangle_faces(
|
| 75 |
+
rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection
|
| 76 |
+
)
|
| 77 |
+
face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
|
| 78 |
+
face_id = torch.unique(face_id).long()
|
| 79 |
+
visblity[face_id] += 1
|
| 80 |
+
visblity = visblity.float() / num_views
|
| 81 |
+
|
| 82 |
+
# Mincut
|
| 83 |
+
## construct outer faces
|
| 84 |
+
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
|
| 85 |
+
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
|
| 86 |
+
connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
|
| 87 |
+
outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
|
| 88 |
+
for i in range(len(connected_components)):
|
| 89 |
+
outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
|
| 90 |
+
outer_face_indices = outer_face_indices.nonzero().reshape(-1)
|
| 91 |
+
|
| 92 |
+
## construct inner faces
|
| 93 |
+
inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
|
| 94 |
+
if verbose:
|
| 95 |
+
tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
|
| 96 |
+
if inner_face_indices.shape[0] == 0:
|
| 97 |
+
return verts, faces
|
| 98 |
+
|
| 99 |
+
## Construct dual graph (faces as nodes, edges as edges)
|
| 100 |
+
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
|
| 101 |
+
dual_edge2edge = edges[dual_edge2edge]
|
| 102 |
+
dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
|
| 103 |
+
if verbose:
|
| 104 |
+
tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
|
| 105 |
+
|
| 106 |
+
## solve mincut problem
|
| 107 |
+
### construct main graph
|
| 108 |
+
g = igraph.Graph()
|
| 109 |
+
g.add_vertices(faces.shape[0])
|
| 110 |
+
g.add_edges(dual_edges.cpu().numpy())
|
| 111 |
+
g.es['weight'] = dual_edges_weights.cpu().numpy()
|
| 112 |
+
|
| 113 |
+
### source and target
|
| 114 |
+
g.add_vertex('s')
|
| 115 |
+
g.add_vertex('t')
|
| 116 |
+
|
| 117 |
+
### connect invisible faces to source
|
| 118 |
+
g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
|
| 119 |
+
|
| 120 |
+
### connect outer faces to target
|
| 121 |
+
g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
|
| 122 |
+
|
| 123 |
+
### solve mincut
|
| 124 |
+
cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
|
| 125 |
+
remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
|
| 126 |
+
if verbose:
|
| 127 |
+
tqdm.write(f'Mincut solved, start checking the cut')
|
| 128 |
+
|
| 129 |
+
### check if the cut is valid with each connected component
|
| 130 |
+
to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
|
| 131 |
+
if debug:
|
| 132 |
+
tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
|
| 133 |
+
valid_remove_cc = []
|
| 134 |
+
cutting_edges = []
|
| 135 |
+
for cc in to_remove_cc:
|
| 136 |
+
#### check if the connected component has low visibility
|
| 137 |
+
visblity_median = visblity[remove_face_indices[cc]].median()
|
| 138 |
+
if debug:
|
| 139 |
+
tqdm.write(f'visblity_median: {visblity_median}')
|
| 140 |
+
if visblity_median > 0.25:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
#### check if the cuting loop is small enough
|
| 144 |
+
cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
|
| 145 |
+
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
|
| 146 |
+
cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
|
| 147 |
+
if len(cc_new_boundary_edge_indices) > 0:
|
| 148 |
+
cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
|
| 149 |
+
cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
|
| 150 |
+
cc_new_boundary_edges_cc_area = []
|
| 151 |
+
for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
|
| 152 |
+
_e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
|
| 153 |
+
_e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
|
| 154 |
+
cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
|
| 155 |
+
if debug:
|
| 156 |
+
cutting_edges.append(cc_new_boundary_edge_indices)
|
| 157 |
+
tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
|
| 158 |
+
if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
valid_remove_cc.append(cc)
|
| 162 |
+
|
| 163 |
+
if debug:
|
| 164 |
+
face_v = verts[faces].mean(dim=1).cpu().numpy()
|
| 165 |
+
vis_dual_edges = dual_edges.cpu().numpy()
|
| 166 |
+
vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
|
| 167 |
+
vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
|
| 168 |
+
vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
|
| 169 |
+
vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
|
| 170 |
+
if len(valid_remove_cc) > 0:
|
| 171 |
+
vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
|
| 172 |
+
utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
|
| 173 |
+
|
| 174 |
+
vis_verts = verts.cpu().numpy()
|
| 175 |
+
vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
|
| 176 |
+
utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if len(valid_remove_cc) > 0:
|
| 180 |
+
remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
|
| 181 |
+
mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
|
| 182 |
+
mask[remove_face_indices] = 0
|
| 183 |
+
faces = faces[mask]
|
| 184 |
+
faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
|
| 185 |
+
if verbose:
|
| 186 |
+
tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
|
| 187 |
+
else:
|
| 188 |
+
if verbose:
|
| 189 |
+
tqdm.write(f'Removed 0 faces by mincut')
|
| 190 |
+
|
| 191 |
+
mesh = _meshfix.PyTMesh()
|
| 192 |
+
mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
|
| 193 |
+
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
|
| 194 |
+
verts, faces = mesh.return_arrays()
|
| 195 |
+
verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
|
| 196 |
+
|
| 197 |
+
return verts, faces
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def postprocess_mesh(
|
| 201 |
+
vertices: np.array,
|
| 202 |
+
faces: np.array,
|
| 203 |
+
simplify: bool = True,
|
| 204 |
+
simplify_ratio: float = 0.9,
|
| 205 |
+
fill_holes: bool = True,
|
| 206 |
+
fill_holes_max_hole_size: float = 0.04,
|
| 207 |
+
fill_holes_max_hole_nbe: int = 32,
|
| 208 |
+
fill_holes_resolution: int = 1024,
|
| 209 |
+
fill_holes_num_views: int = 1000,
|
| 210 |
+
debug: bool = False,
|
| 211 |
+
verbose: bool = False,
|
| 212 |
+
):
|
| 213 |
+
"""
|
| 214 |
+
Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
| 218 |
+
faces (np.array): Faces of the mesh. Shape (F, 3).
|
| 219 |
+
simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
|
| 220 |
+
simplify_ratio (float): Ratio of faces to keep after simplification.
|
| 221 |
+
fill_holes (bool): Whether to fill holes in the mesh.
|
| 222 |
+
fill_holes_max_hole_size (float): Maximum area of a hole to fill.
|
| 223 |
+
fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
|
| 224 |
+
fill_holes_resolution (int): Resolution of the rasterization.
|
| 225 |
+
fill_holes_num_views (int): Number of views to rasterize the mesh.
|
| 226 |
+
verbose (bool): Whether to print progress.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
if verbose:
|
| 230 |
+
tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
| 231 |
+
|
| 232 |
+
# Simplify
|
| 233 |
+
if simplify and simplify_ratio > 0:
|
| 234 |
+
mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
|
| 235 |
+
mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
|
| 236 |
+
vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
|
| 237 |
+
if verbose:
|
| 238 |
+
tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
| 239 |
+
|
| 240 |
+
# Remove invisible faces
|
| 241 |
+
if fill_holes:
|
| 242 |
+
vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
|
| 243 |
+
vertices, faces = _fill_holes(
|
| 244 |
+
vertices, faces,
|
| 245 |
+
max_hole_size=fill_holes_max_hole_size,
|
| 246 |
+
max_hole_nbe=fill_holes_max_hole_nbe,
|
| 247 |
+
resolution=fill_holes_resolution,
|
| 248 |
+
num_views=fill_holes_num_views,
|
| 249 |
+
debug=debug,
|
| 250 |
+
verbose=verbose,
|
| 251 |
+
)
|
| 252 |
+
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
|
| 253 |
+
if verbose:
|
| 254 |
+
tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
| 255 |
+
|
| 256 |
+
return vertices, faces
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def parametrize_mesh(vertices: np.array, faces: np.array):
|
| 260 |
+
"""
|
| 261 |
+
Parametrize a mesh to a texture space, using xatlas.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
| 265 |
+
faces (np.array): Faces of the mesh. Shape (F, 3).
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
| 269 |
+
|
| 270 |
+
vertices = vertices[vmapping]
|
| 271 |
+
faces = indices
|
| 272 |
+
|
| 273 |
+
return vertices, faces, uvs
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def bake_texture(
|
| 277 |
+
vertices: np.array,
|
| 278 |
+
faces: np.array,
|
| 279 |
+
uvs: np.array,
|
| 280 |
+
observations: List[np.array],
|
| 281 |
+
masks: List[np.array],
|
| 282 |
+
extrinsics: List[np.array],
|
| 283 |
+
intrinsics: List[np.array],
|
| 284 |
+
texture_size: int = 2048,
|
| 285 |
+
near: float = 0.1,
|
| 286 |
+
far: float = 10.0,
|
| 287 |
+
mode: Literal['fast', 'opt'] = 'opt',
|
| 288 |
+
lambda_tv: float = 1e-2,
|
| 289 |
+
verbose: bool = False,
|
| 290 |
+
):
|
| 291 |
+
"""
|
| 292 |
+
Bake texture to a mesh from multiple observations.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
| 296 |
+
faces (np.array): Faces of the mesh. Shape (F, 3).
|
| 297 |
+
uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
|
| 298 |
+
observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
|
| 299 |
+
masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
|
| 300 |
+
extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
|
| 301 |
+
intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
|
| 302 |
+
texture_size (int): Size of the texture.
|
| 303 |
+
near (float): Near plane of the camera.
|
| 304 |
+
far (float): Far plane of the camera.
|
| 305 |
+
mode (Literal['fast', 'opt']): Mode of texture baking.
|
| 306 |
+
lambda_tv (float): Weight of total variation loss in optimization.
|
| 307 |
+
verbose (bool): Whether to print progress.
|
| 308 |
+
"""
|
| 309 |
+
vertices = torch.tensor(vertices).cuda()
|
| 310 |
+
faces = torch.tensor(faces.astype(np.int32)).cuda()
|
| 311 |
+
uvs = torch.tensor(uvs).cuda()
|
| 312 |
+
observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations]
|
| 313 |
+
masks = [torch.tensor(m>0).bool().cuda() for m in masks]
|
| 314 |
+
views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics]
|
| 315 |
+
projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics]
|
| 316 |
+
|
| 317 |
+
if mode == 'fast':
|
| 318 |
+
texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
|
| 319 |
+
texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
|
| 320 |
+
rastctx = utils3d.torch.RastContext(backend='cuda')
|
| 321 |
+
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
|
| 322 |
+
with torch.no_grad():
|
| 323 |
+
rast = utils3d.torch.rasterize_triangle_faces(
|
| 324 |
+
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
| 325 |
+
)
|
| 326 |
+
uv_map = rast['uv'][0].detach().flip(0)
|
| 327 |
+
mask = rast['mask'][0].detach().bool() & masks[0]
|
| 328 |
+
|
| 329 |
+
# nearest neighbor interpolation
|
| 330 |
+
uv_map = (uv_map * texture_size).floor().long()
|
| 331 |
+
obs = observation[mask]
|
| 332 |
+
uv_map = uv_map[mask]
|
| 333 |
+
idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
|
| 334 |
+
texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
|
| 335 |
+
texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
|
| 336 |
+
|
| 337 |
+
mask = texture_weights > 0
|
| 338 |
+
texture[mask] /= texture_weights[mask][:, None]
|
| 339 |
+
texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
| 340 |
+
|
| 341 |
+
# inpaint
|
| 342 |
+
mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
|
| 343 |
+
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
| 344 |
+
|
| 345 |
+
elif mode == 'opt':
|
| 346 |
+
rastctx = utils3d.torch.RastContext(backend='cuda')
|
| 347 |
+
observations = [observations.flip(0) for observations in observations]
|
| 348 |
+
masks = [m.flip(0) for m in masks]
|
| 349 |
+
_uv = []
|
| 350 |
+
_uv_dr = []
|
| 351 |
+
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
rast = utils3d.torch.rasterize_triangle_faces(
|
| 354 |
+
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
| 355 |
+
)
|
| 356 |
+
_uv.append(rast['uv'].detach())
|
| 357 |
+
_uv_dr.append(rast['uv_dr'].detach())
|
| 358 |
+
|
| 359 |
+
texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
|
| 360 |
+
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
|
| 361 |
+
|
| 362 |
+
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
| 363 |
+
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
| 364 |
+
|
| 365 |
+
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
| 366 |
+
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
| 367 |
+
|
| 368 |
+
def tv_loss(texture):
|
| 369 |
+
return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
|
| 370 |
+
torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
|
| 371 |
+
|
| 372 |
+
total_steps = 2500
|
| 373 |
+
with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
|
| 374 |
+
for step in range(total_steps):
|
| 375 |
+
optimizer.zero_grad()
|
| 376 |
+
selected = np.random.randint(0, len(views))
|
| 377 |
+
uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
|
| 378 |
+
render = dr.texture(texture, uv, uv_dr)[0]
|
| 379 |
+
loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
|
| 380 |
+
if lambda_tv > 0:
|
| 381 |
+
loss += lambda_tv * tv_loss(texture)
|
| 382 |
+
loss.backward()
|
| 383 |
+
optimizer.step()
|
| 384 |
+
# annealing
|
| 385 |
+
optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
|
| 386 |
+
pbar.set_postfix({'loss': loss.item()})
|
| 387 |
+
pbar.update()
|
| 388 |
+
texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
| 389 |
+
mask = 1 - utils3d.torch.rasterize_triangle_faces(
|
| 390 |
+
rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
|
| 391 |
+
)['mask'][0].detach().cpu().numpy().astype(np.uint8)
|
| 392 |
+
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
| 393 |
+
else:
|
| 394 |
+
raise ValueError(f'Unknown mode: {mode}')
|
| 395 |
+
|
| 396 |
+
return texture
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def to_glb(
|
| 400 |
+
app_rep: Union[Strivec, Gaussian],
|
| 401 |
+
mesh: MeshExtractResult,
|
| 402 |
+
simplify: float = 0.95,
|
| 403 |
+
fill_holes: bool = True,
|
| 404 |
+
fill_holes_max_size: float = 0.04,
|
| 405 |
+
texture_size: int = 1024,
|
| 406 |
+
debug: bool = False,
|
| 407 |
+
verbose: bool = True,
|
| 408 |
+
) -> trimesh.Trimesh:
|
| 409 |
+
"""
|
| 410 |
+
Convert a generated asset to a glb file.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
app_rep (Union[Strivec, Gaussian]): Appearance representation.
|
| 414 |
+
mesh (MeshExtractResult): Extracted mesh.
|
| 415 |
+
simplify (float): Ratio of faces to remove in simplification.
|
| 416 |
+
fill_holes (bool): Whether to fill holes in the mesh.
|
| 417 |
+
fill_holes_max_size (float): Maximum area of a hole to fill.
|
| 418 |
+
texture_size (int): Size of the texture.
|
| 419 |
+
debug (bool): Whether to print debug information.
|
| 420 |
+
verbose (bool): Whether to print progress.
|
| 421 |
+
"""
|
| 422 |
+
vertices = mesh.vertices.cpu().numpy()
|
| 423 |
+
faces = mesh.faces.cpu().numpy()
|
| 424 |
+
|
| 425 |
+
# mesh postprocess
|
| 426 |
+
vertices, faces = postprocess_mesh(
|
| 427 |
+
vertices, faces,
|
| 428 |
+
simplify=simplify > 0,
|
| 429 |
+
simplify_ratio=simplify,
|
| 430 |
+
fill_holes=fill_holes,
|
| 431 |
+
fill_holes_max_hole_size=fill_holes_max_size,
|
| 432 |
+
fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)),
|
| 433 |
+
fill_holes_resolution=1024,
|
| 434 |
+
fill_holes_num_views=1000,
|
| 435 |
+
debug=debug,
|
| 436 |
+
verbose=verbose,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# parametrize mesh
|
| 440 |
+
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
| 441 |
+
|
| 442 |
+
# bake texture
|
| 443 |
+
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
|
| 444 |
+
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
| 445 |
+
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
| 446 |
+
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|
| 447 |
+
texture = bake_texture(
|
| 448 |
+
vertices, faces, uvs,
|
| 449 |
+
observations, masks, extrinsics, intrinsics,
|
| 450 |
+
texture_size=texture_size, mode='opt',
|
| 451 |
+
lambda_tv=0.01,
|
| 452 |
+
verbose=verbose
|
| 453 |
+
)
|
| 454 |
+
texture = Image.fromarray(texture)
|
| 455 |
+
|
| 456 |
+
# rotate mesh (from z-up to y-up)
|
| 457 |
+
vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
|
| 458 |
+
material = trimesh.visual.material.PBRMaterial(
|
| 459 |
+
roughnessFactor=1.0,
|
| 460 |
+
baseColorTexture=texture,
|
| 461 |
+
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
|
| 462 |
+
)
|
| 463 |
+
mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material))
|
| 464 |
+
return mesh
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def simplify_gs(
|
| 468 |
+
gs: Gaussian,
|
| 469 |
+
simplify: float = 0.95,
|
| 470 |
+
verbose: bool = True,
|
| 471 |
+
):
|
| 472 |
+
"""
|
| 473 |
+
Simplify 3D Gaussians
|
| 474 |
+
NOTE: this function is not used in the current implementation for the unsatisfactory performance.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
gs (Gaussian): 3D Gaussian.
|
| 478 |
+
simplify (float): Ratio of Gaussians to remove in simplification.
|
| 479 |
+
"""
|
| 480 |
+
if simplify <= 0:
|
| 481 |
+
return gs
|
| 482 |
+
|
| 483 |
+
# simplify
|
| 484 |
+
observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100)
|
| 485 |
+
observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations]
|
| 486 |
+
|
| 487 |
+
# Following https://arxiv.org/pdf/2411.06019
|
| 488 |
+
renderer = GaussianRenderer({
|
| 489 |
+
"resolution": 1024,
|
| 490 |
+
"near": 0.8,
|
| 491 |
+
"far": 1.6,
|
| 492 |
+
"ssaa": 1,
|
| 493 |
+
"bg_color": (0,0,0),
|
| 494 |
+
})
|
| 495 |
+
new_gs = Gaussian(**gs.init_params)
|
| 496 |
+
new_gs._features_dc = gs._features_dc.clone()
|
| 497 |
+
new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None
|
| 498 |
+
new_gs._opacity = torch.nn.Parameter(gs._opacity.clone())
|
| 499 |
+
new_gs._rotation = torch.nn.Parameter(gs._rotation.clone())
|
| 500 |
+
new_gs._scaling = torch.nn.Parameter(gs._scaling.clone())
|
| 501 |
+
new_gs._xyz = torch.nn.Parameter(gs._xyz.clone())
|
| 502 |
+
|
| 503 |
+
start_lr = [1e-4, 1e-3, 5e-3, 0.025]
|
| 504 |
+
end_lr = [1e-6, 1e-5, 5e-5, 0.00025]
|
| 505 |
+
optimizer = torch.optim.Adam([
|
| 506 |
+
{"params": new_gs._xyz, "lr": start_lr[0]},
|
| 507 |
+
{"params": new_gs._rotation, "lr": start_lr[1]},
|
| 508 |
+
{"params": new_gs._scaling, "lr": start_lr[2]},
|
| 509 |
+
{"params": new_gs._opacity, "lr": start_lr[3]},
|
| 510 |
+
], lr=start_lr[0])
|
| 511 |
+
|
| 512 |
+
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
| 513 |
+
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
| 514 |
+
|
| 515 |
+
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
| 516 |
+
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
| 517 |
+
|
| 518 |
+
_zeta = new_gs.get_opacity.clone().detach().squeeze()
|
| 519 |
+
_lambda = torch.zeros_like(_zeta)
|
| 520 |
+
_delta = 1e-7
|
| 521 |
+
_interval = 10
|
| 522 |
+
num_target = int((1 - simplify) * _zeta.shape[0])
|
| 523 |
+
|
| 524 |
+
with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar:
|
| 525 |
+
for i in range(2500):
|
| 526 |
+
# prune
|
| 527 |
+
if i % 100 == 0:
|
| 528 |
+
mask = new_gs.get_opacity.squeeze() > 0.05
|
| 529 |
+
mask = torch.nonzero(mask).squeeze()
|
| 530 |
+
new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask])
|
| 531 |
+
new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask])
|
| 532 |
+
new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask])
|
| 533 |
+
new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask])
|
| 534 |
+
new_gs._features_dc = new_gs._features_dc[mask]
|
| 535 |
+
new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None
|
| 536 |
+
_zeta = _zeta[mask]
|
| 537 |
+
_lambda = _lambda[mask]
|
| 538 |
+
# update optimizer state
|
| 539 |
+
for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]):
|
| 540 |
+
stored_state = optimizer.state[param_group['params'][0]]
|
| 541 |
+
if 'exp_avg' in stored_state:
|
| 542 |
+
stored_state['exp_avg'] = stored_state['exp_avg'][mask]
|
| 543 |
+
stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask]
|
| 544 |
+
del optimizer.state[param_group['params'][0]]
|
| 545 |
+
param_group['params'][0] = new_param
|
| 546 |
+
optimizer.state[param_group['params'][0]] = stored_state
|
| 547 |
+
|
| 548 |
+
opacity = new_gs.get_opacity.squeeze()
|
| 549 |
+
|
| 550 |
+
# sparisfy
|
| 551 |
+
if i % _interval == 0:
|
| 552 |
+
_zeta = _lambda + opacity.detach()
|
| 553 |
+
if opacity.shape[0] > num_target:
|
| 554 |
+
index = _zeta.topk(num_target)[1]
|
| 555 |
+
_m = torch.ones_like(_zeta, dtype=torch.bool)
|
| 556 |
+
_m[index] = 0
|
| 557 |
+
_zeta[_m] = 0
|
| 558 |
+
_lambda = _lambda + opacity.detach() - _zeta
|
| 559 |
+
|
| 560 |
+
# sample a random view
|
| 561 |
+
view_idx = np.random.randint(len(observations))
|
| 562 |
+
observation = observations[view_idx]
|
| 563 |
+
extrinsic = extrinsics[view_idx]
|
| 564 |
+
intrinsic = intrinsics[view_idx]
|
| 565 |
+
|
| 566 |
+
color = renderer.render(new_gs, extrinsic, intrinsic)['color']
|
| 567 |
+
rgb_loss = torch.nn.functional.l1_loss(color, observation)
|
| 568 |
+
loss = rgb_loss + \
|
| 569 |
+
_delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2))
|
| 570 |
+
|
| 571 |
+
optimizer.zero_grad()
|
| 572 |
+
loss.backward()
|
| 573 |
+
optimizer.step()
|
| 574 |
+
|
| 575 |
+
# update lr
|
| 576 |
+
for j in range(len(optimizer.param_groups)):
|
| 577 |
+
optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j])
|
| 578 |
+
|
| 579 |
+
pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()})
|
| 580 |
+
pbar.update()
|
| 581 |
+
|
| 582 |
+
new_gs._xyz = new_gs._xyz.data
|
| 583 |
+
new_gs._rotation = new_gs._rotation.data
|
| 584 |
+
new_gs._scaling = new_gs._scaling.data
|
| 585 |
+
new_gs._opacity = new_gs._opacity.data
|
| 586 |
+
|
| 587 |
+
return new_gs
|
trellis/utils/random_utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
| 4 |
+
|
| 5 |
+
def radical_inverse(base, n):
|
| 6 |
+
val = 0
|
| 7 |
+
inv_base = 1.0 / base
|
| 8 |
+
inv_base_n = inv_base
|
| 9 |
+
while n > 0:
|
| 10 |
+
digit = n % base
|
| 11 |
+
val += digit * inv_base_n
|
| 12 |
+
n //= base
|
| 13 |
+
inv_base_n *= inv_base
|
| 14 |
+
return val
|
| 15 |
+
|
| 16 |
+
def halton_sequence(dim, n):
|
| 17 |
+
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
|
| 18 |
+
|
| 19 |
+
def hammersley_sequence(dim, n, num_samples):
|
| 20 |
+
return [n / num_samples] + halton_sequence(dim - 1, n)
|
| 21 |
+
|
| 22 |
+
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
|
| 23 |
+
u, v = hammersley_sequence(2, n, num_samples)
|
| 24 |
+
u += offset[0] / num_samples
|
| 25 |
+
v += offset[1]
|
| 26 |
+
if remap:
|
| 27 |
+
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
|
| 28 |
+
theta = np.arccos(1 - 2 * u) - np.pi / 2
|
| 29 |
+
phi = v * 2 * np.pi
|
| 30 |
+
return [phi, theta]
|
trellis/utils/render_utils.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import utils3d
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
|
| 8 |
+
from ..representations import Octree, Gaussian, MeshExtractResult
|
| 9 |
+
from ..modules import sparse as sp
|
| 10 |
+
from .random_utils import sphere_hammersley_sequence
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
|
| 14 |
+
is_list = isinstance(yaws, list)
|
| 15 |
+
if not is_list:
|
| 16 |
+
yaws = [yaws]
|
| 17 |
+
pitchs = [pitchs]
|
| 18 |
+
if not isinstance(rs, list):
|
| 19 |
+
rs = [rs] * len(yaws)
|
| 20 |
+
if not isinstance(fovs, list):
|
| 21 |
+
fovs = [fovs] * len(yaws)
|
| 22 |
+
extrinsics = []
|
| 23 |
+
intrinsics = []
|
| 24 |
+
for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
|
| 25 |
+
fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
|
| 26 |
+
yaw = torch.tensor(float(yaw)).cuda()
|
| 27 |
+
pitch = torch.tensor(float(pitch)).cuda()
|
| 28 |
+
orig = torch.tensor([
|
| 29 |
+
torch.sin(yaw) * torch.cos(pitch),
|
| 30 |
+
torch.cos(yaw) * torch.cos(pitch),
|
| 31 |
+
torch.sin(pitch),
|
| 32 |
+
]).cuda() * r
|
| 33 |
+
extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
| 34 |
+
intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
| 35 |
+
extrinsics.append(extr)
|
| 36 |
+
intrinsics.append(intr)
|
| 37 |
+
if not is_list:
|
| 38 |
+
extrinsics = extrinsics[0]
|
| 39 |
+
intrinsics = intrinsics[0]
|
| 40 |
+
return extrinsics, intrinsics
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_renderer(sample, **kwargs):
|
| 44 |
+
if isinstance(sample, Octree):
|
| 45 |
+
renderer = OctreeRenderer()
|
| 46 |
+
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
| 47 |
+
renderer.rendering_options.near = kwargs.get('near', 0.8)
|
| 48 |
+
renderer.rendering_options.far = kwargs.get('far', 1.6)
|
| 49 |
+
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
|
| 50 |
+
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
|
| 51 |
+
renderer.pipe.primitive = sample.primitive
|
| 52 |
+
elif isinstance(sample, Gaussian):
|
| 53 |
+
renderer = GaussianRenderer()
|
| 54 |
+
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
| 55 |
+
renderer.rendering_options.near = kwargs.get('near', 0.8)
|
| 56 |
+
renderer.rendering_options.far = kwargs.get('far', 1.6)
|
| 57 |
+
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
|
| 58 |
+
renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
|
| 59 |
+
renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
|
| 60 |
+
renderer.pipe.use_mip_gaussian = True
|
| 61 |
+
elif isinstance(sample, MeshExtractResult):
|
| 62 |
+
renderer = MeshRenderer()
|
| 63 |
+
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
| 64 |
+
renderer.rendering_options.near = kwargs.get('near', 1)
|
| 65 |
+
renderer.rendering_options.far = kwargs.get('far', 100)
|
| 66 |
+
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f'Unsupported sample type: {type(sample)}')
|
| 69 |
+
return renderer
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
|
| 73 |
+
renderer = get_renderer(sample, **options)
|
| 74 |
+
rets = {}
|
| 75 |
+
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
|
| 76 |
+
if isinstance(sample, MeshExtractResult):
|
| 77 |
+
res = renderer.render(sample, extr, intr)
|
| 78 |
+
if 'normal' not in rets: rets['normal'] = []
|
| 79 |
+
rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
| 80 |
+
else:
|
| 81 |
+
res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
|
| 82 |
+
if 'color' not in rets: rets['color'] = []
|
| 83 |
+
if 'depth' not in rets: rets['depth'] = []
|
| 84 |
+
rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
| 85 |
+
if 'percent_depth' in res:
|
| 86 |
+
rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
|
| 87 |
+
elif 'depth' in res:
|
| 88 |
+
rets['depth'].append(res['depth'].detach().cpu().numpy())
|
| 89 |
+
else:
|
| 90 |
+
rets['depth'].append(None)
|
| 91 |
+
return rets
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
|
| 95 |
+
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
| 96 |
+
pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
|
| 97 |
+
yaws = yaws.tolist()
|
| 98 |
+
pitch = pitch.tolist()
|
| 99 |
+
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
|
| 100 |
+
return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def render_multiview(sample, resolution=512, nviews=30):
|
| 104 |
+
r = 2
|
| 105 |
+
fov = 40
|
| 106 |
+
cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
|
| 107 |
+
yaws = [cam[0] for cam in cams]
|
| 108 |
+
pitchs = [cam[1] for cam in cams]
|
| 109 |
+
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
|
| 110 |
+
res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
|
| 111 |
+
return res['color'], extrinsics, intrinsics
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs):
|
| 115 |
+
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
|
| 116 |
+
yaw_offset = offset[0]
|
| 117 |
+
yaw = [y + yaw_offset for y in yaw]
|
| 118 |
+
pitch = [offset[1] for _ in range(4)]
|
| 119 |
+
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
|
| 120 |
+
return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|