File size: 9,608 Bytes
49ffc6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec1f0d
 
 
 
 
 
49ffc6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from tqdm import tqdm
import gc

from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3


LORA_CONFIG = {
    "None": {
        "repo_id": None,
        "filename": None,
        "type": "edit",
        "method": "none",
        "prompt_template": "{prompt}",
        "description": "Use the base Qwen-Image-Edit model without any LoRA.",
    },
    "InStyle (Style Transfer)": {
        "repo_id": "peteromallet/Qwen-Image-Edit-InStyle",
        "filename": "InStyle-0.5.safetensors",
        "type": "style",
        "method": "manual_fuse",
        "prompt_template": "Make an image in this style of {prompt}",
        "description": "Transfers the style from a reference image to a new image described by the prompt.",
    },
    "InScene (In-Scene Editing)": {
        "repo_id": "flymy-ai/qwen-image-edit-inscene-lora",
        "filename": "flymy_qwen_image_edit_inscene_lora.safetensors",
        "type": "edit",
        "method": "standard",
        "prompt_template": "{prompt}",
        "description": "Improves in-scene editing, object positioning, and camera perspective changes.",
    },
    "Face Segmentation": {
        "repo_id": "TsienDragon/qwen-image-edit-lora-face-segmentation",
        "filename": "pytorch_lora_weights.safetensors",
        "type": "edit",
        "method": "standard",
        "prompt_template": "change the face to face segmentation mask",
        "description": "Transforms a facial image into a precise segmentation mask.",
    },
    "Object Remover": {
        "repo_id": "valiantcat/Qwen-Image-Edit-Remover-General-LoRA",
        "filename": "qwen-edit-remover.safetensors",
        "type": "edit",
        "method": "standard",
        "prompt_template": "Remove {prompt}",
        "description": "Removes objects from an image while maintaining background consistency.",
    },
}

print("Initializing model...")
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = QwenImageEditPipeline.from_pretrained(
    "Qwen/Qwen-Image-Edit", 
    torch_dtype=dtype
).to(device)

pipe.transformer.__class__ = QwenImageTransformer2DModel
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())

original_transformer_state_dict = pipe.transformer.state_dict()
print("Base model loaded and ready.")

def fuse_lora_manual(transformer, lora_state_dict, alpha=1.0):
    key_mapping = {}
    for key in lora_state_dict.keys():
        base_key = key.replace('diffusion_model.', '').rsplit('.lora_', 1)[0]
        if base_key not in key_mapping:
            key_mapping[base_key] = {}
        if 'lora_A' in key:
            key_mapping[base_key]['down'] = lora_state_dict[key]
        elif 'lora_B' in key:
            key_mapping[base_key]['up'] = lora_state_dict[key]

    for name, module in tqdm(transformer.named_modules(), desc="Fusing layers"):
        if name in key_mapping and isinstance(module, torch.nn.Linear):
            lora_weights = key_mapping[name]
            if 'down' in lora_weights and 'up' in lora_weights:
                device = module.weight.device
                dtype = module.weight.dtype
                lora_down = lora_weights['down'].to(device, dtype=dtype)
                lora_up = lora_weights['up'].to(device, dtype=dtype)
                merged_delta = lora_up @ lora_down
                module.weight.data += alpha * merged_delta
    return transformer

def load_and_fuse_lora(lora_name):
    """Carrega uma LoRA, funde-a ao modelo e retorna o pipeline modificado."""
    config = LORA_CONFIG[lora_name]
    
    print("Resetting transformer to original state...")
    pipe.transformer.load_state_dict(original_transformer_state_dict)
    
    if config["method"] == "none":
        print("No LoRA selected. Using base model.")
        return

    print(f"Loading LoRA: {lora_name}")
    lora_path = hf_hub_download(repo_id=config["repo_id"], filename=config["filename"])

    if config["method"] == "standard":
        print("Using standard loading method...")
        pipe.load_lora_weights(lora_path)
        print("Fusing LoRA into the model...")
        pipe.fuse_lora()
    elif config["method"] == "manual_fuse":
        print("Using manual fusion method...")
        lora_state_dict = load_file(lora_path)
        pipe.transformer = fuse_lora_manual(pipe.transformer, lora_state_dict)
    
    gc.collect()
    torch.cuda.empty_cache()
    print(f"LoRA '{lora_name}' is now active.")

@spaces.GPU(duration=60)
def infer(
    lora_name,
    input_image,
    style_image,
    prompt,
    seed,
    randomize_seed,
    true_guidance_scale,
    num_inference_steps,
    progress=gr.Progress(track_tqdm=True),
):
    if not lora_name:
        raise gr.Error("Please select a LoRA model.")
    
    config = LORA_CONFIG[lora_name]
    
    if config["type"] == "style":
        if style_image is None:
            raise gr.Error("Style Transfer LoRA requires a Style Reference Image.")
        image_for_pipeline = style_image
    else: # 'edit'
        if input_image is None:
            raise gr.Error("This LoRA requires an Input Image.")
        image_for_pipeline = input_image

    if not prompt and config["prompt_template"] != "change the face to face segmentation mask":
        raise gr.Error("A text prompt is required for this LoRA.")
    
    load_and_fuse_lora(lora_name)
    
    final_prompt = config["prompt_template"].format(prompt=prompt)
    
    if randomize_seed:
        seed = random.randint(0, np.iinfo(np.int32).max)
    generator = torch.Generator(device=device).manual_seed(int(seed))
    
    print("--- Running Inference ---")
    print(f"LoRA: {lora_name}")
    print(f"Prompt: {final_prompt}")
    print(f"Seed: {seed}, Steps: {num_inference_steps}, CFG: {true_guidance_scale}")
    
    with torch.inference_mode():
        result_image = pipe(
            image=image_for_pipeline,
            prompt=final_prompt,
            negative_prompt=" ",
            num_inference_steps=int(num_inference_steps),
            generator=generator,
            true_cfg_scale=true_guidance_scale,
        ).images[0]
        
    pipe.unfuse_lora()
    gc.collect()
    torch.cuda.empty_cache()
    
    return result_image, seed

def on_lora_change(lora_name):
    config = LORA_CONFIG[lora_name]
    is_style_lora = config["type"] == "style"
    return {
        lora_description: gr.Markdown(visible=True, value=f"**Description:** {config['description']}"),
        input_image_box: gr.Image(visible=not is_style_lora),
        style_image_box: gr.Image(visible=is_style_lora),
        prompt_box: gr.Textbox(visible=(config["prompt_template"] != "change the face to face segmentation mask"))
    }

with gr.Blocks(css="#col-container { margin: 0 auto; max-width: 1024px; }") as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML('<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_edit_logo.png" alt="Qwen-Image Logo" style="width: 400px; margin: 0 auto; display: block;">')
        gr.Markdown("<h2 style='text-align: center;'>Qwen-Image-Edit Multi-LoRA Playground</h2>")

        with gr.Row():
            with gr.Column(scale=1):
                lora_selector = gr.Dropdown(
                    label="Select LoRA Model",
                    choices=list(LORA_CONFIG.keys()),
                    value="InStyle (Style Transfer)"
                )
                lora_description = gr.Markdown(visible=False)
                
                input_image_box = gr.Image(label="Input Image", type="pil", visible=False)
                style_image_box = gr.Image(label="Style Reference Image", type="pil", visible=True)
                
                prompt_box = gr.Textbox(label="Prompt", placeholder="Describe the content or object to remove...")
                
                run_button = gr.Button("Generate!", variant="primary")

            with gr.Column(scale=1):
                result_image = gr.Image(label="Result", type="pil")
                used_seed = gr.Number(label="Used Seed", interactive=False)

        with gr.Accordion("Advanced Settings", open=False):
            seed_slider = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=42)
            randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
            cfg_slider = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, step=0.1, value=4.0)
            steps_slider = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=25)

        lora_selector.change(
            fn=on_lora_change,
            inputs=lora_selector,
            outputs=[lora_description, input_image_box, style_image_box, prompt_box]
        )

        demo.load(
            fn=on_lora_change,
            inputs=lora_selector,
            outputs=[lora_description, input_image_box, style_image_box, prompt_box]
        )

        run_button.click(
            fn=infer,
            inputs=[
                lora_selector,
                input_image_box, style_image_box,
                prompt_box,
                seed_slider, randomize_seed_checkbox,
                cfg_slider, steps_slider
            ],
            outputs=[result_image, used_seed]
        )

if __name__ == "__main__":
    demo.launch()