Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ import gradio.themes as gr_themes
|
|
12 |
import csv
|
13 |
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
-
MODEL_NAME="nvidia/parakeet-tdt-0.6b-v2"
|
16 |
|
17 |
model = ASRModel.from_pretrained(model_name=MODEL_NAME)
|
18 |
model.eval()
|
@@ -39,12 +39,12 @@ def get_audio_segment(audio_path, start_second, end_second):
|
|
39 |
|
40 |
frame_rate = clipped_audio.frame_rate
|
41 |
if frame_rate <= 0:
|
42 |
-
|
43 |
-
|
44 |
|
45 |
if samples.size == 0:
|
46 |
-
|
47 |
-
|
48 |
|
49 |
return (frame_rate, samples)
|
50 |
except FileNotFoundError:
|
@@ -56,9 +56,25 @@ def get_audio_segment(audio_path, start_second, end_second):
|
|
56 |
|
57 |
@spaces.GPU
|
58 |
def get_transcripts_and_raw_times(audio_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
if not audio_path:
|
60 |
gr.Error("No audio file path provided for transcription.", duration=None)
|
61 |
-
# Return an update to hide the button
|
62 |
return [], [], None, gr.DownloadButton(visible=False)
|
63 |
|
64 |
vis_data = [["N/A", "N/A", "Processing failed"]]
|
@@ -74,34 +90,29 @@ def get_transcripts_and_raw_times(audio_path):
|
|
74 |
audio = AudioSegment.from_file(audio_path)
|
75 |
except Exception as load_e:
|
76 |
gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
|
77 |
-
# Return an update to hide the button
|
78 |
return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
79 |
|
80 |
resampled = False
|
81 |
mono = False
|
82 |
-
|
83 |
target_sr = 16000
|
84 |
if audio.frame_rate != target_sr:
|
85 |
try:
|
86 |
audio = audio.set_frame_rate(target_sr)
|
87 |
resampled = True
|
88 |
except Exception as resample_e:
|
89 |
-
|
90 |
-
|
91 |
-
return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
92 |
|
93 |
if audio.channels == 2:
|
94 |
try:
|
95 |
audio = audio.set_channels(1)
|
96 |
mono = True
|
97 |
except Exception as mono_e:
|
98 |
-
|
99 |
-
|
100 |
-
return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
101 |
elif audio.channels > 2:
|
102 |
-
|
103 |
-
|
104 |
-
return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
105 |
|
106 |
if resampled or mono:
|
107 |
try:
|
@@ -113,9 +124,8 @@ def get_transcripts_and_raw_times(audio_path):
|
|
113 |
info_path_name = f"{original_path_name} (processed)"
|
114 |
except Exception as export_e:
|
115 |
gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
|
116 |
-
if temp_file and hasattr(temp_file, 'name') and os.path.exists(temp_file.name):
|
117 |
os.remove(temp_file.name)
|
118 |
-
# Return an update to hide the button
|
119 |
return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
120 |
else:
|
121 |
transcribe_path = audio_path
|
@@ -127,16 +137,14 @@ def get_transcripts_and_raw_times(audio_path):
|
|
127 |
output = model.transcribe([transcribe_path], timestamps=True)
|
128 |
|
129 |
if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
|
130 |
-
|
131 |
-
|
132 |
-
return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
133 |
|
134 |
segment_timestamps = output[0].timestamp['segment']
|
135 |
csv_headers = ["Start (s)", "End (s)", "Segment"]
|
136 |
vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
|
137 |
raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
|
138 |
|
139 |
-
# Default button update (hidden) in case CSV writing fails
|
140 |
button_update = gr.DownloadButton(visible=False)
|
141 |
try:
|
142 |
temp_csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='', encoding='utf-8')
|
@@ -145,61 +153,65 @@ def get_transcripts_and_raw_times(audio_path):
|
|
145 |
writer.writerows(vis_data)
|
146 |
csv_file_path = temp_csv_file.name
|
147 |
temp_csv_file.close()
|
148 |
-
print(f"CSV transcript saved to temporary file: {csv_file_path}")
|
149 |
-
# If CSV is saved, create update to show button with path
|
150 |
button_update = gr.DownloadButton(value=csv_file_path, visible=True)
|
151 |
except Exception as csv_e:
|
152 |
gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
|
153 |
-
print(f"Error writing CSV: {csv_e}")
|
154 |
-
# csv_file_path remains None, button_update remains hidden
|
155 |
|
156 |
gr.Info("Transcription complete.", duration=2)
|
157 |
-
# Return the data and the button update dictionary
|
158 |
return vis_data, raw_times_data, audio_path, button_update
|
159 |
|
160 |
except torch.cuda.OutOfMemoryError as e:
|
161 |
error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
|
162 |
-
print(f"CUDA OutOfMemoryError: {e}")
|
163 |
gr.Error(error_msg, duration=None)
|
164 |
-
# Return an update to hide the button
|
165 |
return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
166 |
|
167 |
except FileNotFoundError:
|
168 |
error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
|
169 |
-
print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
|
170 |
gr.Error(error_msg, duration=None)
|
171 |
-
# Return an update to hide the button
|
172 |
return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
173 |
|
174 |
except Exception as e:
|
175 |
error_msg = f"Transcription failed: {e}"
|
176 |
-
print(f"Error during transcription processing: {e}")
|
177 |
gr.Error(error_msg, duration=None)
|
178 |
vis_data = [["Error", "Error", error_msg]]
|
179 |
raw_times_data = [[0.0, 0.0]]
|
180 |
-
# Return an update to hide the button
|
181 |
return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
|
|
|
182 |
finally:
|
183 |
try:
|
184 |
if 'model' in locals() and hasattr(model, 'cpu'):
|
185 |
-
|
186 |
-
|
187 |
gc.collect()
|
188 |
if device == 'cuda':
|
189 |
torch.cuda.empty_cache()
|
190 |
except Exception as cleanup_e:
|
191 |
-
print(f"Error during model cleanup: {cleanup_e}")
|
192 |
gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
|
193 |
|
194 |
finally:
|
195 |
if processed_audio_path and os.path.exists(processed_audio_path):
|
196 |
try:
|
197 |
os.remove(processed_audio_path)
|
198 |
-
print(f"Temporary audio file {processed_audio_path} removed.")
|
199 |
except Exception as e:
|
200 |
print(f"Error removing temporary audio file {processed_audio_path}: {e}")
|
201 |
|
|
|
|
|
202 |
def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
if not isinstance(raw_ts_list, list):
|
204 |
print(f"Warning: raw_ts_list is not a list ({type(raw_ts_list)}). Cannot play segment.")
|
205 |
return gr.Audio(value=None, label="Selected Segment")
|
@@ -211,15 +223,14 @@ def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
|
|
211 |
selected_index = evt.index[0]
|
212 |
|
213 |
if selected_index < 0 or selected_index >= len(raw_ts_list):
|
214 |
-
|
215 |
-
|
216 |
|
217 |
if not isinstance(raw_ts_list[selected_index], (list, tuple)) or len(raw_ts_list[selected_index]) != 2:
|
218 |
-
|
219 |
-
|
220 |
|
221 |
start_time_s, end_time_s = raw_ts_list[selected_index]
|
222 |
-
|
223 |
print(f"Attempting to play segment: {current_audio_path} from {start_time_s:.2f}s to {end_time_s:.2f}s")
|
224 |
|
225 |
segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)
|
@@ -334,4 +345,4 @@ with gr.Blocks(theme=nvidia_theme) as demo:
|
|
334 |
if __name__ == "__main__":
|
335 |
print("Launching Gradio Demo...")
|
336 |
demo.queue()
|
337 |
-
demo.launch(mcp_server=True)
|
|
|
12 |
import csv
|
13 |
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
|
16 |
|
17 |
model = ASRModel.from_pretrained(model_name=MODEL_NAME)
|
18 |
model.eval()
|
|
|
39 |
|
40 |
frame_rate = clipped_audio.frame_rate
|
41 |
if frame_rate <= 0:
|
42 |
+
print(f"Warning: Invalid frame rate ({frame_rate}) detected for clipped audio.")
|
43 |
+
frame_rate = audio.frame_rate
|
44 |
|
45 |
if samples.size == 0:
|
46 |
+
print(f"Warning: Clipped audio resulted in empty samples array ({start_second}s to {end_second}s).")
|
47 |
+
return None
|
48 |
|
49 |
return (frame_rate, samples)
|
50 |
except FileNotFoundError:
|
|
|
56 |
|
57 |
@spaces.GPU
|
58 |
def get_transcripts_and_raw_times(audio_path):
|
59 |
+
"""
|
60 |
+
Transcribe an audio file or microphone input and return transcription segments, timestamps, and a CSV download button.
|
61 |
+
|
62 |
+
Exposed as MCP endpoints:
|
63 |
+
- transcribe_mic: for microphone recordings
|
64 |
+
- transcribe_file: for uploaded audio files
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
audio_path (str): Path to the audio file or microphone recording to transcribe.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
tuple: A 4-tuple containing:
|
71 |
+
- vis_data (List[List[str]]): Displayable transcription segments [start_str, end_str, segment_text].
|
72 |
+
- raw_times_data (List[List[float]]): Raw timestamps [[start, end], ...].
|
73 |
+
- current_audio_path (str): The path to the audio used for transcription.
|
74 |
+
- download_button (gr.DownloadButton): A Gradio DownloadButton component for downloading the transcript CSV.
|
75 |
+
"""
|
76 |
if not audio_path:
|
77 |
gr.Error("No audio file path provided for transcription.", duration=None)
|
|
|
78 |
return [], [], None, gr.DownloadButton(visible=False)
|
79 |
|
80 |
vis_data = [["N/A", "N/A", "Processing failed"]]
|
|
|
90 |
audio = AudioSegment.from_file(audio_path)
|
91 |
except Exception as load_e:
|
92 |
gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
|
|
|
93 |
return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
94 |
|
95 |
resampled = False
|
96 |
mono = False
|
|
|
97 |
target_sr = 16000
|
98 |
if audio.frame_rate != target_sr:
|
99 |
try:
|
100 |
audio = audio.set_frame_rate(target_sr)
|
101 |
resampled = True
|
102 |
except Exception as resample_e:
|
103 |
+
gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
|
104 |
+
return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
|
|
105 |
|
106 |
if audio.channels == 2:
|
107 |
try:
|
108 |
audio = audio.set_channels(1)
|
109 |
mono = True
|
110 |
except Exception as mono_e:
|
111 |
+
gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
|
112 |
+
return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
|
|
113 |
elif audio.channels > 2:
|
114 |
+
gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
|
115 |
+
return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
|
|
116 |
|
117 |
if resampled or mono:
|
118 |
try:
|
|
|
124 |
info_path_name = f"{original_path_name} (processed)"
|
125 |
except Exception as export_e:
|
126 |
gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
|
127 |
+
if temp_file and hasattr(temp_file, 'name') and os.path.exists(temp_file.name):
|
128 |
os.remove(temp_file.name)
|
|
|
129 |
return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
130 |
else:
|
131 |
transcribe_path = audio_path
|
|
|
137 |
output = model.transcribe([transcribe_path], timestamps=True)
|
138 |
|
139 |
if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
|
140 |
+
gr.Error("Transcription failed or produced unexpected output format.", duration=None)
|
141 |
+
return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
|
|
142 |
|
143 |
segment_timestamps = output[0].timestamp['segment']
|
144 |
csv_headers = ["Start (s)", "End (s)", "Segment"]
|
145 |
vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
|
146 |
raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
|
147 |
|
|
|
148 |
button_update = gr.DownloadButton(visible=False)
|
149 |
try:
|
150 |
temp_csv_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='', encoding='utf-8')
|
|
|
153 |
writer.writerows(vis_data)
|
154 |
csv_file_path = temp_csv_file.name
|
155 |
temp_csv_file.close()
|
|
|
|
|
156 |
button_update = gr.DownloadButton(value=csv_file_path, visible=True)
|
157 |
except Exception as csv_e:
|
158 |
gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
|
|
|
|
|
159 |
|
160 |
gr.Info("Transcription complete.", duration=2)
|
|
|
161 |
return vis_data, raw_times_data, audio_path, button_update
|
162 |
|
163 |
except torch.cuda.OutOfMemoryError as e:
|
164 |
error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
|
|
|
165 |
gr.Error(error_msg, duration=None)
|
|
|
166 |
return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
167 |
|
168 |
except FileNotFoundError:
|
169 |
error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
|
|
|
170 |
gr.Error(error_msg, duration=None)
|
|
|
171 |
return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, gr.DownloadButton(visible=False)
|
172 |
|
173 |
except Exception as e:
|
174 |
error_msg = f"Transcription failed: {e}"
|
|
|
175 |
gr.Error(error_msg, duration=None)
|
176 |
vis_data = [["Error", "Error", error_msg]]
|
177 |
raw_times_data = [[0.0, 0.0]]
|
|
|
178 |
return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
|
179 |
+
|
180 |
finally:
|
181 |
try:
|
182 |
if 'model' in locals() and hasattr(model, 'cpu'):
|
183 |
+
if device == 'cuda':
|
184 |
+
model.cpu()
|
185 |
gc.collect()
|
186 |
if device == 'cuda':
|
187 |
torch.cuda.empty_cache()
|
188 |
except Exception as cleanup_e:
|
|
|
189 |
gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
|
190 |
|
191 |
finally:
|
192 |
if processed_audio_path and os.path.exists(processed_audio_path):
|
193 |
try:
|
194 |
os.remove(processed_audio_path)
|
|
|
195 |
except Exception as e:
|
196 |
print(f"Error removing temporary audio file {processed_audio_path}: {e}")
|
197 |
|
198 |
+
|
199 |
+
@spaces.API(name="play_segment")
|
200 |
def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
|
201 |
+
"""
|
202 |
+
Play a selected audio segment based on the user's selection event.
|
203 |
+
|
204 |
+
Exposed as MCP endpoint:
|
205 |
+
- play_segment
|
206 |
+
|
207 |
+
Parameters:
|
208 |
+
evt (gr.SelectData): The Gradio SelectData event triggered by selecting a row in the DataFrame.
|
209 |
+
raw_ts_list (List[List[float]]): List of timestamp pairs [[start, end], ...] from transcription.
|
210 |
+
current_audio_path (str): Path to the original audio file.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
gr.Audio: A Gradio Audio component for the clipped segment or an empty Audio component on error.
|
214 |
+
"""
|
215 |
if not isinstance(raw_ts_list, list):
|
216 |
print(f"Warning: raw_ts_list is not a list ({type(raw_ts_list)}). Cannot play segment.")
|
217 |
return gr.Audio(value=None, label="Selected Segment")
|
|
|
223 |
selected_index = evt.index[0]
|
224 |
|
225 |
if selected_index < 0 or selected_index >= len(raw_ts_list):
|
226 |
+
print(f"Invalid index {selected_index} selected for list of length {len(raw_ts_list)}.")
|
227 |
+
return gr.Audio(value=None, label="Selected Segment")
|
228 |
|
229 |
if not isinstance(raw_ts_list[selected_index], (list, tuple)) or len(raw_ts_list[selected_index]) != 2:
|
230 |
+
print(f"Warning: Data at index {selected_index} is not in the expected format [start, end].")
|
231 |
+
return gr.Audio(value=None, label="Selected Segment")
|
232 |
|
233 |
start_time_s, end_time_s = raw_ts_list[selected_index]
|
|
|
234 |
print(f"Attempting to play segment: {current_audio_path} from {start_time_s:.2f}s to {end_time_s:.2f}s")
|
235 |
|
236 |
segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)
|
|
|
345 |
if __name__ == "__main__":
|
346 |
print("Launching Gradio Demo...")
|
347 |
demo.queue()
|
348 |
+
demo.launch(mcp_server=True)
|