sbv2_onnx / app.py
aka7774's picture
Update app.py
68f523b verified
import gradio as gr
import subprocess
import os
import sys
import yaml
from pathlib import Path
import time
import threading
import tempfile
import shutil
import gc # ガベージコレクション用
# --- 定数 ---
# Dockerfile内でクローンされるパスに合わせる
SCRIPT_DIR = Path(__file__).parent
SBV2_REPO_PATH = SCRIPT_DIR / "Style-Bert-VITS2"
# ダウンロード用ファイルの一時置き場 (コンテナ内に作成)
OUTPUT_DIR = SCRIPT_DIR / "outputs"
# --- ヘルパー関数 ---
def add_sbv2_to_path():
"""Style-Bert-VITS2リポジトリのパスを sys.path に追加"""
repo_path_str = str(SBV2_REPO_PATH.resolve())
if SBV2_REPO_PATH.exists() and repo_path_str not in sys.path:
sys.path.insert(0, repo_path_str)
print(f"Added {repo_path_str} to sys.path")
elif not SBV2_REPO_PATH.exists():
print(f"Warning: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}")
def stream_process_output(process, log_list):
"""サブプロセスの標準出力/エラーをリアルタイムでリストに追加"""
try:
if process.stdout:
for line in iter(process.stdout.readline, ''):
log_list.append(line.strip()) # 余分な改行を削除
if process.stderr:
for line in iter(process.stderr.readline, ''):
processed_line = f"stderr: {line.strip()}"
# 警告はそのまま、他はエラーとして強調 (任意)
if "warning" not in line.lower():
processed_line = f"ERROR (stderr): {line.strip()}"
log_list.append(processed_line)
except Exception as e:
log_list.append(f"Error reading process stream: {e}")
# --- Gradio アプリのバックエンド関数 ---
def convert_safetensors_to_onnx_gradio(
safetensors_file_obj,
config_file_obj,
style_vectors_file_obj
): # gr.Progress は削除
"""
アップロードされたSafetensors, config.json, style_vectors.npy を使って
ONNXに変換し、結果をダウンロード可能にする。
"""
log = ["Starting ONNX conversion..."]
# 初期状態ではダウンロードファイルは空
yield "\n".join(log), None
# --- ファイルアップロードの検証 ---
if safetensors_file_obj is None:
log.append("❌ Error: Safetensors file is missing. Please upload the .safetensors file.")
yield "\n".join(log), None
return
if config_file_obj is None:
log.append("❌ Error: config.json file is missing. Please upload the config.json file.")
yield "\n".join(log), None
return
if style_vectors_file_obj is None:
log.append("❌ Error: style_vectors.npy file is missing. Please upload the style_vectors.npy file.")
yield "\n".join(log), None
return
# --- Style-Bert-VITS2 パスの確認 ---
add_sbv2_to_path()
if not SBV2_REPO_PATH.exists():
log.append(f"❌ Error: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}. Check Space build logs.")
yield "\n".join(log), None
return
# --- 出力ディレクトリ作成 ---
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
onnx_output_path_str = None # 最終的なONNXファイルパス (文字列)
current_log = log[:] # ログリストをコピー
try:
# --- 一時ディレクトリを作成して処理 ---
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
current_log.append(f"📁 Created temporary directory: {temp_dir_path}")
yield "\n".join(current_log), None # UI更新
# --- SBV2が期待するディレクトリ構造を一時ディレクトリ内に作成 ---
# モデル名を .safetensors ファイル名から取得 (拡張子なし)
safetensors_filename = Path(safetensors_file_obj.name).name
if not safetensors_filename.lower().endswith(".safetensors"):
current_log.append(f"❌ Error: Invalid safetensors filename: {safetensors_filename}")
yield "\n".join(current_log), None
return
model_name = Path(safetensors_filename).stem # 拡張子を除いた部分
# assets_root を一時ディレクトリ自体にする
assets_root = temp_dir_path
# assets_root の下に model_name のディレクトリを作成
model_dir_in_temp = assets_root / model_name
model_dir_in_temp.mkdir(exist_ok=True)
current_log.append(f" - Created model directory: {model_dir_in_temp.relative_to(assets_root)}")
yield "\n".join(current_log), None
# --- 3つのファイルを model_dir_in_temp にコピー ---
files_to_copy = {
"safetensors": safetensors_file_obj,
"config.json": config_file_obj,
"style_vectors.npy": style_vectors_file_obj,
}
copied_paths = {}
for file_key, file_obj in files_to_copy.items():
original_filename = Path(file_obj.name).name
# ファイル名の基本的な検証 (サニタイズはより厳密に行うことも可能)
if "/" in original_filename or "\\" in original_filename or ".." in original_filename:
current_log.append(f"❌ Error: Invalid characters found in filename: {original_filename}")
yield "\n".join(current_log), None
return # tryブロックを抜ける
# 期待されるファイル名と一致しているか確認 (config と style_vectors)
if file_key == "config.json" and original_filename.lower() != "config.json":
current_log.append(f"⚠️ Warning: Uploaded JSON file name is '{original_filename}', expected 'config.json'. Using uploaded name.")
if file_key == "style_vectors.npy" and original_filename.lower() != "style_vectors.npy":
current_log.append(f"⚠️ Warning: Uploaded NPY file name is '{original_filename}', expected 'style_vectors.npy'. Using uploaded name.")
destination_path = model_dir_in_temp / original_filename
try:
shutil.copy(file_obj.name, destination_path)
current_log.append(f" - Copied '{original_filename}' to model directory.")
# .safetensorsファイルのパスを保存しておく
if file_key == "safetensors":
copied_paths["safetensors"] = destination_path
except Exception as e:
current_log.append(f"❌ Error copying file '{original_filename}': {e}")
yield "\n".join(current_log), None
return # tryブロックを抜ける
yield "\n".join(current_log), None # 各ファイルコピー後にUI更新
# safetensorsファイルがコピーされたか確認
temp_safetensors_path = copied_paths.get("safetensors")
if not temp_safetensors_path:
current_log.append("❌ Error: Failed to locate the copied safetensors file in the temporary directory.")
yield "\n".join(current_log), None
return
current_log.append(f"✅ All required files copied to temporary model directory.")
current_log.append(f" - Using temporary assets_root: {assets_root}")
yield "\n".join(current_log), None
# --- paths.yml を一時的に設定 ---
config_path = SBV2_REPO_PATH / "configs" / "paths.yml"
config_path.parent.mkdir(parents=True, exist_ok=True)
# dataset_root は今回は使わないが設定はしておく (assets_rootと同じ場所)
paths_config = {"dataset_root": str(assets_root.resolve()), "assets_root": str(assets_root.resolve())}
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(paths_config, f)
current_log.append(f" - Saved temporary paths config to {config_path}")
yield "\n".join(current_log), None
# --- ONNX変換スクリプト実行 ---
current_log.append(f"\n🚀 Starting ONNX conversion script for model '{model_name}'...")
convert_script = SBV2_REPO_PATH / "convert_onnx.py"
if not convert_script.exists():
current_log.append(f"❌ Error: convert_onnx.py not found at '{convert_script}'. Check repository setup.")
yield "\n".join(current_log), None
return # tryブロックを抜ける
python_executable = sys.executable
command = [
python_executable,
str(convert_script.resolve()),
"--model",
str(temp_safetensors_path.resolve()) # 一時ディレクトリ内の .safetensors パス
]
current_log.append(f"\n Running command: {' '.join(command)}")
yield "\n".join(current_log), None
process_env = os.environ.copy()
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, encoding='utf-8', errors='replace',
cwd=SBV2_REPO_PATH, # スクリプトの場所で実行
env=process_env)
# ログ出力用リスト (スレッドと共有)
process_output_lines = ["\n--- Conversion Script Output ---"]
thread = threading.Thread(target=stream_process_output, args=(process, process_output_lines))
thread.start()
# 進捗表示のためのループ
while thread.is_alive():
yield "\n".join(current_log + process_output_lines), None
time.sleep(0.3) # 更新頻度
# スレッド終了待ちとプロセス終了待ち
thread.join()
try:
process.wait(timeout=12000) # 20分タイムアウト (モデルサイズにより調整)
except subprocess.TimeoutExpired:
current_log.extend(process_output_lines) # ここまでのログを追加
current_log.append("\n❌ Error: Conversion process timed out after 20 minutes.")
process.kill()
yield "\n".join(current_log), None
return # tryブロックを抜ける
# 最終的なプロセス出力を取得
final_stdout, final_stderr = process.communicate()
if final_stdout:
process_output_lines.extend(final_stdout.strip().split('\n'))
if final_stderr:
processed_stderr = []
for line in final_stderr.strip().split('\n'):
processed_line = f"stderr: {line.strip()}"
if "warning" not in line.lower() and line.strip(): # 空行と警告以外
processed_line = f"ERROR (stderr): {line.strip()}"
processed_stderr.append(processed_line)
if any(line.startswith("ERROR") for line in processed_stderr):
process_output_lines.append("--- Errors/Warnings (stderr) ---")
process_output_lines.extend(processed_stderr)
process_output_lines.append("-----------------------------")
elif processed_stderr: # 警告のみの場合
process_output_lines.append("--- Warnings (stderr) ---")
process_output_lines.extend(processed_stderr)
process_output_lines.append("------------------------")
# 全てのプロセスログをメインログに追加
current_log.extend(process_output_lines)
current_log.append("--- End Script Output ---")
current_log.append("\n-------------------------------")
# --- 結果の確認と出力ファイルのコピー ---
if process.returncode == 0:
current_log.append("✅ ONNX conversion command finished successfully.")
# 期待されるONNXファイルパス (入力と同じディレクトリ内)
expected_onnx_path_in_temp = temp_safetensors_path.with_suffix(".onnx")
if expected_onnx_path_in_temp.exists():
current_log.append(f" - Found converted ONNX file: {expected_onnx_path_in_temp.name}")
# 一時ディレクトリから永続的な出力ディレクトリにコピー
final_onnx_path = OUTPUT_DIR / expected_onnx_path_in_temp.name
try:
shutil.copy(expected_onnx_path_in_temp, final_onnx_path)
current_log.append(f" - Copied ONNX file for download to: {final_onnx_path.relative_to(SCRIPT_DIR)}")
onnx_output_path_str = str(final_onnx_path) # ダウンロード用ファイルパスを設定
except Exception as e:
current_log.append(f"❌ Error copying ONNX file to output directory: {e}")
else:
current_log.append(f"⚠️ Warning: Expected ONNX file not found at '{expected_onnx_path_in_temp.name}'. Check script output above.")
else:
current_log.append(f"❌ ONNX conversion command failed with return code {process.returncode}.")
current_log.append(" Please check the logs above for errors (especially lines starting with 'ERROR').")
# 一時ディレクトリが自動で削除される前に最終結果をyield
yield "\n".join(current_log), onnx_output_path_str
except FileNotFoundError as e:
current_log.append(f"\n❌ Error: A required command or file was not found: {e.filename}. Check Dockerfile setup and PATH.")
current_log.append(f"{e}")
yield "\n".join(current_log), None
except Exception as e:
current_log.append(f"\n❌ An unexpected error occurred: {e}")
import traceback
current_log.append(traceback.format_exc())
yield "\n".join(current_log), None
finally:
# ガベージコレクション
gc.collect()
print("Conversion function finished.") # サーバーログ用
# 最後のyieldでUIを最終状態に更新
# yield "\n".join(current_log), onnx_output_path_str # tryブロック内で既に返している
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Style-Bert-VITS2 Safetensors to ONNX Converter")
gr.Markdown(
"Upload your model's `.safetensors`, `config.json`, and `style_vectors.npy` files. "
"The application will convert the model to ONNX format, and you can download the resulting `.onnx` file."
)
gr.Markdown(
"_(Environment setup is handled automatically when this Space starts.)_"
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Upload Model Files")
safetensors_upload = gr.File(
label="Safetensors Model (.safetensors)",
file_types=[".safetensors"],
)
config_upload = gr.File(
label="Config File (config.json)",
file_types=[".json"],
)
style_vectors_upload = gr.File(
label="Style Vectors (style_vectors.npy)",
file_types=[".npy"],
)
convert_button = gr.Button("2. Convert to ONNX", variant="primary", elem_id="convert_button")
gr.Markdown("---")
gr.Markdown("### 3. Download Result")
onnx_download = gr.File(
label="ONNX Model (.onnx)",
interactive=False, # 出力専用
)
gr.Markdown(
"**Note:** Conversion can take **several minutes** (5-20+ min depending on model size and hardware). "
"Please be patient. The log on the right shows the progress."
)
with gr.Column(scale=2):
output_log = gr.Textbox(
label="Conversion Log",
lines=30, # 高さをさらに増やす
interactive=False,
autoscroll=True,
max_lines=2000 # ログが増える可能性
)
# ボタンクリック時のアクション設定
convert_button.click(
convert_safetensors_to_onnx_gradio,
inputs=[safetensors_upload, config_upload, style_vectors_upload],
outputs=[output_log, onnx_download] # ログとダウンロードファイルの2つを出力
)
# --- アプリの起動 ---
if __name__ == "__main__":
# Style-Bert-VITS2 へのパスを追加
add_sbv2_to_path()
# 出力ディレクトリ作成 (存在確認含む)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR.resolve()}")
# Gradioアプリを起動
demo.queue().launch()