File size: 46,879 Bytes
486e6d4
73a8cdd
 
 
 
 
 
 
 
 
 
 
486e6d4
73a8cdd
486e6d4
73a8cdd
 
 
 
486e6d4
90c8956
 
 
 
 
73a8cdd
 
 
 
 
 
 
 
 
bcd8020
73a8cdd
 
 
 
 
486e6d4
73a8cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10e2b36
73a8cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0a2154
 
73a8cdd
 
 
 
f0a2154
 
73a8cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6f1d21
73a8cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0a2154
 
73a8cdd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
import os
import json
import base64
import requests
import gradio as gr
from PIL import Image, PngImagePlugin
import io
import cv2
from datetime import datetime
import time
import dotenv
from pathlib import Path

dotenv.load_dotenv()

TAG_CANDIDATE_GENERATION_TEMPLATE = os.getenv("TAG_CANDIDATE_GENERATION_TEMPLATE")
TAG_NORMALIZATION_TEMPLATE = os.getenv("TAG_NORMALIZATION_TEMPLATE")
TAG_FILTER_TEMPLATE = os.getenv("TAG_FILTER_TEMPLATE")
TAG_WEIGHT_TEMPLATE = os.getenv("TAG_WEIGHT_TEMPLATE")

print(TAG_CANDIDATE_GENERATION_TEMPLATE[0:10])
print(TAG_NORMALIZATION_TEMPLATE[0:10])
print(TAG_FILTER_TEMPLATE[0:10])
print(TAG_WEIGHT_TEMPLATE[0:10])

# --- Model Lists ---
sd_models = [
    # "animagine-xl-3.1.safetensors [e3c47aedb0]",
    "animagine-xl-4.0.safetensors",
    # "animagine-xl.safetensors",
    # "LoRAMergeModel_animepose_outline_sotai_2025_04.fp16.safetensors [22dc3c53d2]",
]
cn_models =[
  "Bodyline2Image_v2-000350 [15294a3f]",
#   "Bodyline2Image_v2-000400 [f87a785e]",
#   "CN-anytest_v3-50000_fp16 [0963df47]",
#   "CN-anytest_v4-marged [4bb64990]",
]
controlnet_modules = ["invert", "none"] # Example modules
controlnet_modes = ["Balanced", "My prompt is more important", "ControlNet is more important"] # Example modes

def save_images(images: list, output_dir: Path, prefix: str = "generated", subfolder: str = None):
    """Base64エンコードされた画像リストを保存する"""
    # タイムスタンプ付きのディレクトリ名を作成
    output_dir = output_dir / subfolder
    output_dir.mkdir(parents=True, exist_ok=True)
    saved_paths = []
    
    for i, img_base64 in enumerate(images):
        img_data = base64.b64decode(img_base64)
        output_path = output_dir / f"{prefix}_{i}.png"
        
        with open(output_path, "wb") as f:
            f.write(img_data)
        saved_paths.append(output_path)
    
    return saved_paths

# テキストからプロンプトを生成するAPI呼び出し
def generate_prompt_from_text(text_prompt, tag_candidate_generation_template=None, 
                             tag_normalization_template=None, tag_filter_template=None, 
                             tag_weight_template=None):
    """
    テキスト説明からプロンプトを生成するRunPod API呼び出し関数
    
    Args:
        text_prompt: テキストプロンプト
        tag_candidate_generation_template: タグ候補生成テンプレート
        tag_normalization_template: タグ正規化テンプレート
        tag_filter_template: タグフィルターテンプレート
        tag_weight_template: タグ重み付けテンプレート
    
    Returns:
        生成されたプロンプト
    """
    if not text_prompt.strip():
        return "", "エラー: テキストプロンプトが空です。プロンプトを生成するためのテキストを入力してください。"
    
    api_key = os.environ.get("RUNPOD_API_KEY", "")
    endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID", "")
    
    if not endpoint_id:
        return "", "エラー: Endpoint IDが設定されていません"
    
    if not api_key:
        return "", "エラー: API KeyまたはEndpoint IDが設定されていません"
    
    url = f"https://api.runpod.ai/v2/{endpoint_id}/run"
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    payload = {
        "input": {
            "mode": "prompt_only",  # プロンプト生成のみモード
            "prompt": text_prompt,
            "tag_candidate_generation_template": tag_candidate_generation_template,
            "tag_normalization_template": tag_normalization_template,
            "tag_filter_template": tag_filter_template,
            "tag_weight_template": tag_weight_template
        }
    }
    
    # Noneの値を持つキーを削除
    payload["input"] = {k: v for k, v in payload["input"].items() if v is not None}
    
    print(f"プロンプト生成APIリクエスト送信: {json.dumps(payload, indent=2, ensure_ascii=False)}")
    
    try:
        # タイムアウトを2分に設定
        response = requests.post(url, headers=headers, json=payload, timeout=120)
        response.raise_for_status()
        
        result = response.json()
        status = result.get("status")
        
        if status == "IN_QUEUE" or status == "IN_PROGRESS":
            task_id = result.get("id")
            status_url = f"https://api.runpod.ai/v2/{endpoint_id}/status/{task_id}"
            
            # セッション作成
            session = requests.Session()
            session.headers.update(headers)
            
            # 最大2分間ポーリング
            for _ in range(24):
                time.sleep(5)
                try:
                    status_response = session.get(status_url, timeout=(3.05, 10))
                    
                    if status_response.status_code != 200:
                        print(f"Unexpected status code: {status_response.status_code}")
                        time.sleep(2)
                        continue
                        
                    status_data = status_response.json()
                    current_status = status_data.get("status")
                    print(f"Current status: {current_status}")
                    
                    if current_status == "COMPLETED":
                        result = status_data
                        break
                    elif current_status in ["FAILED", "CANCELLED"]:
                        raise Exception(f"Task failed with status: {current_status}")
                        
                except Exception as e:
                    print(f"Error during status check: {e}")
                    time.sleep(2)
                    continue
            
            # セッションのクリーンアップ
            session.close()
        
        if result.get("status") != "COMPLETED":
            raise Exception(f"API処理エラー: {result}")
        
        output = result.get("output", {})
        generated_tags = output.get("generated_tags", [])
        
        if not generated_tags:
            return "", "エラー: タグを生成できませんでした。別のテキストで試してください。"
        
        # タグをカンマ区切りの文字列に変換
        prompt_string = ", ".join(generated_tags)
        
        return prompt_string, None
        
    except Exception as e:
        return "", f"エラーが発生しました: {str(e)}"

# RunPod APIを呼び出す関数を変更
def generate_images_from_prompt(generated_prompt, negative_prompt="", 
                guidance_scale=7.0, num_inference_steps=30, 
                width=512, height=768, num_images=1, is_random_seeds=True, seeds=None, 
                bodyline_prompt=None, bodyline_negative_prompt=None,
                bodyline_steps=20, bodyline_guidance_scale=8.0,
                bodyline_input_resolution=256, bodyline_output_size=768,
                is_random_bodyline_seeds=True, bodyline_seeds=None):
    """
    生成されたプロンプトから画像を生成する関数
    
    Args:
        generated_prompt: 生成済みのプロンプト(カンマ区切りのタグ)
        negative_prompt: ネガティブプロンプト
        guidance_scale: ガイダンススケール
        num_inference_steps: 推論ステップ数
        width: 画像の幅
        height: 画像の高さ
        num_images: 生成する画像の枚数
        seeds: 乱数シード
        bodyline_prompt: ボディライン生成用プロンプト
        bodyline_negative_prompt: ボディライン生成用ネガティブプロンプト
        bodyline_steps: ボディライン生成の推論ステップ数
        bodyline_guidance_scale: ボディライン生成のガイダンススケール
        bodyline_input_resolution: ボディライン生成の入力解像度
        bodyline_output_size: ボディライン生成の出力サイズ
        is_random_bodyline_seeds: ボディライン生成用シードをランダムにするかどうか
        bodyline_seeds: ボディライン生成用シード
        api_key: RunPod API Key
        endpoint_id: RunPod Endpoint ID
    
    Returns:
        生成された画像のリスト
    """
    if not generated_prompt.strip():
        return *[None]*4, "エラー: 生成プロンプトが空です。先にプロンプトを生成してください。", seeds, bodyline_seeds
    
    api_key = os.environ.get("RUNPOD_API_KEY", "")
    endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID", "")
    
    if not endpoint_id:
        return *[None]*4, "エラー: Endpoint IDが設定されていません", seeds, bodyline_seeds
    
    if not api_key:
        return *[None]*4, "エラー: API KeyまたはEndpoint IDが設定されていません", seeds, bodyline_seeds
    
    if is_random_seeds:
        seeds = None
    else:
        seeds = seeds.split(",")
        seeds = [int(seed) for seed in seeds]
        if len(seeds) < num_images:
            # 最後のseedを繰り返し使用
            seeds = seeds + [seeds[-1]] * (num_images - len(seeds))

    if is_random_bodyline_seeds:
        bodyline_seeds = None
    else:
        bodyline_seeds = bodyline_seeds.split(",")
        bodyline_seeds = [int(seed) for seed in bodyline_seeds]
        if len(bodyline_seeds) < num_images:
            # 最後のseedを繰り返し使用
            bodyline_seeds = bodyline_seeds + [bodyline_seeds[-1]] * (num_images - len(bodyline_seeds))
    
    url = f"https://api.runpod.ai/v2/{endpoint_id}/run"
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    # タグリストの作成
    tags = [tag.strip() for tag in generated_prompt.split(",")]
    
    payload = {
        "input": {
            "mode": "image_only",  # 画像生成のみモード
            "tags": tags,
            "negative_prompt": negative_prompt,
            "steps": num_inference_steps,
            "guidance_scale": guidance_scale,
            "width": width,
            "height": height,
            "num_images": num_images,
            "seeds": seeds,
            "bodyline_prompt": bodyline_prompt,
            "bodyline_negative_prompt": bodyline_negative_prompt,
            "bodyline_steps": bodyline_steps,
            "bodyline_guidance_scale": bodyline_guidance_scale,
            "bodyline_input_resolution": bodyline_input_resolution,
            "bodyline_output_size": bodyline_output_size,
            "bodyline_seeds": bodyline_seeds
        }
    }
    
    # Noneの値を持つキーを削除
    payload["input"] = {k: v for k, v in payload["input"].items() if v is not None}
    
    print(f"画像生成APIリクエスト送信: {json.dumps(payload, indent=2, ensure_ascii=False)}")
    
    try:
        # タイムアウトを10分に設定
        response = requests.post(url, headers=headers, json=payload, timeout=600)
        response.raise_for_status()
        
        result = response.json()
        status = result.get("status")
        
        if status == "IN_QUEUE" or status == "IN_PROGRESS":
            task_id = result.get("id")
            status_url = f"https://api.runpod.ai/v2/{endpoint_id}/status/{task_id}"
            
            # セッション作成
            session = requests.Session()
            session.headers.update(headers)
            
            # 最大5分間ポーリング
            for _ in range(60):
                time.sleep(5)
                try:
                    # タイムアウトを細かく設定
                    status_response = session.get(
                        status_url,
                        timeout=(3.05, 10)  # (接続タイムアウト, 読み込みタイムアウト)
                    )
                    
                    if status_response.status_code != 200:
                        print(f"Unexpected status code: {status_response.status_code}")
                        time.sleep(2)
                        continue
                        
                    status_data = status_response.json()
                    current_status = status_data.get("status")
                    print(f"Current status: {current_status}")
                    
                    if current_status == "COMPLETED":
                        result = status_data
                        break
                    elif current_status in ["FAILED", "CANCELLED"]:
                        raise Exception(f"Task failed with status: {current_status}")
                        
                except requests.exceptions.Timeout:
                    print("Request timed out, retrying...")
                    time.sleep(2)
                    continue
                except requests.exceptions.RequestException as e:
                    print(f"Request error: {e}")
                    time.sleep(2)
                    continue
                except json.JSONDecodeError:
                    print("Invalid JSON response, retrying...")
                    time.sleep(2)
                    continue
            
            # セッションのクリーンアップ
            session.close()
        
        if result.get("status") != "COMPLETED":
            raise Exception(f"API処理エラー: {result}")
        
        output = result.get("output", {})
        images = [None]*num_images

        # 画像の保存処理
        if "images" in output:
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            output_dir = Path("output")
            
            # 生成画像の保存
            save_images(
                output["images"],
                output_dir,
                prefix="generated",
                subfolder=timestamp
            )
            
            # ボディラインの保存
            save_images(
                output["bodylines"],
                output_dir,
                prefix="bodyline",
                subfolder=timestamp
            )
            
            # seedsの保存
            seeds = output["parameters"]["image_parameters"]["seeds"]
            
            # メタデータの保存
            metadata = {
                "used_tags": output.get("used_tags", []),
                "parameters": output.get("parameters", {})
            }
            
            metadata_path = output_dir / timestamp / "metadata.json"
            with open(metadata_path, "w", encoding="utf-8") as f:
                json.dump(metadata, f, ensure_ascii=False, indent=2)
            
            print(f"Images and metadata saved to: {output_dir / timestamp}")
            
            # Gradio用の画像変換
            for i, (img_b64, bodyline_b64, diff_b64) in enumerate(zip(output["images"], output["bodylines"], output["remove_bg_diff"]["diff"])):
                img_data = base64.b64decode(img_b64.split(",")[1] if "," in img_b64 else img_b64)
                img = Image.open(io.BytesIO(img_data))
                bodyline_data = base64.b64decode(bodyline_b64.split(",")[1] if "," in bodyline_b64 else bodyline_b64)
                bodyline = Image.open(io.BytesIO(bodyline_data))
                diff_data = base64.b64decode(diff_b64.split(",")[1] if "," in diff_b64 else diff_b64)
                diff = Image.open(io.BytesIO(diff_data))
                images[i] = [bodyline, img, diff]
                
            print(output["remove_bg_diff"]["white_percentage"])
            print(output["remove_bg_diff"]["white_pixels_mask2"])
            print(output["remove_bg_diff"]["white_pixels_diff"])

        return (
            *images, 
            json.dumps(output["parameters"], indent=2, ensure_ascii=False), 
            ', '.join(str(seed) for seed in output["parameters"]["image_parameters"]["seeds"]),
            ', '.join(str(seed) for seed in output["parameters"]["bodyline_parameters"]["seeds"])
        )

    except Exception as e:
        return *[None]*4, f"エラーが発生しました: {str(e)}", seeds, bodyline_seeds

# テンプレートを再読み込みする関数を追加
def reload_templates():
    """テンプレートモジュールを再読み込みする"""
    import importlib
    import templates
    importlib.reload(templates)
    from templates import (
        TAG_CANDIDATE_GENERATION_TEMPLATE,
        TAG_NORMALIZATION_TEMPLATE,
        TAG_FILTER_TEMPLATE,
        TAG_WEIGHT_TEMPLATE
    )
    return (
        TAG_CANDIDATE_GENERATION_TEMPLATE,
        TAG_NORMALIZATION_TEMPLATE,
        TAG_FILTER_TEMPLATE,
        TAG_WEIGHT_TEMPLATE
    )

# --- イラスト生成用関数 ---
def generate_illustration(
    sd_api_url, input_image, illustration_prompt, illustration_neg_prompt,
    target_long_side, sd_model_name, cn_model_name, controlnet_module,
    sampler_name, cfg_scale, seed_str, cn_weight,
    cn_guidance_start, cn_guidance_end, cn_pixel_perfect, cn_control_mode,
    batch_size
):
    """ControlNetを使用してイラストを生成する関数"""
    # 必要なライブラリをインポート
    import numpy as np
    import cv2
    import base64
    import requests
    import json
    from PIL import Image, PngImagePlugin
    import io

    # 引数名を generated_prompt に合わせる
    generated_prompt_text = illustration_prompt

    if not sd_api_url or not sd_api_url.startswith("http"):
        return [], "エラー: Stable Diffusion APIのURLが無効です。"
    if input_image is None:
        # input_image はPIL ImageまたはNumpy Arrayを想定
        return [], "エラー: 入力画像がありません。"
    if not generated_prompt_text or not generated_prompt_text.strip():
        return [], "エラー: プロンプトが空です。"

    try:
        # 入力画像の形式を確認し、PIL Imageに統一
        print(f"入力画像の型: {type(input_image)}")
        if isinstance(input_image, np.ndarray):
            print("Numpy ArrayをPIL Imageに変換します。")
            # 配列の形状を確認してモードを決定 (高さ, 幅, チャンネル数)
            if input_image.shape[2] == 4:
                print("  入力配列は RGBA です。")
                img_pil = Image.fromarray(input_image, mode='RGBA')
            elif input_image.shape[2] == 3:
                print("  入力配列は RGB です。")
                img_pil = Image.fromarray(input_image, mode='RGB')
            else:
                return [], f"エラー: サポートされていないNumpy配列のチャンネル数です ({input_image.shape[2]})。"
        elif isinstance(input_image, Image.Image):
            # すでにPIL Imageの場合
            print("PIL Imageのままです。")
            img_pil = input_image
        else:
            # 画像が読み込まれていないか、予期しない形式の場合
            return [], f"エラー: サポートされていない入力画像形式です ({type(input_image)})。画像をアップロードしてください。"

        # PIL Image のモードを確認し、必要ならRGBに変換(白背景合成含む)
        print(f"入力画像のモード: {img_pil.mode}")
        if img_pil.mode == 'RGBA':
            print("入力画像はRGBAです。白背景に合成します。")
            # 白い背景を作成
            background = Image.new("RGB", img_pil.size, (255, 255, 255))
            background.paste(img_pil, mask=img_pil.split()[3])
            img_pil = background # RGB画像に置き換え
            img_pil = img_pil.convert('RGB') # ensure RGB mode
        elif img_pil.mode != 'RGB':
            # その他のモード(例: P, L)の場合もRGBに変換
            print(f"入力画像をRGBに変換します (元のモード: {img_pil.mode})。")
            img_pil = img_pil.convert('RGB')

        # OpenCV形式 (BGR) に変換
        img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)

        # --- 画像サイズに基づいて出力サイズを計算 ---
        h, w = img_cv.shape[:2]
        if h == 0 or w == 0:
            return [], "エラー: 入力画像のサイズが無効です。"
        print(f"入力画像のサイズ: {w}x{h}")

        if w >= h:
            target_w = target_long_side
            target_h = int(target_long_side * h / w)
        else:
            target_h = target_long_side
            target_w = int(target_long_side * w / h)

        # 64 の倍数に丸める
        target_w = max(64, round(target_w / 64) * 64)
        target_h = max(64, round(target_h / 64) * 64)
        print(f"計算された出力サイズ (64の倍数に丸め): {target_w}x{target_h}")

        # リサイズ (エラーチェック追加)
        try:
            img_resized = cv2.resize(img_cv, (target_w, target_h), interpolation=cv2.INTER_AREA)
        except cv2.error as resize_err:
             return [], f"エラー: 画像のリサイズに失敗しました ({resize_err})。入力画像を確認してください。"

        # 画像のエンコード
        retval, bytes_data = cv2.imencode('.png', img_resized)
        if not retval:
            raise RuntimeError("画像のPNGエンコードに失敗しました")
        encoded_image = base64.b64encode(bytes_data).decode('utf-8')
        print("ControlNet用画像のエンコード完了")

        # シードの処理 (-1 または数値)
        try:
            seed = int(seed_str) if seed_str else -1
        except ValueError:
            seed = -1 # 無効な文字列の場合は -1 (ランダム)

        # --- ペイロード作成 ---
        payload = {
            "prompt": generated_prompt_text, # 共通プロンプトを使用
            "negative_prompt": illustration_neg_prompt,
            "steps": 20, # 固定値。必要ならUIに追加
            "batch_size": batch_size, # ペイロードに batch_size を追加
            "width": target_w,
            "height": target_h,
            "sampler_name": sampler_name,
            "cfg_scale": cfg_scale,
            "seed": seed,
            "override_settings": {
                "sd_model_checkpoint": sd_model_name,
            },
            "alwayson_scripts": {
                "controlnet": {
                    "args": [
                        {
                            "enabled": True,
                            "image": encoded_image,
                            "module": controlnet_module,
                            "model": cn_model_name,
                            "weight": cn_weight,
                            "resize_mode": "Resize and Fill", # "Just Resize" も選択肢
                            "guidance_start": cn_guidance_start,
                            "guidance_end": cn_guidance_end,
                            "pixel_perfect": cn_pixel_perfect,
                            "control_mode": cn_control_mode,
                        }
                    ]
                }
            }
        }
        # ペイロードログ(画像データ省略)
        payload_log = {k: v for k, v in payload.items() if k != 'alwayson_scripts'}
        if 'alwayson_scripts' in payload and 'controlnet' in payload['alwayson_scripts']:
             payload_log['alwayson_scripts'] = {'controlnet': {'args': []}}
             for arg in payload['alwayson_scripts']['controlnet']['args']:
                 arg_log = {k: v for k, v in arg.items() if k != 'image'}
                 payload_log['alwayson_scripts']['controlnet']['args'].append(arg_log)
        print(f"ペイロード (画像データ省略): {json.dumps(payload_log, indent=2, ensure_ascii=False)}")

        # --- API 呼び出し ---
        if "runpod" in sd_api_url:
            input_rp = {
                "input": {
                    "path": "/sdapi/v1/txt2img",
                    "method": "POST",
                    "payload": payload
                }
            }
            headers = {
                "Authorization": f"Bearer {os.getenv('RUNPOD_API_KEY')}"
            }
            response = requests.post(url=sd_api_url, json=input_rp, headers=headers, timeout=600)
            response.raise_for_status()
            if not "output" in response.json():
                raise Exception(f"API呼び出しに失敗しました。data: {response.json()}")
            r = response.json()['output']['body']
        else:
            txt2img_endpoint = f'{sd_api_url.strip("/")}/sdapi/v1/txt2img'
            png_info_endpoint = f'{sd_api_url.strip("/")}/sdapi/v1/png-info'

            print(f"API呼び出し中 (POST): {txt2img_endpoint}")
            print(f"使用するSDモデル (override): {sd_model_name}")

            response = requests.post(url=txt2img_endpoint, json=payload, timeout=600) # タイムアウト10分
            response.raise_for_status()

            print("API呼び出し成功!")
            r = response.json()

        # --- 結果の処理 ---
        output_images = []
        if 'images' in r and r['images']:
            print(f"{len(r['images'])} 枚の画像を処理中...")
            for i, img_base64 in enumerate(r['images']):
                try:
                    if not img_base64:
                        print(f"警告: 画像 {i+1} のデータが空です。スキップします。")
                        continue

                    # Base64デコード
                    if "," in img_base64:
                        img_data_actual = base64.b64decode(img_base64.split(',', 1)[1])
                    else:
                        img_data_actual = base64.b64decode(img_base64)

                    image = Image.open(io.BytesIO(img_data_actual))

                    # # PNG Info の取得と埋め込み (ベストエフォート)
                    # pnginfo_data = None
                    # try:
                    #     png_payload = {"image": "data:image/png;base64," + (img_base64.split(',', 1)[1] if ',' in img_base64 else img_base64)}
                    #     response2 = requests.post(url=png_info_endpoint, json=png_payload, timeout=10)
                    #     response2.raise_for_status()
                    #     info_text = response2.json().get("info", "")
                    #     if info_text:
                    #          pnginfo_data = PngImagePlugin.PngInfo()
                    #          pnginfo_data.add_text("parameters", info_text)
                    #          print(f"画像 {i+1} の PNG Info 取得成功")
                    #     else:
                    #          print(f"画像 {i+1} の PNG Info は空でした。")
                    # except requests.exceptions.RequestException as png_req_err:
                    #      # PNG Info取得失敗は警告にとどめる
                    #      print(f"警告: 画像 {i+1} の PNG Info 取得リクエスト失敗 ({png_req_err})。")
                    # except Exception as png_err:
                    #     print(f"警告: 画像 {i+1} の PNG Info 処理中にエラー ({png_err})。")

                    # Gradio Gallery 用に PIL Image をリストに追加
                    output_images.append(image.copy())

                except base64.binascii.Error as b64_err:
                     print(f"エラー: 画像 {i+1} のBase64デコードに失敗しました ({b64_err})。レスポンスを確認してください。")
                     continue # 次の画像の処理へ
                except Exception as img_proc_err:
                    print(f"エラー: 画像 {i+1} の処理中にエラーが発生しました: {img_proc_err}")
                    continue # 次の画像の処理へ

            if output_images:
                 print(f"{len(output_images)}枚の画像をGradio Gallery用に準備完了。")
                 return output_images, f"{len(output_images)-1}枚の画像を生成しました。"
            else:
                 # 画像データはあったが、すべて処理に失敗した場合
                 return [], "エラー: 画像は受信しましたが、処理中に問題が発生しました。ログを確認してください。"

        else:
            error_info = r.get('info') or r.get('detail') or r.get('error') or r.get('message') or json.dumps(r)
            print(f"エラー: レスポンスに画像が含まれていませんでした。詳細: {error_info}")
            # APIからのエラーメッセージをユーザーに返す
            return [], f"エラー: サーバーから画像が返されませんでした。詳細: {error_info}"

    except requests.exceptions.Timeout:
        print("エラー: API呼び出しがタイムアウトしました。")
        return [], "エラー: API呼び出しがタイムアウトしました (10分)。サーバーの負荷が高いか、設定を確認してください。"
    except requests.exceptions.RequestException as e:
        error_msg = f"リクエストエラー: {e}"
        status_code_msg = ""
        details_msg = ""
        if e.response is not None:
            status_code_msg = f" ステータスコード: {e.response.status_code}"
            try:
                error_details = e.response.json()
                details_msg = f" 詳細: {json.dumps(error_details, indent=2, ensure_ascii=False)}"
            except json.JSONDecodeError:
                 details_msg = f" 応答(Text): {e.response.text[:500]}"
        full_error_msg = f"{error_msg}{status_code_msg}{details_msg}"
        print(full_error_msg)
        # Gradioには要約したメッセージを表示
        return [], f"APIリクエストエラーが発生しました。{status_code_msg}"
    except ImportError as imp_err:
        missing_lib = str(imp_err).split("'")[1]
        print(f"エラー:必要なライブラリ ({missing_lib}) が不足しています。pip install {missing_lib} を試してください。")
        return [], f"エラー:必要なライブラリ ({missing_lib}) が不足しています。"
    except Exception as e:
        import traceback
        print(f"予期せぬエラーが発生しました: {e}")
        traceback.print_exc()
        return [], f"予期せぬエラーが発生しました: {type(e).__name__} - {e}" 

# Gradio UIの構築
def create_ui():
    # パスワード認証用の状態
    authenticated = gr.State(False)
    # 正しいパスワード (環境変数から取得)
    correct_password = os.getenv("USER_PASSWORD", "testpassword") # デフォルトパスワードを設定

    with gr.Blocks() as app:
        # --- パスワード認証UI ---
        with gr.Row(visible=True) as auth_block:
            with gr.Column():
                gr.Markdown("# 認証が必要です")
                password_input = gr.Textbox(label="パスワード", type="password", placeholder="パスワードを入力してください")
                login_button = gr.Button("ログイン")
                auth_status = gr.Markdown(visible=False)

        # --- メインUI (初期状態では非表示) ---
        with gr.Column(visible=False) as main_ui_block:
            with gr.Row():
                with gr.Column(scale=1):
                    text_prompt = gr.Textbox(label="テキスト説明", placeholder="画像にしたい内容を自然な文章で説明してください(例: ピンクの髪の猫耳メイド)", lines=3)
                    generate_prompt_btn = gr.Button("プロンプト生成", variant="secondary")
                    generated_prompt = gr.Textbox(label="プロンプト", value="original, 1girl, solo, pink hair, cat ears, animal ears, smile, masterpiece, high score, great score, absurdres", placeholder="生成されたプロンプトがここに表示されます...", lines=3)

                with gr.Column(scale=3):
                    # タブ1 ポーズ生成
                    with gr.Tab("ポーズ生成"):            
                        with gr.Accordion("Advanced Settings", open=False):
                            with gr.Row():
                                negative_prompt = gr.Textbox(
                                    label="Negative Prompt",
                                    value="nsfw, sensitive, from behind, lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, missing arms, extra arms, missing legs, extra legs, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
                                    lines=2
                                )
                            
                            with gr.Row():
                                guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.1, label="Guidance Scale")
                                num_inference_steps = gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Steps")
                            
                            with gr.Row():
                                width = gr.Slider(minimum=512, maximum=2048, value=832, step=64, label="Width")
                                height = gr.Slider(minimum=512, maximum=2048, value=1216, step=64, label="Height")
                                num_images = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Number of Images")
                                
                            with gr.Row():
                                is_random_seeds = gr.Checkbox(label="Generate Random Seeds", value=True)
                                seeds = gr.Textbox(label="Seeds (Random if empty)", placeholder="Enter seed values. For multiple seeds, separate with commas.")
                            
                            with gr.Accordion("Bodyline Settings", open=False):
                                bodyline_prompt = gr.Textbox(
                                    label="Bodyline Prompt",
                                    value="anime pose, girl, (white background:1.5), (monochrome:1.5), full body, sketch, eyes, breasts, (slim legs, skinny legs:1.2)",
                                    lines=2,
                                    visible=False
                                )
                                bodyline_negative_prompt = gr.Textbox(
                                    label="Bodyline Negative Prompt",
                                    value="(wings:1.6), (clothes:1.4), (garment:1.4), (lighting:1.4), (gray:1.4), (missing limb:1.4), (extra line:1.4), (extra limb:1.4), (extra arm:1.4), (extra legs:1.4), (hair:1.4), (bangs:1.4), (fringe:1.4), (forelock:1.4), (front hair:1.4), (fill:1.4), (ink pool:1.6)",
                                    lines=2,
                                    visible=False
                                )
                                with gr.Row():
                                    bodyline_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=8.0, step=0.1, label="Bodyline Guidance Scale")
                                    bodyline_steps = gr.Slider(minimum=10, maximum=100, value=20, step=1, label="Bodyline Steps")
                                with gr.Row():
                                    bodyline_input_resolution = gr.Slider(minimum=128, maximum=1024, value=256, step=64, label="Bodyline Input Resolution")
                                    bodyline_output_size = gr.Slider(minimum=512, maximum=2048, value=768, step=64, label="Bodyline Output Size")
                                with gr.Row():
                                    is_random_bodyline_seeds = gr.Checkbox(label="Generate Random Bodyline Seeds", value=True)
                                    bodyline_seeds = gr.Textbox(label="Bodyline Seeds (Random if empty)", placeholder="Enter seed values. For multiple seeds, separate with commas.")
                            
                            with gr.Accordion("Tag Template Settings", open=False, visible=False):
                                with gr.Row():
                                    reload_templates_btn = gr.Button("テンプレートを再読み込み", variant="secondary")

                                tag_candidate_generation_template = gr.Textbox(
                                    label="Tag Candidate Generation Template", 
                                    value=TAG_CANDIDATE_GENERATION_TEMPLATE,
                                    lines=6
                                )
                                tag_normalization_template = gr.Textbox(
                                    label="Tag Normalization Template", 
                                    value=TAG_NORMALIZATION_TEMPLATE,
                                    lines=6
                                )
                                tag_filter_template = gr.Textbox(
                                    label="Tag Filter Template", 
                                    value=TAG_FILTER_TEMPLATE,
                                    lines=6
                                )
                                tag_weight_template = gr.Textbox(
                                    label="Tag Weighting Template", 
                                    value=TAG_WEIGHT_TEMPLATE,
                                    lines=6
                                )

                        generate_image_btn = gr.Button("画像生成", variant="primary")

                        with gr.Row():
                            output_gallery1 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain")
                            output_gallery2 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain")
                            output_gallery3 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain")
                            output_gallery4 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain")
                            
                        with gr.Accordion("Status", open=False):
                            status_text = gr.Textbox(label="Status", interactive=False)
                    
                    # タブ2 画像生成
                    with gr.Tab("イラスト生成"):
                        with gr.Row():
                            with gr.Column(scale=1):
                                input_image = gr.Image(label="画像", interactive=True, height=300, image_mode="RGBA", type="pil")
                                examples = gr.Examples(
                                    examples=[
                                        "examples/1.jpg",
                                        "examples/2.jpg",
                                        "examples/3.jpg",
                                        "examples/4.jpg",
                                        "examples/5.jpg"
                                    ],
                                    inputs=input_image
                                )
                                with gr.Row():
                                    cn_weight = gr.Slider(minimum=0.0, maximum=2.0, value=1.7, step=0.05, label="ControlNet Weight")
                                    cn_guidance_start = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Guidance Start (T)")
                                    cn_guidance_end = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Guidance End (T)")
                                with gr.Row():
                                    cn_pixel_perfect = gr.Checkbox(label="Pixel Perfect", value=True)
                                    cn_control_mode = gr.Dropdown(label="Control Mode", choices=controlnet_modes, value=controlnet_modes[2] if controlnet_modes else None)

                                generate_illustration_btn = gr.Button("イラスト生成", variant="primary")

                                with gr.Accordion("Advanced Settings", open=False):
                                    sd_api_url = gr.Textbox(label="Stable Diffusion API URL", value=os.environ.get("SD_API_URL", "http://127.0.0.1:7860"), visible=False)
                                    # Batch size スライダーを追加
                                    batch_size_slider = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Batch Size")
                                    illustration_neg_prompt = gr.Textbox(
                                        label="ネガティブプロンプト",
                                        value="nsfw, sensitive, from behind, lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, missing arms, extra arms, missing legs, extra legs, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
                                        lines=2
                                    )
                                    with gr.Row():
                                        cn_model_dropdown = gr.Dropdown(label="ControlNet Model", choices=cn_models, value=cn_models[0] if len(cn_models) > 3 else (cn_models[0] if cn_models else None))
                                        controlnet_module_dropdown = gr.Dropdown(label="ControlNet Module (Preprocessor)", choices=controlnet_modules, value="invert" if "invert" in controlnet_modules else (controlnet_modules[0] if controlnet_modules else None))
                                    target_long_side = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="ターゲット長辺 (Target Long Side)")
                                    sd_model_dropdown = gr.Dropdown(label="SD Model", choices=sd_models, value=sd_models[1] if len(sd_models) > 1 else (sd_models[0] if sd_models else None))
                                    sampler_name = gr.Textbox(label="Sampler Name", value="Euler a")
                                    with gr.Row():
                                        cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, value=7.0, step=0.5, label="CFG Scale")
                                        seed_input = gr.Textbox(label="Seed", value="-1", placeholder="-1 for random")

                                illustration_status_text = gr.Textbox(label="Status", interactive=False)

                            with gr.Column(scale=2):
                                # 出力ギャラリー (input_imageは上のRowで共有)
                                illustration_output_gallery = gr.Gallery(label="生成されたイラスト", columns=2, height=1000, object_fit="contain")

            output_gallery = [output_gallery1, output_gallery2, output_gallery3, output_gallery4]
            
            # プロンプト生成ボタンのクリックイベント
            generate_prompt_btn.click(
                fn=generate_prompt_from_text,
                inputs=[
                    text_prompt, tag_candidate_generation_template,
                    tag_normalization_template, tag_filter_template,
                    tag_weight_template
                ],
                outputs=[
                    generated_prompt,
                    status_text
                ]
            )
            
            # 画像生成ボタンのクリックイベント
            generate_image_btn.click(
                fn=generate_images_from_prompt,
                inputs=[
                    generated_prompt, negative_prompt,
                    guidance_scale, num_inference_steps,
                    width, height, num_images, is_random_seeds, seeds,
                    bodyline_prompt, bodyline_negative_prompt,
                    bodyline_steps, bodyline_guidance_scale,
                    bodyline_input_resolution, bodyline_output_size,
                    is_random_bodyline_seeds, bodyline_seeds,
                ],
                outputs=[
                    output_gallery[0],
                    output_gallery[1],
                    output_gallery[2],
                    output_gallery[3],
                    status_text,
                    seeds,
                    bodyline_seeds
                ]
            )

            # イラスト生成ボタンのクリックイベント
            generate_illustration_btn.click(
                fn=generate_illustration,
                inputs=[
                    sd_api_url,
                    input_image, # 共有のinput_imageを使用
                    generated_prompt, # illustration_prompt の代わりに generated_prompt を使用
                    illustration_neg_prompt,
                    target_long_side,
                    sd_model_dropdown,
                    cn_model_dropdown,
                    controlnet_module_dropdown,
                    sampler_name,
                    cfg_scale,
                    seed_input, # Textboxコンポーネントを渡す
                    cn_weight,
                    cn_guidance_start,
                    cn_guidance_end,
                    cn_pixel_perfect,
                    cn_control_mode,
                    batch_size_slider
                ],
                outputs=[
                    illustration_output_gallery,
                    illustration_status_text
                ]
            )

            # テンプレート再読み込みボタンのクリックイベント
            reload_templates_btn.click(
                fn=reload_templates,
                inputs=[],
                outputs=[
                    tag_candidate_generation_template,
                    tag_normalization_template,
                    tag_filter_template,
                    tag_weight_template
                ]
            )
        
        # --- 認証ロジック ---
        def authenticate_user(password):
            if password == correct_password:
                return {
                    auth_block: gr.update(visible=False),
                    main_ui_block: gr.update(visible=True),
                    auth_status: gr.update(value="認証成功!メインUIを表示します。", visible=True),
                    authenticated: True
                }
            else:
                return {
                    auth_status: gr.update(value="パスワードが間違っています。", visible=True),
                    authenticated: False
                }

        login_button.click(
            authenticate_user,
            inputs=[password_input],
            outputs=[auth_block, main_ui_block, auth_status, authenticated]
        )
    
    return app

app = create_ui()
user_name = os.getenv("USER_NAME")
# user_password = os.getenv("USER_PASSWORD") # create_ui内で使用するため、ここでは不要
app.launch() # authパラメータを削除