from __future__ import annotations import torch import torchaudio import gradio as gr import spaces from transformers import AutoModel DESCRIPTION = "IndicConformer-600M Multilingual ASR (CTC + RNNT)" LANGUAGE_NAME_TO_CODE = { "Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi", "Gujarati": "gu", "Hindi": "hi", "Kannada": "kn", "Kashmiri": "ks", "Konkani": "kok", "Maithili": "mai", "Malayalam": "ml", "Manipuri": "mni", "Marathi": "mr", "Nepali": "ne", "Odia": "or", "Punjabi": "pa", "Sanskrit": "sa", "Santali": "sat", "Sindhi": "sd", "Tamil": "ta", "Telugu": "te", "Urdu": "ur" } device = "cuda" if torch.cuda.is_available() else "cpu" # Load Indic Conformer model (assumes custom forward handles decoding strategy) model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device) model.eval() @spaces.GPU def transcribe_ctc_and_rnnt(audio_path, language_name): lang_code = LANGUAGE_NAME_TO_CODE[language_name] # Load and preprocess audio waveform, sr = torchaudio.load(audio_path) waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform waveform = torchaudio.functional.resample(waveform, sr, 16000).to(device) try: # Assume model's forward method takes waveform, language code, and decoding type with torch.no_grad(): transcription_ctc = model(waveform, lang_code, "ctc") transcription_rnnt = model(waveform, lang_code, "rnnt") except Exception as e: return f"Error: {str(e)}", "" return transcription_ctc.strip(), transcription_rnnt.strip() # Gradio UI with gr.Blocks() as demo: gr.Markdown(f"## {DESCRIPTION}") with gr.Row(): with gr.Column(): audio = gr.Audio(label="Upload or Record Audio", type="filepath") lang = gr.Dropdown( label="Select Language", choices=list(LANGUAGE_NAME_TO_CODE.keys()), value="Hindi" ) transcribe_btn = gr.Button("Transcribe (CTC + RNNT)") with gr.Column(): gr.Markdown("### CTC Transcription") ctc_output = gr.Textbox(lines=3) gr.Markdown("### RNNT Transcription") rnnt_output = gr.Textbox(lines=3) transcribe_btn.click(fn=transcribe_ctc_and_rnnt, inputs=[audio, lang], outputs=[ctc_output, rnnt_output], api_name="transcribe") if __name__ == "__main__": demo.queue().launch()