TArtx commited on
Commit
2174b21
·
verified ·
1 Parent(s): 181a3df

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from peft import PeftModel, PeftConfig
3
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
4
+ import torch
5
+ import torchaudio
6
+ import os
7
+
8
+ # Check if CUDA is available and set the device
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ def load_model():
12
+ peft_model_id = "TArtx/MinD_CH_PEFT_ID"
13
+ peft_config = PeftConfig.from_pretrained(peft_model_id)
14
+
15
+ model = WhisperForConditionalGeneration.from_pretrained(
16
+ "BELLE-2/Belle-whisper-large-v3-zh",
17
+ device_map=None
18
+ ).to(device)
19
+
20
+ model = PeftModel.from_pretrained(model, peft_model_id)
21
+ return model
22
+
23
+ def transcribe(audio_path):
24
+ if audio_path is None:
25
+ return "Please upload an audio file."
26
+
27
+ try:
28
+ # Load and resample audio
29
+ waveform, sample_rate = torchaudio.load(audio_path)
30
+ waveform = waveform.to(device)
31
+
32
+ # Convert to mono if stereo
33
+ if waveform.shape[0] > 1:
34
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
35
+
36
+ # Resample to 16kHz if needed
37
+ if sample_rate != 16000:
38
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
39
+ waveform = resampler(waveform)
40
+
41
+ # Convert to numpy array
42
+ audio_array = waveform.squeeze().cpu().numpy()
43
+
44
+ # Process audio input
45
+ inputs = processor(
46
+ audio_array,
47
+ sampling_rate=16000,
48
+ return_tensors="pt"
49
+ ).to(device)
50
+
51
+ # Generate transcription
52
+ predicted_ids = model.generate(**inputs)
53
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
54
+ return transcription
55
+
56
+ except Exception as e:
57
+ return f"Error during transcription: {str(e)}"
58
+
59
+ # Initialize model and processor
60
+ print("Loading model...")
61
+ model = load_model()
62
+ processor = WhisperProcessor.from_pretrained(
63
+ "BELLE-2/Belle-whisper-large-v3-zh",
64
+ language="Chinese",
65
+ task="transcribe"
66
+ )
67
+ print("Model loaded!")
68
+
69
+ # Create Gradio interface
70
+ iface = gr.Interface(
71
+ fn=transcribe,
72
+ inputs=gr.Audio(type="filepath"),
73
+ outputs="text",
74
+ title="Chinese-Mindong Speech Recognition",
75
+ description="Upload an audio file for transcription. Model optimized for Eastern Min dialect."
76
+ )
77
+
78
+ # Launch the interface
79
+ if __name__ == "__main__":
80
+ iface.launch()