HunyuanWorld-Demo / hy3dworld /models /layer_decomposer.py
mooki0's picture
Initial commit of Gradio app
57276d4 verified
import os
import json
import torch
from ..utils import sr_utils, seg_utils, inpaint_utils, layer_utils
class LayerDecomposition():
r"""LayerDecomposition is responsible for generating layers in a scene based on input images and masks.
It processes foreground objects, background layers, and sky regions using various models.
Args:
seed (int): Random seed for reproducibility.
strength (float): Strength of the layer generation.
threshold (int): Threshold for object detection.
ratio (float): Ratio for scaling objects.
grounding_model (str): Path to the grounding model for object detection.
zim_model_config (str): Configuration for the ZIM model.
zim_checkpoint (str): Path to the ZIM model checkpoint.
inpaint_model (str): Path to the inpainting model.
inpaint_fg_lora (str): Path to the LoRA weights for foreground inpainting.
inpaint_sky_lora (str): Path to the LoRA weights for sky inpainting.
scale (int): Scale factor for super-resolution.
device (str): Device to run the model on, either "cuda" or "cpu".
dilation_size (int): Size of the dilation for mask processing.
cfg_scale (float): Configuration scale for the model.
prompt_config (dict): Configuration for prompts used in the model.
"""
def __init__(self):
r"""Initialize the LayerDecomposition class with model paths and parameters."""
self.seed = 25
self.strength = 1.0
self.threshold = 20000
self.ratio = 1.5
self.grounding_model = "IDEA-Research/grounding-dino-tiny"
self.zim_model_config = "vit_l"
self.zim_checkpoint = "./ZIM/zim_vit_l_2092" # Add zim anything ckpt here
self.inpaint_model = "black-forest-labs/FLUX.1-Fill-dev"
self.inpaint_fg_lora = "tencent/HunyuanWorld-1"
self.inpaint_sky_lora = "tencent/HunyuanWorld-1"
self.scale = 2
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dilation_size = 80
self.cfg_scale = 5.0
self.prompt_config = {
"indoor": {
"positive_prompt": "",
"negative_prompt": (
"object, table, chair, seat, shelf, sofa, bed, bath, sink,"
"ceramic, wood, plant, tree, light, lamp, candle, television, electronics,"
"oven, fire, low-resolution, blur, mosaic, people")
},
"outdoor": {
"positive_prompt": "",
"negative_prompt": (
"object, chair, tree, plant, flower, grass, stone, rock,"
"building, hill, house, tower, light, lamp, low-resolution, blur, mosaic, people")
}
}
# Load models
print("============= now loading models ===============")
# super-resolution model
self.sr_model = sr_utils.build_sr_model(scale=self.scale, gpu_id=0)
print("============= load Super-Resolution models done ")
# segmentation model
self.zim_predictor = seg_utils.build_zim_model(
self.zim_model_config, self.zim_checkpoint, device='cuda:0')
self.gd_processor, self.gd_model = seg_utils.build_gd_model(
self.grounding_model, device='cuda:0')
print("============= load Segmentation models done ====")
# panorama inpaint model
self.inpaint_fg_model = inpaint_utils.build_inpaint_model(
self.inpaint_model,
self.inpaint_fg_lora,
subfolder="HunyuanWorld-PanoInpaint-Scene",
device=0
)
self.inpaint_sky_model = inpaint_utils.build_inpaint_model(
self.inpaint_model,
self.inpaint_sky_lora,
subfolder="HunyuanWorld-PanoInpaint-Sky",
device=0
)
print("============= load panorama inpaint models done =")
def __call__(self, input, layer):
r"""Generate layers based on the input images and masks.
Args:
input (str or list): Path to the input JSON file or a list of image information.
layer (int): Layer index to process (0 for foreground1, 1 for foreground2,
2 for sky).
Raises:
FileNotFoundError: If the input file does not exist.
ValueError: If the input file is not a JSON file or if the layer index is invalid.
TypeError: If the input is neither a string nor a list.
"""
torch.autocast(device_type=self.device,
dtype=torch.bfloat16).__enter__()
# Input handling and validation
if isinstance(input, str):
if not os.path.exists(input):
raise FileNotFoundError(f"Input file {input} does not exist.")
if not input.endswith('.json'):
raise ValueError("Input file must be a JSON file.")
with open(input, "r") as f:
img_infos = json.load(f)
img_infos = img_infos["output"]
elif isinstance(input, list):
img_infos = input
else:
raise TypeError("Input must be a JSON file path or a list.")
# Processing parameters
params = {
'scale': self.scale,
'seed': self.seed,
'threshold': self.threshold,
'ratio': self.ratio,
'strength': self.strength,
'dilation_size': self.dilation_size,
'cfg_scale': self.cfg_scale,
'prompt_config': self.prompt_config
}
# Layer-specific processing pipelines
if layer == 0:
layer_utils.remove_fg1_pipeline(
img_infos=img_infos,
sr_model=self.sr_model,
zim_predictor=self.zim_predictor,
gd_processor=self.gd_processor,
gd_model=self.gd_model,
inpaint_model=self.inpaint_fg_model,
params=params
)
elif layer == 1:
layer_utils.remove_fg2_pipeline(
img_infos=img_infos,
sr_model=self.sr_model,
zim_predictor=self.zim_predictor,
gd_processor=self.gd_processor,
gd_model=self.gd_model,
inpaint_model=self.inpaint_fg_model,
params=params
)
else:
layer_utils.sky_pipeline(
img_infos=img_infos,
sr_model=self.sr_model,
zim_predictor=self.zim_predictor,
gd_processor=self.gd_processor,
gd_model=self.gd_model,
inpaint_model=self.inpaint_sky_model,
params=params
)