File size: 4,647 Bytes
ff42fba
f7212d2
 
069b4ed
eba970d
3627a6f
069b4ed
04ef4c9
bafb16f
eba970d
391a015
bafb16f
 
 
 
53ede6c
bafb16f
3627a6f
 
 
2fa52c3
bafb16f
53ede6c
bafb16f
 
 
 
42cd95b
3627a6f
 
 
 
 
dca9a76
f1eefe4
 
fea3815
42cd95b
797ee59
 
42cd95b
bafb16f
 
eba970d
bafb16f
de7eff6
 
d95af38
de7eff6
 
bafb16f
de7eff6
 
f7212d2
069b4ed
3627a6f
 
 
dca9a76
2fa52c3
 
3627a6f
de7eff6
3627a6f
bafb16f
3627a6f
42cd95b
 
de7eff6
bafb16f
de7eff6
 
bafb16f
 
de7eff6
069b4ed
bafb16f
de7eff6
 
f7212d2
dca9a76
391a015
de7eff6
d95af38
3627a6f
797ee59
eba970d
de7eff6
069b4ed
de7eff6
797ee59
de7eff6
 
1c8413d
de7eff6
1c8413d
de7eff6
 
 
bafb16f
797ee59
de7eff6
 
 
eba970d
 
02388bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations
import torch
import torchaudio
import gradio as gr
import spaces
from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor

DESCRIPTION = "Wav2Vec2_IndicConformer STT"
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Model Loading ---
print("Loading ASR model (IndicConformer)...")
asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
asr_model.eval()
print(" ASR Model loaded.")

print("\nLoading Language ID model (MMS-LID-1024)...")
lid_model_id = "facebook/mms-lid-1024"
lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device)
lid_model.eval()
print(" Language ID Model loaded.")


# --- Language Mappings ---
LID_TO_ASR_LANG_MAP = {
    # MMS-style codes (e.g., hin_Deva)
    "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
    "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
    "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml",
    "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or",
    "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd",
    "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur",
    "asm": "as", "ben": "bn", "brx": "br", "doi": "doi", "guj": "gu", "hin": "hi",
    "kan": "kn", "kas": "ks", "gom": "kok", "mai": "mai", "mal": "ml", "mni": "mni",
    "mar": "mr", "npi": "ne", "ory": "or", "pan": "pa", "san": "sa", "sat": "sat",
    "snd": "sd", "tam": "ta", "tel": "te", "urd": "ur", "eng": "en"
}

ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu", "en": "English"}


@spaces.GPU
def transcribe_audio_with_lid(audio_path):
    if not audio_path:
        return "Please provide an audio file.", "", ""

    try:
        waveform, sr = torchaudio.load(audio_path)
        waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
    except Exception as e:
        return f"Error loading audio: {e}", "", ""

    try:
        inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = lid_model(**inputs)

        logits = outputs[0]
        predicted_lid_id = logits.argmax(-1).item()
        detected_lid_code = lid_model.config.id2label[predicted_lid_id]

        asr_lang_code = LID_TO_ASR_LANG_MAP.get(detected_lid_code)
        
        if not asr_lang_code:
            detected_lang_str = f"Detected '{detected_lid_code}', which is not supported by the ASR model."
            return detected_lang_str, "N/A", "N/A"

        detected_lang_str = f"Detected Language: {ASR_CODE_TO_NAME.get(asr_lang_code, 'Unknown')}"

        with torch.no_grad():
            transcription_ctc = asr_model(waveform_16k.to(device), asr_lang_code, "ctc")
            transcription_rnnt = asr_model(waveform_16k.to(device), asr_lang_code, "rnnt")

    except Exception as e:
        return f"Error during processing: {str(e)}", "", ""

    return detected_lang_str, transcription_ctc.strip(), transcription_rnnt.strip()


# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"## {DESCRIPTION}")
    gr.Markdown("Upload or record audio in any of the 22 supported Indian languages. The app will automatically detect the language and provide the transcription.")
    
    with gr.Row():
        with gr.Column(scale=1):
            audio = gr.Audio(label="Upload or Record Audio", type="filepath")
            transcribe_btn = gr.Button("Transcribe", variant="primary")
        
        with gr.Column(scale=2):
            detected_lang_output = gr.Label(label="Language Detection Result")
            gr.Markdown("### RNNT Transcription")
            rnnt_output = gr.Textbox(lines=3, label="RNNT Output")
            gr.Markdown("### CTC Transcription")
            ctc_output = gr.Textbox(lines=3, label="CTC Output")

    transcribe_btn.click(
        fn=transcribe_audio_with_lid, 
        inputs=[audio], 
        outputs=[detected_lang_output, ctc_output, rnnt_output],
        api_name="transcribe"
    )

if __name__ == "__main__":
    demo.queue().launch()