Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,20 +1,23 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
import
|
4 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
5 |
|
6 |
# Load model
|
7 |
-
processor = Wav2Vec2Processor.from_pretrained("
|
8 |
-
model = Wav2Vec2ForCTC.from_pretrained("
|
9 |
|
10 |
def transcribe(audio):
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
-
#
|
17 |
-
inputs = processor(
|
18 |
with torch.no_grad():
|
19 |
logits = model(**inputs).logits
|
20 |
predicted_ids = torch.argmax(logits, dim=-1)
|
@@ -24,8 +27,8 @@ def transcribe(audio):
|
|
24 |
# Gradio UI
|
25 |
gr.Interface(
|
26 |
fn=transcribe,
|
27 |
-
inputs=gr.Audio(type="
|
28 |
outputs="text",
|
29 |
-
title="
|
30 |
-
description="
|
31 |
).launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import numpy as np
|
4 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
5 |
|
6 |
# Load model
|
7 |
+
processor = Wav2Vec2Processor.from_pretrained("FPTAI/vietnamese-wav2vec2-base")
|
8 |
+
model = Wav2Vec2ForCTC.from_pretrained("FPTAI/vietnamese-wav2vec2-base")
|
9 |
|
10 |
def transcribe(audio):
|
11 |
+
if audio is None:
|
12 |
+
return "Không có âm thanh."
|
13 |
+
|
14 |
+
# Gradio trả về (sample_rate, numpy_array)
|
15 |
+
sample_rate, audio_data = audio
|
16 |
+
if sample_rate != 16000:
|
17 |
+
return f"Sample rate đang là {sample_rate}Hz. Vui lòng nói lại sau khi chọn 16kHz."
|
18 |
|
19 |
+
# Chuyển sang tensor
|
20 |
+
inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt", padding=True)
|
21 |
with torch.no_grad():
|
22 |
logits = model(**inputs).logits
|
23 |
predicted_ids = torch.argmax(logits, dim=-1)
|
|
|
27 |
# Gradio UI
|
28 |
gr.Interface(
|
29 |
fn=transcribe,
|
30 |
+
inputs=gr.Audio(sources=["microphone"], type="numpy", label="Ghi âm từ micro (16kHz mono)"),
|
31 |
outputs="text",
|
32 |
+
title="STT Tiếng Việt với Wav2Vec2",
|
33 |
+
description="Ghi âm và nhận dạng giọng nói tiếng Việt bằng mô hình FPTAI/wav2vec2-base"
|
34 |
).launch()
|