File size: 6,773 Bytes
57276d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import json
import torch
from ..utils import sr_utils, seg_utils, inpaint_utils, layer_utils


class LayerDecomposition():
    r"""LayerDecomposition is responsible for generating layers in a scene based on input images and masks.
    It processes foreground objects, background layers, and sky regions using various models.
    Args:
        seed (int): Random seed for reproducibility.
        strength (float): Strength of the layer generation.
        threshold (int): Threshold for object detection.
        ratio (float): Ratio for scaling objects.
        grounding_model (str): Path to the grounding model for object detection.
        zim_model_config (str): Configuration for the ZIM model.
        zim_checkpoint (str): Path to the ZIM model checkpoint.
        inpaint_model (str): Path to the inpainting model.
        inpaint_fg_lora (str): Path to the LoRA weights for foreground inpainting.
        inpaint_sky_lora (str): Path to the LoRA weights for sky inpainting.
        scale (int): Scale factor for super-resolution.
        device (str): Device to run the model on, either "cuda" or "cpu".
        dilation_size (int): Size of the dilation for mask processing.
        cfg_scale (float): Configuration scale for the model.
        prompt_config (dict): Configuration for prompts used in the model.
    """
    def __init__(self):
        r"""Initialize the LayerDecomposition class with model paths and parameters."""
        self.seed = 25
        self.strength = 1.0
        self.threshold = 20000
        self.ratio = 1.5
        self.grounding_model = "IDEA-Research/grounding-dino-tiny"
        self.zim_model_config = "vit_l"
        self.zim_checkpoint = "./ZIM/zim_vit_l_2092"  # Add zim anything ckpt here
        self.inpaint_model = "black-forest-labs/FLUX.1-Fill-dev"
        self.inpaint_fg_lora = "tencent/HunyuanWorld-1"
        self.inpaint_sky_lora = "tencent/HunyuanWorld-1"
        self.scale = 2
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dilation_size = 80
        self.cfg_scale = 5.0
        self.prompt_config = {
            "indoor": {
                "positive_prompt": "",
                "negative_prompt": (
                    "object, table, chair, seat, shelf, sofa, bed, bath, sink,"
                    "ceramic, wood, plant, tree, light, lamp, candle, television, electronics,"
                    "oven, fire, low-resolution, blur, mosaic, people")
            },
            "outdoor": {
                "positive_prompt": "",
                "negative_prompt": (
                    "object, chair, tree, plant, flower, grass, stone, rock," 
                    "building, hill, house, tower, light, lamp, low-resolution, blur, mosaic, people")
            }
        }

        # Load models
        print("=============  now loading models ===============")
        # super-resolution model
        self.sr_model = sr_utils.build_sr_model(scale=self.scale, gpu_id=0)
        print("=============  load Super-Resolution models done ")
        # segmentation model
        self.zim_predictor = seg_utils.build_zim_model(
            self.zim_model_config, self.zim_checkpoint, device='cuda:0')
        self.gd_processor, self.gd_model = seg_utils.build_gd_model(
            self.grounding_model, device='cuda:0')
        print("=============  load Segmentation models done ====")
        # panorama inpaint model
        self.inpaint_fg_model = inpaint_utils.build_inpaint_model(
            self.inpaint_model,
            self.inpaint_fg_lora,
            subfolder="HunyuanWorld-PanoInpaint-Scene",
            device=0
        )
        self.inpaint_sky_model = inpaint_utils.build_inpaint_model(
            self.inpaint_model,
            self.inpaint_sky_lora,
            subfolder="HunyuanWorld-PanoInpaint-Sky",
            device=0
        )
        print("=============  load panorama inpaint models done =")

    def __call__(self, input, layer):
        r"""Generate layers based on the input images and masks.
        Args:
            input (str or list): Path to the input JSON file or a list of image information.
            layer (int): Layer index to process (0 for foreground1, 1 for foreground2,
                         2 for sky).
        Raises:
            FileNotFoundError: If the input file does not exist.
            ValueError: If the input file is not a JSON file or if the layer index is invalid.
            TypeError: If the input is neither a string nor a list.
        """
        torch.autocast(device_type=self.device,
                       dtype=torch.bfloat16).__enter__()

        # Input handling and validation
        if isinstance(input, str):
            if not os.path.exists(input):
                raise FileNotFoundError(f"Input file {input} does not exist.")
            if not input.endswith('.json'):
                raise ValueError("Input file must be a JSON file.")
            with open(input, "r") as f:
                img_infos = json.load(f)
                img_infos = img_infos["output"]
        elif isinstance(input, list):
            img_infos = input
        else:
            raise TypeError("Input must be a JSON file path or a list.")

        # Processing parameters
        params = {
            'scale': self.scale,
            'seed': self.seed,
            'threshold': self.threshold,
            'ratio': self.ratio,
            'strength': self.strength,
            'dilation_size': self.dilation_size,
            'cfg_scale': self.cfg_scale,
            'prompt_config': self.prompt_config
        }

        # Layer-specific processing pipelines
        if layer == 0:
            layer_utils.remove_fg1_pipeline(
                img_infos=img_infos,
                sr_model=self.sr_model,
                zim_predictor=self.zim_predictor,
                gd_processor=self.gd_processor,
                gd_model=self.gd_model,
                inpaint_model=self.inpaint_fg_model,
                params=params
            )
        elif layer == 1:
            layer_utils.remove_fg2_pipeline(
                img_infos=img_infos,
                sr_model=self.sr_model,
                zim_predictor=self.zim_predictor,
                gd_processor=self.gd_processor,
                gd_model=self.gd_model,
                inpaint_model=self.inpaint_fg_model,
                params=params
            )
        else:
            layer_utils.sky_pipeline(
                img_infos=img_infos,
                sr_model=self.sr_model,
                zim_predictor=self.zim_predictor,
                gd_processor=self.gd_processor,
                gd_model=self.gd_model,
                inpaint_model=self.inpaint_sky_model,
                params=params
            )