mgruner commited on
Commit
c913b1a
·
1 Parent(s): 4704268

Add support for timestamps as well

Browse files
Files changed (3) hide show
  1. pyproject.toml +1 -0
  2. run.py +40 -4
  3. uv.lock +2 -0
pyproject.toml CHANGED
@@ -8,5 +8,6 @@ dependencies = [
8
  "gradio>=5.29.0",
9
  "nemo-toolkit[asr]>=2.2.1",
10
  "numpy<2.0",
 
11
  "scipy>=1.15.2",
12
  ]
 
8
  "gradio>=5.29.0",
9
  "nemo-toolkit[asr]>=2.2.1",
10
  "numpy<2.0",
11
+ "pandas>=2.2.3",
12
  "scipy>=1.15.2",
13
  ]
run.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import nemo.collections.asr as nemo_asr
3
  import numpy as np
 
4
  from scipy import signal
5
 
6
  TARGET_SR = 16_000 # Hz
@@ -80,14 +81,32 @@ def _resample(audio: np.ndarray, rate: int, target_rate: int) -> np.ndarray:
80
  return resampled
81
 
82
 
83
- def _invoke_model(audio: np.ndarray):
84
  global _model
85
  if not _model:
86
  _model = nemo_asr.models.ASRModel.from_pretrained(
87
  model_name="nvidia/parakeet-tdt-0.6b-v2"
88
  )
89
 
90
- return _model.transcribe(audio=audio)[0].text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  def transcribe(audio: tuple[np.ndarray, int] | None):
@@ -96,9 +115,12 @@ def transcribe(audio: tuple[np.ndarray, int] | None):
96
 
97
  rate, data = audio
98
 
 
99
  data = _to_float32(data)
100
  data = _resample(data, rate, TARGET_SR)
101
- return _invoke_model(data)
 
 
102
 
103
 
104
  app = gr.Interface(
@@ -108,7 +130,21 @@ app = gr.Interface(
108
  type="numpy",
109
  label="Upload or record audio",
110
  ),
111
- outputs=gr.Textbox(label="Transcription", show_copy_button=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  title=TITLE,
113
  description=DESCRIPTION,
114
  )
 
1
  import gradio as gr
2
  import nemo.collections.asr as nemo_asr
3
  import numpy as np
4
+ import pandas as pd
5
  from scipy import signal
6
 
7
  TARGET_SR = 16_000 # Hz
 
81
  return resampled
82
 
83
 
84
+ def _load_model():
85
  global _model
86
  if not _model:
87
  _model = nemo_asr.models.ASRModel.from_pretrained(
88
  model_name="nvidia/parakeet-tdt-0.6b-v2"
89
  )
90
 
91
+ return _model
92
+
93
+
94
+ def _to_pandas(prediction, keyword):
95
+ return pd.DataFrame(prediction.timestamp[keyword])[
96
+ [keyword, "start", "end"]
97
+ ]
98
+
99
+
100
+ def _invoke_model(model, audio: np.ndarray):
101
+ prediction = model.transcribe(audio=audio, timestamps=True)[0]
102
+
103
+ text = prediction.text
104
+
105
+ chars = _to_pandas(prediction, "char")
106
+ words = _to_pandas(prediction, "word")
107
+ segments = _to_pandas(prediction, "segment")
108
+
109
+ return text, chars, words, segments
110
 
111
 
112
  def transcribe(audio: tuple[np.ndarray, int] | None):
 
115
 
116
  rate, data = audio
117
 
118
+ model = _load_model()
119
  data = _to_float32(data)
120
  data = _resample(data, rate, TARGET_SR)
121
+ text, chars, words, segments = _invoke_model(model, data)
122
+
123
+ return text, segments, words, chars
124
 
125
 
126
  app = gr.Interface(
 
130
  type="numpy",
131
  label="Upload or record audio",
132
  ),
133
+ outputs=[
134
+ gr.Textbox(label="Transcription", show_copy_button=True),
135
+ gr.Dataframe(
136
+ label="Segments",
137
+ headers=["Segment", "Start", "End"],
138
+ ),
139
+ gr.Dataframe(
140
+ label="Words",
141
+ headers=["Word", "Start", "End"],
142
+ ),
143
+ gr.Dataframe(
144
+ label="Characters",
145
+ headers=["Character", "Start", "End"],
146
+ ),
147
+ ],
148
  title=TITLE,
149
  description=DESCRIPTION,
150
  )
uv.lock CHANGED
@@ -2694,6 +2694,7 @@ dependencies = [
2694
  { name = "gradio" },
2695
  { name = "nemo-toolkit", extra = ["asr"] },
2696
  { name = "numpy" },
 
2697
  { name = "scipy" },
2698
  ]
2699
 
@@ -2702,6 +2703,7 @@ requires-dist = [
2702
  { name = "gradio", specifier = ">=5.29.0" },
2703
  { name = "nemo-toolkit", extras = ["asr"], specifier = ">=2.2.1" },
2704
  { name = "numpy", specifier = "<2.0" },
 
2705
  { name = "scipy", specifier = ">=1.15.2" },
2706
  ]
2707
 
 
2694
  { name = "gradio" },
2695
  { name = "nemo-toolkit", extra = ["asr"] },
2696
  { name = "numpy" },
2697
+ { name = "pandas" },
2698
  { name = "scipy" },
2699
  ]
2700
 
 
2703
  { name = "gradio", specifier = ">=5.29.0" },
2704
  { name = "nemo-toolkit", extras = ["asr"], specifier = ">=2.2.1" },
2705
  { name = "numpy", specifier = "<2.0" },
2706
+ { name = "pandas", specifier = ">=2.2.3" },
2707
  { name = "scipy", specifier = ">=1.15.2" },
2708
  ]
2709