PartCrafter / src /utils /data_utils.py
alexnasa's picture
Upload 85 files
bef5729 verified
from src.utils.typing_utils import *
import os
import numpy as np
import trimesh
import torch
def normalize_mesh(
mesh: Union[trimesh.Trimesh, trimesh.Scene],
scale: float = 2.0,
):
# if not isinstance(mesh, trimesh.Trimesh) and not isinstance(mesh, trimesh.Scene):
# raise ValueError("Input mesh is not a trimesh.Trimesh or trimesh.Scene object.")
bbox = mesh.bounding_box
translation = -bbox.centroid
scale = scale / bbox.primitive.extents.max()
mesh.apply_translation(translation)
mesh.apply_scale(scale)
return mesh
def remove_overlapping_vertices(mesh: trimesh.Trimesh, reserve_material: bool = False):
if not isinstance(mesh, trimesh.Trimesh):
raise ValueError("Input mesh is not a trimesh.Trimesh object.")
vertices = mesh.vertices
faces = mesh.faces
unique_vertices, index_map, inverse_map = np.unique(
vertices, axis=0, return_index=True, return_inverse=True
)
clean_faces = inverse_map[faces]
clean_mesh = trimesh.Trimesh(vertices=unique_vertices, faces=clean_faces, process=True)
if reserve_material:
uv = mesh.visual.uv
material = mesh.visual.material
clean_uv = uv[index_map]
clean_visual = trimesh.visual.TextureVisuals(uv=clean_uv, material=material)
clean_mesh.visual = clean_visual
return clean_mesh
RGB = [
(82, 170, 220),
(215, 91, 78),
(45, 136, 117),
(247, 172, 83),
(124, 121, 121),
(127, 171, 209),
(243, 152, 101),
(145, 204, 192),
(150, 59, 121),
(181, 206, 78),
(189, 119, 149),
(199, 193, 222),
(200, 151, 54),
(236, 110, 102),
(238, 182, 212),
]
def get_colored_mesh_composition(
meshes: Union[List[trimesh.Trimesh], trimesh.Scene],
is_random: bool = True,
is_sorted: bool = False,
RGB: List[Tuple] = RGB
):
if isinstance(meshes, trimesh.Scene):
meshes = meshes.dump()
if is_sorted:
volumes = []
for mesh in meshes:
try:
volume = mesh.volume
except:
volume = 0.0
volumes.append(volume)
# sort by volume from large to small
meshes = [x for _, x in sorted(zip(volumes, meshes), key=lambda pair: pair[0], reverse=True)]
colored_scene = trimesh.Scene()
for idx, mesh in enumerate(meshes):
if is_random:
color = (np.random.rand(3) * 256).astype(int)
else:
color = np.array(RGB[idx % len(RGB)])
mesh.visual = trimesh.visual.ColorVisuals(
mesh=mesh,
vertex_colors=color,
)
colored_scene.add_geometry(mesh)
return colored_scene
def mesh_to_surface(
mesh: trimesh.Trimesh,
num_pc: int = 204800,
clip_to_num_vertices: bool = False,
return_dict: bool = False,
):
# if not isinstance(mesh, trimesh.Trimesh):
# raise ValueError("mesh must be a trimesh.Trimesh object")
if clip_to_num_vertices:
num_pc = min(num_pc, mesh.vertices.shape[0])
points, face_indices = mesh.sample(num_pc, return_index=True)
normals = mesh.face_normals[face_indices]
if return_dict:
return {
"surface_points": points,
"surface_normals": normals,
}
return points, normals
def scene_to_parts(
mesh: trimesh.Scene,
normalize: bool = True,
scale: float = 2.0,
num_part_pc: int = 204800,
clip_to_num_part_vertices: bool = False,
return_type: Literal["mesh", "point"] = "mesh",
) -> Union[List[trimesh.Geometry], List[Dict[str, np.ndarray]]]:
if not isinstance(mesh, trimesh.Scene):
raise ValueError("mesh must be a trimesh.Scene object")
if normalize:
mesh = normalize_mesh(mesh, scale=scale)
parts: List[trimesh.Geometry] = mesh.dump()
if return_type == "point":
datas: List[Dict[str, np.ndarray]] = []
for geom in parts:
data = mesh_to_surface(
geom,
num_pc=num_part_pc,
clip_to_num_vertices=clip_to_num_part_vertices,
return_dict=True,
)
datas.append(data)
return datas
elif return_type == "mesh":
return parts
else:
raise ValueError("return_type must be 'mesh' or 'point'")
def get_center(mesh: trimesh.Trimesh, method: Literal['mass', 'bbox']):
if method == 'mass':
return mesh.center_mass
elif method =='bbox':
return mesh.bounding_box.centroid
else:
raise ValueError('type must be mass or bbox')
def get_direction(vector: np.ndarray):
return vector / np.linalg.norm(vector)
def move_mesh_by_center(mesh: trimesh.Trimesh, scale: float, method: Literal['mass', 'bbox'] = 'mass'):
offset = scale - 1
center = get_center(mesh, method)
direction = get_direction(center)
translation = direction * offset
mesh = mesh.copy()
mesh.apply_translation(translation)
return mesh
def move_meshes_by_center(meshes: Union[List[trimesh.Trimesh], trimesh.Scene], scale: float):
if isinstance(meshes, trimesh.Scene):
meshes = meshes.dump()
moved_meshes = []
for mesh in meshes:
moved_mesh = move_mesh_by_center(mesh, scale)
moved_meshes.append(moved_mesh)
moved_meshes = trimesh.Scene(moved_meshes)
return moved_meshes
def get_series_splited_meshes(meshes: List[trimesh.Trimesh], scale: float, num_steps: int) -> List[trimesh.Scene]:
series_meshes = []
for i in range(num_steps):
temp_scale = 1 + (scale - 1) * i / (num_steps - 1)
temp_meshes = move_meshes_by_center(meshes, temp_scale)
series_meshes.append(temp_meshes)
return series_meshes
def load_surface(data, num_pc=204800):
surface = data["surface_points"] # Nx3
normal = data["surface_normals"] # Nx3
rng = np.random.default_rng()
ind = rng.choice(surface.shape[0], num_pc, replace=False)
surface = torch.FloatTensor(surface[ind])
normal = torch.FloatTensor(normal[ind])
surface = torch.cat([surface, normal], dim=-1)
return surface
def load_surfaces(surfaces, num_pc=204800):
surfaces = [load_surface(surface, num_pc) for surface in surfaces]
surfaces = torch.stack(surfaces, dim=0)
return surfaces