TArtx's picture
Create app.py
2174b21 verified
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
import torchaudio
import os
# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
peft_model_id = "TArtx/MinD_CH_PEFT_ID"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
"BELLE-2/Belle-whisper-large-v3-zh",
device_map=None
).to(device)
model = PeftModel.from_pretrained(model, peft_model_id)
return model
def transcribe(audio_path):
if audio_path is None:
return "Please upload an audio file."
try:
# Load and resample audio
waveform, sample_rate = torchaudio.load(audio_path)
waveform = waveform.to(device)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Convert to numpy array
audio_array = waveform.squeeze().cpu().numpy()
# Process audio input
inputs = processor(
audio_array,
sampling_rate=16000,
return_tensors="pt"
).to(device)
# Generate transcription
predicted_ids = model.generate(**inputs)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
except Exception as e:
return f"Error during transcription: {str(e)}"
# Initialize model and processor
print("Loading model...")
model = load_model()
processor = WhisperProcessor.from_pretrained(
"BELLE-2/Belle-whisper-large-v3-zh",
language="Chinese",
task="transcribe"
)
print("Model loaded!")
# Create Gradio interface
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="Chinese-Mindong Speech Recognition",
description="Upload an audio file for transcription. Model optimized for Eastern Min dialect."
)
# Launch the interface
if __name__ == "__main__":
iface.launch()