File size: 2,499 Bytes
eba970d
f7212d2
 
069b4ed
eba970d
069b4ed
 
 
f7212d2
eba970d
 
 
 
 
 
 
 
 
d95af38
 
069b4ed
 
 
f7212d2
eba970d
 
069b4ed
d95af38
069b4ed
d95af38
eba970d
069b4ed
f7212d2
069b4ed
 
 
 
 
 
 
f7212d2
069b4ed
f7212d2
069b4ed
f7212d2
d95af38
eba970d
 
069b4ed
eba970d
069b4ed
 
eba970d
 
 
 
069b4ed
 
 
 
eba970d
b16608d
eba970d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations
import torch
import torchaudio
import gradio as gr
import spaces
from transformers import AutoModel

DESCRIPTION = "IndicConformer-600M Multilingual ASR (CTC + RNNT)"

LANGUAGE_NAME_TO_CODE = {
    "Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi",
    "Gujarati": "gu", "Hindi": "hi", "Kannada": "kn", "Kashmiri": "ks",
    "Konkani": "kok", "Maithili": "mai", "Malayalam": "ml", "Manipuri": "mni",
    "Marathi": "mr", "Nepali": "ne", "Odia": "or", "Punjabi": "pa",
    "Sanskrit": "sa", "Santali": "sat", "Sindhi": "sd", "Tamil": "ta",
    "Telugu": "te", "Urdu": "ur"
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Indic Conformer model (assumes custom forward handles decoding strategy)
model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
model.eval()

@spaces.GPU
def transcribe_ctc_and_rnnt(audio_path, language_name):
    lang_code = LANGUAGE_NAME_TO_CODE[language_name]

    # Load and preprocess audio
    waveform, sr = torchaudio.load(audio_path)
    waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform
    waveform = torchaudio.functional.resample(waveform, sr, 16000).to(device)

    try:
        # Assume model's forward method takes waveform, language code, and decoding type
        with torch.no_grad():
            transcription_ctc = model(waveform, lang_code, "ctc")
            transcription_rnnt = model(waveform, lang_code, "rnnt")
    except Exception as e:
        return f"Error: {str(e)}", ""

    return transcription_ctc.strip(), transcription_rnnt.strip()

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown(f"## {DESCRIPTION}")
    with gr.Row():
        with gr.Column():
            audio = gr.Audio(label="Upload or Record Audio", type="filepath")
            lang = gr.Dropdown(
                label="Select Language",
                choices=list(LANGUAGE_NAME_TO_CODE.keys()),
                value="Hindi"
            )
            transcribe_btn = gr.Button("Transcribe (CTC + RNNT)")
        with gr.Column():
            gr.Markdown("### CTC Transcription")
            ctc_output = gr.Textbox(lines=3)
            gr.Markdown("### RNNT Transcription")
            rnnt_output = gr.Textbox(lines=3)

    transcribe_btn.click(fn=transcribe_ctc_and_rnnt, inputs=[audio, lang], outputs=[ctc_output, rnnt_output], api_name="transcribe")

if __name__ == "__main__":
    demo.queue().launch()