Gapeleon commited on
Commit
df84307
·
verified ·
1 Parent(s): fee2275

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +427 -0
  2. packages.txt +2 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nemo.collections.asr.models import ASRModel
2
+ import torch
3
+ import gradio as gr
4
+ import spaces
5
+ import gc
6
+ from pathlib import Path
7
+ from pydub import AudioSegment
8
+ import numpy as np
9
+ import os
10
+ import tempfile
11
+ import gradio.themes as gr_themes
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
15
+
16
+ model = ASRModel.from_pretrained(model_name=MODEL_NAME)
17
+ model.eval()
18
+
19
+ def get_audio_segment(audio_path, start_second, end_second):
20
+ if not audio_path or not Path(audio_path).exists():
21
+ print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.")
22
+ return None
23
+ try:
24
+ start_ms = int(start_second * 1000)
25
+ end_ms = int(end_second * 1000)
26
+ start_ms = max(0, start_ms)
27
+ if end_ms <= start_ms:
28
+ print(f"Warning: End time ({end_second}s) is not after start time ({start_second}s). Adjusting end time.")
29
+ end_ms = start_ms + 100
30
+ # Unconditionally use pydub for all supported types (.mp3, .wav, .mp4, etc)
31
+ audio = AudioSegment.from_file(audio_path) # pydub/ffmpeg supports most formats!
32
+ clipped_audio = audio[start_ms:end_ms]
33
+ samples = np.array(clipped_audio.get_array_of_samples())
34
+ if clipped_audio.channels == 2:
35
+ samples = samples.reshape((-1, 2)).mean(axis=1).astype(samples.dtype)
36
+ frame_rate = clipped_audio.frame_rate
37
+ if frame_rate <= 0:
38
+ print(f"Warning: Invalid frame rate ({frame_rate}) detected for clipped audio.")
39
+ frame_rate = audio.frame_rate
40
+ if samples.size == 0:
41
+ print(f"Warning: Clipped audio resulted in empty samples array ({start_second}s to {end_second}s).")
42
+ return None
43
+ return (frame_rate, samples)
44
+ except FileNotFoundError:
45
+ print(f"Error: Audio file not found at path: {audio_path}")
46
+ return None
47
+ except Exception as e:
48
+ print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
49
+ return None
50
+
51
+ def seconds_to_srt_ts(seconds: float):
52
+ hours = int(seconds // 3600)
53
+ minutes = int((seconds % 3600) // 60)
54
+ secs = int(seconds % 60)
55
+ ms = int((seconds - int(seconds)) * 1000)
56
+ return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}"
57
+
58
+ @spaces.GPU
59
+ def get_transcripts_and_raw_times(file_path):
60
+ if not file_path:
61
+ gr.Error("No file path provided for transcription.", duration=None)
62
+ return [], [], None, gr.DownloadButton(visible=False)
63
+
64
+ vis_data = [["N/A", "N/A", "Processing failed"]]
65
+ raw_times_data = [[0.0, 0.0]]
66
+ temp_files = [] # To track all temporary files created
67
+ srt_file_path = None
68
+ original_path_name = Path(file_path).name
69
+
70
+ try:
71
+ try:
72
+ gr.Info(f"Loading file: {original_path_name}", duration=2)
73
+ # pydub/ffmpeg supports .mp3, .wav, .mp4, .m4a, .aac, etc.
74
+ audio = AudioSegment.from_file(file_path) # pydub handles mp4 via ffmpeg!
75
+ except Exception as load_e:
76
+ gr.Error(f"Failed to load file {original_path_name}: {load_e}", duration=None)
77
+ return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False)
78
+
79
+ # Process audio for transcription
80
+ try:
81
+ target_sr = 16000
82
+ if audio.frame_rate != target_sr:
83
+ audio = audio.set_frame_rate(target_sr)
84
+ if audio.channels == 2:
85
+ audio = audio.set_channels(1)
86
+ elif audio.channels > 2:
87
+ gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
88
+ return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False)
89
+ except Exception as process_e:
90
+ gr.Error(f"Failed to process audio: {process_e}", duration=None)
91
+ return [["Error", "Error", "Audio processing failed"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False)
92
+
93
+ # Check if audio is longer than chunk size
94
+ audio_length_sec = len(audio) / 1000.0 # pydub uses milliseconds
95
+
96
+ # Configuration for chunking - 10 minutes works on a 24GB RTX3090.
97
+ chunk_size_sec = 10 * 60
98
+ overlap_sec = 5 # 5 seconds overlap between chunks
99
+
100
+ # Convert to milliseconds for pydub
101
+ chunk_size_ms = chunk_size_sec * 1000
102
+ overlap_ms = overlap_sec * 1000
103
+
104
+ # Determine if we need chunking
105
+ need_chunking = audio_length_sec > chunk_size_sec
106
+
107
+ # Initialize list to hold ALL segments from ALL chunks
108
+ all_segments = []
109
+
110
+ if need_chunking:
111
+ # Calculate number of chunks
112
+ total_chunks = max(1, int(np.ceil(audio_length_sec / chunk_size_sec)))
113
+ print(f"Audio length: {audio_length_sec:.2f} seconds ({audio_length_sec/60:.2f} minutes)")
114
+ print(f"Chunk size: {chunk_size_sec} seconds ({chunk_size_sec/60:.2f} minutes)")
115
+ print(f"Total chunks needed: {total_chunks}")
116
+
117
+ gr.Info(f"Audio is {audio_length_sec/60:.1f} minutes long. Processing in {total_chunks} chunks...", duration=3)
118
+
119
+ # Process each chunk
120
+ for i in range(total_chunks):
121
+ # Calculate chunk boundaries in milliseconds
122
+ chunk_start_ms = i * chunk_size_ms
123
+ chunk_end_ms = min(len(audio), (i + 1) * chunk_size_ms)
124
+
125
+ # Add overlap except for first and last chunks
126
+ if i > 0:
127
+ chunk_start_ms -= overlap_ms # Extend start earlier
128
+
129
+ if i < total_chunks - 1 and chunk_end_ms + overlap_ms <= len(audio):
130
+ chunk_end_ms += overlap_ms # Extend end later
131
+
132
+ # Calculate the effective region (excluding overlaps)
133
+ effective_start_ms = chunk_start_ms
134
+ effective_end_ms = chunk_end_ms
135
+
136
+ # Don't count overlap in effective region
137
+ if i > 0:
138
+ effective_start_ms += overlap_ms
139
+ if i < total_chunks - 1:
140
+ effective_end_ms -= overlap_ms
141
+
142
+ # Convert to seconds for logging
143
+ chunk_start_sec = chunk_start_ms / 1000
144
+ chunk_end_sec = chunk_end_ms / 1000
145
+ effective_start_sec = effective_start_ms / 1000
146
+ effective_end_sec = effective_end_ms / 1000
147
+
148
+ print(f"Chunk {i+1} boundaries: {chunk_start_sec:.2f}s - {chunk_end_sec:.2f}s")
149
+ print(f"Chunk {i+1} effective: {effective_start_sec:.2f}s - {effective_end_sec:.2f}s")
150
+
151
+ # Extract chunk
152
+ chunk = audio[chunk_start_ms:chunk_end_ms]
153
+
154
+ # Save chunk to temporary file
155
+ chunk_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
156
+ chunk.export(chunk_file.name, format="wav")
157
+ temp_files.append(chunk_file.name)
158
+ chunk_file.close()
159
+
160
+ try:
161
+ # Move model to GPU at the latest possible time
162
+ model.to(device)
163
+
164
+ # Process chunk
165
+ chunk_duration = (chunk_end_ms - chunk_start_ms) / 1000.0
166
+ gr.Info(f"Transcribing chunk {i+1}/{total_chunks} ({chunk_start_sec:.1f}s - {chunk_end_sec:.1f}s, {chunk_duration:.1f}s)...", duration=2)
167
+
168
+ output = model.transcribe([chunk_file.name], timestamps=True)
169
+
170
+ # Move model back to CPU immediately after processing
171
+ if device == 'cuda':
172
+ model.cpu()
173
+
174
+ if (output and isinstance(output, list) and output[0] and
175
+ hasattr(output[0], 'timestamp') and output[0].timestamp and
176
+ 'segment' in output[0].timestamp):
177
+
178
+ chunk_segments = output[0].timestamp['segment']
179
+ segments_before = len(all_segments)
180
+
181
+ print(f"Chunk {i+1}: Got {len(chunk_segments)} segments")
182
+
183
+ # Add all segments from this chunk, adjusting timestamps
184
+ for segment in chunk_segments:
185
+ # Adjust timestamps to global timeline
186
+ segment_start = segment['start'] + chunk_start_sec
187
+ segment_end = segment['end'] + chunk_start_sec
188
+
189
+ # Only keep segments that are mostly within the effective region
190
+ # Using segment midpoint to determine inclusion
191
+ segment_midpoint = (segment_start + segment_end) / 2
192
+ if effective_start_sec <= segment_midpoint <= effective_end_sec:
193
+ all_segments.append({
194
+ 'start': segment_start,
195
+ 'end': segment_end,
196
+ 'segment': segment['segment']
197
+ })
198
+
199
+ print(f"Chunk {i+1}: Added {len(all_segments) - segments_before} segments (total now: {len(all_segments)})")
200
+
201
+ # Clean memory between chunks
202
+ gc.collect()
203
+ if device == 'cuda':
204
+ torch.cuda.empty_cache()
205
+
206
+ except torch.cuda.OutOfMemoryError as oom_e:
207
+ print(f"CUDA Out of Memory error on chunk {i+1}: {oom_e}")
208
+ gr.Warning(f"CUDA memory error on chunk {i+1}. Trying to continue with remaining chunks.", duration=3)
209
+ if device == 'cuda':
210
+ model.cpu() # Make sure we move back to CPU
211
+ torch.cuda.empty_cache()
212
+ gc.collect()
213
+ # Continue with next chunk
214
+
215
+ except Exception as chunk_e:
216
+ gr.Warning(f"Error processing chunk {i+1}: {chunk_e}", duration=3)
217
+ print(f"Error processing chunk {i+1}: {chunk_e}")
218
+ if device == 'cuda':
219
+ model.cpu() # Make sure we move back to CPU
220
+ # Continue with other chunks even if one fails
221
+
222
+ else:
223
+ # For shorter audio, process the entire file at once
224
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
225
+ audio.export(temp_file.name, format="wav")
226
+ temp_files.append(temp_file.name)
227
+ temp_file.close()
228
+
229
+ try:
230
+ model.to(device)
231
+ gr.Info(f"Transcribing {original_path_name} on {device}...", duration=2)
232
+ output = model.transcribe([temp_file.name], timestamps=True)
233
+
234
+ # Move model back to CPU immediately
235
+ if device == 'cuda':
236
+ model.cpu()
237
+
238
+ if (not output or not isinstance(output, list) or not output[0]
239
+ or not hasattr(output[0], 'timestamp') or not output[0].timestamp
240
+ or 'segment' not in output[0].timestamp):
241
+ gr.Error("Transcription failed or produced unexpected output format.", duration=None)
242
+ return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False)
243
+
244
+ chunk_segments = output[0].timestamp['segment']
245
+ for segment in chunk_segments:
246
+ all_segments.append({
247
+ 'start': segment['start'],
248
+ 'end': segment['end'],
249
+ 'segment': segment['segment']
250
+ })
251
+ print(f"Single chunk processing: Got {len(all_segments)} segments")
252
+
253
+ except torch.cuda.OutOfMemoryError as e:
254
+ error_msg = 'CUDA out of memory. The file may be too large for available GPU memory.'
255
+ print(f"CUDA OutOfMemoryError: {e}")
256
+ gr.Error(error_msg, duration=None)
257
+ return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False)
258
+
259
+ except Exception as e:
260
+ error_msg = f"Transcription failed: {e}"
261
+ print(f"Error during transcription processing: {e}")
262
+ gr.Error(error_msg, duration=None)
263
+ return [["Error", "Error", error_msg]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False)
264
+
265
+ # If we have no segments (all chunks failed) return an error
266
+ if len(all_segments) == 0:
267
+ gr.Error("Failed to transcribe any portion of the audio.", duration=None)
268
+ return [["Error", "Error", "No transcription segments generated"]], [[0.0, 0.0]], file_path, gr.DownloadButton(visible=False)
269
+
270
+ # Debug: print a few segments to check timestamps
271
+ print(f"All segments: {len(all_segments)}")
272
+ all_segments.sort(key=lambda x: x['start']) # Ensure chronological order
273
+ print(f"First segment: {all_segments[0]['start']:.2f}s - {all_segments[0]['end']:.2f}s: {all_segments[0]['segment']}")
274
+ if len(all_segments) > 1:
275
+ print(f"Second segment: {all_segments[1]['start']:.2f}s - {all_segments[1]['end']:.2f}s: {all_segments[1]['segment']}")
276
+ if len(all_segments) > 2:
277
+ middle_idx = len(all_segments) // 2
278
+ print(f"Middle segment: {all_segments[middle_idx]['start']:.2f}s - {all_segments[middle_idx]['end']:.2f}s: {all_segments[middle_idx]['segment']}")
279
+ print(f"Last segment: {all_segments[-1]['start']:.2f}s - {all_segments[-1]['end']:.2f}s: {all_segments[-1]['segment']}")
280
+
281
+ # Create visualization data
282
+ vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in all_segments]
283
+ raw_times_data = [[ts['start'], ts['end']] for ts in all_segments]
284
+
285
+ # Generate SRT with correct timestamps
286
+ srt_lines = []
287
+ for i, ts in enumerate(all_segments, 1):
288
+ start = seconds_to_srt_ts(ts['start'])
289
+ end = seconds_to_srt_ts(ts['end'])
290
+ text = ts['segment'].replace('\n', ' ').strip()
291
+ srt_lines.append(f"{i}\n{start} --> {end}\n{text}\n")
292
+
293
+ # Save SRT file
294
+ button_update = gr.DownloadButton(visible=False)
295
+ try:
296
+ temp_srt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".srt", mode='w', encoding='utf-8')
297
+ temp_srt_file.write('\n'.join(srt_lines))
298
+ srt_file_path = temp_srt_file.name
299
+ temp_srt_file.close()
300
+ print(f"SRT transcript saved to temporary file: {srt_file_path}")
301
+ button_update = gr.DownloadButton(value=srt_file_path, visible=True, label="Download Subtitle File (.srt)")
302
+ except Exception as srt_e:
303
+ gr.Error(f"Failed to create transcript SRT file: {srt_e}", duration=None)
304
+ print(f"Error writing SRT: {srt_e}")
305
+
306
+ gr.Info(f"Transcription complete! Generated {len(all_segments)} segments.", duration=2)
307
+ return vis_data, raw_times_data, file_path, button_update
308
+
309
+ finally:
310
+ # Clean up all temporary files
311
+ for temp_path in temp_files:
312
+ if temp_path and os.path.exists(temp_path):
313
+ try:
314
+ os.remove(temp_path)
315
+ print(f"Temporary file {temp_path} removed.")
316
+ except Exception as e:
317
+ print(f"Error removing temporary file {temp_path}: {e}")
318
+
319
+ # Final cleanup
320
+ try:
321
+ if 'model' in locals() and hasattr(model, 'cpu'):
322
+ if device == 'cuda':
323
+ model.cpu()
324
+ gc.collect()
325
+ if device == 'cuda':
326
+ torch.cuda.empty_cache()
327
+ except Exception as cleanup_e:
328
+ print(f"Error during model cleanup: {cleanup_e}")
329
+ gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
330
+
331
+ def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
332
+ if not isinstance(raw_ts_list, list):
333
+ print(f"Warning: raw_ts_list is not a list ({type(raw_ts_list)}). Cannot play segment.")
334
+ return gr.Audio(value=None, label="Selected Segment")
335
+ if not current_audio_path:
336
+ print("No audio path available to play segment from.")
337
+ return gr.Audio(value=None, label="Selected Segment")
338
+ selected_index = evt.index[0]
339
+ if selected_index < 0 or selected_index >= len(raw_ts_list):
340
+ print(f"Invalid index {selected_index} selected for list of length {len(raw_ts_list)}.")
341
+ return gr.Audio(value=None, label="Selected Segment")
342
+ if (not isinstance(raw_ts_list[selected_index], (list, tuple))
343
+ or len(raw_ts_list[selected_index]) != 2):
344
+ print(f"Warning: Data at index {selected_index} is not in the expected format [start, end].")
345
+ return gr.Audio(value=None, label="Selected Segment")
346
+ start_time_s, end_time_s = raw_ts_list[selected_index]
347
+ print(f"Attempting to play segment: {current_audio_path} from {start_time_s:.2f}s to {end_time_s:.2f}s")
348
+ segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)
349
+ if segment_data:
350
+ print("Segment data retrieved successfully.")
351
+ return gr.Audio(value=segment_data, autoplay=True, label=f"Segment: {start_time_s:.2f}s - {end_time_s:.2f}s", interactive=False)
352
+ else:
353
+ print("Failed to get audio segment data.")
354
+ return gr.Audio(value=None, label="Selected Segment")
355
+
356
+ article = (
357
+ "<p style='font-size: 1.1em;'>"
358
+ "Upload an <b>audio file</b> (wav, mp3, etc) <b>or a video file</b> (mp4, m4a, etc) and this tool will extract the audio stream and generate subtitles in .srt format.<br>"
359
+ "Files longer than 10 minutes will be automatically split into chunks for processing.</p>"
360
+ )
361
+
362
+ # NVIDIA-inspired theme
363
+ nvidia_theme = gr_themes.Default(
364
+ primary_hue=gr_themes.Color(
365
+ c50="#E6F1D9",
366
+ c100="#CEE3B3",
367
+ c200="#B5D58C",
368
+ c300="#9CC766",
369
+ c400="#84B940",
370
+ c500="#76B900",
371
+ c600="#68A600",
372
+ c700="#5A9200",
373
+ c800="#4C7E00",
374
+ c900="#3E6A00",
375
+ c950="#2F5600"
376
+ ),
377
+ neutral_hue="gray",
378
+ font=[gr_themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
379
+ ).set()
380
+
381
+ with gr.Blocks(theme=nvidia_theme) as demo:
382
+ model_display_name = MODEL_NAME.split('/')[-1] if '/' in MODEL_NAME else MODEL_NAME
383
+ gr.Markdown(f"<h1 style='text-align: center; margin: 0 auto;'>Subtitle Generation (en) with {model_display_name}</h1>")
384
+ gr.HTML(article)
385
+
386
+ current_audio_path_state = gr.State(None)
387
+ raw_timestamps_list_state = gr.State([])
388
+
389
+ # Use gr.File instead of gr.Audio to accept video files
390
+ file_input = gr.File(
391
+ label="Upload Audio or Video File (MP3, WAV, MP4, etc)",
392
+ file_types=[".mp3", ".wav", ".mp4", ".m4a", ".aac", ".ogg", ".flac", ".mov", ".mkv"],
393
+ )
394
+
395
+ file_transcribe_btn = gr.Button("Transcribe Uploaded File", variant="primary")
396
+
397
+ gr.Markdown("---")
398
+ gr.Markdown("<p><strong style='color: #FF0000; font-size: 1.2em;'>Transcription Results (Click row to play segment)</strong></p>")
399
+
400
+ download_btn = gr.DownloadButton(label="Download Subtitle File (.srt)", visible=False)
401
+
402
+ vis_timestamps_df = gr.DataFrame(
403
+ headers=["Start (s)", "End (s)", "Segment"],
404
+ datatype=["number", "number", "str"],
405
+ wrap=True,
406
+ label="Transcription Segments"
407
+ )
408
+
409
+ selected_segment_player = gr.Audio(label="Selected Segment", interactive=False)
410
+
411
+ file_transcribe_btn.click(
412
+ fn=get_transcripts_and_raw_times,
413
+ inputs=[file_input],
414
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
415
+ api_name="transcribe_file"
416
+ )
417
+
418
+ vis_timestamps_df.select(
419
+ fn=play_segment,
420
+ inputs=[raw_timestamps_list_state, current_audio_path_state],
421
+ outputs=[selected_segment_player],
422
+ )
423
+
424
+ if __name__ == "__main__":
425
+ print("Launching Gradio Demo...")
426
+ demo.queue()
427
+ demo.launch(server_name="0.0.0.0")
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Cython
2
+ git+https://github.com/NVIDIA/NeMo.git@r2.3.0#egg=nemo_toolkit[asr]
3
+ numpy<2.0
4
+ spaces
5
+ gradio