vankienemk commited on
Commit
f8da254
·
verified ·
1 Parent(s): d0b7fd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,20 +1,23 @@
1
  import gradio as gr
2
  import torch
3
- import soundfile as sf
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
  # Load model
7
- processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
8
- model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
9
 
10
  def transcribe(audio):
11
- # Load audio
12
- speech, rate = sf.read(audio)
13
- if rate != 16000:
14
- return "Vui lòng cung cấp file audio 16kHz."
 
 
 
15
 
16
- # Preprocess and predict
17
- inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
18
  with torch.no_grad():
19
  logits = model(**inputs).logits
20
  predicted_ids = torch.argmax(logits, dim=-1)
@@ -24,8 +27,8 @@ def transcribe(audio):
24
  # Gradio UI
25
  gr.Interface(
26
  fn=transcribe,
27
- inputs=gr.Audio(type="filepath", label="Upload audio (16kHz, mono)"),
28
  outputs="text",
29
- title="Wav2Vec2 Vietnamese STT",
30
- description="Nhận dạng giọng nói tiếng Việt bằng mô hình wav2vec2-base-vietnamese-250h từ VLSP."
31
  ).launch()
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
  # Load model
7
+ processor = Wav2Vec2Processor.from_pretrained("FPTAI/vietnamese-wav2vec2-base")
8
+ model = Wav2Vec2ForCTC.from_pretrained("FPTAI/vietnamese-wav2vec2-base")
9
 
10
  def transcribe(audio):
11
+ if audio is None:
12
+ return "Không âm thanh."
13
+
14
+ # Gradio trả về (sample_rate, numpy_array)
15
+ sample_rate, audio_data = audio
16
+ if sample_rate != 16000:
17
+ return f"Sample rate đang là {sample_rate}Hz. Vui lòng nói lại sau khi chọn 16kHz."
18
 
19
+ # Chuyển sang tensor
20
+ inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt", padding=True)
21
  with torch.no_grad():
22
  logits = model(**inputs).logits
23
  predicted_ids = torch.argmax(logits, dim=-1)
 
27
  # Gradio UI
28
  gr.Interface(
29
  fn=transcribe,
30
+ inputs=gr.Audio(sources=["microphone"], type="numpy", label="Ghi âm từ micro (16kHz mono)"),
31
  outputs="text",
32
+ title="STT Tiếng Việt với Wav2Vec2",
33
+ description="Ghi âm và nhận dạng giọng nói tiếng Việt bằng mô hình FPTAI/wav2vec2-base"
34
  ).launch()