File size: 18,029 Bytes
4f11bbf
ba5e3a9
ce8a201
7c6ede0
4f11bbf
0a98475
 
 
9eb42b2
aaa4bd6
7c6ede0
4f11bbf
 
 
 
 
 
 
 
6a535bc
 
c9be4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f11bbf
7c6ede0
4f11bbf
 
 
0a98475
ce8a201
4f11bbf
 
 
 
c9be4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c6ede0
4f11bbf
6a535bc
4f11bbf
 
6a535bc
d4575dc
4f11bbf
 
d4575dc
 
 
4f11bbf
 
 
d206e43
 
4f11bbf
d206e43
4f11bbf
 
 
d206e43
4f11bbf
 
 
 
 
 
21b4fcb
4f11bbf
c9be4ad
d4575dc
 
4f11bbf
d4575dc
4f11bbf
 
 
d4575dc
4f11bbf
 
 
 
 
 
 
 
c9be4ad
4f11bbf
d4575dc
6a535bc
 
 
4f11bbf
 
 
d4575dc
 
 
c9be4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4575dc
c9be4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f11bbf
d4575dc
 
4f11bbf
 
 
 
 
d4575dc
4f11bbf
 
 
d4575dc
 
4f11bbf
6a535bc
 
4f11bbf
 
 
 
6a535bc
4f11bbf
 
 
 
 
 
 
 
 
6a535bc
4f11bbf
 
d4575dc
af69235
4f11bbf
d4575dc
4f11bbf
 
d4575dc
4f11bbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1873d1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29074da
4f11bbf
9eb42b2
6a535bc
 
4f11bbf
 
9eb42b2
4f11bbf
 
6a535bc
4f11bbf
6a535bc
 
4f11bbf
0a98475
4f11bbf
 
 
 
 
 
 
 
0a98475
4f11bbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29074da
 
4f11bbf
 
29074da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
# NeMo ASRモデル、PyTorch、Gradioなどをインポート
from nemo.collections.asr.models import ASRModel
import torch
import gradio as gr
import spaces # Hugging Face Spaces ライブラリをインポート
import gc
from pathlib import Path
import os
import json
from typing import List, Tuple, Optional

# pydub をインポート (音声ファイルの長さ取得のため)
try:
    from pydub import AudioSegment
    PYDUB_AVAILABLE = True
except ImportError:
    PYDUB_AVAILABLE = False
    print("Warning: pydub not found. Audio duration cannot be determined automatically for long audio optimization.")

# グローバル設定
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
TARGET_SAMPLE_RATE = 16000
# 音声の長さに関する閾値 (秒)
LONG_AUDIO_THRESHOLD_SECONDS = 480  # 8分
VERY_LONG_AUDIO_THRESHOLD_SECONDS = 10800 # 3時間
# チャンク分割時の設定
CHUNK_LENGTH_SECONDS = 1800 # 30分
CHUNK_OVERLAP_SECONDS = 60  # 1分
# セグメント処理の設定
MAX_SEGMENT_LENGTH_SECONDS = 15  # 最大セグメント長(秒)を15秒に短縮
MAX_SEGMENT_CHARS = 100  # 最大セグメント文字数を100文字に短縮
MIN_SEGMENT_GAP_SECONDS = 0.3  # 最小セグメント間隔(秒)
# VTTファイルの最大サイズ(バイト)
MAX_VTT_SIZE_BYTES = 10 * 1024 * 1024  # 10MB
# 文の区切り文字
SENTENCE_ENDINGS = ['.', '!', '?', '。', '!', '?']
SENTENCE_PAUSES = [',', '、', ';', ';', ':', ':']

device = "cuda" if torch.cuda.is_available() else "cpu" # スクリプト起動時のデバイス検出

# モデルの初期化 (グローバルに一度だけ行う)
print(f"Initializing ASR model: {MODEL_NAME}")
print(f"Initial device check: {device}") # 初期デバイス確認
model = ASRModel.from_pretrained(model_name=MODEL_NAME)
model.eval()
# 初期状態ではモデルをCPUに置いておく (GPU関数内で .to(device) する)
model.cpu()
print("ASR model initialized and moved to CPU.")

def find_natural_break_point(text: str, max_length: int) -> int:
    """テキスト内で自然な区切り点を探す"""
    if len(text) <= max_length:
        return len(text)
    
    # 文末で区切る
    for i in range(max_length, 0, -1):
        if i < len(text) and text[i] in SENTENCE_ENDINGS:
            return i + 1
    
    # 文の区切りで区切る
    for i in range(max_length, 0, -1):
        if i < len(text) and text[i] in SENTENCE_PAUSES:
            return i + 1
    
    # スペースで区切る
    for i in range(max_length, 0, -1):
        if i < len(text) and text[i].isspace():
            return i + 1
    
    # それでも見つからない場合は最大長で区切る
    return max_length

def split_segment(segment: dict, max_length_seconds: float, max_chars: int) -> List[dict]:
    """セグメントを自然な区切りで分割する"""
    if (segment['end'] - segment['start']) <= max_length_seconds and len(segment['segment']) <= max_chars:
        return [segment]
    
    result = []
    current_text = segment['segment']
    current_start = segment['start']
    total_duration = segment['end'] - segment['start']
    
    while current_text:
        # 文字数に基づく分割点を探す
        break_point = find_natural_break_point(current_text, max_chars)
        
        # 時間に基づく分割点を計算
        text_ratio = break_point / len(segment['segment'])
        segment_duration = total_duration * text_ratio
        
        # 分割点が最大長を超えないように調整
        if segment_duration > max_length_seconds:
            time_ratio = max_length_seconds / total_duration
            break_point = int(len(segment['segment']) * time_ratio)
            break_point = find_natural_break_point(current_text, break_point)
            segment_duration = max_length_seconds
        
        # 新しいセグメントを作成
        new_segment = {
            'start': current_start,
            'end': current_start + segment_duration,
            'segment': current_text[:break_point].strip()
        }
        result.append(new_segment)
        
        # 残りのテキストと開始時間を更新
        current_text = current_text[break_point:].strip()
        current_start = new_segment['end']
    
    return result

def transcribe_audio_core(
    audio_path: str,
    duration_sec: float,
    current_device: str # 実行時のデバイスを引数で受け取る
) -> Tuple[Optional[List], Optional[List], Optional[List]]:
    """
    音声ファイルを文字起こしし、タイムスタンプを取得する(コア処理)。
    この関数は実際にGPU上で実行されることを想定。
    """
    long_audio_settings_applied = False
    try:
        gr.Info(f"Starting transcription on {current_device} for: {Path(audio_path).name}", duration=3)
        
        if current_device == 'cuda':
            torch.cuda.empty_cache()
            gc.collect()
            print(f"CUDA memory before loading model to GPU: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

        # モデルを実行デバイスに移動
        model.to(current_device)
        model.to(torch.float32) # 推論前にfloat32に戻す (bfloat16は後段)

        if current_device == 'cuda':
            print(f"CUDA memory after loading model to GPU (float32): {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

        # 長尺音声用の設定 (閾値を超え、かつpydubで長さが取得できた場合)
        if PYDUB_AVAILABLE and duration_sec > LONG_AUDIO_THRESHOLD_SECONDS:
            gr.Info(f"Audio duration ({duration_sec:.2f}s) exceeds threshold. Applying long audio settings.", duration=3)
            try:
                print("Applying long audio settings: Local Attention and Chunking.")
                model.change_attention_model("rel_pos_local_attn", [128, 128])  # 256,256から128,128に変更
                model.change_subsampling_conv_chunking_factor(1)
                long_audio_settings_applied = True
                print("Successfully applied long audio settings.")
            except Exception as setting_e:
                warning_msg = f"Warning: Failed to apply long audio settings: {setting_e}"
                print(warning_msg)
                gr.Warning(warning_msg, duration=5)
        
        # bfloat16への変換 (CUDAの場合のみ)
        if current_device == 'cuda':
            print("Converting model to bfloat16 for inference on CUDA.")
            model.to(torch.bfloat16)
            print(f"CUDA memory after converting to bfloat16: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

        # 文字起こし実行
        print(f"Transcribing {audio_path}...")
        output = model.transcribe([audio_path], timestamps=True, batch_size=2)  # バッチサイズを2に設定
        print("Transcription API call finished.")

        if not output or not isinstance(output, list) or not output[0] or \
           not hasattr(output[0], 'timestamp') or not output[0].timestamp or \
           'segment' not in output[0].timestamp:
            error_msg = "Transcription failed or produced unexpected output format."
            print(error_msg)
            gr.Error(error_msg, duration=5)
            return None, None, None

        segment_timestamps = output[0].timestamp['segment']
        
        # セグメントの前処理:より適切なセグメント分割
        processed_segments = []
        current_segment = None
        
        for ts in segment_timestamps:
            if current_segment is None:
                current_segment = ts
            else:
                # セグメント結合の条件を厳格化
                time_gap = ts['start'] - current_segment['end']
                current_text = current_segment['segment']
                next_text = ts['segment']
                
                # 結合条件のチェック
                should_merge = (
                    time_gap < MIN_SEGMENT_GAP_SECONDS and  # 時間間隔が短い
                    len(current_text) + len(next_text) < MAX_SEGMENT_CHARS and  # 文字数制限
                    (current_segment['end'] - current_segment['start']) < MAX_SEGMENT_LENGTH_SECONDS and  # 現在のセグメントが短い
                    (ts['end'] - ts['start']) < MAX_SEGMENT_LENGTH_SECONDS and  # 次のセグメントが短い
                    not any(current_text.strip().endswith(p) for p in SENTENCE_ENDINGS)  # 文の区切りでない
                )
                
                if should_merge:
                    current_segment['end'] = ts['end']
                    current_segment['segment'] += ' ' + ts['segment']
                else:
                    # 現在のセグメントを分割
                    split_segments = split_segment(current_segment, MAX_SEGMENT_LENGTH_SECONDS, MAX_SEGMENT_CHARS)
                    processed_segments.extend(split_segments)
                    current_segment = ts
        
        if current_segment is not None:
            # 最後のセグメントも分割
            split_segments = split_segment(current_segment, MAX_SEGMENT_LENGTH_SECONDS, MAX_SEGMENT_CHARS)
            processed_segments.extend(split_segments)
        
        # 処理済みセグメントからデータを生成
        vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in processed_segments]
        raw_times_data = [[ts['start'], ts['end']] for ts in processed_segments]
        
        # 単語タイムスタンプの処理を改善
        word_timestamps_raw = output[0].timestamp.get("word", [])
        word_vis_data = []
        
        for w in word_timestamps_raw:
            if not isinstance(w, dict) or not all(k in w for k in ['start', 'end', 'word']):
                continue
                
            # 単語のタイムスタンプを最も近いセグメントに割り当て
            word_start = float(w['start'])
            word_end = float(w['end'])
            
            # 単語が完全に含まれるセグメントを探す
            for seg in processed_segments:
                if word_start >= seg['start'] - 0.05 and word_end <= seg['end'] + 0.05:
                    word_vis_data.append([f"{word_start:.2f}", f"{word_end:.2f}", w["word"]])
                    break
        
        gr.Info("Transcription successful!", duration=3)
        return vis_data, raw_times_data, word_vis_data

    except torch.cuda.OutOfMemoryError as oom_e:
        error_msg = f"CUDA out of memory during transcription: {oom_e}. Try a shorter audio file or a more powerful GPU."
        print(error_msg)
        gr.Error(error_msg, duration=None) # 長く表示
        return None, None, None
    except Exception as e:
        error_msg = f"Error during transcription: {e}"
        print(error_msg)
        gr.Error(error_msg, duration=None) # 長く表示
        return None, None, None
    finally:
        print("Starting transcription cleanup...")
        if long_audio_settings_applied:
            try:
                print("Reverting long audio settings...")
                model.change_attention_model("rel_pos") # 元のAttentionに戻す
                model.change_subsampling_conv_chunking_factor(-1) # 元のChunking Factorに戻す
                print("Successfully reverted long audio settings.")
            except Exception as revert_e:
                warning_msg = f"Warning: Failed to revert long audio settings: {revert_e}"
                print(warning_msg)
                gr.Warning(warning_msg, duration=5)
        
        # モデルをCPUに戻し、CUDAキャッシュをクリア
        model.cpu() # 必ずCPUに戻す
        print("Model moved to CPU.")
        if current_device == 'cuda': # current_device を使う
            gc.collect()
            torch.cuda.empty_cache()
            print("CUDA cache cleared.")
        print("Transcription cleanup finished.")

@spaces.GPU(duration=60) # GPUリソースを要求し、タイムアウトを60秒に設定
def process_audio_file(audio_filepath: str) -> dict: # Gradioから渡されるのは一時ファイルのパス
    """
    アップロードされた音声ファイルを処理し、文字起こし結果をJSONで返す。
    この関数がGradioのコールバックとなり、GPU環境で実行される。
    """
    # この関数が呼ばれた時点でGPUが利用可能になっているはず (Hugging Face Spacesの場合)
    # なので、再度デバイスチェックを行う
    current_processing_device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device check inside @spaces.GPU function: {current_processing_device}")
    gr.Info(f"Processing on: {current_processing_device}", duration=3)

    if not PYDUB_AVAILABLE:
        gr.Warning("pydub library is not available. Audio duration cannot be determined, long audio optimizations might not be applied correctly.", duration=5)
        duration_sec = 0 # 長さが不明な場合は0とする
    else:
        try:
            gr.Info(f"Loading audio file: {Path(audio_filepath).name}", duration=2)
            audio = AudioSegment.from_file(audio_filepath)
            duration_sec = audio.duration_seconds
            print(f"Audio duration: {duration_sec:.2f} seconds.")
        except Exception as e:
            error_msg = f"Failed to load audio or get duration using pydub: {e}"
            print(error_msg)
            gr.Error(error_msg, duration=5)
            # pydubが失敗しても、NeMoは処理を試みることができるので、duration_sec = 0 で続行
            duration_sec = 0

    # 文字起こしコア処理を呼び出し
    vis_data, raw_times_data, word_vis_data = transcribe_audio_core(audio_filepath, duration_sec, current_processing_device)

    if not vis_data:
        # transcribe_audio_core内でエラー通知はされているはず
        return {"error": "Transcription failed. Check logs and messages for details."}

    # 結果をJSON形式で返却 (ユーザー指定の形式に合わせる)
    output_segments = []
    word_idx = 0
    for seg_data in vis_data:
        s_start_time = float(seg_data[0])
        s_end_time = float(seg_data[1])
        s_text = seg_data[2]
        segment_words_list: List[dict] = []
        
        if word_vis_data: # word_vis_data が存在する場合のみ処理
            temp_current_word_idx = word_idx
            while temp_current_word_idx < len(word_vis_data):
                w_data = word_vis_data[temp_current_word_idx]
                w_start_time = float(w_data[0])
                w_end_time = float(w_data[1])
                
                # 単語がセグメントの範囲内にあるかチェック (多少の誤差を許容)
                if w_start_time >= s_start_time and w_end_time <= s_end_time + 0.1:
                    segment_words_list.append({
                        "start": w_start_time,
                        "end": w_end_time,
                        "word": w_data[2]
                    })
                    temp_current_word_idx += 1
                elif w_start_time < s_start_time: # 単語がセグメントより前に開始している場合はスキップ
                    temp_current_word_idx += 1
                elif w_start_time > s_end_time: # 単語がセグメントより後に開始している場合はループを抜ける
                    break
                else: # その他のケース (ほぼありえないが念のため)
                    temp_current_word_idx += 1
            word_idx = temp_current_word_idx
        
        output_segments.append({
            "start": s_start_time,
            "end": s_end_time,
            "text": s_text,
            "words": segment_words_list
        })
    
    result = {"segments": output_segments}
    
    return result

# Gradioインターフェースの設定
with gr.Blocks() as demo:
    gr.Markdown("# GPU Transcription Service (Improved)")
    gr.Markdown("Upload an audio file for transcription. Processing will use GPU if available on the server.")
    
    file_input = gr.File(label="Upload Audio File", type="filepath") # type="filepath" を明示
    output_json = gr.JSON(label="Transcription Result")
    
    file_input.change( # ファイルがアップロード/変更されたら実行
        fn=process_audio_file,
        inputs=[file_input],
        outputs=[output_json]
    )
    gr.Examples(
        examples=[
            [os.path.join(os.path.dirname(__file__), "audio_example.wav") if os.path.exists(os.path.join(os.path.dirname(__file__), "audio_example.wav")) else "https://www.kozco.com/tech/piano2-CoolEdit.mp3"]
        ],
        inputs=[file_input],
        label="Example Audio (Click to load)"
    )

if __name__ == "__main__":
    # ダミーの音声ファイルを作成 (Examples用、もし存在しなければ)
    example_dir = os.path.dirname(__file__)
    dummy_audio_path = os.path.join(example_dir, "audio_example.wav")
    if not os.path.exists(dummy_audio_path) and PYDUB_AVAILABLE:
        try:
            print(f"Creating a dummy audio file for example: {dummy_audio_path}")
            silence = AudioSegment.silent(duration=1000) # 1秒の無音
            # 簡単な音を追加 (pydubの機能で)
            tone1 = AudioSegment.sine(440, duration=200) # A4
            tone2 = AudioSegment.sine(880, duration=200) # A5
            dummy_segment = silence + tone1 + silence[:200] + tone2 + silence
            dummy_segment.export(dummy_audio_path, format="wav")
            print("Dummy audio file created.")
        except Exception as e:
            print(f"Could not create dummy audio file: {e}")
    elif not PYDUB_AVAILABLE:
        print("Skipping dummy audio file creation as pydub is not available.")
    
    print("Launching Gradio demo...")
    demo.queue()  # リクエストキューを有効化
    demo.launch(show_error=True)  # エラー詳細を表示