Spaces:
Runtime error
Runtime error
import os | |
import json | |
import base64 | |
import requests | |
import gradio as gr | |
from PIL import Image, PngImagePlugin | |
import io | |
import cv2 | |
from datetime import datetime | |
import time | |
import dotenv | |
from pathlib import Path | |
dotenv.load_dotenv() | |
TAG_CANDIDATE_GENERATION_TEMPLATE = os.getenv("TAG_CANDIDATE_GENERATION_TEMPLATE") | |
TAG_NORMALIZATION_TEMPLATE = os.getenv("TAG_NORMALIZATION_TEMPLATE") | |
TAG_FILTER_TEMPLATE = os.getenv("TAG_FILTER_TEMPLATE") | |
TAG_WEIGHT_TEMPLATE = os.getenv("TAG_WEIGHT_TEMPLATE") | |
print(TAG_CANDIDATE_GENERATION_TEMPLATE[0:10]) | |
print(TAG_NORMALIZATION_TEMPLATE[0:10]) | |
print(TAG_FILTER_TEMPLATE[0:10]) | |
print(TAG_WEIGHT_TEMPLATE[0:10]) | |
# --- Model Lists --- | |
sd_models = [ | |
# "animagine-xl-3.1.safetensors [e3c47aedb0]", | |
"animagine-xl-4.0.safetensors", | |
# "animagine-xl.safetensors", | |
# "LoRAMergeModel_animepose_outline_sotai_2025_04.fp16.safetensors [22dc3c53d2]", | |
] | |
cn_models =[ | |
"Bodyline2Image_v2-000350 [15294a3f]", | |
# "Bodyline2Image_v2-000400 [f87a785e]", | |
# "CN-anytest_v3-50000_fp16 [0963df47]", | |
# "CN-anytest_v4-marged [4bb64990]", | |
] | |
controlnet_modules = ["invert", "none"] # Example modules | |
controlnet_modes = ["Balanced", "My prompt is more important", "ControlNet is more important"] # Example modes | |
def save_images(images: list, output_dir: Path, prefix: str = "generated", subfolder: str = None): | |
"""Base64エンコードされた画像リストを保存する""" | |
# タイムスタンプ付きのディレクトリ名を作成 | |
output_dir = output_dir / subfolder | |
output_dir.mkdir(parents=True, exist_ok=True) | |
saved_paths = [] | |
for i, img_base64 in enumerate(images): | |
img_data = base64.b64decode(img_base64) | |
output_path = output_dir / f"{prefix}_{i}.png" | |
with open(output_path, "wb") as f: | |
f.write(img_data) | |
saved_paths.append(output_path) | |
return saved_paths | |
# テキストからプロンプトを生成するAPI呼び出し | |
def generate_prompt_from_text(text_prompt, tag_candidate_generation_template=None, | |
tag_normalization_template=None, tag_filter_template=None, | |
tag_weight_template=None): | |
""" | |
テキスト説明からプロンプトを生成するRunPod API呼び出し関数 | |
Args: | |
text_prompt: テキストプロンプト | |
tag_candidate_generation_template: タグ候補生成テンプレート | |
tag_normalization_template: タグ正規化テンプレート | |
tag_filter_template: タグフィルターテンプレート | |
tag_weight_template: タグ重み付けテンプレート | |
Returns: | |
生成されたプロンプト | |
""" | |
if not text_prompt.strip(): | |
return "", "エラー: テキストプロンプトが空です。プロンプトを生成するためのテキストを入力してください。" | |
api_key = os.environ.get("RUNPOD_API_KEY", "") | |
endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID", "") | |
if not endpoint_id: | |
return "", "エラー: Endpoint IDが設定されていません" | |
if not api_key: | |
return "", "エラー: API KeyまたはEndpoint IDが設定されていません" | |
url = f"https://api.runpod.ai/v2/{endpoint_id}/run" | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"input": { | |
"mode": "prompt_only", # プロンプト生成のみモード | |
"prompt": text_prompt, | |
"tag_candidate_generation_template": tag_candidate_generation_template, | |
"tag_normalization_template": tag_normalization_template, | |
"tag_filter_template": tag_filter_template, | |
"tag_weight_template": tag_weight_template | |
} | |
} | |
# Noneの値を持つキーを削除 | |
payload["input"] = {k: v for k, v in payload["input"].items() if v is not None} | |
print(f"プロンプト生成APIリクエスト送信: {json.dumps(payload, indent=2, ensure_ascii=False)}") | |
try: | |
# タイムアウトを2分に設定 | |
response = requests.post(url, headers=headers, json=payload, timeout=120) | |
response.raise_for_status() | |
result = response.json() | |
status = result.get("status") | |
if status == "IN_QUEUE" or status == "IN_PROGRESS": | |
task_id = result.get("id") | |
status_url = f"https://api.runpod.ai/v2/{endpoint_id}/status/{task_id}" | |
# セッション作成 | |
session = requests.Session() | |
session.headers.update(headers) | |
# 最大2分間ポーリング | |
for _ in range(24): | |
time.sleep(5) | |
try: | |
status_response = session.get(status_url, timeout=(3.05, 10)) | |
if status_response.status_code != 200: | |
print(f"Unexpected status code: {status_response.status_code}") | |
time.sleep(2) | |
continue | |
status_data = status_response.json() | |
current_status = status_data.get("status") | |
print(f"Current status: {current_status}") | |
if current_status == "COMPLETED": | |
result = status_data | |
break | |
elif current_status in ["FAILED", "CANCELLED"]: | |
raise Exception(f"Task failed with status: {current_status}") | |
except Exception as e: | |
print(f"Error during status check: {e}") | |
time.sleep(2) | |
continue | |
# セッションのクリーンアップ | |
session.close() | |
if result.get("status") != "COMPLETED": | |
raise Exception(f"API処理エラー: {result}") | |
output = result.get("output", {}) | |
generated_tags = output.get("generated_tags", []) | |
if not generated_tags: | |
return "", "エラー: タグを生成できませんでした。別のテキストで試してください。" | |
# タグをカンマ区切りの文字列に変換 | |
prompt_string = ", ".join(generated_tags) | |
return prompt_string, None | |
except Exception as e: | |
return "", f"エラーが発生しました: {str(e)}" | |
# RunPod APIを呼び出す関数を変更 | |
def generate_images_from_prompt(generated_prompt, negative_prompt="", | |
guidance_scale=7.0, num_inference_steps=30, | |
width=512, height=768, num_images=1, is_random_seeds=True, seeds=None, | |
bodyline_prompt=None, bodyline_negative_prompt=None, | |
bodyline_steps=20, bodyline_guidance_scale=8.0, | |
bodyline_input_resolution=256, bodyline_output_size=768, | |
is_random_bodyline_seeds=True, bodyline_seeds=None): | |
""" | |
生成されたプロンプトから画像を生成する関数 | |
Args: | |
generated_prompt: 生成済みのプロンプト(カンマ区切りのタグ) | |
negative_prompt: ネガティブプロンプト | |
guidance_scale: ガイダンススケール | |
num_inference_steps: 推論ステップ数 | |
width: 画像の幅 | |
height: 画像の高さ | |
num_images: 生成する画像の枚数 | |
seeds: 乱数シード | |
bodyline_prompt: ボディライン生成用プロンプト | |
bodyline_negative_prompt: ボディライン生成用ネガティブプロンプト | |
bodyline_steps: ボディライン生成の推論ステップ数 | |
bodyline_guidance_scale: ボディライン生成のガイダンススケール | |
bodyline_input_resolution: ボディライン生成の入力解像度 | |
bodyline_output_size: ボディライン生成の出力サイズ | |
is_random_bodyline_seeds: ボディライン生成用シードをランダムにするかどうか | |
bodyline_seeds: ボディライン生成用シード | |
api_key: RunPod API Key | |
endpoint_id: RunPod Endpoint ID | |
Returns: | |
生成された画像のリスト | |
""" | |
if not generated_prompt.strip(): | |
return *[None]*4, "エラー: 生成プロンプトが空です。先にプロンプトを生成してください。", seeds, bodyline_seeds | |
api_key = os.environ.get("RUNPOD_API_KEY", "") | |
endpoint_id = os.environ.get("RUNPOD_ENDPOINT_ID", "") | |
if not endpoint_id: | |
return *[None]*4, "エラー: Endpoint IDが設定されていません", seeds, bodyline_seeds | |
if not api_key: | |
return *[None]*4, "エラー: API KeyまたはEndpoint IDが設定されていません", seeds, bodyline_seeds | |
if is_random_seeds: | |
seeds = None | |
else: | |
seeds = seeds.split(",") | |
seeds = [int(seed) for seed in seeds] | |
if len(seeds) < num_images: | |
# 最後のseedを繰り返し使用 | |
seeds = seeds + [seeds[-1]] * (num_images - len(seeds)) | |
if is_random_bodyline_seeds: | |
bodyline_seeds = None | |
else: | |
bodyline_seeds = bodyline_seeds.split(",") | |
bodyline_seeds = [int(seed) for seed in bodyline_seeds] | |
if len(bodyline_seeds) < num_images: | |
# 最後のseedを繰り返し使用 | |
bodyline_seeds = bodyline_seeds + [bodyline_seeds[-1]] * (num_images - len(bodyline_seeds)) | |
url = f"https://api.runpod.ai/v2/{endpoint_id}/run" | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
# タグリストの作成 | |
tags = [tag.strip() for tag in generated_prompt.split(",")] | |
payload = { | |
"input": { | |
"mode": "image_only", # 画像生成のみモード | |
"tags": tags, | |
"negative_prompt": negative_prompt, | |
"steps": num_inference_steps, | |
"guidance_scale": guidance_scale, | |
"width": width, | |
"height": height, | |
"num_images": num_images, | |
"seeds": seeds, | |
"bodyline_prompt": bodyline_prompt, | |
"bodyline_negative_prompt": bodyline_negative_prompt, | |
"bodyline_steps": bodyline_steps, | |
"bodyline_guidance_scale": bodyline_guidance_scale, | |
"bodyline_input_resolution": bodyline_input_resolution, | |
"bodyline_output_size": bodyline_output_size, | |
"bodyline_seeds": bodyline_seeds | |
} | |
} | |
# Noneの値を持つキーを削除 | |
payload["input"] = {k: v for k, v in payload["input"].items() if v is not None} | |
print(f"画像生成APIリクエスト送信: {json.dumps(payload, indent=2, ensure_ascii=False)}") | |
try: | |
# タイムアウトを10分に設定 | |
response = requests.post(url, headers=headers, json=payload, timeout=600) | |
response.raise_for_status() | |
result = response.json() | |
status = result.get("status") | |
if status == "IN_QUEUE" or status == "IN_PROGRESS": | |
task_id = result.get("id") | |
status_url = f"https://api.runpod.ai/v2/{endpoint_id}/status/{task_id}" | |
# セッション作成 | |
session = requests.Session() | |
session.headers.update(headers) | |
# 最大5分間ポーリング | |
for _ in range(60): | |
time.sleep(5) | |
try: | |
# タイムアウトを細かく設定 | |
status_response = session.get( | |
status_url, | |
timeout=(3.05, 10) # (接続タイムアウト, 読み込みタイムアウト) | |
) | |
if status_response.status_code != 200: | |
print(f"Unexpected status code: {status_response.status_code}") | |
time.sleep(2) | |
continue | |
status_data = status_response.json() | |
current_status = status_data.get("status") | |
print(f"Current status: {current_status}") | |
if current_status == "COMPLETED": | |
result = status_data | |
break | |
elif current_status in ["FAILED", "CANCELLED"]: | |
raise Exception(f"Task failed with status: {current_status}") | |
except requests.exceptions.Timeout: | |
print("Request timed out, retrying...") | |
time.sleep(2) | |
continue | |
except requests.exceptions.RequestException as e: | |
print(f"Request error: {e}") | |
time.sleep(2) | |
continue | |
except json.JSONDecodeError: | |
print("Invalid JSON response, retrying...") | |
time.sleep(2) | |
continue | |
# セッションのクリーンアップ | |
session.close() | |
if result.get("status") != "COMPLETED": | |
raise Exception(f"API処理エラー: {result}") | |
output = result.get("output", {}) | |
images = [None]*num_images | |
# 画像の保存処理 | |
if "images" in output: | |
timestamp = time.strftime("%Y%m%d_%H%M%S") | |
output_dir = Path("output") | |
# 生成画像の保存 | |
save_images( | |
output["images"], | |
output_dir, | |
prefix="generated", | |
subfolder=timestamp | |
) | |
# ボディラインの保存 | |
save_images( | |
output["bodylines"], | |
output_dir, | |
prefix="bodyline", | |
subfolder=timestamp | |
) | |
# seedsの保存 | |
seeds = output["parameters"]["image_parameters"]["seeds"] | |
# メタデータの保存 | |
metadata = { | |
"used_tags": output.get("used_tags", []), | |
"parameters": output.get("parameters", {}) | |
} | |
metadata_path = output_dir / timestamp / "metadata.json" | |
with open(metadata_path, "w", encoding="utf-8") as f: | |
json.dump(metadata, f, ensure_ascii=False, indent=2) | |
print(f"Images and metadata saved to: {output_dir / timestamp}") | |
# Gradio用の画像変換 | |
for i, (img_b64, bodyline_b64, diff_b64) in enumerate(zip(output["images"], output["bodylines"], output["remove_bg_diff"]["diff"])): | |
img_data = base64.b64decode(img_b64.split(",")[1] if "," in img_b64 else img_b64) | |
img = Image.open(io.BytesIO(img_data)) | |
bodyline_data = base64.b64decode(bodyline_b64.split(",")[1] if "," in bodyline_b64 else bodyline_b64) | |
bodyline = Image.open(io.BytesIO(bodyline_data)) | |
diff_data = base64.b64decode(diff_b64.split(",")[1] if "," in diff_b64 else diff_b64) | |
diff = Image.open(io.BytesIO(diff_data)) | |
images[i] = [bodyline, img, diff] | |
print(output["remove_bg_diff"]["white_percentage"]) | |
print(output["remove_bg_diff"]["white_pixels_mask2"]) | |
print(output["remove_bg_diff"]["white_pixels_diff"]) | |
return ( | |
*images, | |
json.dumps(output["parameters"], indent=2, ensure_ascii=False), | |
', '.join(str(seed) for seed in output["parameters"]["image_parameters"]["seeds"]), | |
', '.join(str(seed) for seed in output["parameters"]["bodyline_parameters"]["seeds"]) | |
) | |
except Exception as e: | |
return *[None]*4, f"エラーが発生しました: {str(e)}", seeds, bodyline_seeds | |
# テンプレートを再読み込みする関数を追加 | |
def reload_templates(): | |
"""テンプレートモジュールを再読み込みする""" | |
import importlib | |
import templates | |
importlib.reload(templates) | |
from templates import ( | |
TAG_CANDIDATE_GENERATION_TEMPLATE, | |
TAG_NORMALIZATION_TEMPLATE, | |
TAG_FILTER_TEMPLATE, | |
TAG_WEIGHT_TEMPLATE | |
) | |
return ( | |
TAG_CANDIDATE_GENERATION_TEMPLATE, | |
TAG_NORMALIZATION_TEMPLATE, | |
TAG_FILTER_TEMPLATE, | |
TAG_WEIGHT_TEMPLATE | |
) | |
# --- イラスト生成用関数 --- | |
def generate_illustration( | |
sd_api_url, input_image, illustration_prompt, illustration_neg_prompt, | |
target_long_side, sd_model_name, cn_model_name, controlnet_module, | |
sampler_name, cfg_scale, seed_str, cn_weight, | |
cn_guidance_start, cn_guidance_end, cn_pixel_perfect, cn_control_mode, | |
batch_size | |
): | |
"""ControlNetを使用してイラストを生成する関数""" | |
# 必要なライブラリをインポート | |
import numpy as np | |
import cv2 | |
import base64 | |
import requests | |
import json | |
from PIL import Image, PngImagePlugin | |
import io | |
# 引数名を generated_prompt に合わせる | |
generated_prompt_text = illustration_prompt | |
if not sd_api_url or not sd_api_url.startswith("http"): | |
return [], "エラー: Stable Diffusion APIのURLが無効です。" | |
if input_image is None: | |
# input_image はPIL ImageまたはNumpy Arrayを想定 | |
return [], "エラー: 入力画像がありません。" | |
if not generated_prompt_text or not generated_prompt_text.strip(): | |
return [], "エラー: プロンプトが空です。" | |
try: | |
# 入力画像の形式を確認し、PIL Imageに統一 | |
print(f"入力画像の型: {type(input_image)}") | |
if isinstance(input_image, np.ndarray): | |
print("Numpy ArrayをPIL Imageに変換します。") | |
# 配列の形状を確認してモードを決定 (高さ, 幅, チャンネル数) | |
if input_image.shape[2] == 4: | |
print(" 入力配列は RGBA です。") | |
img_pil = Image.fromarray(input_image, mode='RGBA') | |
elif input_image.shape[2] == 3: | |
print(" 入力配列は RGB です。") | |
img_pil = Image.fromarray(input_image, mode='RGB') | |
else: | |
return [], f"エラー: サポートされていないNumpy配列のチャンネル数です ({input_image.shape[2]})。" | |
elif isinstance(input_image, Image.Image): | |
# すでにPIL Imageの場合 | |
print("PIL Imageのままです。") | |
img_pil = input_image | |
else: | |
# 画像が読み込まれていないか、予期しない形式の場合 | |
return [], f"エラー: サポートされていない入力画像形式です ({type(input_image)})。画像をアップロードしてください。" | |
# PIL Image のモードを確認し、必要ならRGBに変換(白背景合成含む) | |
print(f"入力画像のモード: {img_pil.mode}") | |
if img_pil.mode == 'RGBA': | |
print("入力画像はRGBAです。白背景に合成します。") | |
# 白い背景を作成 | |
background = Image.new("RGB", img_pil.size, (255, 255, 255)) | |
background.paste(img_pil, mask=img_pil.split()[3]) | |
img_pil = background # RGB画像に置き換え | |
img_pil = img_pil.convert('RGB') # ensure RGB mode | |
elif img_pil.mode != 'RGB': | |
# その他のモード(例: P, L)の場合もRGBに変換 | |
print(f"入力画像をRGBに変換します (元のモード: {img_pil.mode})。") | |
img_pil = img_pil.convert('RGB') | |
# OpenCV形式 (BGR) に変換 | |
img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) | |
# --- 画像サイズに基づいて出力サイズを計算 --- | |
h, w = img_cv.shape[:2] | |
if h == 0 or w == 0: | |
return [], "エラー: 入力画像のサイズが無効です。" | |
print(f"入力画像のサイズ: {w}x{h}") | |
if w >= h: | |
target_w = target_long_side | |
target_h = int(target_long_side * h / w) | |
else: | |
target_h = target_long_side | |
target_w = int(target_long_side * w / h) | |
# 64 の倍数に丸める | |
target_w = max(64, round(target_w / 64) * 64) | |
target_h = max(64, round(target_h / 64) * 64) | |
print(f"計算された出力サイズ (64の倍数に丸め): {target_w}x{target_h}") | |
# リサイズ (エラーチェック追加) | |
try: | |
img_resized = cv2.resize(img_cv, (target_w, target_h), interpolation=cv2.INTER_AREA) | |
except cv2.error as resize_err: | |
return [], f"エラー: 画像のリサイズに失敗しました ({resize_err})。入力画像を確認してください。" | |
# 画像のエンコード | |
retval, bytes_data = cv2.imencode('.png', img_resized) | |
if not retval: | |
raise RuntimeError("画像のPNGエンコードに失敗しました") | |
encoded_image = base64.b64encode(bytes_data).decode('utf-8') | |
print("ControlNet用画像のエンコード完了") | |
# シードの処理 (-1 または数値) | |
try: | |
seed = int(seed_str) if seed_str else -1 | |
except ValueError: | |
seed = -1 # 無効な文字列の場合は -1 (ランダム) | |
# --- ペイロード作成 --- | |
payload = { | |
"prompt": generated_prompt_text, # 共通プロンプトを使用 | |
"negative_prompt": illustration_neg_prompt, | |
"steps": 20, # 固定値。必要ならUIに追加 | |
"batch_size": batch_size, # ペイロードに batch_size を追加 | |
"width": target_w, | |
"height": target_h, | |
"sampler_name": sampler_name, | |
"cfg_scale": cfg_scale, | |
"seed": seed, | |
"override_settings": { | |
"sd_model_checkpoint": sd_model_name, | |
}, | |
"alwayson_scripts": { | |
"controlnet": { | |
"args": [ | |
{ | |
"enabled": True, | |
"image": encoded_image, | |
"module": controlnet_module, | |
"model": cn_model_name, | |
"weight": cn_weight, | |
"resize_mode": "Resize and Fill", # "Just Resize" も選択肢 | |
"guidance_start": cn_guidance_start, | |
"guidance_end": cn_guidance_end, | |
"pixel_perfect": cn_pixel_perfect, | |
"control_mode": cn_control_mode, | |
} | |
] | |
} | |
} | |
} | |
# ペイロードログ(画像データ省略) | |
payload_log = {k: v for k, v in payload.items() if k != 'alwayson_scripts'} | |
if 'alwayson_scripts' in payload and 'controlnet' in payload['alwayson_scripts']: | |
payload_log['alwayson_scripts'] = {'controlnet': {'args': []}} | |
for arg in payload['alwayson_scripts']['controlnet']['args']: | |
arg_log = {k: v for k, v in arg.items() if k != 'image'} | |
payload_log['alwayson_scripts']['controlnet']['args'].append(arg_log) | |
print(f"ペイロード (画像データ省略): {json.dumps(payload_log, indent=2, ensure_ascii=False)}") | |
# --- API 呼び出し --- | |
if "runpod" in sd_api_url: | |
input_rp = { | |
"input": { | |
"path": "/sdapi/v1/txt2img", | |
"method": "POST", | |
"payload": payload | |
} | |
} | |
headers = { | |
"Authorization": f"Bearer {os.getenv('RUNPOD_API_KEY')}" | |
} | |
response = requests.post(url=sd_api_url, json=input_rp, headers=headers, timeout=600) | |
response.raise_for_status() | |
if not "output" in response.json(): | |
raise Exception(f"API呼び出しに失敗しました。data: {response.json()}") | |
r = response.json()['output']['body'] | |
else: | |
txt2img_endpoint = f'{sd_api_url.strip("/")}/sdapi/v1/txt2img' | |
png_info_endpoint = f'{sd_api_url.strip("/")}/sdapi/v1/png-info' | |
print(f"API呼び出し中 (POST): {txt2img_endpoint}") | |
print(f"使用するSDモデル (override): {sd_model_name}") | |
response = requests.post(url=txt2img_endpoint, json=payload, timeout=600) # タイムアウト10分 | |
response.raise_for_status() | |
print("API呼び出し成功!") | |
r = response.json() | |
# --- 結果の処理 --- | |
output_images = [] | |
if 'images' in r and r['images']: | |
print(f"{len(r['images'])} 枚の画像を処理中...") | |
for i, img_base64 in enumerate(r['images']): | |
try: | |
if not img_base64: | |
print(f"警告: 画像 {i+1} のデータが空です。スキップします。") | |
continue | |
# Base64デコード | |
if "," in img_base64: | |
img_data_actual = base64.b64decode(img_base64.split(',', 1)[1]) | |
else: | |
img_data_actual = base64.b64decode(img_base64) | |
image = Image.open(io.BytesIO(img_data_actual)) | |
# # PNG Info の取得と埋め込み (ベストエフォート) | |
# pnginfo_data = None | |
# try: | |
# png_payload = {"image": "data:image/png;base64," + (img_base64.split(',', 1)[1] if ',' in img_base64 else img_base64)} | |
# response2 = requests.post(url=png_info_endpoint, json=png_payload, timeout=10) | |
# response2.raise_for_status() | |
# info_text = response2.json().get("info", "") | |
# if info_text: | |
# pnginfo_data = PngImagePlugin.PngInfo() | |
# pnginfo_data.add_text("parameters", info_text) | |
# print(f"画像 {i+1} の PNG Info 取得成功") | |
# else: | |
# print(f"画像 {i+1} の PNG Info は空でした。") | |
# except requests.exceptions.RequestException as png_req_err: | |
# # PNG Info取得失敗は警告にとどめる | |
# print(f"警告: 画像 {i+1} の PNG Info 取得リクエスト失敗 ({png_req_err})。") | |
# except Exception as png_err: | |
# print(f"警告: 画像 {i+1} の PNG Info 処理中にエラー ({png_err})。") | |
# Gradio Gallery 用に PIL Image をリストに追加 | |
output_images.append(image.copy()) | |
except base64.binascii.Error as b64_err: | |
print(f"エラー: 画像 {i+1} のBase64デコードに失敗しました ({b64_err})。レスポンスを確認してください。") | |
continue # 次の画像の処理へ | |
except Exception as img_proc_err: | |
print(f"エラー: 画像 {i+1} の処理中にエラーが発生しました: {img_proc_err}") | |
continue # 次の画像の処理へ | |
if output_images: | |
print(f"{len(output_images)}枚の画像をGradio Gallery用に準備完了。") | |
return output_images, f"{len(output_images)-1}枚の画像を生成しました。" | |
else: | |
# 画像データはあったが、すべて処理に失敗した場合 | |
return [], "エラー: 画像は受信しましたが、処理中に問題が発生しました。ログを確認してください。" | |
else: | |
error_info = r.get('info') or r.get('detail') or r.get('error') or r.get('message') or json.dumps(r) | |
print(f"エラー: レスポンスに画像が含まれていませんでした。詳細: {error_info}") | |
# APIからのエラーメッセージをユーザーに返す | |
return [], f"エラー: サーバーから画像が返されませんでした。詳細: {error_info}" | |
except requests.exceptions.Timeout: | |
print("エラー: API呼び出しがタイムアウトしました。") | |
return [], "エラー: API呼び出しがタイムアウトしました (10分)。サーバーの負荷が高いか、設定を確認してください。" | |
except requests.exceptions.RequestException as e: | |
error_msg = f"リクエストエラー: {e}" | |
status_code_msg = "" | |
details_msg = "" | |
if e.response is not None: | |
status_code_msg = f" ステータスコード: {e.response.status_code}" | |
try: | |
error_details = e.response.json() | |
details_msg = f" 詳細: {json.dumps(error_details, indent=2, ensure_ascii=False)}" | |
except json.JSONDecodeError: | |
details_msg = f" 応答(Text): {e.response.text[:500]}" | |
full_error_msg = f"{error_msg}{status_code_msg}{details_msg}" | |
print(full_error_msg) | |
# Gradioには要約したメッセージを表示 | |
return [], f"APIリクエストエラーが発生しました。{status_code_msg}" | |
except ImportError as imp_err: | |
missing_lib = str(imp_err).split("'")[1] | |
print(f"エラー:必要なライブラリ ({missing_lib}) が不足しています。pip install {missing_lib} を試してください。") | |
return [], f"エラー:必要なライブラリ ({missing_lib}) が不足しています。" | |
except Exception as e: | |
import traceback | |
print(f"予期せぬエラーが発生しました: {e}") | |
traceback.print_exc() | |
return [], f"予期せぬエラーが発生しました: {type(e).__name__} - {e}" | |
# Gradio UIの構築 | |
def create_ui(): | |
# パスワード認証用の状態 | |
authenticated = gr.State(False) | |
# 正しいパスワード (環境変数から取得) | |
correct_password = os.getenv("USER_PASSWORD", "testpassword") # デフォルトパスワードを設定 | |
with gr.Blocks() as app: | |
# --- パスワード認証UI --- | |
with gr.Row(visible=True) as auth_block: | |
with gr.Column(): | |
gr.Markdown("# 認証が必要です") | |
password_input = gr.Textbox(label="パスワード", type="password", placeholder="パスワードを入力してください") | |
login_button = gr.Button("ログイン") | |
auth_status = gr.Markdown(visible=False) | |
# --- メインUI (初期状態では非表示) --- | |
with gr.Column(visible=False) as main_ui_block: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
text_prompt = gr.Textbox(label="テキスト説明", placeholder="画像にしたい内容を自然な文章で説明してください(例: ピンクの髪の猫耳メイド)", lines=3) | |
generate_prompt_btn = gr.Button("プロンプト生成", variant="secondary") | |
generated_prompt = gr.Textbox(label="プロンプト", value="original, 1girl, solo, pink hair, cat ears, animal ears, smile, masterpiece, high score, great score, absurdres", placeholder="生成されたプロンプトがここに表示されます...", lines=3) | |
with gr.Column(scale=3): | |
# タブ1 ポーズ生成 | |
with gr.Tab("ポーズ生成"): | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="nsfw, sensitive, from behind, lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, missing arms, extra arms, missing legs, extra legs, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", | |
lines=2 | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.1, label="Guidance Scale") | |
num_inference_steps = gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Steps") | |
with gr.Row(): | |
width = gr.Slider(minimum=512, maximum=2048, value=832, step=64, label="Width") | |
height = gr.Slider(minimum=512, maximum=2048, value=1216, step=64, label="Height") | |
num_images = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Number of Images") | |
with gr.Row(): | |
is_random_seeds = gr.Checkbox(label="Generate Random Seeds", value=True) | |
seeds = gr.Textbox(label="Seeds (Random if empty)", placeholder="Enter seed values. For multiple seeds, separate with commas.") | |
with gr.Accordion("Bodyline Settings", open=False): | |
bodyline_prompt = gr.Textbox( | |
label="Bodyline Prompt", | |
value="anime pose, girl, (white background:1.5), (monochrome:1.5), full body, sketch, eyes, breasts, (slim legs, skinny legs:1.2)", | |
lines=2, | |
visible=False | |
) | |
bodyline_negative_prompt = gr.Textbox( | |
label="Bodyline Negative Prompt", | |
value="(wings:1.6), (clothes:1.4), (garment:1.4), (lighting:1.4), (gray:1.4), (missing limb:1.4), (extra line:1.4), (extra limb:1.4), (extra arm:1.4), (extra legs:1.4), (hair:1.4), (bangs:1.4), (fringe:1.4), (forelock:1.4), (front hair:1.4), (fill:1.4), (ink pool:1.6)", | |
lines=2, | |
visible=False | |
) | |
with gr.Row(): | |
bodyline_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=8.0, step=0.1, label="Bodyline Guidance Scale") | |
bodyline_steps = gr.Slider(minimum=10, maximum=100, value=20, step=1, label="Bodyline Steps") | |
with gr.Row(): | |
bodyline_input_resolution = gr.Slider(minimum=128, maximum=1024, value=256, step=64, label="Bodyline Input Resolution") | |
bodyline_output_size = gr.Slider(minimum=512, maximum=2048, value=768, step=64, label="Bodyline Output Size") | |
with gr.Row(): | |
is_random_bodyline_seeds = gr.Checkbox(label="Generate Random Bodyline Seeds", value=True) | |
bodyline_seeds = gr.Textbox(label="Bodyline Seeds (Random if empty)", placeholder="Enter seed values. For multiple seeds, separate with commas.") | |
with gr.Accordion("Tag Template Settings", open=False, visible=False): | |
with gr.Row(): | |
reload_templates_btn = gr.Button("テンプレートを再読み込み", variant="secondary") | |
tag_candidate_generation_template = gr.Textbox( | |
label="Tag Candidate Generation Template", | |
value=TAG_CANDIDATE_GENERATION_TEMPLATE, | |
lines=6 | |
) | |
tag_normalization_template = gr.Textbox( | |
label="Tag Normalization Template", | |
value=TAG_NORMALIZATION_TEMPLATE, | |
lines=6 | |
) | |
tag_filter_template = gr.Textbox( | |
label="Tag Filter Template", | |
value=TAG_FILTER_TEMPLATE, | |
lines=6 | |
) | |
tag_weight_template = gr.Textbox( | |
label="Tag Weighting Template", | |
value=TAG_WEIGHT_TEMPLATE, | |
lines=6 | |
) | |
generate_image_btn = gr.Button("画像生成", variant="primary") | |
with gr.Row(): | |
output_gallery1 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain") | |
output_gallery2 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain") | |
output_gallery3 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain") | |
output_gallery4 = gr.Gallery(label="Generated Results", columns=1, height=600, object_fit="contain") | |
with gr.Accordion("Status", open=False): | |
status_text = gr.Textbox(label="Status", interactive=False) | |
# タブ2 画像生成 | |
with gr.Tab("イラスト生成"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(label="画像", interactive=True, height=300, image_mode="RGBA", type="pil") | |
examples = gr.Examples( | |
examples=[ | |
"examples/1.jpg", | |
"examples/2.jpg", | |
"examples/3.jpg", | |
"examples/4.jpg", | |
"examples/5.jpg" | |
], | |
inputs=input_image | |
) | |
with gr.Row(): | |
cn_weight = gr.Slider(minimum=0.0, maximum=2.0, value=1.7, step=0.05, label="ControlNet Weight") | |
cn_guidance_start = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Guidance Start (T)") | |
cn_guidance_end = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Guidance End (T)") | |
with gr.Row(): | |
cn_pixel_perfect = gr.Checkbox(label="Pixel Perfect", value=True) | |
cn_control_mode = gr.Dropdown(label="Control Mode", choices=controlnet_modes, value=controlnet_modes[2] if controlnet_modes else None) | |
generate_illustration_btn = gr.Button("イラスト生成", variant="primary") | |
with gr.Accordion("Advanced Settings", open=False): | |
sd_api_url = gr.Textbox(label="Stable Diffusion API URL", value=os.environ.get("SD_API_URL", "http://127.0.0.1:7860"), visible=False) | |
# Batch size スライダーを追加 | |
batch_size_slider = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Batch Size") | |
illustration_neg_prompt = gr.Textbox( | |
label="ネガティブプロンプト", | |
value="nsfw, sensitive, from behind, lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, missing arms, extra arms, missing legs, extra legs, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", | |
lines=2 | |
) | |
with gr.Row(): | |
cn_model_dropdown = gr.Dropdown(label="ControlNet Model", choices=cn_models, value=cn_models[0] if len(cn_models) > 3 else (cn_models[0] if cn_models else None)) | |
controlnet_module_dropdown = gr.Dropdown(label="ControlNet Module (Preprocessor)", choices=controlnet_modules, value="invert" if "invert" in controlnet_modules else (controlnet_modules[0] if controlnet_modules else None)) | |
target_long_side = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="ターゲット長辺 (Target Long Side)") | |
sd_model_dropdown = gr.Dropdown(label="SD Model", choices=sd_models, value=sd_models[1] if len(sd_models) > 1 else (sd_models[0] if sd_models else None)) | |
sampler_name = gr.Textbox(label="Sampler Name", value="Euler a") | |
with gr.Row(): | |
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, value=7.0, step=0.5, label="CFG Scale") | |
seed_input = gr.Textbox(label="Seed", value="-1", placeholder="-1 for random") | |
illustration_status_text = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(scale=2): | |
# 出力ギャラリー (input_imageは上のRowで共有) | |
illustration_output_gallery = gr.Gallery(label="生成されたイラスト", columns=2, height=1000, object_fit="contain") | |
output_gallery = [output_gallery1, output_gallery2, output_gallery3, output_gallery4] | |
# プロンプト生成ボタンのクリックイベント | |
generate_prompt_btn.click( | |
fn=generate_prompt_from_text, | |
inputs=[ | |
text_prompt, tag_candidate_generation_template, | |
tag_normalization_template, tag_filter_template, | |
tag_weight_template | |
], | |
outputs=[ | |
generated_prompt, | |
status_text | |
] | |
) | |
# 画像生成ボタンのクリックイベント | |
generate_image_btn.click( | |
fn=generate_images_from_prompt, | |
inputs=[ | |
generated_prompt, negative_prompt, | |
guidance_scale, num_inference_steps, | |
width, height, num_images, is_random_seeds, seeds, | |
bodyline_prompt, bodyline_negative_prompt, | |
bodyline_steps, bodyline_guidance_scale, | |
bodyline_input_resolution, bodyline_output_size, | |
is_random_bodyline_seeds, bodyline_seeds, | |
], | |
outputs=[ | |
output_gallery[0], | |
output_gallery[1], | |
output_gallery[2], | |
output_gallery[3], | |
status_text, | |
seeds, | |
bodyline_seeds | |
] | |
) | |
# イラスト生成ボタンのクリックイベント | |
generate_illustration_btn.click( | |
fn=generate_illustration, | |
inputs=[ | |
sd_api_url, | |
input_image, # 共有のinput_imageを使用 | |
generated_prompt, # illustration_prompt の代わりに generated_prompt を使用 | |
illustration_neg_prompt, | |
target_long_side, | |
sd_model_dropdown, | |
cn_model_dropdown, | |
controlnet_module_dropdown, | |
sampler_name, | |
cfg_scale, | |
seed_input, # Textboxコンポーネントを渡す | |
cn_weight, | |
cn_guidance_start, | |
cn_guidance_end, | |
cn_pixel_perfect, | |
cn_control_mode, | |
batch_size_slider | |
], | |
outputs=[ | |
illustration_output_gallery, | |
illustration_status_text | |
] | |
) | |
# テンプレート再読み込みボタンのクリックイベント | |
reload_templates_btn.click( | |
fn=reload_templates, | |
inputs=[], | |
outputs=[ | |
tag_candidate_generation_template, | |
tag_normalization_template, | |
tag_filter_template, | |
tag_weight_template | |
] | |
) | |
# --- 認証ロジック --- | |
def authenticate_user(password): | |
if password == correct_password: | |
return { | |
auth_block: gr.update(visible=False), | |
main_ui_block: gr.update(visible=True), | |
auth_status: gr.update(value="認証成功!メインUIを表示します。", visible=True), | |
authenticated: True | |
} | |
else: | |
return { | |
auth_status: gr.update(value="パスワードが間違っています。", visible=True), | |
authenticated: False | |
} | |
login_button.click( | |
authenticate_user, | |
inputs=[password_input], | |
outputs=[auth_block, main_ui_block, auth_status, authenticated] | |
) | |
return app | |
app = create_ui() | |
user_name = os.getenv("USER_NAME") | |
# user_password = os.getenv("USER_PASSWORD") # create_ui内で使用するため、ここでは不要 | |
app.launch() # authパラメータを削除 | |