Nithin Rao Koluguri commited on
Commit
f4154c5
·
1 Parent(s): 0e3aa4b

Add SRT download button

Browse files

Signed-off-by: Nithin Rao Koluguri <nithinraok>

Files changed (1) hide show
  1. app.py +65 -30
app.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  import os
11
  import gradio.themes as gr_themes
12
  import csv
 
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  MODEL_NAME="nvidia/parakeet-tdt-0.6b-v2"
@@ -72,20 +73,52 @@ def get_audio_segment(audio_path, start_second, end_second):
72
  print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
73
  return None
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  @spaces.GPU
76
  def get_transcripts_and_raw_times(audio_path, session_dir):
77
  if not audio_path:
78
  gr.Error("No audio file path provided for transcription.", duration=None)
79
- # Return an update to hide the button
80
- return [], [], None, gr.DownloadButton(visible=False)
81
 
82
  vis_data = [["N/A", "N/A", "Processing failed"]]
83
  raw_times_data = [[0.0, 0.0]]
84
  processed_audio_path = None
85
  csv_file_path = None
 
86
  original_path_name = Path(audio_path).name
87
  audio_name = Path(audio_path).stem
88
 
 
 
 
 
89
  try:
90
  try:
91
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
@@ -93,8 +126,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
93
  duration_sec = audio.duration_seconds
94
  except Exception as load_e:
95
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
96
- # Return an update to hide the button
97
- return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
98
 
99
  resampled = False
100
  mono = False
@@ -106,8 +138,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
106
  resampled = True
107
  except Exception as resample_e:
108
  gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
109
- # Return an update to hide the button
110
- return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
111
 
112
  if audio.channels == 2:
113
  try:
@@ -115,12 +146,10 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
115
  mono = True
116
  except Exception as mono_e:
117
  gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
118
- # Return an update to hide the button
119
- return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
120
  elif audio.channels > 2:
121
  gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
122
- # Return an update to hide the button
123
- return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
124
 
125
  if resampled or mono:
126
  try:
@@ -132,8 +161,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
132
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
133
  if processed_audio_path and os.path.exists(processed_audio_path):
134
  os.remove(processed_audio_path)
135
- # Return an update to hide the button
136
- return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
137
  else:
138
  transcribe_path = audio_path
139
  info_path_name = original_path_name
@@ -163,46 +191,52 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
163
 
164
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
165
  gr.Error("Transcription failed or produced unexpected output format.", duration=None)
166
- # Return an update to hide the button
167
- return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
168
 
169
  segment_timestamps = output[0].timestamp['segment']
170
  csv_headers = ["Start (s)", "End (s)", "Segment"]
171
  vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
172
  raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
173
 
174
- # Default button update (hidden) in case CSV writing fails
175
- button_update = gr.DownloadButton(visible=False)
176
  try:
177
  csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
178
  writer = csv.writer(open(csv_file_path, 'w'))
179
  writer.writerow(csv_headers)
180
  writer.writerows(vis_data)
181
  print(f"CSV transcript saved to temporary file: {csv_file_path}")
182
- # If CSV is saved, create update to show button with path
183
- button_update = gr.DownloadButton(value=csv_file_path, visible=True)
184
  except Exception as csv_e:
185
  gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
186
  print(f"Error writing CSV: {csv_e}")
187
- # csv_file_path remains None, button_update remains hidden
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  gr.Info("Transcription complete.", duration=2)
190
- # Return the data and the button update dictionary
191
- return vis_data, raw_times_data, audio_path, button_update
192
 
193
  except torch.cuda.OutOfMemoryError as e:
194
  error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
195
  print(f"CUDA OutOfMemoryError: {e}")
196
  gr.Error(error_msg, duration=None)
197
- # Return an update to hide the button
198
- return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
199
 
200
  except FileNotFoundError:
201
  error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
202
  print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
203
  gr.Error(error_msg, duration=None)
204
- # Return an update to hide the button
205
- return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
206
 
207
  except Exception as e:
208
  error_msg = f"Transcription failed: {e}"
@@ -210,8 +244,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
210
  gr.Error(error_msg, duration=None)
211
  vis_data = [["Error", "Error", error_msg]]
212
  raw_times_data = [[0.0, 0.0]]
213
- # Return an update to hide the button
214
- return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
215
  finally:
216
  # --- Model Cleanup ---
217
  try:
@@ -349,7 +382,9 @@ with gr.Blocks(theme=nvidia_theme) as demo:
349
  gr.Markdown("<p><strong style='color: #FF0000; font-size: 1.2em;'>Transcription Results (Click row to play segment)</strong></p>")
350
 
351
  # Define the DownloadButton *before* the DataFrame
352
- download_btn = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
 
 
353
 
354
  vis_timestamps_df = gr.DataFrame(
355
  headers=["Start (s)", "End (s)", "Segment"],
@@ -364,14 +399,14 @@ with gr.Blocks(theme=nvidia_theme) as demo:
364
  mic_transcribe_btn.click(
365
  fn=get_transcripts_and_raw_times,
366
  inputs=[mic_input, session_dir],
367
- outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
368
  api_name="transcribe_mic"
369
  )
370
 
371
  file_transcribe_btn.click(
372
  fn=get_transcripts_and_raw_times,
373
  inputs=[file_input, session_dir],
374
- outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn],
375
  api_name="transcribe_file"
376
  )
377
 
 
10
  import os
11
  import gradio.themes as gr_themes
12
  import csv
13
+ import datetime
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  MODEL_NAME="nvidia/parakeet-tdt-0.6b-v2"
 
73
  print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
74
  return None
75
 
76
+ def format_srt_time(seconds: float) -> str:
77
+ """Converts seconds to SRT time format HH:MM:SS,mmm using datetime.timedelta"""
78
+ sanitized_total_seconds = max(0.0, seconds)
79
+ delta = datetime.timedelta(seconds=sanitized_total_seconds)
80
+ total_int_seconds = int(delta.total_seconds())
81
+
82
+ hours = total_int_seconds // 3600
83
+ remainder_seconds_after_hours = total_int_seconds % 3600
84
+ minutes = remainder_seconds_after_hours // 60
85
+ seconds_part = remainder_seconds_after_hours % 60
86
+ milliseconds = delta.microseconds // 1000
87
+
88
+ return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"
89
+
90
+ def generate_srt_content(segment_timestamps: list) -> str:
91
+ """Generates SRT formatted string from segment timestamps."""
92
+ srt_content = []
93
+ for i, ts in enumerate(segment_timestamps):
94
+ start_time = format_srt_time(ts['start'])
95
+ end_time = format_srt_time(ts['end'])
96
+ text = ts['segment']
97
+ srt_content.append(str(i + 1))
98
+ srt_content.append(f"{start_time} --> {end_time}")
99
+ srt_content.append(text)
100
+ srt_content.append("")
101
+ return "\n".join(srt_content)
102
+
103
  @spaces.GPU
104
  def get_transcripts_and_raw_times(audio_path, session_dir):
105
  if not audio_path:
106
  gr.Error("No audio file path provided for transcription.", duration=None)
107
+ # Return an update to hide the buttons
108
+ return [], [], None, gr.DownloadButton(label="Download Transcript (CSV)", visible=False), gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
109
 
110
  vis_data = [["N/A", "N/A", "Processing failed"]]
111
  raw_times_data = [[0.0, 0.0]]
112
  processed_audio_path = None
113
  csv_file_path = None
114
+ srt_file_path = None
115
  original_path_name = Path(audio_path).name
116
  audio_name = Path(audio_path).stem
117
 
118
+ # Initialize button states
119
+ csv_button_update = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
120
+ srt_button_update = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
121
+
122
  try:
123
  try:
124
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
 
126
  duration_sec = audio.duration_seconds
127
  except Exception as load_e:
128
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
129
+ return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
130
 
131
  resampled = False
132
  mono = False
 
138
  resampled = True
139
  except Exception as resample_e:
140
  gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
141
+ return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
142
 
143
  if audio.channels == 2:
144
  try:
 
146
  mono = True
147
  except Exception as mono_e:
148
  gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
149
+ return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
150
  elif audio.channels > 2:
151
  gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
152
+ return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
153
 
154
  if resampled or mono:
155
  try:
 
161
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
162
  if processed_audio_path and os.path.exists(processed_audio_path):
163
  os.remove(processed_audio_path)
164
+ return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
165
  else:
166
  transcribe_path = audio_path
167
  info_path_name = original_path_name
 
191
 
192
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
193
  gr.Error("Transcription failed or produced unexpected output format.", duration=None)
194
+ # Return an update to hide the buttons
195
+ return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
196
 
197
  segment_timestamps = output[0].timestamp['segment']
198
  csv_headers = ["Start (s)", "End (s)", "Segment"]
199
  vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
200
  raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
201
 
202
+ # CSV file generation
 
203
  try:
204
  csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
205
  writer = csv.writer(open(csv_file_path, 'w'))
206
  writer.writerow(csv_headers)
207
  writer.writerows(vis_data)
208
  print(f"CSV transcript saved to temporary file: {csv_file_path}")
209
+ csv_button_update = gr.DownloadButton(value=csv_file_path, visible=True, label="Download Transcript (CSV)")
 
210
  except Exception as csv_e:
211
  gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
212
  print(f"Error writing CSV: {csv_e}")
213
+
214
+ if segment_timestamps:
215
+ try:
216
+ srt_content = generate_srt_content(segment_timestamps)
217
+ srt_file_path = Path(session_dir, f"transcription_{audio_name}.srt")
218
+ with open(srt_file_path, 'w', encoding='utf-8') as f:
219
+ f.write(srt_content)
220
+ print(f"SRT transcript saved to temporary file: {srt_file_path}")
221
+ srt_button_update = gr.DownloadButton(value=srt_file_path, visible=True, label="Download Transcript (SRT)")
222
+ except Exception as srt_e:
223
+ gr.Warning(f"Failed to create transcript SRT file: {srt_e}", duration=5)
224
+ print(f"Error writing SRT: {srt_e}")
225
 
226
  gr.Info("Transcription complete.", duration=2)
227
+ return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update
 
228
 
229
  except torch.cuda.OutOfMemoryError as e:
230
  error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
231
  print(f"CUDA OutOfMemoryError: {e}")
232
  gr.Error(error_msg, duration=None)
233
+ return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
234
 
235
  except FileNotFoundError:
236
  error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
237
  print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
238
  gr.Error(error_msg, duration=None)
239
+ return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
 
240
 
241
  except Exception as e:
242
  error_msg = f"Transcription failed: {e}"
 
244
  gr.Error(error_msg, duration=None)
245
  vis_data = [["Error", "Error", error_msg]]
246
  raw_times_data = [[0.0, 0.0]]
247
+ return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update
 
248
  finally:
249
  # --- Model Cleanup ---
250
  try:
 
382
  gr.Markdown("<p><strong style='color: #FF0000; font-size: 1.2em;'>Transcription Results (Click row to play segment)</strong></p>")
383
 
384
  # Define the DownloadButton *before* the DataFrame
385
+ with gr.Row():
386
+ download_btn_csv = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
387
+ download_btn_srt = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)
388
 
389
  vis_timestamps_df = gr.DataFrame(
390
  headers=["Start (s)", "End (s)", "Segment"],
 
399
  mic_transcribe_btn.click(
400
  fn=get_transcripts_and_raw_times,
401
  inputs=[mic_input, session_dir],
402
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
403
  api_name="transcribe_mic"
404
  )
405
 
406
  file_transcribe_btn.click(
407
  fn=get_transcripts_and_raw_times,
408
  inputs=[file_input, session_dir],
409
+ outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
410
  api_name="transcribe_file"
411
  )
412