Add support for timestamps as well
Browse files- pyproject.toml +1 -0
- run.py +40 -4
- uv.lock +2 -0
pyproject.toml
CHANGED
@@ -8,5 +8,6 @@ dependencies = [
|
|
8 |
"gradio>=5.29.0",
|
9 |
"nemo-toolkit[asr]>=2.2.1",
|
10 |
"numpy<2.0",
|
|
|
11 |
"scipy>=1.15.2",
|
12 |
]
|
|
|
8 |
"gradio>=5.29.0",
|
9 |
"nemo-toolkit[asr]>=2.2.1",
|
10 |
"numpy<2.0",
|
11 |
+
"pandas>=2.2.3",
|
12 |
"scipy>=1.15.2",
|
13 |
]
|
run.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import nemo.collections.asr as nemo_asr
|
3 |
import numpy as np
|
|
|
4 |
from scipy import signal
|
5 |
|
6 |
TARGET_SR = 16_000 # Hz
|
@@ -80,14 +81,32 @@ def _resample(audio: np.ndarray, rate: int, target_rate: int) -> np.ndarray:
|
|
80 |
return resampled
|
81 |
|
82 |
|
83 |
-
def
|
84 |
global _model
|
85 |
if not _model:
|
86 |
_model = nemo_asr.models.ASRModel.from_pretrained(
|
87 |
model_name="nvidia/parakeet-tdt-0.6b-v2"
|
88 |
)
|
89 |
|
90 |
-
return _model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
|
93 |
def transcribe(audio: tuple[np.ndarray, int] | None):
|
@@ -96,9 +115,12 @@ def transcribe(audio: tuple[np.ndarray, int] | None):
|
|
96 |
|
97 |
rate, data = audio
|
98 |
|
|
|
99 |
data = _to_float32(data)
|
100 |
data = _resample(data, rate, TARGET_SR)
|
101 |
-
|
|
|
|
|
102 |
|
103 |
|
104 |
app = gr.Interface(
|
@@ -108,7 +130,21 @@ app = gr.Interface(
|
|
108 |
type="numpy",
|
109 |
label="Upload or record audio",
|
110 |
),
|
111 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
title=TITLE,
|
113 |
description=DESCRIPTION,
|
114 |
)
|
|
|
1 |
import gradio as gr
|
2 |
import nemo.collections.asr as nemo_asr
|
3 |
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
from scipy import signal
|
6 |
|
7 |
TARGET_SR = 16_000 # Hz
|
|
|
81 |
return resampled
|
82 |
|
83 |
|
84 |
+
def _load_model():
|
85 |
global _model
|
86 |
if not _model:
|
87 |
_model = nemo_asr.models.ASRModel.from_pretrained(
|
88 |
model_name="nvidia/parakeet-tdt-0.6b-v2"
|
89 |
)
|
90 |
|
91 |
+
return _model
|
92 |
+
|
93 |
+
|
94 |
+
def _to_pandas(prediction, keyword):
|
95 |
+
return pd.DataFrame(prediction.timestamp[keyword])[
|
96 |
+
[keyword, "start", "end"]
|
97 |
+
]
|
98 |
+
|
99 |
+
|
100 |
+
def _invoke_model(model, audio: np.ndarray):
|
101 |
+
prediction = model.transcribe(audio=audio, timestamps=True)[0]
|
102 |
+
|
103 |
+
text = prediction.text
|
104 |
+
|
105 |
+
chars = _to_pandas(prediction, "char")
|
106 |
+
words = _to_pandas(prediction, "word")
|
107 |
+
segments = _to_pandas(prediction, "segment")
|
108 |
+
|
109 |
+
return text, chars, words, segments
|
110 |
|
111 |
|
112 |
def transcribe(audio: tuple[np.ndarray, int] | None):
|
|
|
115 |
|
116 |
rate, data = audio
|
117 |
|
118 |
+
model = _load_model()
|
119 |
data = _to_float32(data)
|
120 |
data = _resample(data, rate, TARGET_SR)
|
121 |
+
text, chars, words, segments = _invoke_model(model, data)
|
122 |
+
|
123 |
+
return text, segments, words, chars
|
124 |
|
125 |
|
126 |
app = gr.Interface(
|
|
|
130 |
type="numpy",
|
131 |
label="Upload or record audio",
|
132 |
),
|
133 |
+
outputs=[
|
134 |
+
gr.Textbox(label="Transcription", show_copy_button=True),
|
135 |
+
gr.Dataframe(
|
136 |
+
label="Segments",
|
137 |
+
headers=["Segment", "Start", "End"],
|
138 |
+
),
|
139 |
+
gr.Dataframe(
|
140 |
+
label="Words",
|
141 |
+
headers=["Word", "Start", "End"],
|
142 |
+
),
|
143 |
+
gr.Dataframe(
|
144 |
+
label="Characters",
|
145 |
+
headers=["Character", "Start", "End"],
|
146 |
+
),
|
147 |
+
],
|
148 |
title=TITLE,
|
149 |
description=DESCRIPTION,
|
150 |
)
|
uv.lock
CHANGED
@@ -2694,6 +2694,7 @@ dependencies = [
|
|
2694 |
{ name = "gradio" },
|
2695 |
{ name = "nemo-toolkit", extra = ["asr"] },
|
2696 |
{ name = "numpy" },
|
|
|
2697 |
{ name = "scipy" },
|
2698 |
]
|
2699 |
|
@@ -2702,6 +2703,7 @@ requires-dist = [
|
|
2702 |
{ name = "gradio", specifier = ">=5.29.0" },
|
2703 |
{ name = "nemo-toolkit", extras = ["asr"], specifier = ">=2.2.1" },
|
2704 |
{ name = "numpy", specifier = "<2.0" },
|
|
|
2705 |
{ name = "scipy", specifier = ">=1.15.2" },
|
2706 |
]
|
2707 |
|
|
|
2694 |
{ name = "gradio" },
|
2695 |
{ name = "nemo-toolkit", extra = ["asr"] },
|
2696 |
{ name = "numpy" },
|
2697 |
+
{ name = "pandas" },
|
2698 |
{ name = "scipy" },
|
2699 |
]
|
2700 |
|
|
|
2703 |
{ name = "gradio", specifier = ">=5.29.0" },
|
2704 |
{ name = "nemo-toolkit", extras = ["asr"], specifier = ">=2.2.1" },
|
2705 |
{ name = "numpy", specifier = "<2.0" },
|
2706 |
+
{ name = "pandas", specifier = ">=2.2.3" },
|
2707 |
{ name = "scipy", specifier = ">=1.15.2" },
|
2708 |
]
|
2709 |
|