Spaces:
Running
on
Zero
Running
on
Zero
from src.utils.typing_utils import * | |
import os | |
import numpy as np | |
from PIL import Image | |
import trimesh | |
from trimesh.transformations import rotation_matrix | |
import pyrender | |
from diffusers.utils import export_to_video | |
from diffusers.utils.loading_utils import load_video | |
import torch | |
from torchvision.utils import make_grid | |
import math | |
os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
def explode_mesh(mesh, explosion_scale=0.4): | |
# ensure we have a Scene | |
if isinstance(mesh, trimesh.Trimesh): | |
scene = trimesh.Scene(mesh) | |
elif isinstance(mesh, trimesh.Scene): | |
scene = mesh | |
else: | |
raise ValueError(f"Expected Trimesh or Scene, got {type(mesh)}") | |
if len(scene.geometry) <= 1: | |
print("Nothing to explode") | |
return scene | |
# 1) collect (name, geom, world_center) | |
parts = [] | |
for name, geom in scene.geometry.items(): | |
# ← get(name) returns (4×4 world‐space matrix, parent_frame) | |
world_tf, _ = scene.graph.get(name) | |
pts = trimesh.transformations.transform_points(geom.vertices, world_tf) | |
center = pts.mean(axis=0) | |
parts.append((name, geom, center)) | |
# compute global center | |
all_centers = np.stack([c for _,_,c in parts], axis=0) | |
global_center = all_centers.mean(axis=0) | |
exploded = trimesh.Scene() | |
for name, geom, center in parts: | |
dir_vec = center - global_center | |
norm = np.linalg.norm(dir_vec) | |
if norm < 1e-6: | |
dir_vec = np.random.randn(3) | |
dir_vec /= np.linalg.norm(dir_vec) | |
else: | |
dir_vec /= norm | |
offset = dir_vec * explosion_scale | |
# fetch the same 4×4, then bump just the translation | |
world_tf, _ = scene.graph.get(name) | |
world_tf = world_tf.copy() | |
world_tf[:3, 3] += offset | |
exploded.add_geometry(geom, transform=world_tf, geom_name=name) | |
print(f"[explode] {name} moved by {np.linalg.norm(offset):.4f}") | |
return exploded | |
def render( | |
scene: pyrender.Scene, | |
renderer: pyrender.Renderer, | |
camera: pyrender.Camera, | |
pose: np.ndarray, | |
light: Optional[pyrender.Light] = None, | |
normalize_depth: bool = False, | |
flags: int = pyrender.constants.RenderFlags.NONE, | |
return_type: Literal['pil', 'ndarray'] = 'pil' | |
) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[Image.Image, Image.Image]]: | |
camera_node = scene.add(camera, pose=pose) | |
if light is not None: | |
light_node = scene.add(light, pose=pose) | |
image, depth = renderer.render( | |
scene, | |
flags=flags | |
) | |
scene.remove_node(camera_node) | |
if light is not None: | |
scene.remove_node(light_node) | |
if normalize_depth or return_type == 'pil': | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
if return_type == 'pil': | |
image = Image.fromarray(image) | |
depth = Image.fromarray(depth.astype(np.uint8)) | |
return image, depth | |
def rotation_matrix_from_vectors(vec1, vec2): | |
a, b = vec1 / np.linalg.norm(vec1), vec2 / np.linalg.norm(vec2) | |
v = np.cross(a, b) | |
c = np.dot(a, b) | |
s = np.linalg.norm(v) | |
if s == 0: | |
return np.eye(3) if c > 0 else -np.eye(3) | |
kmat = np.array([ | |
[0, -v[2], v[1]], | |
[v[2], 0, -v[0]], | |
[-v[1], v[0], 0] | |
]) | |
return np.eye(3) + kmat + kmat @ kmat * ((1 - c) / (s ** 2)) | |
def create_circular_camera_positions( | |
num_views: int, | |
radius: float, | |
axis: np.ndarray = np.array([0.0, 1.0, 0.0]) | |
) -> List[np.ndarray]: | |
# Create a list of positions for a circular camera trajectory | |
# around the given axis with the given radius. | |
positions = [] | |
axis = axis / np.linalg.norm(axis) | |
for i in range(num_views): | |
theta = 2 * np.pi * i / num_views | |
position = np.array([ | |
np.sin(theta) * radius, | |
0.0, | |
np.cos(theta) * radius | |
]) | |
if not np.allclose(axis, np.array([0.0, 1.0, 0.0])): | |
R = rotation_matrix_from_vectors(np.array([0.0, 1.0, 0.0]), axis) | |
position = R @ position | |
positions.append(position) | |
return positions | |
def create_circular_camera_poses( | |
num_views: int, | |
radius: float, | |
axis: np.ndarray = np.array([0.0, 1.0, 0.0]) | |
) -> List[np.ndarray]: | |
# Create a list of poses for a circular camera trajectory | |
# around the given axis with the given radius. | |
# The camera always looks at the origin. | |
# The up vector is always [0, 1, 0]. | |
canonical_pose = np.array([ | |
[1.0, 0.0, 0.0, 0.0], | |
[0.0, 1.0, 0.0, 0.0], | |
[0.0, 0.0, 1.0, radius], | |
[0.0, 0.0, 0.0, 1.0] | |
]) | |
poses = [] | |
for i in range(num_views): | |
theta = 2 * np.pi * i / num_views | |
R = rotation_matrix( | |
angle=theta, | |
direction=axis, | |
point=[0, 0, 0] | |
) | |
pose = R @ canonical_pose | |
poses.append(pose) | |
return poses | |
def render_views_around_mesh( | |
mesh: Union[trimesh.Trimesh, trimesh.Scene], | |
num_views: int = 36, | |
radius: float = 3.5, | |
axis: np.ndarray = np.array([0.0, 1.0, 0.0]), | |
image_size: tuple = (512, 512), | |
fov: float = 40.0, | |
light_intensity: Optional[float] = 5.0, | |
znear: float = 0.1, | |
zfar: float = 10.0, | |
normalize_depth: bool = False, | |
flags: int = pyrender.constants.RenderFlags.NONE, | |
return_depth: bool = False, | |
return_type: Literal['pil', 'ndarray'] = 'pil' | |
) -> Union[ | |
List[Image.Image], | |
List[np.ndarray], | |
Tuple[List[Image.Image], List[Image.Image]], | |
Tuple[List[np.ndarray], List[np.ndarray]] | |
]: | |
meshes = [] | |
scenes = [] | |
if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)): | |
raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object") | |
if isinstance(mesh, trimesh.Trimesh): | |
for i in range(num_views): | |
scenes.append(pyrender.Scene.from_trimesh_scene(trimesh.Scene(mesh))) | |
else: | |
for i in range(num_views): | |
value = math.sin(math.pi * (i - 1) / num_views) | |
scenes.append(pyrender.Scene.from_trimesh_scene(explode_mesh(mesh, 0.2 * value), | |
ambient_light=[0.02, 0.02, 0.02], | |
bg_color=[0.0, 0.0, 0.0, 1.0])) | |
light = pyrender.DirectionalLight( | |
color=np.ones(3), | |
intensity=light_intensity | |
) if light_intensity is not None else None | |
camera = pyrender.PerspectiveCamera( | |
yfov=np.deg2rad(fov), | |
aspectRatio=image_size[0]/image_size[1], | |
znear=znear, | |
zfar=zfar | |
) | |
renderer = pyrender.OffscreenRenderer(*image_size) | |
camera_poses = create_circular_camera_poses( | |
num_views, | |
radius, | |
axis = axis | |
) | |
images, depths = [], [] | |
for i, pose in enumerate(camera_poses): | |
image, depth = render( | |
scenes[i], renderer, camera, pose, light, | |
normalize_depth=normalize_depth, | |
flags=flags, | |
return_type=return_type | |
) | |
images.append(image) | |
depths.append(depth) | |
renderer.delete() | |
if return_depth: | |
return images, depths | |
return images | |
def render_normal_views_around_mesh( | |
mesh: Union[trimesh.Trimesh, trimesh.Scene], | |
num_views: int = 36, | |
radius: float = 3.5, | |
axis: np.ndarray = np.array([0.0, 1.0, 0.0]), | |
image_size: tuple = (512, 512), | |
fov: float = 40.0, | |
light_intensity: Optional[float] = 5.0, | |
znear: float = 0.1, | |
zfar: float = 10.0, | |
normalize_depth: bool = False, | |
flags: int = pyrender.constants.RenderFlags.NONE, | |
return_depth: bool = False, | |
return_type: Literal['pil', 'ndarray'] = 'pil' | |
) -> Union[ | |
List[Image.Image], | |
List[np.ndarray], | |
Tuple[List[Image.Image], List[Image.Image]], | |
Tuple[List[np.ndarray], List[np.ndarray]] | |
]: | |
if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)): | |
raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object") | |
if isinstance(mesh, trimesh.Scene): | |
mesh = mesh.to_geometry() | |
normals = mesh.vertex_normals | |
colors = ((normals + 1.0) / 2.0 * 255).astype(np.uint8) | |
mesh.visual = trimesh.visual.ColorVisuals( | |
mesh=mesh, | |
vertex_colors=colors | |
) | |
mesh = trimesh.Scene(mesh) | |
return render_views_around_mesh( | |
mesh, num_views, radius, axis, | |
image_size, fov, light_intensity, znear, zfar, | |
normalize_depth, flags, | |
return_depth, return_type | |
) | |
def create_camera_pose_on_sphere( | |
azimuth: float = 0.0, # in degrees | |
elevation: float = 0.0, # in degrees | |
radius: float = 3.5, | |
) -> np.ndarray: | |
# Create a camera pose for a given azimuth and elevation | |
# with the given radius. | |
# The camera always looks at the origin. | |
# The up vector is always [0, 1, 0]. | |
canonical_pose = np.array([ | |
[1.0, 0.0, 0.0, 0.0], | |
[0.0, 1.0, 0.0, 0.0], | |
[0.0, 0.0, 1.0, radius], | |
[0.0, 0.0, 0.0, 1.0] | |
]) | |
azimuth = np.deg2rad(azimuth) | |
elevation = np.deg2rad(elevation) | |
position = np.array([ | |
np.cos(elevation) * np.sin(azimuth), | |
np.sin(elevation), | |
np.cos(elevation) * np.cos(azimuth), | |
]) | |
R = np.eye(4) | |
R[:3, :3] = rotation_matrix_from_vectors( | |
np.array([0.0, 0.0, 1.0]), | |
position | |
) | |
pose = R @ canonical_pose | |
return pose | |
def render_single_view( | |
mesh: Union[trimesh.Trimesh, trimesh.Scene], | |
azimuth: float = 0.0, # in degrees | |
elevation: float = 0.0, # in degrees | |
radius: float = 3.5, | |
image_size: tuple = (512, 512), | |
fov: float = 40.0, | |
light_intensity: Optional[float] = 5.0, | |
num_env_lights: int = 0, | |
znear: float = 0.1, | |
zfar: float = 10.0, | |
normalize_depth: bool = False, | |
flags: int = pyrender.constants.RenderFlags.NONE, | |
return_depth: bool = False, | |
return_type: Literal['pil', 'ndarray'] = 'pil' | |
) -> Union[ | |
Image.Image, | |
np.ndarray, | |
Tuple[Image.Image, Image.Image], | |
Tuple[np.ndarray, np.ndarray] | |
]: | |
if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)): | |
raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object") | |
if isinstance(mesh, trimesh.Trimesh): | |
mesh = trimesh.Scene(mesh) | |
scene = pyrender.Scene.from_trimesh_scene(mesh) | |
light = pyrender.DirectionalLight( | |
color=np.ones(3), | |
intensity=light_intensity | |
) if light_intensity is not None else None | |
camera = pyrender.PerspectiveCamera( | |
yfov=np.deg2rad(fov), | |
aspectRatio=image_size[0]/image_size[1], | |
znear=znear, | |
zfar=zfar | |
) | |
renderer = pyrender.OffscreenRenderer(*image_size) | |
camera_pose = create_camera_pose_on_sphere( | |
azimuth, | |
elevation, | |
radius | |
) | |
if num_env_lights > 0: | |
env_light_poses = create_circular_camera_poses( | |
num_env_lights, | |
radius, | |
axis = np.array([0.0, 1.0, 0.0]) | |
) | |
for pose in env_light_poses: | |
scene.add(pyrender.DirectionalLight( | |
color=np.ones(3), | |
intensity=light_intensity | |
), pose=pose) | |
# set light to None | |
light = None | |
image, depth = render( | |
scene, renderer, camera, camera_pose, light, | |
normalize_depth=normalize_depth, | |
flags=flags, | |
return_type=return_type | |
) | |
renderer.delete() | |
if return_depth: | |
return image, depth | |
return image | |
def render_normal_single_view( | |
mesh: Union[trimesh.Trimesh, trimesh.Scene], | |
azimuth: float = 0.0, # in degrees | |
elevation: float = 0.0, # in degrees | |
radius: float = 3.5, | |
image_size: tuple = (512, 512), | |
fov: float = 40.0, | |
light_intensity: Optional[float] = 5.0, | |
znear: float = 0.1, | |
zfar: float = 10.0, | |
normalize_depth: bool = False, | |
flags: int = pyrender.constants.RenderFlags.NONE, | |
return_depth: bool = False, | |
return_type: Literal['pil', 'ndarray'] = 'pil' | |
) -> Union[ | |
Image.Image, | |
np.ndarray, | |
Tuple[Image.Image, Image.Image], | |
Tuple[np.ndarray, np.ndarray] | |
]: | |
if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)): | |
raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object") | |
if isinstance(mesh, trimesh.Scene): | |
mesh = mesh.to_geometry() | |
normals = mesh.vertex_normals | |
colors = ((normals + 1.0) / 2.0 * 255).astype(np.uint8) | |
mesh.visual = trimesh.visual.ColorVisuals( | |
mesh=mesh, | |
vertex_colors=colors | |
) | |
mesh = trimesh.Scene(mesh) | |
return render_single_view( | |
mesh, azimuth, elevation, radius, | |
image_size, fov, light_intensity, znear, zfar, | |
normalize_depth, flags, | |
return_depth, return_type | |
) | |
def export_renderings( | |
images: List[Image.Image], | |
export_path: str, | |
fps: int = 36, | |
loop: int = 0 | |
): | |
export_type = export_path.split('.')[-1] | |
if export_type == 'mp4': | |
export_to_video( | |
images, | |
export_path, | |
fps=fps, | |
) | |
elif export_type == 'gif': | |
duration = 1000 / fps | |
images[0].save( | |
export_path, | |
save_all=True, | |
append_images=images[1:], | |
duration=duration, | |
loop=loop | |
) | |
else: | |
raise ValueError(f'Unknown export type: {export_type}') | |
def make_grid_for_images_or_videos( | |
images_or_videos: Union[List[Image.Image], List[List[Image.Image]]], | |
nrow: int = 4, | |
padding: int = 0, | |
pad_value: int = 0, | |
image_size: tuple = (512, 512), | |
return_type: Literal['pil', 'ndarray'] = 'pil' | |
) -> Union[Image.Image, List[Image.Image], np.ndarray]: | |
if isinstance(images_or_videos[0], Image.Image): | |
images = [np.array(image.resize(image_size).convert('RGB')) for image in images_or_videos] | |
images = np.stack(images, axis=0).transpose(0, 3, 1, 2) # [N, C, H, W] | |
images = torch.from_numpy(images) | |
image_grid = make_grid( | |
images, | |
nrow=nrow, | |
padding=padding, | |
pad_value=pad_value, | |
normalize=False | |
) # [C, H', W'] | |
image_grid = image_grid.cpu().numpy() | |
if return_type == 'pil': | |
image_grid = Image.fromarray(image_grid.transpose(1, 2, 0)) | |
return image_grid | |
elif isinstance(images_or_videos[0], list) and isinstance(images_or_videos[0][0], Image.Image): | |
image_grids = [] | |
for i in range(len(images_or_videos[0])): | |
images = [video[i] for video in images_or_videos] | |
image_grid = make_grid_for_images_or_videos( | |
images, | |
nrow=nrow, | |
padding=padding, | |
return_type=return_type | |
) | |
image_grids.append(image_grid) | |
if return_type == 'ndarray': | |
image_grids = np.stack(image_grids, axis=0) | |
return image_grids | |
else: | |
raise ValueError(f'Unknown input type: {type(images_or_videos[0])}') |