File size: 17,324 Bytes
0d64808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f523b
0d64808
 
 
 
 
 
 
68f523b
 
 
 
 
0d64808
68f523b
 
0d64808
 
 
 
 
68f523b
0d64808
68f523b
 
 
 
 
 
 
 
 
0d64808
 
 
68f523b
0d64808
 
68f523b
0d64808
 
 
68f523b
0d64808
 
 
 
 
 
68f523b
0d64808
 
68f523b
 
 
 
 
 
 
 
0d64808
 
68f523b
0d64808
 
 
 
 
 
68f523b
 
0d64808
68f523b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d64808
 
 
 
 
68f523b
 
0d64808
 
68f523b
0d64808
 
 
68f523b
0d64808
 
68f523b
0d64808
 
 
 
 
 
 
 
68f523b
0d64808
68f523b
0d64808
 
 
 
 
 
 
 
 
68f523b
0d64808
 
 
 
 
 
68f523b
0d64808
68f523b
0d64808
 
68f523b
0d64808
 
68f523b
 
0d64808
 
 
 
 
 
 
 
 
 
 
68f523b
0d64808
 
68f523b
 
0d64808
68f523b
 
 
 
 
 
0d64808
 
 
68f523b
0d64808
 
 
 
68f523b
0d64808
 
 
 
68f523b
0d64808
 
68f523b
 
 
 
 
 
0d64808
68f523b
0d64808
68f523b
 
0d64808
68f523b
0d64808
 
 
68f523b
0d64808
 
 
68f523b
0d64808
 
 
 
68f523b
0d64808
 
68f523b
 
0d64808
 
 
 
 
 
68f523b
 
 
 
 
0d64808
 
 
 
68f523b
0d64808
68f523b
0d64808
 
68f523b
 
 
 
 
 
 
 
 
 
 
0d64808
68f523b
0d64808
68f523b
0d64808
 
 
68f523b
 
0d64808
 
 
 
 
68f523b
0d64808
 
68f523b
0d64808
 
 
 
 
68f523b
0d64808
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import gradio as gr
import subprocess
import os
import sys
import yaml
from pathlib import Path
import time
import threading
import tempfile
import shutil
import gc # ガベージコレクション用

# --- 定数 ---
# Dockerfile内でクローンされるパスに合わせる
SCRIPT_DIR = Path(__file__).parent
SBV2_REPO_PATH = SCRIPT_DIR / "Style-Bert-VITS2"
# ダウンロード用ファイルの一時置き場 (コンテナ内に作成)
OUTPUT_DIR = SCRIPT_DIR / "outputs"

# --- ヘルパー関数 ---
def add_sbv2_to_path():
    """Style-Bert-VITS2リポジトリのパスを sys.path に追加"""
    repo_path_str = str(SBV2_REPO_PATH.resolve())
    if SBV2_REPO_PATH.exists() and repo_path_str not in sys.path:
        sys.path.insert(0, repo_path_str)
        print(f"Added {repo_path_str} to sys.path")
    elif not SBV2_REPO_PATH.exists():
        print(f"Warning: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}")

def stream_process_output(process, log_list):
    """サブプロセスの標準出力/エラーをリアルタイムでリストに追加"""
    try:
        if process.stdout:
            for line in iter(process.stdout.readline, ''):
                log_list.append(line.strip()) # 余分な改行を削除
        if process.stderr:
             for line in iter(process.stderr.readline, ''):
                processed_line = f"stderr: {line.strip()}"
                # 警告はそのまま、他はエラーとして強調 (任意)
                if "warning" not in line.lower():
                     processed_line = f"ERROR (stderr): {line.strip()}"
                log_list.append(processed_line)
    except Exception as e:
        log_list.append(f"Error reading process stream: {e}")

# --- Gradio アプリのバックエンド関数 ---
def convert_safetensors_to_onnx_gradio(
    safetensors_file_obj,
    config_file_obj,
    style_vectors_file_obj
    ): # gr.Progress は削除
    """
    アップロードされたSafetensors, config.json, style_vectors.npy を使って
    ONNXに変換し、結果をダウンロード可能にする。
    """
    log = ["Starting ONNX conversion..."]
    # 初期状態ではダウンロードファイルは空
    yield "\n".join(log), None

    # --- ファイルアップロードの検証 ---
    if safetensors_file_obj is None:
        log.append("❌ Error: Safetensors file is missing. Please upload the .safetensors file.")
        yield "\n".join(log), None
        return
    if config_file_obj is None:
        log.append("❌ Error: config.json file is missing. Please upload the config.json file.")
        yield "\n".join(log), None
        return
    if style_vectors_file_obj is None:
        log.append("❌ Error: style_vectors.npy file is missing. Please upload the style_vectors.npy file.")
        yield "\n".join(log), None
        return

    # --- Style-Bert-VITS2 パスの確認 ---
    add_sbv2_to_path()
    if not SBV2_REPO_PATH.exists():
        log.append(f"❌ Error: Style-Bert-VITS2 repository not found at {SBV2_REPO_PATH}. Check Space build logs.")
        yield "\n".join(log), None
        return

    # --- 出力ディレクトリ作成 ---
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    onnx_output_path_str = None # 最終的なONNXファイルパス (文字列)
    current_log = log[:] # ログリストをコピー

    try:
        # --- 一時ディレクトリを作成して処理 ---
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_dir_path = Path(temp_dir)
            current_log.append(f"📁 Created temporary directory: {temp_dir_path}")
            yield "\n".join(current_log), None # UI更新

            # --- SBV2が期待するディレクトリ構造を一時ディレクトリ内に作成 ---
            # モデル名を .safetensors ファイル名から取得 (拡張子なし)
            safetensors_filename = Path(safetensors_file_obj.name).name
            if not safetensors_filename.lower().endswith(".safetensors"):
                 current_log.append(f"❌ Error: Invalid safetensors filename: {safetensors_filename}")
                 yield "\n".join(current_log), None
                 return
            model_name = Path(safetensors_filename).stem # 拡張子を除いた部分

            # assets_root を一時ディレクトリ自体にする
            assets_root = temp_dir_path
            # assets_root の下に model_name のディレクトリを作成
            model_dir_in_temp = assets_root / model_name
            model_dir_in_temp.mkdir(exist_ok=True)
            current_log.append(f"   - Created model directory: {model_dir_in_temp.relative_to(assets_root)}")
            yield "\n".join(current_log), None

            # --- 3つのファイルを model_dir_in_temp にコピー ---
            files_to_copy = {
                "safetensors": safetensors_file_obj,
                "config.json": config_file_obj,
                "style_vectors.npy": style_vectors_file_obj,
            }
            copied_paths = {}

            for file_key, file_obj in files_to_copy.items():
                original_filename = Path(file_obj.name).name
                # ファイル名の基本的な検証 (サニタイズはより厳密に行うことも可能)
                if "/" in original_filename or "\\" in original_filename or ".." in original_filename:
                     current_log.append(f"❌ Error: Invalid characters found in filename: {original_filename}")
                     yield "\n".join(current_log), None
                     return # tryブロックを抜ける
                # 期待されるファイル名と一致しているか確認 (config と style_vectors)
                if file_key == "config.json" and original_filename.lower() != "config.json":
                    current_log.append(f"⚠️ Warning: Uploaded JSON file name is '{original_filename}', expected 'config.json'. Using uploaded name.")
                if file_key == "style_vectors.npy" and original_filename.lower() != "style_vectors.npy":
                    current_log.append(f"⚠️ Warning: Uploaded NPY file name is '{original_filename}', expected 'style_vectors.npy'. Using uploaded name.")

                destination_path = model_dir_in_temp / original_filename
                try:
                    shutil.copy(file_obj.name, destination_path)
                    current_log.append(f"   - Copied '{original_filename}' to model directory.")
                    # .safetensorsファイルのパスを保存しておく
                    if file_key == "safetensors":
                        copied_paths["safetensors"] = destination_path
                except Exception as e:
                     current_log.append(f"❌ Error copying file '{original_filename}': {e}")
                     yield "\n".join(current_log), None
                     return # tryブロックを抜ける
                yield "\n".join(current_log), None # 各ファイルコピー後にUI更新

            # safetensorsファイルがコピーされたか確認
            temp_safetensors_path = copied_paths.get("safetensors")
            if not temp_safetensors_path:
                current_log.append("❌ Error: Failed to locate the copied safetensors file in the temporary directory.")
                yield "\n".join(current_log), None
                return

            current_log.append(f"✅ All required files copied to temporary model directory.")
            current_log.append(f"   - Using temporary assets_root: {assets_root}")
            yield "\n".join(current_log), None

            # --- paths.yml を一時的に設定 ---
            config_path = SBV2_REPO_PATH / "configs" / "paths.yml"
            config_path.parent.mkdir(parents=True, exist_ok=True)
            # dataset_root は今回は使わないが設定はしておく (assets_rootと同じ場所)
            paths_config = {"dataset_root": str(assets_root.resolve()), "assets_root": str(assets_root.resolve())}
            with open(config_path, "w", encoding="utf-8") as f:
                yaml.dump(paths_config, f)
            current_log.append(f"   - Saved temporary paths config to {config_path}")
            yield "\n".join(current_log), None

            # --- ONNX変換スクリプト実行 ---
            current_log.append(f"\n🚀 Starting ONNX conversion script for model '{model_name}'...")
            convert_script = SBV2_REPO_PATH / "convert_onnx.py"
            if not convert_script.exists():
                 current_log.append(f"❌ Error: convert_onnx.py not found at '{convert_script}'. Check repository setup.")
                 yield "\n".join(current_log), None
                 return # tryブロックを抜ける

            python_executable = sys.executable
            command = [
                python_executable,
                str(convert_script.resolve()),
                "--model",
                str(temp_safetensors_path.resolve()) # 一時ディレクトリ内の .safetensors パス
            ]
            current_log.append(f"\n   Running command: {' '.join(command)}")
            yield "\n".join(current_log), None

            process_env = os.environ.copy()
            process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                       text=True, encoding='utf-8', errors='replace',
                                       cwd=SBV2_REPO_PATH, # スクリプトの場所で実行
                                       env=process_env)

            # ログ出力用リスト (スレッドと共有)
            process_output_lines = ["\n--- Conversion Script Output ---"]
            thread = threading.Thread(target=stream_process_output, args=(process, process_output_lines))
            thread.start()

            # 進捗表示のためのループ
            while thread.is_alive():
                yield "\n".join(current_log + process_output_lines), None
                time.sleep(0.3) # 更新頻度

            # スレッド終了待ちとプロセス終了待ち
            thread.join()
            try:
                 process.wait(timeout=12000) # 20分タイムアウト (モデルサイズにより調整)
            except subprocess.TimeoutExpired:
                 current_log.extend(process_output_lines) # ここまでのログを追加
                 current_log.append("\n❌ Error: Conversion process timed out after 20 minutes.")
                 process.kill()
                 yield "\n".join(current_log), None
                 return # tryブロックを抜ける

            # 最終的なプロセス出力を取得
            final_stdout, final_stderr = process.communicate()
            if final_stdout:
                 process_output_lines.extend(final_stdout.strip().split('\n'))
            if final_stderr:
                 processed_stderr = []
                 for line in final_stderr.strip().split('\n'):
                     processed_line = f"stderr: {line.strip()}"
                     if "warning" not in line.lower() and line.strip(): # 空行と警告以外
                         processed_line = f"ERROR (stderr): {line.strip()}"
                     processed_stderr.append(processed_line)
                 if any(line.startswith("ERROR") for line in processed_stderr):
                      process_output_lines.append("--- Errors/Warnings (stderr) ---")
                      process_output_lines.extend(processed_stderr)
                      process_output_lines.append("-----------------------------")
                 elif processed_stderr: # 警告のみの場合
                      process_output_lines.append("--- Warnings (stderr) ---")
                      process_output_lines.extend(processed_stderr)
                      process_output_lines.append("------------------------")


            # 全てのプロセスログをメインログに追加
            current_log.extend(process_output_lines)
            current_log.append("--- End Script Output ---")
            current_log.append("\n-------------------------------")

            # --- 結果の確認と出力ファイルのコピー ---
            if process.returncode == 0:
                current_log.append("✅ ONNX conversion command finished successfully.")
                # 期待されるONNXファイルパス (入力と同じディレクトリ内)
                expected_onnx_path_in_temp = temp_safetensors_path.with_suffix(".onnx")

                if expected_onnx_path_in_temp.exists():
                    current_log.append(f"   - Found converted ONNX file: {expected_onnx_path_in_temp.name}")
                    # 一時ディレクトリから永続的な出力ディレクトリにコピー
                    final_onnx_path = OUTPUT_DIR / expected_onnx_path_in_temp.name
                    try:
                        shutil.copy(expected_onnx_path_in_temp, final_onnx_path)
                        current_log.append(f"   - Copied ONNX file for download to: {final_onnx_path.relative_to(SCRIPT_DIR)}")
                        onnx_output_path_str = str(final_onnx_path) # ダウンロード用ファイルパスを設定
                    except Exception as e:
                        current_log.append(f"❌ Error copying ONNX file to output directory: {e}")
                else:
                    current_log.append(f"⚠️ Warning: Expected ONNX file not found at '{expected_onnx_path_in_temp.name}'. Check script output above.")
            else:
                current_log.append(f"❌ ONNX conversion command failed with return code {process.returncode}.")
                current_log.append("   Please check the logs above for errors (especially lines starting with 'ERROR').")

            # 一時ディレクトリが自動で削除される前に最終結果をyield
            yield "\n".join(current_log), onnx_output_path_str

    except FileNotFoundError as e:
        current_log.append(f"\n❌ Error: A required command or file was not found: {e.filename}. Check Dockerfile setup and PATH.")
        current_log.append(f"{e}")
        yield "\n".join(current_log), None
    except Exception as e:
        current_log.append(f"\n❌ An unexpected error occurred: {e}")
        import traceback
        current_log.append(traceback.format_exc())
        yield "\n".join(current_log), None
    finally:
        # ガベージコレクション
        gc.collect()
        print("Conversion function finished.") # サーバーログ用
        # 最後のyieldでUIを最終状態に更新
        # yield "\n".join(current_log), onnx_output_path_str # tryブロック内で既に返している


# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Style-Bert-VITS2 Safetensors to ONNX Converter")
    gr.Markdown(
        "Upload your model's `.safetensors`, `config.json`, and `style_vectors.npy` files. "
        "The application will convert the model to ONNX format, and you can download the resulting `.onnx` file."
    )
    gr.Markdown(
        "_(Environment setup is handled automatically when this Space starts.)_"
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 1. Upload Model Files")
            safetensors_upload = gr.File(
                label="Safetensors Model (.safetensors)",
                file_types=[".safetensors"],
            )
            config_upload = gr.File(
                label="Config File (config.json)",
                file_types=[".json"],
            )
            style_vectors_upload = gr.File(
                label="Style Vectors (style_vectors.npy)",
                file_types=[".npy"],
            )

            convert_button = gr.Button("2. Convert to ONNX", variant="primary", elem_id="convert_button")

            gr.Markdown("---")
            gr.Markdown("### 3. Download Result")
            onnx_download = gr.File(
                label="ONNX Model (.onnx)",
                interactive=False, # 出力専用
            )
            gr.Markdown(
                "**Note:** Conversion can take **several minutes** (5-20+ min depending on model size and hardware). "
                "Please be patient. The log on the right shows the progress."
            )

        with gr.Column(scale=2):
            output_log = gr.Textbox(
                label="Conversion Log",
                lines=30, # 高さをさらに増やす
                interactive=False,
                autoscroll=True,
                max_lines=2000 # ログが増える可能性
            )

    # ボタンクリック時のアクション設定
    convert_button.click(
        convert_safetensors_to_onnx_gradio,
        inputs=[safetensors_upload, config_upload, style_vectors_upload],
        outputs=[output_log, onnx_download] # ログとダウンロードファイルの2つを出力
    )

# --- アプリの起動 ---
if __name__ == "__main__":
    # Style-Bert-VITS2 へのパスを追加
    add_sbv2_to_path()
    # 出力ディレクトリ作成 (存在確認含む)
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    print(f"Output directory: {OUTPUT_DIR.resolve()}")

    # Gradioアプリを起動
    demo.queue().launch()