File size: 16,617 Bytes
66dfd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image
import math

# --- nag_app.pyから移植した機能 ---
# 翻訳ライブラリのインポート
from deep_translator import GoogleTranslator
from langdetect import detect

# NAG対応パイプラインのインポート
# 注: このコードを実行するには、nag_app.pyのHugging Face Spaceから
# `src`ディレクトリ(pipeline_flux_kontext_nag.pyとtransformer_flux.pyを含む)を
# このファイルと同じ階層に配置する必要があります。
from src.pipeline_flux_kontext_nag import NAGFluxKontextPipeline
from src.transformer_flux import NAGFluxTransformer2DModel
# --- ここまでが移植部分 ---

# エラー解決のためにdiffusersの内部マッピングをインポート
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING


# 定数の設定
MAX_SEED = np.iinfo(np.int32).max
DEFAULT_NAG_NEGATIVE_PROMPT = "Low resolution, blurry, lack of details, big head"
OUTPUT_RESOLUTION = 1024

# --- nag_app.pyから移植したモデル読み込み処理 ---
# NAG対応のKontextモデルをロード
transformer = NAGFluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-Kontext-dev",
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)
pipe = NAGFluxKontextPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Kontext-dev",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipe = pipe.to("cuda")
# --- ここまでが移植部分 ---

# --- LoRAの読み込み処理 (5つ) ---
print("Loading LoRA weights...")
# LoRA名とアダプター名のマッピング
LORA_MAPPING = {
    "Hyper-SD": "hyper",
    "Relighting": "relight",
    "LoRA 3": "lora_3",
    "LoRA 4": "lora_4",
    "LoRA 5": "lora_5",
}
# 1. Hyper-SD LoRA
pipe.load_lora_weights(
    "ByteDance/Hyper-SD",
    weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors",
    adapter_name=LORA_MAPPING["Hyper-SD"]
)
# 2. Relighting LoRA
pipe.load_lora_weights(
    "linoyts/relighting-kontext-dev-lora",
    weight_name="relighting-kontext-dev-lora.safetensors",
    adapter_name=LORA_MAPPING["Relighting"]
)
# 3. 追加のLoRA 3 (後で設定)
# ★ 注意: 以下のリポジトリ名とファイル名は仮のものです。後で正しいものに置き換えてください。
try:
    pipe.load_lora_weights(
        "author/repo_name_3", # 例: "cagliostrolab/animagine-xl-3.0"
        weight_name="lora_file_3.safetensors", # 例: "animagine-xl-3.0.safetensors"
        adapter_name=LORA_MAPPING["LoRA 3"]
    )
except Exception as e:
    print(f"Warning: Could not load {list(LORA_MAPPING.keys())[2]}. Please check repository and file names. Error:", e)

# 4. 追加のLoRA 4 (後で設定)
try:
    pipe.load_lora_weights(
        "author/repo_name_4",
        weight_name="lora_file_4.safetensors",
        adapter_name=LORA_MAPPING["LoRA 4"]
    )
except Exception as e:
    print(f"Warning: Could not load {list(LORA_MAPPING.keys())[3]}. Please check repository and file names. Error:", e)

# 5. 追加のLoRA 5 (後で設定)
try:
    pipe.load_lora_weights(
        "author/repo_name_5",
        weight_name="lora_file_5.safetensors",
        adapter_name=LORA_MAPPING["LoRA 5"]
    )
except Exception as e:
    print(f"Warning: Could not load {list(LORA_MAPPING.keys())[4]}. Please check repository and file names. Error:", e)

print("LoRA weights loading process finished.")
# --- ここまでが変更部分 ---

# カスタムモデルをdiffusersのLoRA対応表に登録する
_SET_ADAPTER_SCALE_FN_MAPPING[NAGFluxTransformer2DModel.__name__] = _SET_ADAPTER_SCALE_FN_MAPPING["FluxTransformer2DModel"]
print("Custom model 'NAGFluxTransformer2DModel' registered for LoRA.")


def round_to_multiple(number, multiple=8):
    return multiple * round(number / multiple)

def concatenate_images(images, direction="horizontal"):
    if not images: return None
    valid_images = [img for img in images if img is not None]
    if not valid_images: return None
    if len(valid_images) == 1: return valid_images[0].convert("RGB")
    valid_images = [img.convert("RGB") for img in valid_images]
    if direction == "horizontal":
        total_width = sum(img.width for img in valid_images)
        max_height = max(img.height for img in valid_images)
        concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255))
        x_offset = 0
        for img in valid_images:
            y_offset = (max_height - img.height) // 2
            concatenated.paste(img, (x_offset, y_offset))
            x_offset += img.width
    else:
        max_width = max(img.width for img in valid_images)
        total_height = sum(img.height for img in valid_images)
        concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255))
        y_offset = 0
        for img in valid_images:
            x_offset = (max_width - img.width) // 2
            concatenated.paste(img, (x_offset, y_offset))
            y_offset += img.height
    return concatenated

@spaces.GPU(duration=25)
# ★ infer関数の引数に negative_prompt を追加
def infer(input_images, prompt, negative_prompt, seed, randomize_seed, guidance_scale, nag_negative_prompt, nag_scale, num_inference_steps,
          # LoRAの有効/無効と強度を個別に受け取る
          enable_lora1, weight_lora1,
          enable_lora2, weight_lora2,
          enable_lora3, weight_lora3,
          enable_lora4, weight_lora4,
          enable_lora5, weight_lora5,
          progress=gr.Progress(track_tqdm=True)):
    
    active_adapters = []
    active_weights = []
    
    lora_params = [
        (enable_lora1, weight_lora1, "Hyper-SD"),
        (enable_lora2, weight_lora2, "Relighting"),
        (enable_lora3, weight_lora3, "LoRA 3"),
        (enable_lora4, weight_lora4, "LoRA 4"),
        (enable_lora5, weight_lora5, "LoRA 5"),
    ]
    
    for is_enabled, weight, name in lora_params:
        if is_enabled:
            adapter_name = LORA_MAPPING[name]
            active_adapters.append(adapter_name)
            active_weights.append(weight)
            print(f"Applying {name} LoRA with weight {weight}")

    if active_adapters:
        pipe.set_adapters(active_adapters, adapter_weights=active_weights)
    else:
        print("No LoRA selected. Running without LoRA.")
        pipe.disable_lora()

    prompt = prompt.strip()
    if prompt:
        print(f"Original prompt: {prompt}")
        try:
            detected_lang = detect(prompt)
            if detected_lang != 'en':
                print(f"Detected language: {detected_lang}. Translating to English...")
                translated_prompt = GoogleTranslator(source=detected_lang, target='en').translate(prompt)
                prompt = translated_prompt
                print(f"Translated prompt: {prompt}")
            else:
                print("Prompt is already in English.")
        except Exception as e:
            print(f"Warning: Translation or language detection failed: {e}. Using original prompt.")

    # ★ negative_promptを処理するコードを追加
    negative_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else None

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    if input_images is None:
        raise gr.Error("Please upload at least one image.")
    
    if not isinstance(input_images, list):
        input_images = [input_images]
    
    valid_images = [img[0] for img in input_images if img is not None]
    
    if not valid_images:
        raise gr.Error("Please upload at least one valid image.")

    if len(valid_images) == 1:
        print("Single image detected. Calculating aspect-ratio aware dimensions.")
        input_for_pipe = valid_images[0]
        
        input_width, input_height = input_for_pipe.size
        aspect_ratio = input_width / input_height
        target_pixels = OUTPUT_RESOLUTION * OUTPUT_RESOLUTION
        
        final_height = int(math.sqrt(target_pixels / aspect_ratio))
        final_width = int(aspect_ratio * final_height)
        
        final_width = round_to_multiple(final_width, 8)
        final_height = round_to_multiple(final_height, 8)
        
        print(f"Output dimensions set to: {final_width}x{final_height}")

    else:
        print(f"Multiple ({len(valid_images)}) images detected. Using fixed 1024x1024 output.")
        input_for_pipe = concatenate_images(valid_images, "horizontal")
        if input_for_pipe is None:
            raise gr.Error("Failed to process the input images.")
        
        final_width = OUTPUT_RESOLUTION
        final_height = OUTPUT_RESOLUTION
    
    final_prompt = f"From the provided reference images, create a unified, cohesive image such that {prompt}. Maintain the identity and characteristics of each subject while adjusting their proportions, scale, and positioning to create a harmonious, naturally balanced composition. Blend and integrate all elements seamlessly with consistent lighting, perspective, and style.the final result should look like a single naturally captured scene where all subjects are properly sized and positioned relative to each other, not assembled from multiple sources."
    
    # ★ pipe()呼び出しに negative_prompt を追加
    image = pipe(
        image=input_for_pipe,
        prompt=final_prompt,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        nag_negative_prompt=nag_negative_prompt,
        nag_scale=nag_scale,
        width=final_width,
        height=final_height,
        num_inference_steps=num_inference_steps,
        generator=torch.Generator().manual_seed(seed),
    ).images[0]
    
    pipe.disable_lora()
    
    return image, seed, gr.update(visible=True)

css="""
#col-container {
    margin: 0 auto;
    max-width: 960px;
}
.lora-row {
    align-items: center;
    margin-bottom: 8px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 Kontext [dev] - Multi-Image with NAG
        Compose a new image from multiple images using FLUX.1 Kontext, enhanced with Normalized Attention Guidance (NAG) and automatic prompt translation.
        - **Single Image Input**: Output will match the input aspect ratio.
        - **Multiple Image Inputs**: Output will be a fixed 1024x1024 resolution.
        """)
        with gr.Row():
            with gr.Column():
                input_images = gr.Gallery(
                    label="Upload image(s) for editing", 
                    show_label=True,
                    elem_id="gallery_input",
                    columns=3,
                    rows=2,
                    object_fit="contain",
                    height="auto",
                    file_types=['image'],
                    type='pil'
                )
                
                with gr.Row():
                    prompt = gr.Text(
                        label="Prompt",
                        show_label=False,
                        max_lines=1,
                        placeholder="Enter your prompt (auto-translates to English)",
                        container=False,
                    )
                    run_button = gr.Button("Run", scale=0)
                    
                with gr.Accordion("Advanced Settings", open=False):
                    # --- ★ UIを修正: 各LoRAコンポーネントを個別の変数として定義 ---
                    gr.Markdown("### LoRA Settings")
                    
                    with gr.Row(elem_classes="lora-row"):
                        enable_lora1 = gr.Checkbox(label="Hyper-SD", value=True, scale=1)
                        weight_lora1 = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.02, value=0.12, scale=3, visible=True)
                    
                    with gr.Row(elem_classes="lora-row"):
                        enable_lora2 = gr.Checkbox(label="Relighting", value=False, scale=1)
                        weight_lora2 = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=1.0, scale=3, visible=False)

                    with gr.Row(elem_classes="lora-row"):
                        enable_lora3 = gr.Checkbox(label="LoRA 3", value=False, scale=1)
                        weight_lora3 = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.8, scale=3, visible=False)

                    with gr.Row(elem_classes="lora-row"):
                        enable_lora4 = gr.Checkbox(label="LoRA 4", value=False, scale=1)
                        weight_lora4 = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.8, scale=3, visible=False)

                    with gr.Row(elem_classes="lora-row"):
                        enable_lora5 = gr.Checkbox(label="LoRA 5", value=False, scale=1)
                        weight_lora5 = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.8, scale=3, visible=False)
                    # --- ★ ここまでが変更部分 ---

                    gr.Markdown("### Generation Settings")
                    
                    # ★ UIに negative_prompt を追加
                    negative_prompt = gr.Text(
                        label="Negative Prompt (Standard)",
                        placeholder="Enter concepts to avoid (e.g., ugly, deformed)",
                        max_lines=2,
                    )
                    
                    num_inference_steps = gr.Slider(
                        label="Inference Steps",
                        minimum=8,
                        maximum=50,
                        step=1,
                        value=8,
                    )
                    guidance_scale = gr.Slider(
                        label="Guidance Scale",
                        minimum=1,
                        maximum=10,
                        step=0.25,
                        value=4.5,
                    )
                    nag_negative_prompt = gr.Text(
                        label="Negative Prompt for NAG",
                        value=DEFAULT_NAG_NEGATIVE_PROMPT,
                        max_lines=2,
                        placeholder="Enter concepts to avoid with NAG",
                    )
                    nag_scale = gr.Slider(
                        label="NAG Scale",
                        minimum=0.0,
                        maximum=20.0,
                        step=0.25,
                        value=3.5
                    )
                    seed = gr.Slider(
                        label="Seed",
                        minimum=0,
                        maximum=MAX_SEED,
                        step=1,
                        value=0,
                    )
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
            with gr.Column():
                result = gr.Image(label="Result", show_label=False, interactive=False, format="png")
                reuse_button = gr.Button("Reuse this image", visible=False)
    
    # ★ イベントハンドラを更新: all_inputsに negative_prompt を追加
    all_inputs = [
        input_images, prompt, negative_prompt, seed, randomize_seed, guidance_scale, 
        nag_negative_prompt, nag_scale, num_inference_steps,
        enable_lora1, weight_lora1,
        enable_lora2, weight_lora2,
        enable_lora3, weight_lora3,
        enable_lora4, weight_lora4,
        enable_lora5, weight_lora5,
    ]

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs = all_inputs,
        outputs = [result, seed, reuse_button]
    )
    # --- ★ ここまでが変更部分 ---
    
    reuse_button.click(
        fn = lambda image: [image] if image is not None else [],
        inputs = [result],
        outputs = [input_images]
    )

    # --- ★ 各チェックボックスとスライダーの表示を個別に連動させる ---
    def update_visibility(is_checked):
        return gr.update(visible=is_checked)

    enable_lora1.change(fn=update_visibility, inputs=enable_lora1, outputs=weight_lora1)
    enable_lora2.change(fn=update_visibility, inputs=enable_lora2, outputs=weight_lora2)
    enable_lora3.change(fn=update_visibility, inputs=enable_lora3, outputs=weight_lora3)
    enable_lora4.change(fn=update_visibility, inputs=enable_lora4, outputs=weight_lora4)
    enable_lora5.change(fn=update_visibility, inputs=enable_lora5, outputs=weight_lora5)
    # --- ★ ここまでが変更部分 ---

demo.launch()