xinjie.wang
update
631a83a
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import os
import random
from collections import defaultdict
import numpy as np
import torch
from PIL import Image
from embodied_gen.models.image_comm_model import build_hf_image_pipeline
from embodied_gen.models.segment_model import RembgRemover
from embodied_gen.models.text_model import PROMPT_APPEND
from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
from embodied_gen.utils.gpt_clients import GPT_CLIENT
from embodied_gen.utils.log import logger
from embodied_gen.utils.process_media import (
check_object_edge_truncated,
render_asset3d,
)
from embodied_gen.validators.quality_checkers import (
ImageSegChecker,
SemanticConsistChecker,
TextGenAlignChecker,
)
# Avoid huggingface/tokenizers: The current process just got forked.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
random.seed(0)
logger.info("Loading Models...")
SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
PIPE_IMG = build_hf_image_pipeline(os.environ.get("TEXT_MODEL", "sd35"))
BG_REMOVER = RembgRemover()
__all__ = [
"text_to_image",
"text_to_3d",
]
def text_to_image(
prompt: str,
save_path: str,
n_retry: int,
img_denoise_step: int,
text_guidance_scale: float,
n_img_sample: int,
image_hw: tuple[int, int] = (1024, 1024),
seed: int = None,
) -> bool:
select_image = None
success_flag = False
assert save_path.endswith(".png"), "Image save path must end with `.png`."
for try_idx in range(n_retry):
if select_image is not None:
select_image[0].save(save_path.replace(".png", "_raw.png"))
select_image[1].save(save_path)
break
f_prompt = PROMPT_APPEND.format(object=prompt)
logger.info(
f"Image GEN for {os.path.basename(save_path)}\n"
f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
)
torch.cuda.empty_cache()
images = PIPE_IMG.run(
f_prompt,
num_inference_steps=img_denoise_step,
guidance_scale=text_guidance_scale,
num_images_per_prompt=n_img_sample,
height=image_hw[0],
width=image_hw[1],
generator=(
torch.Generator().manual_seed(seed)
if seed is not None
else None
),
)
for idx in range(len(images)):
raw_image: Image.Image = images[idx]
image = BG_REMOVER(raw_image)
image.save(save_path)
semantic_flag, semantic_result = SEMANTIC_CHECKER(
prompt, [image.convert("RGB")]
)
seg_flag, seg_result = SEG_CHECKER(
[raw_image, image.convert("RGB")]
)
image_mask = np.array(image)[..., -1]
edge_flag = check_object_edge_truncated(image_mask)
logger.warning(
f"SEMANTIC: {semantic_result}. SEG: {seg_result}. EDGE: {edge_flag}"
)
if (
(edge_flag and semantic_flag and seg_flag)
or (edge_flag and semantic_flag is None)
or (edge_flag and seg_flag is None)
):
select_image = [raw_image, image]
success_flag = True
break
seed = random.randint(0, 100000) if seed is not None else None
return success_flag
def text_to_3d(**kwargs) -> dict:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
if args.asset_names is None or len(args.asset_names) == 0:
args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
img_save_dir = os.path.join(args.output_root, "images")
asset_save_dir = os.path.join(args.output_root, "asset3d")
os.makedirs(img_save_dir, exist_ok=True)
os.makedirs(asset_save_dir, exist_ok=True)
results = defaultdict(dict)
for prompt, node in zip(args.prompts, args.asset_names):
success_flag = False
n_pipe_retry = args.n_pipe_retry
seed_img = args.seed_img
seed_3d = args.seed_3d
while success_flag is False and n_pipe_retry > 0:
logger.info(
f"GEN pipeline for node {node}\n"
f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}"
)
# Text-to-image GEN
save_node = node.replace(" ", "_")
gen_image_path = f"{img_save_dir}/{save_node}.png"
textgen_flag = text_to_image(
prompt,
gen_image_path,
args.n_image_retry,
args.img_denoise_step,
args.text_guidance_scale,
args.n_img_sample,
seed=seed_img,
)
# Asset 3D GEN
node_save_dir = f"{asset_save_dir}/{save_node}"
asset_type = node if "sample3d_" not in node else None
imageto3d_api(
image_path=[gen_image_path],
output_root=node_save_dir,
asset_type=[asset_type],
seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
n_retry=args.n_asset_retry,
keep_intermediate=args.keep_intermediate,
)
mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
image_path = render_asset3d(
mesh_path,
output_root=f"{node_save_dir}/result",
num_images=6,
elevation=(30, -30),
output_subdir="renders",
no_index_file=True,
)
check_text = asset_type if asset_type is not None else prompt
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
logger.warning(
f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}"
)
results["assets"][node] = f"{node_save_dir}/result"
results["quality"][node] = qa_result
if qa_flag is None or qa_flag is True:
success_flag = True
break
n_pipe_retry -= 1
seed_img = (
random.randint(0, 100000) if seed_img is not None else None
)
seed_3d = (
random.randint(0, 100000) if seed_3d is not None else None
)
torch.cuda.empty_cache()
return results
def parse_args():
parser = argparse.ArgumentParser(description="3D Layout Generation Config")
parser.add_argument("--prompts", nargs="+", help="text descriptions")
parser.add_argument(
"--output_root",
type=str,
help="Directory to save outputs",
)
parser.add_argument(
"--asset_names",
type=str,
nargs="+",
default=None,
help="Asset names to generate",
)
parser.add_argument(
"--n_img_sample",
type=int,
default=3,
help="Number of image samples to generate",
)
parser.add_argument(
"--text_guidance_scale",
type=float,
default=7,
help="Text-to-image guidance scale",
)
parser.add_argument(
"--img_denoise_step",
type=int,
default=25,
help="Denoising steps for image generation",
)
parser.add_argument(
"--n_image_retry",
type=int,
default=2,
help="Max retry count for image generation",
)
parser.add_argument(
"--n_asset_retry",
type=int,
default=2,
help="Max retry count for 3D generation",
)
parser.add_argument(
"--n_pipe_retry",
type=int,
default=1,
help="Max retry count for 3D asset generation",
)
parser.add_argument(
"--seed_img",
type=int,
default=None,
help="Random seed for image generation",
)
parser.add_argument(
"--seed_3d",
type=int,
default=0,
help="Random seed for 3D generation",
)
parser.add_argument("--keep_intermediate", action="store_true")
args, unknown = parser.parse_known_args()
return args
if __name__ == "__main__":
text_to_3d()