File size: 4,647 Bytes
ff42fba f7212d2 069b4ed eba970d 3627a6f 069b4ed 04ef4c9 bafb16f eba970d 391a015 bafb16f 53ede6c bafb16f 3627a6f 2fa52c3 bafb16f 53ede6c bafb16f 42cd95b 3627a6f dca9a76 f1eefe4 fea3815 42cd95b 797ee59 42cd95b bafb16f eba970d bafb16f de7eff6 d95af38 de7eff6 bafb16f de7eff6 f7212d2 069b4ed 3627a6f dca9a76 2fa52c3 3627a6f de7eff6 3627a6f bafb16f 3627a6f 42cd95b de7eff6 bafb16f de7eff6 bafb16f de7eff6 069b4ed bafb16f de7eff6 f7212d2 dca9a76 391a015 de7eff6 d95af38 3627a6f 797ee59 eba970d de7eff6 069b4ed de7eff6 797ee59 de7eff6 1c8413d de7eff6 1c8413d de7eff6 bafb16f 797ee59 de7eff6 eba970d 02388bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from __future__ import annotations
import torch
import torchaudio
import gradio as gr
import spaces
from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
DESCRIPTION = "Wav2Vec2_IndicConformer STT"
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Model Loading ---
print("Loading ASR model (IndicConformer)...")
asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
asr_model.eval()
print(" ASR Model loaded.")
print("\nLoading Language ID model (MMS-LID-1024)...")
lid_model_id = "facebook/mms-lid-1024"
lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device)
lid_model.eval()
print(" Language ID Model loaded.")
# --- Language Mappings ---
LID_TO_ASR_LANG_MAP = {
# MMS-style codes (e.g., hin_Deva)
"asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
"guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
"kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml",
"mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or",
"pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd",
"tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur",
"asm": "as", "ben": "bn", "brx": "br", "doi": "doi", "guj": "gu", "hin": "hi",
"kan": "kn", "kas": "ks", "gom": "kok", "mai": "mai", "mal": "ml", "mni": "mni",
"mar": "mr", "npi": "ne", "ory": "or", "pan": "pa", "san": "sa", "sat": "sat",
"snd": "sd", "tam": "ta", "tel": "te", "urd": "ur", "eng": "en"
}
ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu", "en": "English"}
@spaces.GPU
def transcribe_audio_with_lid(audio_path):
if not audio_path:
return "Please provide an audio file.", "", ""
try:
waveform, sr = torchaudio.load(audio_path)
waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
except Exception as e:
return f"Error loading audio: {e}", "", ""
try:
inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(device)
with torch.no_grad():
outputs = lid_model(**inputs)
logits = outputs[0]
predicted_lid_id = logits.argmax(-1).item()
detected_lid_code = lid_model.config.id2label[predicted_lid_id]
asr_lang_code = LID_TO_ASR_LANG_MAP.get(detected_lid_code)
if not asr_lang_code:
detected_lang_str = f"Detected '{detected_lid_code}', which is not supported by the ASR model."
return detected_lang_str, "N/A", "N/A"
detected_lang_str = f"Detected Language: {ASR_CODE_TO_NAME.get(asr_lang_code, 'Unknown')}"
with torch.no_grad():
transcription_ctc = asr_model(waveform_16k.to(device), asr_lang_code, "ctc")
transcription_rnnt = asr_model(waveform_16k.to(device), asr_lang_code, "rnnt")
except Exception as e:
return f"Error during processing: {str(e)}", "", ""
return detected_lang_str, transcription_ctc.strip(), transcription_rnnt.strip()
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"## {DESCRIPTION}")
gr.Markdown("Upload or record audio in any of the 22 supported Indian languages. The app will automatically detect the language and provide the transcription.")
with gr.Row():
with gr.Column(scale=1):
audio = gr.Audio(label="Upload or Record Audio", type="filepath")
transcribe_btn = gr.Button("Transcribe", variant="primary")
with gr.Column(scale=2):
detected_lang_output = gr.Label(label="Language Detection Result")
gr.Markdown("### RNNT Transcription")
rnnt_output = gr.Textbox(lines=3, label="RNNT Output")
gr.Markdown("### CTC Transcription")
ctc_output = gr.Textbox(lines=3, label="CTC Output")
transcribe_btn.click(
fn=transcribe_audio_with_lid,
inputs=[audio],
outputs=[detected_lang_output, ctc_output, rnnt_output],
api_name="transcribe"
)
if __name__ == "__main__":
demo.queue().launch() |