PartCrafter / src /utils /render_utils.py
alex
split gif preview added
166aa76
from src.utils.typing_utils import *
import os
import numpy as np
from PIL import Image
import trimesh
from trimesh.transformations import rotation_matrix
import pyrender
from diffusers.utils import export_to_video
from diffusers.utils.loading_utils import load_video
import torch
from torchvision.utils import make_grid
import math
os.environ['PYOPENGL_PLATFORM'] = 'egl'
def explode_mesh(mesh, explosion_scale=0.4):
# ensure we have a Scene
if isinstance(mesh, trimesh.Trimesh):
scene = trimesh.Scene(mesh)
elif isinstance(mesh, trimesh.Scene):
scene = mesh
else:
raise ValueError(f"Expected Trimesh or Scene, got {type(mesh)}")
if len(scene.geometry) <= 1:
print("Nothing to explode")
return scene
# 1) collect (name, geom, world_center)
parts = []
for name, geom in scene.geometry.items():
# ← get(name) returns (4×4 world‐space matrix, parent_frame)
world_tf, _ = scene.graph.get(name)
pts = trimesh.transformations.transform_points(geom.vertices, world_tf)
center = pts.mean(axis=0)
parts.append((name, geom, center))
# compute global center
all_centers = np.stack([c for _,_,c in parts], axis=0)
global_center = all_centers.mean(axis=0)
exploded = trimesh.Scene()
for name, geom, center in parts:
dir_vec = center - global_center
norm = np.linalg.norm(dir_vec)
if norm < 1e-6:
dir_vec = np.random.randn(3)
dir_vec /= np.linalg.norm(dir_vec)
else:
dir_vec /= norm
offset = dir_vec * explosion_scale
# fetch the same 4×4, then bump just the translation
world_tf, _ = scene.graph.get(name)
world_tf = world_tf.copy()
world_tf[:3, 3] += offset
exploded.add_geometry(geom, transform=world_tf, geom_name=name)
print(f"[explode] {name} moved by {np.linalg.norm(offset):.4f}")
return exploded
def render(
scene: pyrender.Scene,
renderer: pyrender.Renderer,
camera: pyrender.Camera,
pose: np.ndarray,
light: Optional[pyrender.Light] = None,
normalize_depth: bool = False,
flags: int = pyrender.constants.RenderFlags.NONE,
return_type: Literal['pil', 'ndarray'] = 'pil'
) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[Image.Image, Image.Image]]:
camera_node = scene.add(camera, pose=pose)
if light is not None:
light_node = scene.add(light, pose=pose)
image, depth = renderer.render(
scene,
flags=flags
)
scene.remove_node(camera_node)
if light is not None:
scene.remove_node(light_node)
if normalize_depth or return_type == 'pil':
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
if return_type == 'pil':
image = Image.fromarray(image)
depth = Image.fromarray(depth.astype(np.uint8))
return image, depth
def rotation_matrix_from_vectors(vec1, vec2):
a, b = vec1 / np.linalg.norm(vec1), vec2 / np.linalg.norm(vec2)
v = np.cross(a, b)
c = np.dot(a, b)
s = np.linalg.norm(v)
if s == 0:
return np.eye(3) if c > 0 else -np.eye(3)
kmat = np.array([
[0, -v[2], v[1]],
[v[2], 0, -v[0]],
[-v[1], v[0], 0]
])
return np.eye(3) + kmat + kmat @ kmat * ((1 - c) / (s ** 2))
def create_circular_camera_positions(
num_views: int,
radius: float,
axis: np.ndarray = np.array([0.0, 1.0, 0.0])
) -> List[np.ndarray]:
# Create a list of positions for a circular camera trajectory
# around the given axis with the given radius.
positions = []
axis = axis / np.linalg.norm(axis)
for i in range(num_views):
theta = 2 * np.pi * i / num_views
position = np.array([
np.sin(theta) * radius,
0.0,
np.cos(theta) * radius
])
if not np.allclose(axis, np.array([0.0, 1.0, 0.0])):
R = rotation_matrix_from_vectors(np.array([0.0, 1.0, 0.0]), axis)
position = R @ position
positions.append(position)
return positions
def create_circular_camera_poses(
num_views: int,
radius: float,
axis: np.ndarray = np.array([0.0, 1.0, 0.0])
) -> List[np.ndarray]:
# Create a list of poses for a circular camera trajectory
# around the given axis with the given radius.
# The camera always looks at the origin.
# The up vector is always [0, 1, 0].
canonical_pose = np.array([
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, radius],
[0.0, 0.0, 0.0, 1.0]
])
poses = []
for i in range(num_views):
theta = 2 * np.pi * i / num_views
R = rotation_matrix(
angle=theta,
direction=axis,
point=[0, 0, 0]
)
pose = R @ canonical_pose
poses.append(pose)
return poses
def render_views_around_mesh(
mesh: Union[trimesh.Trimesh, trimesh.Scene],
num_views: int = 36,
radius: float = 3.5,
axis: np.ndarray = np.array([0.0, 1.0, 0.0]),
image_size: tuple = (512, 512),
fov: float = 40.0,
light_intensity: Optional[float] = 5.0,
znear: float = 0.1,
zfar: float = 10.0,
normalize_depth: bool = False,
flags: int = pyrender.constants.RenderFlags.NONE,
return_depth: bool = False,
return_type: Literal['pil', 'ndarray'] = 'pil'
) -> Union[
List[Image.Image],
List[np.ndarray],
Tuple[List[Image.Image], List[Image.Image]],
Tuple[List[np.ndarray], List[np.ndarray]]
]:
meshes = []
scenes = []
if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
if isinstance(mesh, trimesh.Trimesh):
for i in range(num_views):
scenes.append(pyrender.Scene.from_trimesh_scene(trimesh.Scene(mesh)))
else:
for i in range(num_views):
value = math.sin(math.pi * (i - 1) / num_views)
scenes.append(pyrender.Scene.from_trimesh_scene(explode_mesh(mesh, 0.2 * value),
ambient_light=[0.02, 0.02, 0.02],
bg_color=[0.0, 0.0, 0.0, 1.0]))
light = pyrender.DirectionalLight(
color=np.ones(3),
intensity=light_intensity
) if light_intensity is not None else None
camera = pyrender.PerspectiveCamera(
yfov=np.deg2rad(fov),
aspectRatio=image_size[0]/image_size[1],
znear=znear,
zfar=zfar
)
renderer = pyrender.OffscreenRenderer(*image_size)
camera_poses = create_circular_camera_poses(
num_views,
radius,
axis = axis
)
images, depths = [], []
for i, pose in enumerate(camera_poses):
image, depth = render(
scenes[i], renderer, camera, pose, light,
normalize_depth=normalize_depth,
flags=flags,
return_type=return_type
)
images.append(image)
depths.append(depth)
renderer.delete()
if return_depth:
return images, depths
return images
def render_normal_views_around_mesh(
mesh: Union[trimesh.Trimesh, trimesh.Scene],
num_views: int = 36,
radius: float = 3.5,
axis: np.ndarray = np.array([0.0, 1.0, 0.0]),
image_size: tuple = (512, 512),
fov: float = 40.0,
light_intensity: Optional[float] = 5.0,
znear: float = 0.1,
zfar: float = 10.0,
normalize_depth: bool = False,
flags: int = pyrender.constants.RenderFlags.NONE,
return_depth: bool = False,
return_type: Literal['pil', 'ndarray'] = 'pil'
) -> Union[
List[Image.Image],
List[np.ndarray],
Tuple[List[Image.Image], List[Image.Image]],
Tuple[List[np.ndarray], List[np.ndarray]]
]:
if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
if isinstance(mesh, trimesh.Scene):
mesh = mesh.to_geometry()
normals = mesh.vertex_normals
colors = ((normals + 1.0) / 2.0 * 255).astype(np.uint8)
mesh.visual = trimesh.visual.ColorVisuals(
mesh=mesh,
vertex_colors=colors
)
mesh = trimesh.Scene(mesh)
return render_views_around_mesh(
mesh, num_views, radius, axis,
image_size, fov, light_intensity, znear, zfar,
normalize_depth, flags,
return_depth, return_type
)
def create_camera_pose_on_sphere(
azimuth: float = 0.0, # in degrees
elevation: float = 0.0, # in degrees
radius: float = 3.5,
) -> np.ndarray:
# Create a camera pose for a given azimuth and elevation
# with the given radius.
# The camera always looks at the origin.
# The up vector is always [0, 1, 0].
canonical_pose = np.array([
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, radius],
[0.0, 0.0, 0.0, 1.0]
])
azimuth = np.deg2rad(azimuth)
elevation = np.deg2rad(elevation)
position = np.array([
np.cos(elevation) * np.sin(azimuth),
np.sin(elevation),
np.cos(elevation) * np.cos(azimuth),
])
R = np.eye(4)
R[:3, :3] = rotation_matrix_from_vectors(
np.array([0.0, 0.0, 1.0]),
position
)
pose = R @ canonical_pose
return pose
def render_single_view(
mesh: Union[trimesh.Trimesh, trimesh.Scene],
azimuth: float = 0.0, # in degrees
elevation: float = 0.0, # in degrees
radius: float = 3.5,
image_size: tuple = (512, 512),
fov: float = 40.0,
light_intensity: Optional[float] = 5.0,
num_env_lights: int = 0,
znear: float = 0.1,
zfar: float = 10.0,
normalize_depth: bool = False,
flags: int = pyrender.constants.RenderFlags.NONE,
return_depth: bool = False,
return_type: Literal['pil', 'ndarray'] = 'pil'
) -> Union[
Image.Image,
np.ndarray,
Tuple[Image.Image, Image.Image],
Tuple[np.ndarray, np.ndarray]
]:
if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
if isinstance(mesh, trimesh.Trimesh):
mesh = trimesh.Scene(mesh)
scene = pyrender.Scene.from_trimesh_scene(mesh)
light = pyrender.DirectionalLight(
color=np.ones(3),
intensity=light_intensity
) if light_intensity is not None else None
camera = pyrender.PerspectiveCamera(
yfov=np.deg2rad(fov),
aspectRatio=image_size[0]/image_size[1],
znear=znear,
zfar=zfar
)
renderer = pyrender.OffscreenRenderer(*image_size)
camera_pose = create_camera_pose_on_sphere(
azimuth,
elevation,
radius
)
if num_env_lights > 0:
env_light_poses = create_circular_camera_poses(
num_env_lights,
radius,
axis = np.array([0.0, 1.0, 0.0])
)
for pose in env_light_poses:
scene.add(pyrender.DirectionalLight(
color=np.ones(3),
intensity=light_intensity
), pose=pose)
# set light to None
light = None
image, depth = render(
scene, renderer, camera, camera_pose, light,
normalize_depth=normalize_depth,
flags=flags,
return_type=return_type
)
renderer.delete()
if return_depth:
return image, depth
return image
def render_normal_single_view(
mesh: Union[trimesh.Trimesh, trimesh.Scene],
azimuth: float = 0.0, # in degrees
elevation: float = 0.0, # in degrees
radius: float = 3.5,
image_size: tuple = (512, 512),
fov: float = 40.0,
light_intensity: Optional[float] = 5.0,
znear: float = 0.1,
zfar: float = 10.0,
normalize_depth: bool = False,
flags: int = pyrender.constants.RenderFlags.NONE,
return_depth: bool = False,
return_type: Literal['pil', 'ndarray'] = 'pil'
) -> Union[
Image.Image,
np.ndarray,
Tuple[Image.Image, Image.Image],
Tuple[np.ndarray, np.ndarray]
]:
if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
if isinstance(mesh, trimesh.Scene):
mesh = mesh.to_geometry()
normals = mesh.vertex_normals
colors = ((normals + 1.0) / 2.0 * 255).astype(np.uint8)
mesh.visual = trimesh.visual.ColorVisuals(
mesh=mesh,
vertex_colors=colors
)
mesh = trimesh.Scene(mesh)
return render_single_view(
mesh, azimuth, elevation, radius,
image_size, fov, light_intensity, znear, zfar,
normalize_depth, flags,
return_depth, return_type
)
def export_renderings(
images: List[Image.Image],
export_path: str,
fps: int = 36,
loop: int = 0
):
export_type = export_path.split('.')[-1]
if export_type == 'mp4':
export_to_video(
images,
export_path,
fps=fps,
)
elif export_type == 'gif':
duration = 1000 / fps
images[0].save(
export_path,
save_all=True,
append_images=images[1:],
duration=duration,
loop=loop
)
else:
raise ValueError(f'Unknown export type: {export_type}')
def make_grid_for_images_or_videos(
images_or_videos: Union[List[Image.Image], List[List[Image.Image]]],
nrow: int = 4,
padding: int = 0,
pad_value: int = 0,
image_size: tuple = (512, 512),
return_type: Literal['pil', 'ndarray'] = 'pil'
) -> Union[Image.Image, List[Image.Image], np.ndarray]:
if isinstance(images_or_videos[0], Image.Image):
images = [np.array(image.resize(image_size).convert('RGB')) for image in images_or_videos]
images = np.stack(images, axis=0).transpose(0, 3, 1, 2) # [N, C, H, W]
images = torch.from_numpy(images)
image_grid = make_grid(
images,
nrow=nrow,
padding=padding,
pad_value=pad_value,
normalize=False
) # [C, H', W']
image_grid = image_grid.cpu().numpy()
if return_type == 'pil':
image_grid = Image.fromarray(image_grid.transpose(1, 2, 0))
return image_grid
elif isinstance(images_or_videos[0], list) and isinstance(images_or_videos[0][0], Image.Image):
image_grids = []
for i in range(len(images_or_videos[0])):
images = [video[i] for video in images_or_videos]
image_grid = make_grid_for_images_or_videos(
images,
nrow=nrow,
padding=padding,
return_type=return_type
)
image_grids.append(image_grid)
if return_type == 'ndarray':
image_grids = np.stack(image_grids, axis=0)
return image_grids
else:
raise ValueError(f'Unknown input type: {type(images_or_videos[0])}')