Spaces:
Runtime error
Runtime error
File size: 46,879 Bytes
486e6d4 73a8cdd 486e6d4 73a8cdd 486e6d4 73a8cdd 486e6d4 90c8956 73a8cdd bcd8020 73a8cdd 486e6d4 73a8cdd 10e2b36 73a8cdd f0a2154 73a8cdd f0a2154 73a8cdd a6f1d21 73a8cdd f0a2154 73a8cdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 |
import os
import json
import base64
import requests
import gradio as gr
from PIL import Image, PngImagePlugin
import io
import cv2
from datetime import datetime
import time
import dotenv
from pathlib import Path
dotenv.load_dotenv()
TAG_CANDIDATE_GENERATION_TEMPLATE = os.getenv("TAG_CANDIDATE_GENERATION_TEMPLATE")
TAG_NORMALIZATION_TEMPLATE = os.getenv("TAG_NORMALIZATION_TEMPLATE")
TAG_FILTER_TEMPLATE = os.getenv("TAG_FILTER_TEMPLATE")
TAG_WEIGHT_TEMPLATE = os.getenv("TAG_WEIGHT_TEMPLATE")
print(TAG_CANDIDATE_GENERATION_TEMPLATE[0:10])
print(TAG_NORMALIZATION_TEMPLATE[0:10])
print(TAG_FILTER_TEMPLATE[0:10])
print(TAG_WEIGHT_TEMPLATE[0:10])
# --- Model Lists ---
sd_models = [
# "animagine-xl-3.1.safetensors [e3c47aedb0]",
"animagine-xl-4.0.safetensors",
# "animagine-xl.safetensors",
# "LoRAMergeModel_animepose_outline_sotai_2025_04.fp16.safetensors [22dc3c53d2]",
]
cn_models =[
"Bodyline2Image_v2-000350 [15294a3f]",
# "Bodyline2Image_v2-000400 [f87a785e]",
# "CN-anytest_v3-50000_fp16 [0963df47]",
# "CN-anytest_v4-marged [4bb64990]",
]
controlnet_modules = ["invert", "none"] # Example modules
controlnet_modes = ["Balanced", "My prompt is more important", "ControlNet is more important"] # Example modes
def save_images(images: list, output_dir: Path, prefix: str = "generated", subfolder: str = None):
"""Base64エンコードされた画像リストを保存する"""
# タイムスタンプ付きのディレクトリ名を作成
output_dir = output_dir / subfolder
output_dir.mkdir(parents=True, exist_ok=True)
saved_paths = []
for i, img_base64 in enumerate(images):
img_data = base64.b64decode(img_base64)
output_path = output_dir / f"{prefix}_{i}.png"
with open(output_path, "wb") as f:
f.write(img_data)
saved_paths.append(output_path)
return saved_paths
# テキストからプロンプトを生成するAPI呼び出し
def generate_prompt_from_text(text_prompt, tag_candidate_generation_template=None,
tag_normalization_template=None, tag_filter_template=None,
tag_weight_template=None):
"""
テキスト説明からプロンプトを生成するRunPod API呼び出し関数
Args:
text_prompt: テキストプロンプト
tag_candidate_generation_template: タグ候補生成テンプレート
tag_normalization_template: タグ正規化テンプレート
tag_filter_template: タグフィルターテンプレート
tag_weight_template: タグ重み付けテンプレート
Returns:
生成されたプロンプト
"""
if not text_prompt.strip():
return "", "エラー: テキストプロンプトが空です。プロンプトを生成するためのテキストを入力してください。"
api_key = os.environ.get("RUNPOD_API_KEY", "")
endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID", "")
if not endpoint_id:
return "", "エラー: Endpoint IDが設定されていません"
if not api_key:
return "", "エラー: API KeyまたはEndpoint IDが設定されていません"
url = f"https://api.runpod.ai/v2/{endpoint_id}/run"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"input": {
"mode": "prompt_only", # プロンプト生成のみモード
"prompt": text_prompt,
"tag_candidate_generation_template": tag_candidate_generation_template,
"tag_normalization_template": tag_normalization_template,
"tag_filter_template": tag_filter_template,
"tag_weight_template": tag_weight_template
}
}
# Noneの値を持つキーを削除
payload["input"] = {k: v for k, v in payload["input"].items() if v is not None}
print(f"プロンプト生成APIリクエスト送信: {json.dumps(payload, indent=2, ensure_ascii=False)}")
try:
# タイムアウトを2分に設定
response = requests.post(url, headers=headers, json=payload, timeout=120)
response.raise_for_status()
result = response.json()
status = result.get("status")
if status == "IN_QUEUE" or status == "IN_PROGRESS":
task_id = result.get("id")
status_url = f"https://api.runpod.ai/v2/{endpoint_id}/status/{task_id}"
# セッション作成
session = requests.Session()
session.headers.update(headers)
# 最大2分間ポーリング
for _ in range(24):
time.sleep(5)
try:
status_response = session.get(status_url, timeout=(3.05, 10))
if status_response.status_code != 200:
print(f"Unexpected status code: {status_response.status_code}")
time.sleep(2)
continue
status_data = status_response.json()
current_status = status_data.get("status")
print(f"Current status: {current_status}")
if current_status == "COMPLETED":
result = status_data
break
elif current_status in ["FAILED", "CANCELLED"]:
raise Exception(f"Task failed with status: {current_status}")
except Exception as e:
print(f"Error during status check: {e}")
time.sleep(2)
continue
# セッションのクリーンアップ
session.close()
if result.get("status") != "COMPLETED":
raise Exception(f"API処理エラー: {result}")
output = result.get("output", {})
generated_tags = output.get("generated_tags", [])
if not generated_tags:
return "", "エラー: タグを生成できませんでした。別のテキストで試してください。"
# タグをカンマ区切りの文字列に変換
prompt_string = ", ".join(generated_tags)
return prompt_string, None
except Exception as e:
return "", f"エラーが発生しました: {str(e)}"
# RunPod APIを呼び出す関数を変更
def generate_images_from_prompt(generated_prompt, negative_prompt="",
guidance_scale=7.0, num_inference_steps=30,
width=512, height=768, num_images=1, is_random_seeds=True, seeds=None,
bodyline_prompt=None, bodyline_negative_prompt=None,
bodyline_steps=20, bodyline_guidance_scale=8.0,
bodyline_input_resolution=256, bodyline_output_size=768,
is_random_bodyline_seeds=True, bodyline_seeds=None):
"""
生成されたプロンプトから画像を生成する関数
Args:
generated_prompt: 生成済みのプロンプト(カンマ区切りのタグ)
negative_prompt: ネガティブプロンプト
guidance_scale: ガイダンススケール
num_inference_steps: 推論ステップ数
width: 画像の幅
height: 画像の高さ
num_images: 生成する画像の枚数
seeds: 乱数シード
bodyline_prompt: ボディライン生成用プロンプト
bodyline_negative_prompt: ボディライン生成用ネガティブプロンプト
bodyline_steps: ボディライン生成の推論ステップ数
bodyline_guidance_scale: ボディライン生成のガイダンススケール
bodyline_input_resolution: ボディライン生成の入力解像度
bodyline_output_size: ボディライン生成の出力サイズ
is_random_bodyline_seeds: ボディライン生成用シードをランダムにするかどうか
bodyline_seeds: ボディライン生成用シード
api_key: RunPod API Key
endpoint_id: RunPod Endpoint ID
Returns:
生成された画像のリスト
"""
if not generated_prompt.strip():
return *[None]*4, "エラー: 生成プロンプトが空です。先にプロンプトを生成してください。", seeds, bodyline_seeds
api_key = os.environ.get("RUNPOD_API_KEY", "")
endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID", "")
if not endpoint_id:
return *[None]*4, "エラー: Endpoint IDが設定されていません", seeds, bodyline_seeds
if not api_key:
return *[None]*4, "エラー: API KeyまたはEndpoint IDが設定されていません", seeds, bodyline_seeds
if is_random_seeds:
seeds = None
else:
seeds = seeds.split(",")
seeds = [int(seed) for seed in seeds]
if len(seeds) < num_images:
# 最後のseedを繰り返し使用
seeds = seeds + [seeds[-1]] * (num_images - len(seeds))
if is_random_bodyline_seeds:
bodyline_seeds = None
else:
bodyline_seeds = bodyline_seeds.split(",")
bodyline_seeds = [int(seed) for seed in bodyline_seeds]
if len(bodyline_seeds) < num_images:
# 最後のseedを繰り返し使用
bodyline_seeds = bodyline_seeds + [bodyline_seeds[-1]] * (num_images - len(bodyline_seeds))
url = f"https://api.runpod.ai/v2/{endpoint_id}/run"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# タグリストの作成
tags = [tag.strip() for tag in generated_prompt.split(",")]
payload = {
"input": {
"mode": "image_only", # 画像生成のみモード
"tags": tags,
"negative_prompt": negative_prompt,
"steps": num_inference_steps,
"guidance_scale": guidance_scale,
"width": width,
"height": height,
"num_images": num_images,
"seeds": seeds,
"bodyline_prompt": bodyline_prompt,
"bodyline_negative_prompt": bodyline_negative_prompt,
"bodyline_steps": bodyline_steps,
"bodyline_guidance_scale": bodyline_guidance_scale,
"bodyline_input_resolution": bodyline_input_resolution,
"bodyline_output_size": bodyline_output_size,
"bodyline_seeds": bodyline_seeds
}
}
# Noneの値を持つキーを削除
payload["input"] = {k: v for k, v in payload["input"].items() if v is not None}
print(f"画像生成APIリクエスト送信: {json.dumps(payload, indent=2, ensure_ascii=False)}")
try:
# タイムアウトを10分に設定
response = requests.post(url, headers=headers, json=payload, timeout=600)
response.raise_for_status()
result = response.json()
status = result.get("status")
if status == "IN_QUEUE" or status == "IN_PROGRESS":
task_id = result.get("id")
status_url = f"https://api.runpod.ai/v2/{endpoint_id}/status/{task_id}"
# セッション作成
session = requests.Session()
session.headers.update(headers)
# 最大5分間ポーリング
for _ in range(60):
time.sleep(5)
try:
# タイムアウトを細かく設定
status_response = session.get(
status_url,
timeout=(3.05, 10) # (接続タイムアウト, 読み込みタイムアウト)
)
if status_response.status_code != 200:
print(f"Unexpected status code: {status_response.status_code}")
time.sleep(2)
continue
status_data = status_response.json()
current_status = status_data.get("status")
print(f"Current status: {current_status}")
if current_status == "COMPLETED":
result = status_data
break
elif current_status in ["FAILED", "CANCELLED"]:
raise Exception(f"Task failed with status: {current_status}")
except requests.exceptions.Timeout:
print("Request timed out, retrying...")
time.sleep(2)
continue
except requests.exceptions.RequestException as e:
print(f"Request error: {e}")
time.sleep(2)
continue
except json.JSONDecodeError:
print("Invalid JSON response, retrying...")
time.sleep(2)
continue
# セッションのクリーンアップ
session.close()
if result.get("status") != "COMPLETED":
raise Exception(f"API処理エラー: {result}")
output = result.get("output", {})
images = [None]*num_images
# 画像の保存処理
if "images" in output:
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_dir = Path("output")
# 生成画像の保存
save_images(
output["images"],
output_dir,
prefix="generated",
subfolder=timestamp
)
# ボディラインの保存
save_images(
output["bodylines"],
output_dir,
prefix="bodyline",
subfolder=timestamp
)
# seedsの保存
seeds = output["parameters"]["image_parameters"]["seeds"]
# メタデータの保存
metadata = {
"used_tags": output.get("used_tags", []),
"parameters": output.get("parameters", {})
}
metadata_path = output_dir / timestamp / "metadata.json"
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
print(f"Images and metadata saved to: {output_dir / timestamp}")
# Gradio用の画像変換
for i, (img_b64, bodyline_b64, diff_b64) in enumerate(zip(output["images"], output["bodylines"], output["remove_bg_diff"]["diff"])):
img_data = base64.b64decode(img_b64.split(",")[1] if "," in img_b64 else img_b64)
img = Image.open(io.BytesIO(img_data))
bodyline_data = base64.b64decode(bodyline_b64.split(",")[1] if "," in bodyline_b64 else bodyline_b64)
bodyline = Image.open(io.BytesIO(bodyline_data))
diff_data = base64.b64decode(diff_b64.split(",")[1] if "," in diff_b64 else diff_b64)
diff = Image.open(io.BytesIO(diff_data))
images[i] = [bodyline, img, diff]
print(output["remove_bg_diff"]["white_percentage"])
print(output["remove_bg_diff"]["white_pixels_mask2"])
print(output["remove_bg_diff"]["white_pixels_diff"])
return (
*images,
json.dumps(output["parameters"], indent=2, ensure_ascii=False),
', '.join(str(seed) for seed in output["parameters"]["image_parameters"]["seeds"]),
', '.join(str(seed) for seed in output["parameters"]["bodyline_parameters"]["seeds"])
)
except Exception as e:
return *[None]*4, f"エラーが発生しました: {str(e)}", seeds, bodyline_seeds
# テンプレートを再読み込みする関数を追加
def reload_templates():
"""テンプレートモジュールを再読み込みする"""
import importlib
import templates
importlib.reload(templates)
from templates import (
TAG_CANDIDATE_GENERATION_TEMPLATE,
TAG_NORMALIZATION_TEMPLATE,
TAG_FILTER_TEMPLATE,
TAG_WEIGHT_TEMPLATE
)
return (
TAG_CANDIDATE_GENERATION_TEMPLATE,
TAG_NORMALIZATION_TEMPLATE,
TAG_FILTER_TEMPLATE,
TAG_WEIGHT_TEMPLATE
)
# --- イラスト生成用関数 ---
def generate_illustration(
sd_api_url, input_image, illustration_prompt, illustration_neg_prompt,
target_long_side, sd_model_name, cn_model_name, controlnet_module,
sampler_name, cfg_scale, seed_str, cn_weight,
cn_guidance_start, cn_guidance_end, cn_pixel_perfect, cn_control_mode,
batch_size
):
"""ControlNetを使用してイラストを生成する関数"""
# 必要なライブラリをインポート
import numpy as np
import cv2
import base64
import requests
import json
from PIL import Image, PngImagePlugin
import io
# 引数名を generated_prompt に合わせる
generated_prompt_text = illustration_prompt
if not sd_api_url or not sd_api_url.startswith("http"):
return [], "エラー: Stable Diffusion APIのURLが無効です。"
if input_image is None:
# input_image はPIL ImageまたはNumpy Arrayを想定
return [], "エラー: 入力画像がありません。"
if not generated_prompt_text or not generated_prompt_text.strip():
return [], "エラー: プロンプトが空です。"
try:
# 入力画像の形式を確認し、PIL Imageに統一
print(f"入力画像の型: {type(input_image)}")
if isinstance(input_image, np.ndarray):
print("Numpy ArrayをPIL Imageに変換します。")
# 配列の形状を確認してモードを決定 (高さ, 幅, チャンネル数)
if input_image.shape[2] == 4:
print(" 入力配列は RGBA です。")
img_pil = Image.fromarray(input_image, mode='RGBA')
elif input_image.shape[2] == 3:
print(" 入力配列は RGB です。")
img_pil = Image.fromarray(input_image, mode='RGB')
else:
return [], f"エラー: サポートされていないNumpy配列のチャンネル数です ({input_image.shape[2]})。"
elif isinstance(input_image, Image.Image):
# すでにPIL Imageの場合
print("PIL Imageのままです。")
img_pil = input_image
else:
# 画像が読み込まれていないか、予期しない形式の場合
return [], f"エラー: サポートされていない入力画像形式です ({type(input_image)})。画像をアップロードしてください。"
# PIL Image のモードを確認し、必要ならRGBに変換(白背景合成含む)
print(f"入力画像のモード: {img_pil.mode}")
if img_pil.mode == 'RGBA':
print("入力画像はRGBAです。白背景に合成します。")
# 白い背景を作成
background = Image.new("RGB", img_pil.size, (255, 255, 255))
background.paste(img_pil, mask=img_pil.split()[3])
img_pil = background # RGB画像に置き換え
img_pil = img_pil.convert('RGB') # ensure RGB mode
elif img_pil.mode != 'RGB':
# その他のモード(例: P, L)の場合もRGBに変換
print(f"入力画像をRGBに変換します (元のモード: {img_pil.mode})。")
img_pil = img_pil.convert('RGB')
# OpenCV形式 (BGR) に変換
img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
# --- 画像サイズに基づいて出力サイズを計算 ---
h, w = img_cv.shape[:2]
if h == 0 or w == 0:
return [], "エラー: 入力画像のサイズが無効です。"
print(f"入力画像のサイズ: {w}x{h}")
if w >= h:
target_w = target_long_side
target_h = int(target_long_side * h / w)
else:
target_h = target_long_side
target_w = int(target_long_side * w / h)
# 64 の倍数に丸める
target_w = max(64, round(target_w / 64) * 64)
target_h = max(64, round(target_h / 64) * 64)
print(f"計算された出力サイズ (64の倍数に丸め): {target_w}x{target_h}")
# リサイズ (エラーチェック追加)
try:
img_resized = cv2.resize(img_cv, (target_w, target_h), interpolation=cv2.INTER_AREA)
except cv2.error as resize_err:
return [], f"エラー: 画像のリサイズに失敗しました ({resize_err})。入力画像を確認してください。"
# 画像のエンコード
retval, bytes_data = cv2.imencode('.png', img_resized)
if not retval:
raise RuntimeError("画像のPNGエンコードに失敗しました")
encoded_image = base64.b64encode(bytes_data).decode('utf-8')
print("ControlNet用画像のエンコード完了")
# シードの処理 (-1 または数値)
try:
seed = int(seed_str) if seed_str else -1
except ValueError:
seed = -1 # 無効な文字列の場合は -1 (ランダム)
# --- ペイロード作成 ---
payload = {
"prompt": generated_prompt_text, # 共通プロンプトを使用
"negative_prompt": illustration_neg_prompt,
"steps": 20, # 固定値。必要ならUIに追加
"batch_size": batch_size, # ペイロードに batch_size を追加
"width": target_w,
"height": target_h,
"sampler_name": sampler_name,
"cfg_scale": cfg_scale,
"seed": seed,
"override_settings": {
"sd_model_checkpoint": sd_model_name,
},
"alwayson_scripts": {
"controlnet": {
"args": [
{
"enabled": True,
"image": encoded_image,
"module": controlnet_module,
"model": cn_model_name,
"weight": cn_weight,
"resize_mode": "Resize and Fill", # "Just Resize" も選択肢
"guidance_start": cn_guidance_start,
"guidance_end": cn_guidance_end,
"pixel_perfect": cn_pixel_perfect,
"control_mode": cn_control_mode,
}
]
}
}
}
# ペイロードログ(画像データ省略)
payload_log = {k: v for k, v in payload.items() if k != 'alwayson_scripts'}
if 'alwayson_scripts' in payload and 'controlnet' in payload['alwayson_scripts']:
payload_log['alwayson_scripts'] = {'controlnet': {'args': []}}
for arg in payload['alwayson_scripts']['controlnet']['args']:
arg_log = {k: v for k, v in arg.items() if k != 'image'}
payload_log['alwayson_scripts']['controlnet']['args'].append(arg_log)
print(f"ペイロード (画像データ省略): {json.dumps(payload_log, indent=2, ensure_ascii=False)}")
# --- API 呼び出し ---
if "runpod" in sd_api_url:
input_rp = {
"input": {
"path": "/sdapi/v1/txt2img",
"method": "POST",
"payload": payload
}
}
headers = {
"Authorization": f"Bearer {os.getenv('RUNPOD_API_KEY')}"
}
response = requests.post(url=sd_api_url, json=input_rp, headers=headers, timeout=600)
response.raise_for_status()
if not "output" in response.json():
raise Exception(f"API呼び出しに失敗しました。data: {response.json()}")
r = response.json()['output']['body']
else:
txt2img_endpoint = f'{sd_api_url.strip("/")}/sdapi/v1/txt2img'
png_info_endpoint = f'{sd_api_url.strip("/")}/sdapi/v1/png-info'
print(f"API呼び出し中 (POST): {txt2img_endpoint}")
print(f"使用するSDモデル (override): {sd_model_name}")
response = requests.post(url=txt2img_endpoint, json=payload, timeout=600) # タイムアウト10分
response.raise_for_status()
print("API呼び出し成功!")
r = response.json()
# --- 結果の処理 ---
output_images = []
if 'images' in r and r['images']:
print(f"{len(r['images'])} 枚の画像を処理中...")
for i, img_base64 in enumerate(r['images']):
try:
if not img_base64:
print(f"警告: 画像 {i+1} のデータが空です。スキップします。")
continue
# Base64デコード
if "," in img_base64:
img_data_actual = base64.b64decode(img_base64.split(',', 1)[1])
else:
img_data_actual = base64.b64decode(img_base64)
image = Image.open(io.BytesIO(img_data_actual))
# # PNG Info の取得と埋め込み (ベストエフォート)
# pnginfo_data = None
# try:
# png_payload = {"image": "data:image/png;base64," + (img_base64.split(',', 1)[1] if ',' in img_base64 else img_base64)}
# response2 = requests.post(url=png_info_endpoint, json=png_payload, timeout=10)
# response2.raise_for_status()
# info_text = response2.json().get("info", "")
# if info_text:
# pnginfo_data = PngImagePlugin.PngInfo()
# pnginfo_data.add_text("parameters", info_text)
# print(f"画像 {i+1} の PNG Info 取得成功")
# else:
# print(f"画像 {i+1} の PNG Info は空でした。")
# except requests.exceptions.RequestException as png_req_err:
# # PNG Info取得失敗は警告にとどめる
# print(f"警告: 画像 {i+1} の PNG Info 取得リクエスト失敗 ({png_req_err})。")
# except Exception as png_err:
# print(f"警告: 画像 {i+1} の PNG Info 処理中にエラー ({png_err})。")
# Gradio Gallery 用に PIL Image をリストに追加
output_images.append(image.copy())
except base64.binascii.Error as b64_err:
print(f"エラー: 画像 {i+1} のBase64デコードに失敗しました ({b64_err})。レスポンスを確認してください。")
continue # 次の画像の処理へ
except Exception as img_proc_err:
print(f"エラー: 画像 {i+1} の処理中にエラーが発生しました: {img_proc_err}")
continue # 次の画像の処理へ
if output_images:
print(f"{len(output_images)}枚の画像をGradio Gallery用に準備完了。")
return output_images, f"{len(output_images)-1}枚の画像を生成しました。"
else:
# 画像データはあったが、すべて処理に失敗した場合
return [], "エラー: 画像は受信しましたが、処理中に問題が発生しました。ログを確認してください。"
else:
error_info = r.get('info') or r.get('detail') or r.get('error') or r.get('message') or json.dumps(r)
print(f"エラー: レスポンスに画像が含まれていませんでした。詳細: {error_info}")
# APIからのエラーメッセージをユーザーに返す
return [], f"エラー: サーバーから画像が返されませんでした。詳細: {error_info}"
except requests.exceptions.Timeout:
print("エラー: API呼び出しがタイムアウトしました。")
return [], "エラー: API呼び出しがタイムアウトしました (10分)。サーバーの負荷が高いか、設定を確認してください。"
except requests.exceptions.RequestException as e:
error_msg = f"リクエストエラー: {e}"
status_code_msg = ""
details_msg = ""
if e.response is not None:
status_code_msg = f" ステータスコード: {e.response.status_code}"
try:
error_details = e.response.json()
details_msg = f" 詳細: {json.dumps(error_details, indent=2, ensure_ascii=False)}"
except json.JSONDecodeError:
details_msg = f" 応答(Text): {e.response.text[:500]}"
full_error_msg = f"{error_msg}{status_code_msg}{details_msg}"
print(full_error_msg)
# Gradioには要約したメッセージを表示
return [], f"APIリクエストエラーが発生しました。{status_code_msg}"
except ImportError as imp_err:
missing_lib = str(imp_err).split("'")[1]
print(f"エラー:必要なライブラリ ({missing_lib}) が不足しています。pip install {missing_lib} を試してください。")
return [], f"エラー:必要なライブラリ ({missing_lib}) が不足しています。"
except Exception as e:
import traceback
print(f"予期せぬエラーが発生しました: {e}")
traceback.print_exc()
return [], f"予期せぬエラーが発生しました: {type(e).__name__} - {e}"
# Gradio UIの構築
def create_ui():
# パスワード認証用の状態
authenticated = gr.State(False)
# 正しいパスワード (環境変数から取得)
correct_password = os.getenv("USER_PASSWORD", "testpassword") # デフォルトパスワードを設定
with gr.Blocks() as app:
# --- パスワード認証UI ---
with gr.Row(visible=True) as auth_block:
with gr.Column():
gr.Markdown("# 認証が必要です")
password_input = gr.Textbox(label="パスワード", type="password", placeholder="パスワードを入力してください")
login_button = gr.Button("ログイン")
auth_status = gr.Markdown(visible=False)
# --- メインUI (初期状態では非表示) ---
with gr.Column(visible=False) as main_ui_block:
with gr.Row():
with gr.Column(scale=1):
text_prompt = gr.Textbox(label="テキスト説明", placeholder="画像にしたい内容を自然な文章で説明してください(例: ピンクの髪の猫耳メイド)", lines=3)
generate_prompt_btn = gr.Button("プロンプト生成", variant="secondary")
generated_prompt = gr.Textbox(label="プロンプト", value="original, 1girl, solo, pink hair, cat ears, animal ears, smile, masterpiece, high score, great score, absurdres", placeholder="生成されたプロンプトがここに表示されます...", lines=3)
with gr.Column(scale=3):
# タブ1 ポーズ生成
with gr.Tab("ポーズ生成"):
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="nsfw, sensitive, from behind, lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, missing arms, extra arms, missing legs, extra legs, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
lines=2
)
with gr.Row():
guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.1, label="Guidance Scale")
num_inference_steps = gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Steps")
with gr.Row():
width = gr.Slider(minimum=512, maximum=2048, value=832, step=64, label="Width")
height = gr.Slider(minimum=512, maximum=2048, value=1216, step=64, label="Height")
num_images = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Number of Images")
with gr.Row():
is_random_seeds = gr.Checkbox(label="Generate Random Seeds", value=True)
seeds = gr.Textbox(label="Seeds (Random if empty)", placeholder="Enter seed values. For multiple seeds, separate with commas.")
with gr.Accordion("Bodyline Settings", open=False):
bodyline_prompt = gr.Textbox(
label="Bodyline Prompt",
value="anime pose, girl, (white background:1.5), (monochrome:1.5), full body, sketch, eyes, breasts, (slim legs, skinny legs:1.2)",
lines=2,
visible=False
)
bodyline_negative_prompt = gr.Textbox(
label="Bodyline Negative Prompt",
value="(wings:1.6), (clothes:1.4), (garment:1.4), (lighting:1.4), (gray:1.4), (missing limb:1.4), (extra line:1.4), (extra limb:1.4), (extra arm:1.4), (extra legs:1.4), (hair:1.4), (bangs:1.4), (fringe:1.4), (forelock:1.4), (front hair:1.4), (fill:1.4), (ink pool:1.6)",
lines=2,
visible=False
)
with gr.Row():
bodyline_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=8.0, step=0.1, label="Bodyline Guidance Scale")
bodyline_steps = gr.Slider(minimum=10, maximum=100, value=20, step=1, label="Bodyline Steps")
with gr.Row():
bodyline_input_resolution = gr.Slider(minimum=128, maximum=1024, value=256, step=64, label="Bodyline Input Resolution")
bodyline_output_size = gr.Slider(minimum=512, maximum=2048, value=768, step=64, label="Bodyline Output Size")
with gr.Row():
is_random_bodyline_seeds = gr.Checkbox(label="Generate Random Bodyline Seeds", value=True)
bodyline_seeds = gr.Textbox(label="Bodyline Seeds (Random if empty)", placeholder="Enter seed values. For multiple seeds, separate with commas.")
with gr.Accordion("Tag Template Settings", open=False, visible=False):
with gr.Row():
reload_templates_btn = gr.Button("テンプレートを再読み込み", variant="secondary")
tag_candidate_generation_template = gr.Textbox(
label="Tag Candidate Generation Template",
value=TAG_CANDIDATE_GENERATION_TEMPLATE,
lines=6
)
tag_normalization_template = gr.Textbox(
label="Tag Normalization Template",
value=TAG_NORMALIZATION_TEMPLATE,
lines=6
)
tag_filter_template = gr.Textbox(
label="Tag Filter Template",
value=TAG_FILTER_TEMPLATE,
lines=6
)
tag_weight_template = gr.Textbox(
label="Tag Weighting Template",
value=TAG_WEIGHT_TEMPLATE,
lines=6
)
generate_image_btn = gr.Button("画像生成", variant="primary")
with gr.Row():
output_gallery1 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain")
output_gallery2 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain")
output_gallery3 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain")
output_gallery4 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain")
with gr.Accordion("Status", open=False):
status_text = gr.Textbox(label="Status", interactive=False)
# タブ2 画像生成
with gr.Tab("イラスト生成"):
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="画像", interactive=True, height=300, image_mode="RGBA", type="pil")
examples = gr.Examples(
examples=[
"examples/1.jpg",
"examples/2.jpg",
"examples/3.jpg",
"examples/4.jpg",
"examples/5.jpg"
],
inputs=input_image
)
with gr.Row():
cn_weight = gr.Slider(minimum=0.0, maximum=2.0, value=1.7, step=0.05, label="ControlNet Weight")
cn_guidance_start = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Guidance Start (T)")
cn_guidance_end = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Guidance End (T)")
with gr.Row():
cn_pixel_perfect = gr.Checkbox(label="Pixel Perfect", value=True)
cn_control_mode = gr.Dropdown(label="Control Mode", choices=controlnet_modes, value=controlnet_modes[2] if controlnet_modes else None)
generate_illustration_btn = gr.Button("イラスト生成", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
sd_api_url = gr.Textbox(label="Stable Diffusion API URL", value=os.environ.get("SD_API_URL", "http://127.0.0.1:7860"), visible=False)
# Batch size スライダーを追加
batch_size_slider = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Batch Size")
illustration_neg_prompt = gr.Textbox(
label="ネガティブプロンプト",
value="nsfw, sensitive, from behind, lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, missing arms, extra arms, missing legs, extra legs, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
lines=2
)
with gr.Row():
cn_model_dropdown = gr.Dropdown(label="ControlNet Model", choices=cn_models, value=cn_models[0] if len(cn_models) > 3 else (cn_models[0] if cn_models else None))
controlnet_module_dropdown = gr.Dropdown(label="ControlNet Module (Preprocessor)", choices=controlnet_modules, value="invert" if "invert" in controlnet_modules else (controlnet_modules[0] if controlnet_modules else None))
target_long_side = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="ターゲット長辺 (Target Long Side)")
sd_model_dropdown = gr.Dropdown(label="SD Model", choices=sd_models, value=sd_models[1] if len(sd_models) > 1 else (sd_models[0] if sd_models else None))
sampler_name = gr.Textbox(label="Sampler Name", value="Euler a")
with gr.Row():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, value=7.0, step=0.5, label="CFG Scale")
seed_input = gr.Textbox(label="Seed", value="-1", placeholder="-1 for random")
illustration_status_text = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
# 出力ギャラリー (input_imageは上のRowで共有)
illustration_output_gallery = gr.Gallery(label="生成されたイラスト", columns=2, height=1000, object_fit="contain")
output_gallery = [output_gallery1, output_gallery2, output_gallery3, output_gallery4]
# プロンプト生成ボタンのクリックイベント
generate_prompt_btn.click(
fn=generate_prompt_from_text,
inputs=[
text_prompt, tag_candidate_generation_template,
tag_normalization_template, tag_filter_template,
tag_weight_template
],
outputs=[
generated_prompt,
status_text
]
)
# 画像生成ボタンのクリックイベント
generate_image_btn.click(
fn=generate_images_from_prompt,
inputs=[
generated_prompt, negative_prompt,
guidance_scale, num_inference_steps,
width, height, num_images, is_random_seeds, seeds,
bodyline_prompt, bodyline_negative_prompt,
bodyline_steps, bodyline_guidance_scale,
bodyline_input_resolution, bodyline_output_size,
is_random_bodyline_seeds, bodyline_seeds,
],
outputs=[
output_gallery[0],
output_gallery[1],
output_gallery[2],
output_gallery[3],
status_text,
seeds,
bodyline_seeds
]
)
# イラスト生成ボタンのクリックイベント
generate_illustration_btn.click(
fn=generate_illustration,
inputs=[
sd_api_url,
input_image, # 共有のinput_imageを使用
generated_prompt, # illustration_prompt の代わりに generated_prompt を使用
illustration_neg_prompt,
target_long_side,
sd_model_dropdown,
cn_model_dropdown,
controlnet_module_dropdown,
sampler_name,
cfg_scale,
seed_input, # Textboxコンポーネントを渡す
cn_weight,
cn_guidance_start,
cn_guidance_end,
cn_pixel_perfect,
cn_control_mode,
batch_size_slider
],
outputs=[
illustration_output_gallery,
illustration_status_text
]
)
# テンプレート再読み込みボタンのクリックイベント
reload_templates_btn.click(
fn=reload_templates,
inputs=[],
outputs=[
tag_candidate_generation_template,
tag_normalization_template,
tag_filter_template,
tag_weight_template
]
)
# --- 認証ロジック ---
def authenticate_user(password):
if password == correct_password:
return {
auth_block: gr.update(visible=False),
main_ui_block: gr.update(visible=True),
auth_status: gr.update(value="認証成功!メインUIを表示します。", visible=True),
authenticated: True
}
else:
return {
auth_status: gr.update(value="パスワードが間違っています。", visible=True),
authenticated: False
}
login_button.click(
authenticate_user,
inputs=[password_input],
outputs=[auth_block, main_ui_block, auth_status, authenticated]
)
return app
app = create_ui()
user_name = os.getenv("USER_NAME")
# user_password = os.getenv("USER_PASSWORD") # create_ui内で使用するため、ここでは不要
app.launch() # authパラメータを削除
|