File size: 2,388 Bytes
2174b21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import torchaudio
import os

# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model():
    peft_model_id = "TArtx/MinD_CH_PEFT_ID"
    peft_config = PeftConfig.from_pretrained(peft_model_id)
    
    model = WhisperForConditionalGeneration.from_pretrained(
        "BELLE-2/Belle-whisper-large-v3-zh",
        device_map=None
    ).to(device)
    
    model = PeftModel.from_pretrained(model, peft_model_id)
    return model

def transcribe(audio_path):
    if audio_path is None:
        return "Please upload an audio file."
    
    try:
        # Load and resample audio
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = waveform.to(device)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample to 16kHz if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)
        
        # Convert to numpy array
        audio_array = waveform.squeeze().cpu().numpy()
        
        # Process audio input
        inputs = processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        ).to(device)
        
        # Generate transcription
        predicted_ids = model.generate(**inputs)
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        return transcription
        
    except Exception as e:
        return f"Error during transcription: {str(e)}"

# Initialize model and processor
print("Loading model...")
model = load_model()
processor = WhisperProcessor.from_pretrained(
    "BELLE-2/Belle-whisper-large-v3-zh",
    language="Chinese",
    task="transcribe"
)
print("Model loaded!")

# Create Gradio interface
iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath"),
    outputs="text",
    title="Chinese-Mindong Speech Recognition",
    description="Upload an audio file for transcription. Model optimized for Eastern Min dialect."
)

# Launch the interface
if __name__ == "__main__":
    iface.launch()