Spaces:
Running
Running
File size: 2,499 Bytes
eba970d f7212d2 069b4ed eba970d 069b4ed f7212d2 eba970d d95af38 069b4ed f7212d2 eba970d 069b4ed d95af38 069b4ed d95af38 eba970d 069b4ed f7212d2 069b4ed f7212d2 069b4ed f7212d2 069b4ed f7212d2 d95af38 eba970d 069b4ed eba970d 069b4ed eba970d 069b4ed eba970d b16608d eba970d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from __future__ import annotations
import torch
import torchaudio
import gradio as gr
import spaces
from transformers import AutoModel
DESCRIPTION = "IndicConformer-600M Multilingual ASR (CTC + RNNT)"
LANGUAGE_NAME_TO_CODE = {
"Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi",
"Gujarati": "gu", "Hindi": "hi", "Kannada": "kn", "Kashmiri": "ks",
"Konkani": "kok", "Maithili": "mai", "Malayalam": "ml", "Manipuri": "mni",
"Marathi": "mr", "Nepali": "ne", "Odia": "or", "Punjabi": "pa",
"Sanskrit": "sa", "Santali": "sat", "Sindhi": "sd", "Tamil": "ta",
"Telugu": "te", "Urdu": "ur"
}
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Indic Conformer model (assumes custom forward handles decoding strategy)
model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
model.eval()
@spaces.GPU
def transcribe_ctc_and_rnnt(audio_path, language_name):
lang_code = LANGUAGE_NAME_TO_CODE[language_name]
# Load and preprocess audio
waveform, sr = torchaudio.load(audio_path)
waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform
waveform = torchaudio.functional.resample(waveform, sr, 16000).to(device)
try:
# Assume model's forward method takes waveform, language code, and decoding type
with torch.no_grad():
transcription_ctc = model(waveform, lang_code, "ctc")
transcription_rnnt = model(waveform, lang_code, "rnnt")
except Exception as e:
return f"Error: {str(e)}", ""
return transcription_ctc.strip(), transcription_rnnt.strip()
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown(f"## {DESCRIPTION}")
with gr.Row():
with gr.Column():
audio = gr.Audio(label="Upload or Record Audio", type="filepath")
lang = gr.Dropdown(
label="Select Language",
choices=list(LANGUAGE_NAME_TO_CODE.keys()),
value="Hindi"
)
transcribe_btn = gr.Button("Transcribe (CTC + RNNT)")
with gr.Column():
gr.Markdown("### CTC Transcription")
ctc_output = gr.Textbox(lines=3)
gr.Markdown("### RNNT Transcription")
rnnt_output = gr.Textbox(lines=3)
transcribe_btn.click(fn=transcribe_ctc_and_rnnt, inputs=[audio, lang], outputs=[ctc_output, rnnt_output], api_name="transcribe")
if __name__ == "__main__":
demo.queue().launch()
|