|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import ( |
|
|
AutoModelForSpeechSeq2Seq, |
|
|
AutoProcessor, |
|
|
AutoModelForCTC, |
|
|
AutoModel, |
|
|
WhisperProcessor, |
|
|
WhisperForConditionalGeneration, |
|
|
) |
|
|
import librosa |
|
|
import numpy as np |
|
|
from jiwer import wer, cer |
|
|
import time |
|
|
|
|
|
|
|
|
LANGUAGE_CONFIGS = { |
|
|
"Hindi (हिंदी)": { |
|
|
"code": "hi", |
|
|
"script": "Devanagari", |
|
|
"models": ["AudioX-North", "IndicConformer", "MMS"] |
|
|
}, |
|
|
"Gujarati (ગુજરાતી)": { |
|
|
"code": "gu", |
|
|
"script": "Gujarati", |
|
|
"models": ["AudioX-North", "IndicConformer", "MMS"] |
|
|
}, |
|
|
"Marathi (मराठी)": { |
|
|
"code": "mr", |
|
|
"script": "Devanagari", |
|
|
"models": ["AudioX-North", "IndicConformer", "MMS"] |
|
|
}, |
|
|
"Tamil (தமிழ்)": { |
|
|
"code": "ta", |
|
|
"script": "Tamil", |
|
|
"models": ["AudioX-South", "IndicConformer", "MMS"] |
|
|
}, |
|
|
"Telugu (తెలుగు)": { |
|
|
"code": "te", |
|
|
"script": "Telugu", |
|
|
"models": ["AudioX-South", "IndicConformer", "MMS"] |
|
|
}, |
|
|
"Kannada (ಕನ್ನಡ)": { |
|
|
"code": "kn", |
|
|
"script": "Kannada", |
|
|
"models": ["AudioX-South", "IndicConformer", "MMS"] |
|
|
}, |
|
|
"Malayalam (മലയാളം)": { |
|
|
"code": "ml", |
|
|
"script": "Malayalam", |
|
|
"models": ["AudioX-South", "IndicConformer", "MMS"] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
"AudioX-North": { |
|
|
"repo": "jiviai/audioX-north-v1", |
|
|
"model_type": "whisper", |
|
|
"description": "Supports Hindi, Gujarati, Marathi", |
|
|
"languages": ["hi", "gu", "mr"] |
|
|
}, |
|
|
"AudioX-South": { |
|
|
"repo": "jiviai/audioX-south-v1", |
|
|
"model_type": "whisper", |
|
|
"description": "Supports Tamil, Telugu, Kannada, Malayalam", |
|
|
"languages": ["ta", "te", "kn", "ml"] |
|
|
}, |
|
|
"IndicConformer": { |
|
|
"repo": "ai4bharat/indic-conformer-600m-multilingual", |
|
|
"model_type": "ctc_rnnt", |
|
|
"description": "Supports 22 Indian languages", |
|
|
"trust_remote_code": True, |
|
|
"languages": ["hi", "gu", "mr", "ta", "te", "kn", "ml", "bn", "pa", "or", "as", "ur"] |
|
|
}, |
|
|
"MMS": { |
|
|
"repo": "facebook/mms-1b-all", |
|
|
"model_type": "ctc", |
|
|
"description": "Supports 1,400+ languages", |
|
|
"languages": ["hi", "gu", "mr", "ta", "te", "kn", "ml"] |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def load_model_and_processor(model_name): |
|
|
config = MODEL_CONFIGS[model_name] |
|
|
repo = config["repo"] |
|
|
model_type = config["model_type"] |
|
|
trust_remote_code = config.get("trust_remote_code", False) |
|
|
|
|
|
try: |
|
|
if model_name == "IndicConformer": |
|
|
print(f"Loading {model_name}...") |
|
|
try: |
|
|
model = AutoModel.from_pretrained( |
|
|
repo, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
except Exception as e1: |
|
|
print(f"Primary loading failed, trying fallback: {e1}") |
|
|
model = AutoModel.from_pretrained(repo, trust_remote_code=True) |
|
|
processor = None |
|
|
return model, processor, model_type |
|
|
|
|
|
elif model_name in ["AudioX-North", "AudioX-South"]: |
|
|
|
|
|
processor = WhisperProcessor.from_pretrained(repo) |
|
|
model = WhisperForConditionalGeneration.from_pretrained(repo) |
|
|
model.config.forced_decoder_ids = None |
|
|
return model, processor, model_type |
|
|
|
|
|
elif model_name == "MMS": |
|
|
model = AutoModelForCTC.from_pretrained(repo) |
|
|
processor = AutoProcessor.from_pretrained(repo) |
|
|
return model, processor, model_type |
|
|
|
|
|
except Exception as e: |
|
|
return None, None, f"Error loading model: {str(e)}" |
|
|
|
|
|
|
|
|
def compute_metrics(reference, hypothesis, audio_duration, total_time): |
|
|
if not reference or not hypothesis: |
|
|
return None, None, None, None |
|
|
try: |
|
|
reference = reference.strip().lower() |
|
|
hypothesis = hypothesis.strip().lower() |
|
|
wer_score = wer(reference, hypothesis) |
|
|
cer_score = cer(reference, hypothesis) |
|
|
rtf = total_time / audio_duration if audio_duration > 0 else None |
|
|
return wer_score, cer_score, rtf, total_time |
|
|
except Exception: |
|
|
return None, None, None, None |
|
|
|
|
|
|
|
|
def transcribe_audio(audio_file, selected_language, selected_models, reference_text=""): |
|
|
if not audio_file: |
|
|
return "Please upload an audio file.", [], "" |
|
|
|
|
|
if not selected_models: |
|
|
return "Please select at least one model.", [], "" |
|
|
|
|
|
if not selected_language: |
|
|
return "Please select a language.", [], "" |
|
|
|
|
|
|
|
|
lang_info = LANGUAGE_CONFIGS[selected_language] |
|
|
lang_code = lang_info["code"] |
|
|
|
|
|
table_data = [] |
|
|
try: |
|
|
|
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
|
audio_duration = len(audio) / sr |
|
|
|
|
|
for model_name in selected_models: |
|
|
|
|
|
if model_name.replace("AudioX-", "AudioX-") not in lang_info["models"]: |
|
|
table_data.append([ |
|
|
model_name, |
|
|
f"Language {selected_language} not supported by this model", |
|
|
"-", "-", "-", "-" |
|
|
]) |
|
|
continue |
|
|
|
|
|
model, processor, model_type = load_model_and_processor(model_name) |
|
|
if isinstance(model_type, str) and model_type.startswith("Error"): |
|
|
table_data.append([ |
|
|
model_name, |
|
|
f"Error: {model_type}", |
|
|
"-", "-", "-", "-" |
|
|
]) |
|
|
continue |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
if model_name == "IndicConformer": |
|
|
|
|
|
wav = torch.from_numpy(audio).unsqueeze(0) |
|
|
if torch.max(torch.abs(wav)) > 0: |
|
|
wav = wav / torch.max(torch.abs(wav)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
transcription = model(wav, lang_code, "rnnt") |
|
|
if isinstance(transcription, list): |
|
|
transcription = transcription[0] if transcription else "" |
|
|
transcription = str(transcription).strip() |
|
|
|
|
|
elif model_name in ["AudioX-North", "AudioX-South"]: |
|
|
|
|
|
if sr != 16000: |
|
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) |
|
|
|
|
|
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features |
|
|
|
|
|
with torch.no_grad(): |
|
|
predicted_ids = model.generate( |
|
|
input_features, |
|
|
task="transcribe", |
|
|
language=lang_code |
|
|
) |
|
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
else: |
|
|
|
|
|
inputs = processor(audio, sampling_rate=16000, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
input_values = inputs["input_values"] |
|
|
logits = model(input_values).logits |
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
except Exception as e: |
|
|
transcription = f"Processing error: {str(e)}" |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
|
|
|
wer_score, cer_score, rtf = "-", "-", "-" |
|
|
if reference_text and transcription and not transcription.startswith("Processing error"): |
|
|
wer_val, cer_val, rtf_val, _ = compute_metrics( |
|
|
reference_text, transcription, audio_duration, total_time |
|
|
) |
|
|
wer_score = f"{wer_val:.3f}" if wer_val is not None else "-" |
|
|
cer_score = f"{cer_val:.3f}" if cer_val is not None else "-" |
|
|
rtf = f"{rtf_val:.3f}" if rtf_val is not None else "-" |
|
|
|
|
|
|
|
|
table_data.append([ |
|
|
model_name, |
|
|
transcription, |
|
|
wer_score, |
|
|
cer_score, |
|
|
rtf, |
|
|
f"{total_time:.2f}s" |
|
|
]) |
|
|
|
|
|
|
|
|
summary = f"**Language:** {selected_language} ({lang_code})\n" |
|
|
summary += f"**Audio Duration:** {audio_duration:.2f}s\n" |
|
|
summary += f"**Models Tested:** {len(selected_models)}\n" |
|
|
if reference_text: |
|
|
summary += f"**Reference Text:** {reference_text[:100]}{'...' if len(reference_text) > 100 else ''}\n" |
|
|
|
|
|
|
|
|
copyable_text = "MULTILINGUAL SPEECH-TO-TEXT BENCHMARK RESULTS\n" + "="*55 + "\n\n" |
|
|
copyable_text += f"Language: {selected_language} ({lang_code})\n" |
|
|
copyable_text += f"Script: {lang_info['script']}\n" |
|
|
copyable_text += f"Audio Duration: {audio_duration:.2f}s\n" |
|
|
copyable_text += f"Models Tested: {len(selected_models)}\n" |
|
|
if reference_text: |
|
|
copyable_text += f"Reference Text: {reference_text}\n" |
|
|
copyable_text += "\n" + "-"*55 + "\n\n" |
|
|
|
|
|
for i, row in enumerate(table_data): |
|
|
copyable_text += f"MODEL {i+1}: {row[0]}\n" |
|
|
copyable_text += f"Transcription: {row[1]}\n" |
|
|
copyable_text += f"WER: {row[2]}\n" |
|
|
copyable_text += f"CER: {row[3]}\n" |
|
|
copyable_text += f"RTF: {row[4]}\n" |
|
|
copyable_text += f"Time Taken: {row[5]}\n" |
|
|
copyable_text += "\n" + "-"*35 + "\n\n" |
|
|
|
|
|
return summary, table_data, copyable_text |
|
|
except Exception as e: |
|
|
error_msg = f"Error during transcription: {str(e)}" |
|
|
return error_msg, [], error_msg |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
language_choices = list(LANGUAGE_CONFIGS.keys()) |
|
|
|
|
|
with gr.Blocks(title="Multilingual Speech-to-Text Benchmark", css=""" |
|
|
.language-info { background: #f0f8ff; padding: 10px; border-radius: 5px; margin: 10px 0; } |
|
|
.copy-area { font-family: monospace; font-size: 12px; } |
|
|
""") as iface: |
|
|
gr.Markdown(""" |
|
|
# 🌐 Multilingual Speech-to-Text Benchmark |
|
|
|
|
|
Compare ASR models across **7 Indian Languages** with comprehensive metrics. |
|
|
|
|
|
**Supported Languages:** Hindi, Gujarati, Marathi, Tamil, Telugu, Kannada, Malayalam |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
language_selection = gr.Dropdown( |
|
|
choices=language_choices, |
|
|
label="🗣️ Select Language", |
|
|
value=language_choices[0], |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
audio_input = gr.Audio( |
|
|
label="📹 Upload Audio File (16kHz recommended)", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
|
|
|
model_selection = gr.CheckboxGroup( |
|
|
choices=["AudioX-North", "IndicConformer", "MMS"], |
|
|
label="🤖 Select Models", |
|
|
value=["AudioX-North", "IndicConformer"], |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
reference_input = gr.Textbox( |
|
|
label="📄 Reference Text (optional, paste supported)", |
|
|
placeholder="Paste reference transcription here...", |
|
|
lines=4, |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("🚀 Run Multilingual Benchmark", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
summary_output = gr.Markdown( |
|
|
label="📊 Summary", |
|
|
value="Select language, upload audio file and choose models to begin..." |
|
|
) |
|
|
|
|
|
results_table = gr.Dataframe( |
|
|
headers=["Model", "Transcription", "WER", "CER", "RTF", "Time"], |
|
|
datatype=["str", "str", "str", "str", "str", "str"], |
|
|
label="🏆 Results Comparison", |
|
|
interactive=False, |
|
|
wrap=True, |
|
|
column_widths=[120, 350, 60, 60, 60, 80] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### 📋 Export Results") |
|
|
copyable_output = gr.Textbox( |
|
|
label="Copy-Paste Friendly Results", |
|
|
lines=12, |
|
|
max_lines=25, |
|
|
show_copy_button=True, |
|
|
interactive=False, |
|
|
elem_classes="copy-area", |
|
|
placeholder="Benchmark results will appear here..." |
|
|
) |
|
|
|
|
|
|
|
|
def update_model_choices(selected_language): |
|
|
if not selected_language: |
|
|
return gr.CheckboxGroup(choices=[], value=[]) |
|
|
|
|
|
lang_info = LANGUAGE_CONFIGS[selected_language] |
|
|
available_models = lang_info["models"] |
|
|
|
|
|
|
|
|
model_map = { |
|
|
"AudioX-North": "AudioX-North", |
|
|
"AudioX-South": "AudioX-South", |
|
|
"IndicConformer": "IndicConformer", |
|
|
"MMS": "MMS" |
|
|
} |
|
|
|
|
|
available_choices = [model_map[model] for model in available_models if model in model_map] |
|
|
default_selection = available_choices[:2] if len(available_choices) >= 2 else available_choices |
|
|
|
|
|
return gr.CheckboxGroup(choices=available_choices, value=default_selection) |
|
|
|
|
|
|
|
|
language_selection.change( |
|
|
fn=update_model_choices, |
|
|
inputs=[language_selection], |
|
|
outputs=[model_selection] |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=transcribe_audio, |
|
|
inputs=[audio_input, language_selection, model_selection, reference_input], |
|
|
outputs=[summary_output, results_table, copyable_output] |
|
|
) |
|
|
|
|
|
reference_input.submit( |
|
|
fn=transcribe_audio, |
|
|
inputs=[audio_input, language_selection, model_selection, reference_input], |
|
|
outputs=[summary_output, results_table, copyable_output] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### 📤 Language & Model Support Matrix |
|
|
|
|
|
| Language | Script | AudioX-North | AudioX-South | IndicConformer | MMS | |
|
|
|----------|---------|-------------|-------------|---------------|-----| |
|
|
| Hindi | Devanagari | ✅ | ❌ | ✅ | ✅ | |
|
|
| Gujarati | Gujarati | ✅ | ❌ | ✅ | ✅ | |
|
|
| Marathi | Devanagari | ✅ | ❌ | ✅ | ✅ | |
|
|
| Tamil | Tamil | ❌ | ✅ | ✅ | ✅ | |
|
|
| Telugu | Telugu | ❌ | ✅ | ✅ | ✅ | |
|
|
| Kannada | Kannada | ❌ | ✅ | ✅ | ✅ | |
|
|
| Malayalam | Malayalam | ❌ | ✅ | ✅ | ✅ | |
|
|
|
|
|
### 💡 Tips: |
|
|
- **Models auto-filter** based on selected language |
|
|
- **Reference Text**: Enable WER/CER calculation by providing ground truth |
|
|
- **Copy Results**: Export formatted results using the copy button |
|
|
- **Best Performance**: Use AudioX models for their specialized languages |
|
|
""") |
|
|
|
|
|
return iface |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface = create_interface() |
|
|
iface.launch( |
|
|
share=False, |
|
|
debug=True, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
show_error=True |
|
|
) |