AvtnshM's picture
v25, added Malyalam
609956c verified
import gradio as gr
import torch
import torchaudio
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForCTC,
AutoModel,
WhisperProcessor,
WhisperForConditionalGeneration,
)
import librosa
import numpy as np
from jiwer import wer, cer
import time
# Language configurations
LANGUAGE_CONFIGS = {
"Hindi (हिंदी)": {
"code": "hi",
"script": "Devanagari",
"models": ["AudioX-North", "IndicConformer", "MMS"]
},
"Gujarati (ગુજરાતી)": {
"code": "gu",
"script": "Gujarati",
"models": ["AudioX-North", "IndicConformer", "MMS"]
},
"Marathi (मराठी)": {
"code": "mr",
"script": "Devanagari",
"models": ["AudioX-North", "IndicConformer", "MMS"]
},
"Tamil (தமிழ்)": {
"code": "ta",
"script": "Tamil",
"models": ["AudioX-South", "IndicConformer", "MMS"]
},
"Telugu (తెలుగు)": {
"code": "te",
"script": "Telugu",
"models": ["AudioX-South", "IndicConformer", "MMS"]
},
"Kannada (ಕನ್ನಡ)": {
"code": "kn",
"script": "Kannada",
"models": ["AudioX-South", "IndicConformer", "MMS"]
},
"Malayalam (മലയാളം)": {
"code": "ml",
"script": "Malayalam",
"models": ["AudioX-South", "IndicConformer", "MMS"]
}
}
# Model configurations
MODEL_CONFIGS = {
"AudioX-North": {
"repo": "jiviai/audioX-north-v1",
"model_type": "whisper",
"description": "Supports Hindi, Gujarati, Marathi",
"languages": ["hi", "gu", "mr"]
},
"AudioX-South": {
"repo": "jiviai/audioX-south-v1",
"model_type": "whisper",
"description": "Supports Tamil, Telugu, Kannada, Malayalam",
"languages": ["ta", "te", "kn", "ml"]
},
"IndicConformer": {
"repo": "ai4bharat/indic-conformer-600m-multilingual",
"model_type": "ctc_rnnt",
"description": "Supports 22 Indian languages",
"trust_remote_code": True,
"languages": ["hi", "gu", "mr", "ta", "te", "kn", "ml", "bn", "pa", "or", "as", "ur"]
},
"MMS": {
"repo": "facebook/mms-1b-all",
"model_type": "ctc",
"description": "Supports 1,400+ languages",
"languages": ["hi", "gu", "mr", "ta", "te", "kn", "ml"]
},
}
# Load model and processor
def load_model_and_processor(model_name):
config = MODEL_CONFIGS[model_name]
repo = config["repo"]
model_type = config["model_type"]
trust_remote_code = config.get("trust_remote_code", False)
try:
if model_name == "IndicConformer":
print(f"Loading {model_name}...")
try:
model = AutoModel.from_pretrained(
repo,
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
except Exception as e1:
print(f"Primary loading failed, trying fallback: {e1}")
model = AutoModel.from_pretrained(repo, trust_remote_code=True)
processor = None
return model, processor, model_type
elif model_name in ["AudioX-North", "AudioX-South"]:
# Use Whisper processor and model for AudioX variants
processor = WhisperProcessor.from_pretrained(repo)
model = WhisperForConditionalGeneration.from_pretrained(repo)
model.config.forced_decoder_ids = None
return model, processor, model_type
elif model_name == "MMS":
model = AutoModelForCTC.from_pretrained(repo)
processor = AutoProcessor.from_pretrained(repo)
return model, processor, model_type
except Exception as e:
return None, None, f"Error loading model: {str(e)}"
# Compute metrics (WER, CER, RTF)
def compute_metrics(reference, hypothesis, audio_duration, total_time):
if not reference or not hypothesis:
return None, None, None, None
try:
reference = reference.strip().lower()
hypothesis = hypothesis.strip().lower()
wer_score = wer(reference, hypothesis)
cer_score = cer(reference, hypothesis)
rtf = total_time / audio_duration if audio_duration > 0 else None
return wer_score, cer_score, rtf, total_time
except Exception:
return None, None, None, None
# Main transcription function
def transcribe_audio(audio_file, selected_language, selected_models, reference_text=""):
if not audio_file:
return "Please upload an audio file.", [], ""
if not selected_models:
return "Please select at least one model.", [], ""
if not selected_language:
return "Please select a language.", [], ""
# Get language info
lang_info = LANGUAGE_CONFIGS[selected_language]
lang_code = lang_info["code"]
table_data = []
try:
# Load and preprocess audio once
audio, sr = librosa.load(audio_file, sr=16000)
audio_duration = len(audio) / sr
for model_name in selected_models:
# Check if model supports the selected language
if model_name.replace("AudioX-", "AudioX-") not in lang_info["models"]:
table_data.append([
model_name,
f"Language {selected_language} not supported by this model",
"-", "-", "-", "-"
])
continue
model, processor, model_type = load_model_and_processor(model_name)
if isinstance(model_type, str) and model_type.startswith("Error"):
table_data.append([
model_name,
f"Error: {model_type}",
"-", "-", "-", "-"
])
continue
start_time = time.time()
try:
if model_name == "IndicConformer":
# AI4Bharat specific processing
wav = torch.from_numpy(audio).unsqueeze(0)
if torch.max(torch.abs(wav)) > 0:
wav = wav / torch.max(torch.abs(wav))
with torch.no_grad():
transcription = model(wav, lang_code, "rnnt")
if isinstance(transcription, list):
transcription = transcription[0] if transcription else ""
transcription = str(transcription).strip()
elif model_name in ["AudioX-North", "AudioX-South"]:
# AudioX Whisper-based processing
if sr != 16000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features
with torch.no_grad():
predicted_ids = model.generate(
input_features,
task="transcribe",
language=lang_code
)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
else: # MMS
# Standard CTC processing for MMS
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
input_values = inputs["input_values"]
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
except Exception as e:
transcription = f"Processing error: {str(e)}"
total_time = time.time() - start_time
# Compute metrics
wer_score, cer_score, rtf = "-", "-", "-"
if reference_text and transcription and not transcription.startswith("Processing error"):
wer_val, cer_val, rtf_val, _ = compute_metrics(
reference_text, transcription, audio_duration, total_time
)
wer_score = f"{wer_val:.3f}" if wer_val is not None else "-"
cer_score = f"{cer_val:.3f}" if cer_val is not None else "-"
rtf = f"{rtf_val:.3f}" if rtf_val is not None else "-"
# Add row to table
table_data.append([
model_name,
transcription,
wer_score,
cer_score,
rtf,
f"{total_time:.2f}s"
])
# Create summary text
summary = f"**Language:** {selected_language} ({lang_code})\n"
summary += f"**Audio Duration:** {audio_duration:.2f}s\n"
summary += f"**Models Tested:** {len(selected_models)}\n"
if reference_text:
summary += f"**Reference Text:** {reference_text[:100]}{'...' if len(reference_text) > 100 else ''}\n"
# Create copyable text output
copyable_text = "MULTILINGUAL SPEECH-TO-TEXT BENCHMARK RESULTS\n" + "="*55 + "\n\n"
copyable_text += f"Language: {selected_language} ({lang_code})\n"
copyable_text += f"Script: {lang_info['script']}\n"
copyable_text += f"Audio Duration: {audio_duration:.2f}s\n"
copyable_text += f"Models Tested: {len(selected_models)}\n"
if reference_text:
copyable_text += f"Reference Text: {reference_text}\n"
copyable_text += "\n" + "-"*55 + "\n\n"
for i, row in enumerate(table_data):
copyable_text += f"MODEL {i+1}: {row[0]}\n"
copyable_text += f"Transcription: {row[1]}\n"
copyable_text += f"WER: {row[2]}\n"
copyable_text += f"CER: {row[3]}\n"
copyable_text += f"RTF: {row[4]}\n"
copyable_text += f"Time Taken: {row[5]}\n"
copyable_text += "\n" + "-"*35 + "\n\n"
return summary, table_data, copyable_text
except Exception as e:
error_msg = f"Error during transcription: {str(e)}"
return error_msg, [], error_msg
# Create Gradio interface
def create_interface():
language_choices = list(LANGUAGE_CONFIGS.keys())
with gr.Blocks(title="Multilingual Speech-to-Text Benchmark", css="""
.language-info { background: #f0f8ff; padding: 10px; border-radius: 5px; margin: 10px 0; }
.copy-area { font-family: monospace; font-size: 12px; }
""") as iface:
gr.Markdown("""
# 🌐 Multilingual Speech-to-Text Benchmark
Compare ASR models across **7 Indian Languages** with comprehensive metrics.
**Supported Languages:** Hindi, Gujarati, Marathi, Tamil, Telugu, Kannada, Malayalam
""")
with gr.Row():
with gr.Column(scale=1):
# Language selection
language_selection = gr.Dropdown(
choices=language_choices,
label="🗣️ Select Language",
value=language_choices[0],
interactive=True
)
audio_input = gr.Audio(
label="📹 Upload Audio File (16kHz recommended)",
type="filepath"
)
# Dynamic model selection based on language
model_selection = gr.CheckboxGroup(
choices=["AudioX-North", "IndicConformer", "MMS"],
label="🤖 Select Models",
value=["AudioX-North", "IndicConformer"],
interactive=True
)
reference_input = gr.Textbox(
label="📄 Reference Text (optional, paste supported)",
placeholder="Paste reference transcription here...",
lines=4,
interactive=True
)
submit_btn = gr.Button("🚀 Run Multilingual Benchmark", variant="primary", size="lg")
with gr.Column(scale=2):
summary_output = gr.Markdown(
label="📊 Summary",
value="Select language, upload audio file and choose models to begin..."
)
results_table = gr.Dataframe(
headers=["Model", "Transcription", "WER", "CER", "RTF", "Time"],
datatype=["str", "str", "str", "str", "str", "str"],
label="🏆 Results Comparison",
interactive=False,
wrap=True,
column_widths=[120, 350, 60, 60, 60, 80]
)
# Copyable results section
with gr.Group():
gr.Markdown("### 📋 Export Results")
copyable_output = gr.Textbox(
label="Copy-Paste Friendly Results",
lines=12,
max_lines=25,
show_copy_button=True,
interactive=False,
elem_classes="copy-area",
placeholder="Benchmark results will appear here..."
)
# Update model choices based on language selection
def update_model_choices(selected_language):
if not selected_language:
return gr.CheckboxGroup(choices=[], value=[])
lang_info = LANGUAGE_CONFIGS[selected_language]
available_models = lang_info["models"]
# Map display names
model_map = {
"AudioX-North": "AudioX-North",
"AudioX-South": "AudioX-South",
"IndicConformer": "IndicConformer",
"MMS": "MMS"
}
available_choices = [model_map[model] for model in available_models if model in model_map]
default_selection = available_choices[:2] if len(available_choices) >= 2 else available_choices
return gr.CheckboxGroup(choices=available_choices, value=default_selection)
# Connect language selection to model updates
language_selection.change(
fn=update_model_choices,
inputs=[language_selection],
outputs=[model_selection]
)
# Connect the main function
submit_btn.click(
fn=transcribe_audio,
inputs=[audio_input, language_selection, model_selection, reference_input],
outputs=[summary_output, results_table, copyable_output]
)
reference_input.submit(
fn=transcribe_audio,
inputs=[audio_input, language_selection, model_selection, reference_input],
outputs=[summary_output, results_table, copyable_output]
)
# Language information display
gr.Markdown("""
---
### 📤 Language & Model Support Matrix
| Language | Script | AudioX-North | AudioX-South | IndicConformer | MMS |
|----------|---------|-------------|-------------|---------------|-----|
| Hindi | Devanagari | ✅ | ❌ | ✅ | ✅ |
| Gujarati | Gujarati | ✅ | ❌ | ✅ | ✅ |
| Marathi | Devanagari | ✅ | ❌ | ✅ | ✅ |
| Tamil | Tamil | ❌ | ✅ | ✅ | ✅ |
| Telugu | Telugu | ❌ | ✅ | ✅ | ✅ |
| Kannada | Kannada | ❌ | ✅ | ✅ | ✅ |
| Malayalam | Malayalam | ❌ | ✅ | ✅ | ✅ |
### 💡 Tips:
- **Models auto-filter** based on selected language
- **Reference Text**: Enable WER/CER calculation by providing ground truth
- **Copy Results**: Export formatted results using the copy button
- **Best Performance**: Use AudioX models for their specialized languages
""")
return iface
if __name__ == "__main__":
iface = create_interface()
iface.launch(
share=False,
debug=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)