Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
import logging | |
import math | |
import mimetypes | |
import os | |
import textwrap | |
from glob import glob | |
from typing import Union | |
import cv2 | |
import imageio | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import numpy as np | |
import spaces | |
from matplotlib.patches import Patch | |
from moviepy.editor import VideoFileClip, clips_array | |
from PIL import Image | |
from embodied_gen.data.differentiable_render import entrypoint as render_api | |
from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
__all__ = [ | |
"render_asset3d", | |
"merge_images_video", | |
"filter_small_connected_components", | |
"filter_image_small_connected_components", | |
"combine_images_to_grid", | |
"SceneTreeVisualizer", | |
"is_image_file", | |
"parse_text_prompts", | |
"check_object_edge_truncated", | |
] | |
def render_asset3d( | |
mesh_path: str, | |
output_root: str, | |
distance: float = 5.0, | |
num_images: int = 1, | |
elevation: list[float] = (0.0,), | |
pbr_light_factor: float = 1.2, | |
return_key: str = "image_color/*", | |
output_subdir: str = "renders", | |
gen_color_mp4: bool = False, | |
gen_viewnormal_mp4: bool = False, | |
gen_glonormal_mp4: bool = False, | |
no_index_file: bool = False, | |
with_mtl: bool = True, | |
) -> list[str]: | |
input_args = dict( | |
mesh_path=mesh_path, | |
output_root=output_root, | |
uuid=output_subdir, | |
distance=distance, | |
num_images=num_images, | |
elevation=elevation, | |
pbr_light_factor=pbr_light_factor, | |
with_mtl=with_mtl, | |
gen_color_mp4=gen_color_mp4, | |
gen_viewnormal_mp4=gen_viewnormal_mp4, | |
gen_glonormal_mp4=gen_glonormal_mp4, | |
no_index_file=no_index_file, | |
) | |
try: | |
_ = render_api(**input_args) | |
except Exception as e: | |
logger.error(f"Error occurred during rendering: {e}.") | |
dst_paths = glob(os.path.join(output_root, output_subdir, return_key)) | |
return dst_paths | |
def merge_images_video(color_images, normal_images, output_path) -> None: | |
width = color_images[0].shape[1] | |
combined_video = [ | |
np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]]) | |
for rgb_img, normal_img in zip(color_images, normal_images) | |
] | |
imageio.mimsave(output_path, combined_video, fps=50) | |
return | |
def merge_video_video( | |
video_path1: str, video_path2: str, output_path: str | |
) -> None: | |
"""Merge two videos by the left half and the right half of the videos.""" | |
clip1 = VideoFileClip(video_path1) | |
clip2 = VideoFileClip(video_path2) | |
if clip1.size != clip2.size: | |
raise ValueError("The resolutions of the two videos do not match.") | |
width, height = clip1.size | |
clip1_half = clip1.crop(x1=0, y1=0, x2=width // 2, y2=height) | |
clip2_half = clip2.crop(x1=width // 2, y1=0, x2=width, y2=height) | |
final_clip = clips_array([[clip1_half, clip2_half]]) | |
final_clip.write_videofile(output_path, codec="libx264") | |
def filter_small_connected_components( | |
mask: Union[Image.Image, np.ndarray], | |
area_ratio: float, | |
connectivity: int = 8, | |
) -> np.ndarray: | |
if isinstance(mask, Image.Image): | |
mask = np.array(mask) | |
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( | |
mask, | |
connectivity=connectivity, | |
) | |
small_components = np.zeros_like(mask, dtype=np.uint8) | |
mask_area = (mask != 0).sum() | |
min_area = mask_area // area_ratio | |
for label in range(1, num_labels): | |
area = stats[label, cv2.CC_STAT_AREA] | |
if area < min_area: | |
small_components[labels == label] = 255 | |
mask = cv2.bitwise_and(mask, cv2.bitwise_not(small_components)) | |
return mask | |
def filter_image_small_connected_components( | |
image: Union[Image.Image, np.ndarray], | |
area_ratio: float = 10, | |
connectivity: int = 8, | |
) -> np.ndarray: | |
if isinstance(image, Image.Image): | |
image = image.convert("RGBA") | |
image = np.array(image) | |
mask = image[..., 3] | |
mask = filter_small_connected_components(mask, area_ratio, connectivity) | |
image[..., 3] = mask | |
return image | |
def combine_images_to_grid( | |
images: list[str | Image.Image], | |
cat_row_col: tuple[int, int] = None, | |
target_wh: tuple[int, int] = (512, 512), | |
) -> list[str | Image.Image]: | |
n_images = len(images) | |
if n_images == 1: | |
return images | |
if cat_row_col is None: | |
n_col = math.ceil(math.sqrt(n_images)) | |
n_row = math.ceil(n_images / n_col) | |
else: | |
n_row, n_col = cat_row_col | |
images = [ | |
Image.open(p).convert("RGB") if isinstance(p, str) else p | |
for p in images | |
] | |
images = [img.resize(target_wh) for img in images] | |
grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1] | |
grid = Image.new("RGB", (grid_w, grid_h), (0, 0, 0)) | |
for idx, img in enumerate(images): | |
row, col = divmod(idx, n_col) | |
grid.paste(img, (col * target_wh[0], row * target_wh[1])) | |
return [grid] | |
class SceneTreeVisualizer: | |
def __init__(self, layout_info: LayoutInfo) -> None: | |
self.tree = layout_info.tree | |
self.relation = layout_info.relation | |
self.objs_desc = layout_info.objs_desc | |
self.G = nx.DiGraph() | |
self.root = self._find_root() | |
self._build_graph() | |
self.role_colors = { | |
Scene3DItemEnum.BACKGROUND.value: "plum", | |
Scene3DItemEnum.CONTEXT.value: "lightblue", | |
Scene3DItemEnum.ROBOT.value: "lightcoral", | |
Scene3DItemEnum.MANIPULATED_OBJS.value: "lightgreen", | |
Scene3DItemEnum.DISTRACTOR_OBJS.value: "lightgray", | |
Scene3DItemEnum.OTHERS.value: "orange", | |
} | |
def _find_root(self) -> str: | |
children = {c for cs in self.tree.values() for c, _ in cs} | |
parents = set(self.tree.keys()) | |
roots = parents - children | |
if not roots: | |
raise ValueError("No root node found.") | |
return next(iter(roots)) | |
def _build_graph(self): | |
for parent, children in self.tree.items(): | |
for child, relation in children: | |
self.G.add_edge(parent, child, relation=relation) | |
def _get_node_role(self, node: str) -> str: | |
if node == self.relation.get(Scene3DItemEnum.BACKGROUND.value): | |
return Scene3DItemEnum.BACKGROUND.value | |
if node == self.relation.get(Scene3DItemEnum.CONTEXT.value): | |
return Scene3DItemEnum.CONTEXT.value | |
if node == self.relation.get(Scene3DItemEnum.ROBOT.value): | |
return Scene3DItemEnum.ROBOT.value | |
if node in self.relation.get( | |
Scene3DItemEnum.MANIPULATED_OBJS.value, [] | |
): | |
return Scene3DItemEnum.MANIPULATED_OBJS.value | |
if node in self.relation.get( | |
Scene3DItemEnum.DISTRACTOR_OBJS.value, [] | |
): | |
return Scene3DItemEnum.DISTRACTOR_OBJS.value | |
return Scene3DItemEnum.OTHERS.value | |
def _get_positions( | |
self, root, width=1.0, vert_gap=0.1, vert_loc=1, xcenter=0.5, pos=None | |
): | |
if pos is None: | |
pos = {root: (xcenter, vert_loc)} | |
else: | |
pos[root] = (xcenter, vert_loc) | |
children = list(self.G.successors(root)) | |
if children: | |
dx = width / len(children) | |
next_x = xcenter - width / 2 - dx / 2 | |
for child in children: | |
next_x += dx | |
pos = self._get_positions( | |
child, | |
width=dx, | |
vert_gap=vert_gap, | |
vert_loc=vert_loc - vert_gap, | |
xcenter=next_x, | |
pos=pos, | |
) | |
return pos | |
def render( | |
self, | |
save_path: str, | |
figsize=(8, 6), | |
dpi=300, | |
title: str = "Scene 3D Hierarchy Tree", | |
): | |
node_colors = [ | |
self.role_colors[self._get_node_role(n)] for n in self.G.nodes | |
] | |
pos = self._get_positions(self.root) | |
plt.figure(figsize=figsize) | |
nx.draw( | |
self.G, | |
pos, | |
with_labels=True, | |
arrows=False, | |
node_size=2000, | |
node_color=node_colors, | |
font_size=10, | |
font_weight="bold", | |
) | |
# Draw edge labels | |
edge_labels = nx.get_edge_attributes(self.G, "relation") | |
nx.draw_networkx_edge_labels( | |
self.G, | |
pos, | |
edge_labels=edge_labels, | |
font_size=9, | |
font_color="black", | |
) | |
# Draw small description text under each node (if available) | |
for node, (x, y) in pos.items(): | |
desc = self.objs_desc.get(node) | |
if desc: | |
wrapped = "\n".join(textwrap.wrap(desc, width=30)) | |
plt.text( | |
x, | |
y - 0.006, | |
wrapped, | |
fontsize=6, | |
ha="center", | |
va="top", | |
wrap=True, | |
color="black", | |
bbox=dict( | |
facecolor="dimgray", | |
edgecolor="darkgray", | |
alpha=0.1, | |
boxstyle="round,pad=0.2", | |
), | |
) | |
plt.title(title, fontsize=12) | |
task_desc = self.relation.get("task_desc", "") | |
if task_desc: | |
plt.suptitle( | |
f"Task Description: {task_desc}", fontsize=10, y=0.999 | |
) | |
plt.axis("off") | |
legend_handles = [ | |
Patch(facecolor=color, edgecolor='black', label=role) | |
for role, color in self.role_colors.items() | |
] | |
plt.legend( | |
handles=legend_handles, | |
loc="lower center", | |
ncol=3, | |
bbox_to_anchor=(0.5, -0.1), | |
fontsize=9, | |
) | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
plt.savefig(save_path, dpi=dpi, bbox_inches="tight") | |
plt.close() | |
def load_scene_dict(file_path: str) -> dict: | |
scene_dict = {} | |
with open(file_path, "r", encoding='utf-8') as f: | |
for line in f: | |
line = line.strip() | |
if not line or ":" not in line: | |
continue | |
scene_id, desc = line.split(":", 1) | |
scene_dict[scene_id.strip()] = desc.strip() | |
return scene_dict | |
def is_image_file(filename: str) -> bool: | |
mime_type, _ = mimetypes.guess_type(filename) | |
return mime_type is not None and mime_type.startswith('image') | |
def parse_text_prompts(prompts: list[str]) -> list[str]: | |
if len(prompts) == 1 and prompts[0].endswith(".txt"): | |
with open(prompts[0], "r") as f: | |
prompts = [ | |
line.strip() | |
for line in f | |
if line.strip() and not line.strip().startswith("#") | |
] | |
return prompts | |
def check_object_edge_truncated( | |
mask: np.ndarray, edge_threshold: int = 5 | |
) -> bool: | |
"""Checks if a binary object mask is truncated at the image edges. | |
Args: | |
mask: A 2D binary NumPy array where nonzero values indicate the object region. | |
edge_threshold: Number of pixels from each image edge to consider for truncation. | |
Defaults to 5. | |
Returns: | |
True if the object is fully enclosed (not truncated). | |
False if the object touches or crosses any image boundary. | |
""" | |
top = mask[:edge_threshold, :].any() | |
bottom = mask[-edge_threshold:, :].any() | |
left = mask[:, :edge_threshold].any() | |
right = mask[:, -edge_threshold:].any() | |
return not (top or bottom or left or right) | |
if __name__ == "__main__": | |
merge_video_video( | |
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa | |
"outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa | |
"merge.mp4", | |
) | |