fdaudens HF Staff commited on
Commit
ec29942
·
verified ·
1 Parent(s): 59b3391

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -44
app.py CHANGED
@@ -12,7 +12,7 @@ import gradio.themes as gr_themes
12
  import csv
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- MODEL_NAME="nvidia/parakeet-tdt-0.6b-v2"
16
 
17
  model = ASRModel.from_pretrained(model_name=MODEL_NAME)
18
  model.eval()
@@ -39,12 +39,12 @@ def get_audio_segment(audio_path, start_second, end_second):
39
 
40
  frame_rate = clipped_audio.frame_rate
41
  if frame_rate <= 0:
42
- print(f"Warning: Invalid frame rate ({frame_rate}) detected for clipped audio.")
43
- frame_rate = audio.frame_rate
44
 
45
  if samples.size == 0:
46
- print(f"Warning: Clipped audio resulted in empty samples array ({start_second}s to {end_second}s).")
47
- return None
48
 
49
  return (frame_rate, samples)
50
  except FileNotFoundError:
@@ -56,9 +56,25 @@ def get_audio_segment(audio_path, start_second, end_second):
56
 
57
  @spaces.GPU
58
  def get_transcripts_and_raw_times(audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if not audio_path:
60
  gr.Error("No audio file path provided for transcription.", duration=None)
61
- # Return an update to hide the button
62
  return [], [], None, gr.DownloadButton(visible=False)
63
 
64
  vis_data = [["N/A", "N/A", "Processing failed"]]
@@ -74,34 +90,29 @@ def get_transcripts_and_raw_times(audio_path):
74
  audio = AudioSegment.from_file(audio_path)
75
  except Exception as load_e:
76
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
77
- # Return an update to hide the button
78
  return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
79
 
80
  resampled = False
81
  mono = False
82
-
83
  target_sr = 16000
84
  if audio.frame_rate != target_sr:
85
  try:
86
  audio = audio.set_frame_rate(target_sr)
87
  resampled = True
88
  except Exception as resample_e:
89
- gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
90
- # Return an update to hide the button
91
- return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
92
 
93
  if audio.channels == 2:
94
  try:
95
  audio = audio.set_channels(1)
96
  mono = True
97
  except Exception as mono_e:
98
- gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
99
- # Return an update to hide the button
100
- return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
101
  elif audio.channels > 2:
102
- gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
103
- # Return an update to hide the button
104
- return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
105
 
106
  if resampled or mono:
107
  try:
@@ -113,9 +124,8 @@ def get_transcripts_and_raw_times(audio_path):
113
  info_path_name = f"{original_path_name} (processed)"
114
  except Exception as export_e:
115
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
116
- if temp_file and hasattr(temp_file, 'name') and os.path.exists(temp_file.name): # Check temp_file has 'name' attribute
117
  os.remove(temp_file.name)
118
- # Return an update to hide the button
119
  return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
120
  else:
121
  transcribe_path = audio_path
@@ -127,16 +137,14 @@ def get_transcripts_and_raw_times(audio_path):
127
  output = model.transcribe([transcribe_path], timestamps=True)
128
 
129
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
130
- gr.Error("Transcription failed or produced unexpected output format.", duration=None)
131
- # Return an update to hide the button
132
- return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
133
 
134
  segment_timestamps = output[0].timestamp['segment']
135
  csv_headers = ["Start (s)", "End (s)", "Segment"]
136
  vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
137
  raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
138
 
139
- # Default button update (hidden) in case CSV writing fails
140
  button_update = gr.DownloadButton(visible=False)
141
  try:
142
  temp_csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='', encoding='utf-8')
@@ -145,61 +153,65 @@ def get_transcripts_and_raw_times(audio_path):
145
  writer.writerows(vis_data)
146
  csv_file_path = temp_csv_file.name
147
  temp_csv_file.close()
148
- print(f"CSV transcript saved to temporary file: {csv_file_path}")
149
- # If CSV is saved, create update to show button with path
150
  button_update = gr.DownloadButton(value=csv_file_path, visible=True)
151
  except Exception as csv_e:
152
  gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
153
- print(f"Error writing CSV: {csv_e}")
154
- # csv_file_path remains None, button_update remains hidden
155
 
156
  gr.Info("Transcription complete.", duration=2)
157
- # Return the data and the button update dictionary
158
  return vis_data, raw_times_data, audio_path, button_update
159
 
160
  except torch.cuda.OutOfMemoryError as e:
161
  error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
162
- print(f"CUDA OutOfMemoryError: {e}")
163
  gr.Error(error_msg, duration=None)
164
- # Return an update to hide the button
165
  return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
166
 
167
  except FileNotFoundError:
168
  error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
169
- print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
170
  gr.Error(error_msg, duration=None)
171
- # Return an update to hide the button
172
  return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
173
 
174
  except Exception as e:
175
  error_msg = f"Transcription failed: {e}"
176
- print(f"Error during transcription processing: {e}")
177
  gr.Error(error_msg, duration=None)
178
  vis_data = [["Error", "Error", error_msg]]
179
  raw_times_data = [[0.0, 0.0]]
180
- # Return an update to hide the button
181
  return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
 
182
  finally:
183
  try:
184
  if 'model' in locals() and hasattr(model, 'cpu'):
185
- if device == 'cuda':
186
- model.cpu()
187
  gc.collect()
188
  if device == 'cuda':
189
  torch.cuda.empty_cache()
190
  except Exception as cleanup_e:
191
- print(f"Error during model cleanup: {cleanup_e}")
192
  gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
193
 
194
  finally:
195
  if processed_audio_path and os.path.exists(processed_audio_path):
196
  try:
197
  os.remove(processed_audio_path)
198
- print(f"Temporary audio file {processed_audio_path} removed.")
199
  except Exception as e:
200
  print(f"Error removing temporary audio file {processed_audio_path}: {e}")
201
 
 
 
202
  def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if not isinstance(raw_ts_list, list):
204
  print(f"Warning: raw_ts_list is not a list ({type(raw_ts_list)}). Cannot play segment.")
205
  return gr.Audio(value=None, label="Selected Segment")
@@ -211,15 +223,14 @@ def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
211
  selected_index = evt.index[0]
212
 
213
  if selected_index < 0 or selected_index >= len(raw_ts_list):
214
- print(f"Invalid index {selected_index} selected for list of length {len(raw_ts_list)}.")
215
- return gr.Audio(value=None, label="Selected Segment")
216
 
217
  if not isinstance(raw_ts_list[selected_index], (list, tuple)) or len(raw_ts_list[selected_index]) != 2:
218
- print(f"Warning: Data at index {selected_index} is not in the expected format [start, end].")
219
- return gr.Audio(value=None, label="Selected Segment")
220
 
221
  start_time_s, end_time_s = raw_ts_list[selected_index]
222
-
223
  print(f"Attempting to play segment: {current_audio_path} from {start_time_s:.2f}s to {end_time_s:.2f}s")
224
 
225
  segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)
@@ -334,4 +345,4 @@ with gr.Blocks(theme=nvidia_theme) as demo:
334
  if __name__ == "__main__":
335
  print("Launching Gradio Demo...")
336
  demo.queue()
337
- demo.launch(mcp_server=True)
 
12
  import csv
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
16
 
17
  model = ASRModel.from_pretrained(model_name=MODEL_NAME)
18
  model.eval()
 
39
 
40
  frame_rate = clipped_audio.frame_rate
41
  if frame_rate <= 0:
42
+ print(f"Warning: Invalid frame rate ({frame_rate}) detected for clipped audio.")
43
+ frame_rate = audio.frame_rate
44
 
45
  if samples.size == 0:
46
+ print(f"Warning: Clipped audio resulted in empty samples array ({start_second}s to {end_second}s).")
47
+ return None
48
 
49
  return (frame_rate, samples)
50
  except FileNotFoundError:
 
56
 
57
  @spaces.GPU
58
  def get_transcripts_and_raw_times(audio_path):
59
+ """
60
+ Transcribe an audio file or microphone input and return transcription segments, timestamps, and a CSV download button.
61
+
62
+ Exposed as MCP endpoints:
63
+ - transcribe_mic: for microphone recordings
64
+ - transcribe_file: for uploaded audio files
65
+
66
+ Parameters:
67
+ audio_path (str): Path to the audio file or microphone recording to transcribe.
68
+
69
+ Returns:
70
+ tuple: A 4-tuple containing:
71
+ - vis_data (List[List[str]]): Displayable transcription segments [start_str, end_str, segment_text].
72
+ - raw_times_data (List[List[float]]): Raw timestamps [[start, end], ...].
73
+ - current_audio_path (str): The path to the audio used for transcription.
74
+ - download_button (gr.DownloadButton): A Gradio DownloadButton component for downloading the transcript CSV.
75
+ """
76
  if not audio_path:
77
  gr.Error("No audio file path provided for transcription.", duration=None)
 
78
  return [], [], None, gr.DownloadButton(visible=False)
79
 
80
  vis_data = [["N/A", "N/A", "Processing failed"]]
 
90
  audio = AudioSegment.from_file(audio_path)
91
  except Exception as load_e:
92
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
 
93
  return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
94
 
95
  resampled = False
96
  mono = False
 
97
  target_sr = 16000
98
  if audio.frame_rate != target_sr:
99
  try:
100
  audio = audio.set_frame_rate(target_sr)
101
  resampled = True
102
  except Exception as resample_e:
103
+ gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
104
+ return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
 
105
 
106
  if audio.channels == 2:
107
  try:
108
  audio = audio.set_channels(1)
109
  mono = True
110
  except Exception as mono_e:
111
+ gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
112
+ return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
 
113
  elif audio.channels > 2:
114
+ gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
115
+ return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
 
116
 
117
  if resampled or mono:
118
  try:
 
124
  info_path_name = f"{original_path_name} (processed)"
125
  except Exception as export_e:
126
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
127
+ if temp_file and hasattr(temp_file, 'name') and os.path.exists(temp_file.name):
128
  os.remove(temp_file.name)
 
129
  return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
130
  else:
131
  transcribe_path = audio_path
 
137
  output = model.transcribe([transcribe_path], timestamps=True)
138
 
139
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
140
+ gr.Error("Transcription failed or produced unexpected output format.", duration=None)
141
+ return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
 
142
 
143
  segment_timestamps = output[0].timestamp['segment']
144
  csv_headers = ["Start (s)", "End (s)", "Segment"]
145
  vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
146
  raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
147
 
 
148
  button_update = gr.DownloadButton(visible=False)
149
  try:
150
  temp_csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='', encoding='utf-8')
 
153
  writer.writerows(vis_data)
154
  csv_file_path = temp_csv_file.name
155
  temp_csv_file.close()
 
 
156
  button_update = gr.DownloadButton(value=csv_file_path, visible=True)
157
  except Exception as csv_e:
158
  gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
 
 
159
 
160
  gr.Info("Transcription complete.", duration=2)
 
161
  return vis_data, raw_times_data, audio_path, button_update
162
 
163
  except torch.cuda.OutOfMemoryError as e:
164
  error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
 
165
  gr.Error(error_msg, duration=None)
 
166
  return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
167
 
168
  except FileNotFoundError:
169
  error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
 
170
  gr.Error(error_msg, duration=None)
 
171
  return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
172
 
173
  except Exception as e:
174
  error_msg = f"Transcription failed: {e}"
 
175
  gr.Error(error_msg, duration=None)
176
  vis_data = [["Error", "Error", error_msg]]
177
  raw_times_data = [[0.0, 0.0]]
 
178
  return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
179
+
180
  finally:
181
  try:
182
  if 'model' in locals() and hasattr(model, 'cpu'):
183
+ if device == 'cuda':
184
+ model.cpu()
185
  gc.collect()
186
  if device == 'cuda':
187
  torch.cuda.empty_cache()
188
  except Exception as cleanup_e:
 
189
  gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
190
 
191
  finally:
192
  if processed_audio_path and os.path.exists(processed_audio_path):
193
  try:
194
  os.remove(processed_audio_path)
 
195
  except Exception as e:
196
  print(f"Error removing temporary audio file {processed_audio_path}: {e}")
197
 
198
+
199
+ @spaces.API(name="play_segment")
200
  def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
201
+ """
202
+ Play a selected audio segment based on the user's selection event.
203
+
204
+ Exposed as MCP endpoint:
205
+ - play_segment
206
+
207
+ Parameters:
208
+ evt (gr.SelectData): The Gradio SelectData event triggered by selecting a row in the DataFrame.
209
+ raw_ts_list (List[List[float]]): List of timestamp pairs [[start, end], ...] from transcription.
210
+ current_audio_path (str): Path to the original audio file.
211
+
212
+ Returns:
213
+ gr.Audio: A Gradio Audio component for the clipped segment or an empty Audio component on error.
214
+ """
215
  if not isinstance(raw_ts_list, list):
216
  print(f"Warning: raw_ts_list is not a list ({type(raw_ts_list)}). Cannot play segment.")
217
  return gr.Audio(value=None, label="Selected Segment")
 
223
  selected_index = evt.index[0]
224
 
225
  if selected_index < 0 or selected_index >= len(raw_ts_list):
226
+ print(f"Invalid index {selected_index} selected for list of length {len(raw_ts_list)}.")
227
+ return gr.Audio(value=None, label="Selected Segment")
228
 
229
  if not isinstance(raw_ts_list[selected_index], (list, tuple)) or len(raw_ts_list[selected_index]) != 2:
230
+ print(f"Warning: Data at index {selected_index} is not in the expected format [start, end].")
231
+ return gr.Audio(value=None, label="Selected Segment")
232
 
233
  start_time_s, end_time_s = raw_ts_list[selected_index]
 
234
  print(f"Attempting to play segment: {current_audio_path} from {start_time_s:.2f}s to {end_time_s:.2f}s")
235
 
236
  segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)
 
345
  if __name__ == "__main__":
346
  print("Launching Gradio Demo...")
347
  demo.queue()
348
+ demo.launch(mcp_server=True)