Spaces:
Running
on
Zero
Running
on
Zero
# Project EmbodiedGen | |
# | |
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
# implied. See the License for the specific language governing | |
# permissions and limitations under the License. | |
import argparse | |
import os | |
import random | |
from collections import defaultdict | |
import numpy as np | |
import torch | |
from PIL import Image | |
from embodied_gen.models.image_comm_model import build_hf_image_pipeline | |
from embodied_gen.models.segment_model import RembgRemover | |
from embodied_gen.models.text_model import PROMPT_APPEND | |
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api | |
from embodied_gen.utils.gpt_clients import GPT_CLIENT | |
from embodied_gen.utils.log import logger | |
from embodied_gen.utils.process_media import ( | |
check_object_edge_truncated, | |
render_asset3d, | |
) | |
from embodied_gen.validators.quality_checkers import ( | |
ImageSegChecker, | |
SemanticConsistChecker, | |
TextGenAlignChecker, | |
) | |
# Avoid huggingface/tokenizers: The current process just got forked. | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
random.seed(0) | |
logger.info("Loading Models...") | |
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT) | |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT) | |
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT) | |
PIPE_IMG = build_hf_image_pipeline(os.environ.get("TEXT_MODEL", "sd35")) | |
BG_REMOVER = RembgRemover() | |
__all__ = [ | |
"text_to_image", | |
"text_to_3d", | |
] | |
def text_to_image( | |
prompt: str, | |
save_path: str, | |
n_retry: int, | |
img_denoise_step: int, | |
text_guidance_scale: float, | |
n_img_sample: int, | |
image_hw: tuple[int, int] = (1024, 1024), | |
seed: int = None, | |
) -> bool: | |
select_image = None | |
success_flag = False | |
assert save_path.endswith(".png"), "Image save path must end with `.png`." | |
for try_idx in range(n_retry): | |
if select_image is not None: | |
select_image[0].save(save_path.replace(".png", "_raw.png")) | |
select_image[1].save(save_path) | |
break | |
f_prompt = PROMPT_APPEND.format(object=prompt) | |
logger.info( | |
f"Image GEN for {os.path.basename(save_path)}\n" | |
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}" | |
) | |
torch.cuda.empty_cache() | |
images = PIPE_IMG.run( | |
f_prompt, | |
num_inference_steps=img_denoise_step, | |
guidance_scale=text_guidance_scale, | |
num_images_per_prompt=n_img_sample, | |
height=image_hw[0], | |
width=image_hw[1], | |
generator=( | |
torch.Generator().manual_seed(seed) | |
if seed is not None | |
else None | |
), | |
) | |
for idx in range(len(images)): | |
raw_image: Image.Image = images[idx] | |
image = BG_REMOVER(raw_image) | |
image.save(save_path) | |
semantic_flag, semantic_result = SEMANTIC_CHECKER( | |
prompt, [image.convert("RGB")] | |
) | |
seg_flag, seg_result = SEG_CHECKER( | |
[raw_image, image.convert("RGB")] | |
) | |
image_mask = np.array(image)[..., -1] | |
edge_flag = check_object_edge_truncated(image_mask) | |
logger.warning( | |
f"SEMANTIC: {semantic_result}. SEG: {seg_result}. EDGE: {edge_flag}" | |
) | |
if ( | |
(edge_flag and semantic_flag and seg_flag) | |
or (edge_flag and semantic_flag is None) | |
or (edge_flag and seg_flag is None) | |
): | |
select_image = [raw_image, image] | |
success_flag = True | |
break | |
seed = random.randint(0, 100000) if seed is not None else None | |
return success_flag | |
def text_to_3d(**kwargs) -> dict: | |
args = parse_args() | |
for k, v in kwargs.items(): | |
if hasattr(args, k) and v is not None: | |
setattr(args, k, v) | |
if args.asset_names is None or len(args.asset_names) == 0: | |
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))] | |
img_save_dir = os.path.join(args.output_root, "images") | |
asset_save_dir = os.path.join(args.output_root, "asset3d") | |
os.makedirs(img_save_dir, exist_ok=True) | |
os.makedirs(asset_save_dir, exist_ok=True) | |
results = defaultdict(dict) | |
for prompt, node in zip(args.prompts, args.asset_names): | |
success_flag = False | |
n_pipe_retry = args.n_pipe_retry | |
seed_img = args.seed_img | |
seed_3d = args.seed_3d | |
while success_flag is False and n_pipe_retry > 0: | |
logger.info( | |
f"GEN pipeline for node {node}\n" | |
f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}" | |
) | |
# Text-to-image GEN | |
save_node = node.replace(" ", "_") | |
gen_image_path = f"{img_save_dir}/{save_node}.png" | |
textgen_flag = text_to_image( | |
prompt, | |
gen_image_path, | |
args.n_image_retry, | |
args.img_denoise_step, | |
args.text_guidance_scale, | |
args.n_img_sample, | |
seed=seed_img, | |
) | |
# Asset 3D GEN | |
node_save_dir = f"{asset_save_dir}/{save_node}" | |
asset_type = node if "sample3d_" not in node else None | |
imageto3d_api( | |
image_path=[gen_image_path], | |
output_root=node_save_dir, | |
asset_type=[asset_type], | |
seed=random.randint(0, 100000) if seed_3d is None else seed_3d, | |
n_retry=args.n_asset_retry, | |
keep_intermediate=args.keep_intermediate, | |
) | |
mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj" | |
image_path = render_asset3d( | |
mesh_path, | |
output_root=f"{node_save_dir}/result", | |
num_images=6, | |
elevation=(30, -30), | |
output_subdir="renders", | |
no_index_file=True, | |
) | |
check_text = asset_type if asset_type is not None else prompt | |
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path) | |
logger.warning( | |
f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}" | |
) | |
results["assets"][node] = f"{node_save_dir}/result" | |
results["quality"][node] = qa_result | |
if qa_flag is None or qa_flag is True: | |
success_flag = True | |
break | |
n_pipe_retry -= 1 | |
seed_img = ( | |
random.randint(0, 100000) if seed_img is not None else None | |
) | |
seed_3d = ( | |
random.randint(0, 100000) if seed_3d is not None else None | |
) | |
torch.cuda.empty_cache() | |
return results | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="3D Layout Generation Config") | |
parser.add_argument("--prompts", nargs="+", help="text descriptions") | |
parser.add_argument( | |
"--output_root", | |
type=str, | |
help="Directory to save outputs", | |
) | |
parser.add_argument( | |
"--asset_names", | |
type=str, | |
nargs="+", | |
default=None, | |
help="Asset names to generate", | |
) | |
parser.add_argument( | |
"--n_img_sample", | |
type=int, | |
default=3, | |
help="Number of image samples to generate", | |
) | |
parser.add_argument( | |
"--text_guidance_scale", | |
type=float, | |
default=7, | |
help="Text-to-image guidance scale", | |
) | |
parser.add_argument( | |
"--img_denoise_step", | |
type=int, | |
default=25, | |
help="Denoising steps for image generation", | |
) | |
parser.add_argument( | |
"--n_image_retry", | |
type=int, | |
default=2, | |
help="Max retry count for image generation", | |
) | |
parser.add_argument( | |
"--n_asset_retry", | |
type=int, | |
default=2, | |
help="Max retry count for 3D generation", | |
) | |
parser.add_argument( | |
"--n_pipe_retry", | |
type=int, | |
default=1, | |
help="Max retry count for 3D asset generation", | |
) | |
parser.add_argument( | |
"--seed_img", | |
type=int, | |
default=None, | |
help="Random seed for image generation", | |
) | |
parser.add_argument( | |
"--seed_3d", | |
type=int, | |
default=0, | |
help="Random seed for 3D generation", | |
) | |
parser.add_argument("--keep_intermediate", action="store_true") | |
args, unknown = parser.parse_known_args() | |
return args | |
if __name__ == "__main__": | |
text_to_3d() | |