import gradio as gr import random import json import os from difflib import SequenceMatcher from jiwer import wer import torchaudio import torch from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, HubertForCTC, HubertProcessor import whisper # Load metadata with open("common_voice_en_validated_249_hf_ready.json") as f: data = json.load(f) # Available filter values ages = sorted(set(entry["age"] for entry in data)) genders = sorted(set(entry["gender"] for entry in data)) accents = sorted(set(entry["accent"] for entry in data)) # Load models device = "cuda" if torch.cuda.is_available() else "cpu" # Whisper whisper_model = whisper.load_model("medium").to(device) # Wav2Vec2 wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(device) # HuBERT hubert_processor = HubertProcessor.from_pretrained("facebook/hubert-large-ls960-ft") hubert_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device) def load_audio(file_path): waveform, sr = torchaudio.load(file_path) return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy() def transcribe_whisper(file_path): result = whisper_model.transcribe(file_path) return result["text"].strip().lower() def transcribe_wav2vec(file_path): audio = load_audio(file_path) inputs = wav2vec_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): logits = wav2vec_model(**inputs.to(device)).logits predicted_ids = torch.argmax(logits, dim=-1) return wav2vec_processor.batch_decode(predicted_ids)[0].strip().lower() def transcribe_hubert(file_path): audio = load_audio(file_path) inputs = hubert_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): logits = hubert_model(**inputs.to(device)).logits predicted_ids = torch.argmax(logits, dim=-1) return hubert_processor.batch_decode(predicted_ids)[0].strip().lower() def highlight_differences(ref, hyp): sm = SequenceMatcher(None, ref.split(), hyp.split()) result = [] for opcode, i1, i2, j1, j2 in sm.get_opcodes(): if opcode == 'equal': result.extend(hyp.split()[j1:j2]) elif opcode in ('replace', 'insert', 'delete'): wrong = hyp.split()[j1:j2] result.extend([f"{w}" for w in wrong]) return " ".join(result) def run_demo(age, gender, accent): filtered = [ entry for entry in data if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent ] if not filtered: return "No matching sample.", None, "", "", "", "", "", "" sample = random.choice(filtered) file_path = os.path.join("common_voice_en_validated_249", sample["path"]) gold = sample["sentence"].strip().lower() whisper_text = transcribe_whisper(file_path) wav2vec_text = transcribe_wav2vec(file_path) hubert_text = transcribe_hubert(file_path) table = f"""
ModelTranscriptionWER
Gold{gold}0.00
Whisper{highlight_differences(gold, whisper_text)}{wer(gold, whisper_text):.2f}
Wav2Vec2{highlight_differences(gold, wav2vec_text)}{wer(gold, wav2vec_text):.2f}
HuBERT{highlight_differences(gold, hubert_text)}{wer(gold, hubert_text):.2f}
""" return sample["sentence"], file_path, gold, whisper_text, wav2vec_text, hubert_text, table, f"Audio path: {file_path}" with gr.Blocks() as demo: gr.Markdown("# ASR Model Comparison on ESL Audio") gr.Markdown("Filter by age, gender, and accent. Then generate a random ESL learner's audio to compare how Whisper, Wav2Vec2, and HuBERT transcribe it.") with gr.Row(): age = gr.Dropdown(choices=ages, label="Age") gender = gr.Dropdown(choices=genders, label="Gender") accent = gr.Dropdown(choices=accents, label="Accent") btn = gr.Button("Generate and Transcribe") audio = gr.Audio(label="Audio", type="filepath") wer_output = gr.HTML() btn.click(fn=run_demo, inputs=[age, gender, accent], outputs=[ gr.Textbox(label="Gold (Correct)"), audio, gr.Textbox(label="Whisper Output"), gr.Textbox(label="Wav2Vec2 Output"), gr.Textbox(label="HuBERT Output"), wer_output, gr.Textbox(label="Path") ]) demo.launch()