import gradio as gr import random import json import os from difflib import SequenceMatcher from jiwer import wer import torchaudio import torch from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, HubertForCTC, HubertProcessor import whisper # Load metadata with open("common_voice_en_validated_249_hf_ready.json") as f: data = json.load(f) # Available filter values ages = sorted(set(entry["age"] for entry in data)) genders = sorted(set(entry["gender"] for entry in data)) accents = sorted(set(entry["accent"] for entry in data)) # Load models device = "cuda" if torch.cuda.is_available() else "cpu" # Whisper whisper_model = whisper.load_model("medium").to(device) # Wav2Vec2 wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(device) # HuBERT hubert_processor = HubertProcessor.from_pretrained("facebook/hubert-large-ls960-ft") hubert_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device) def load_audio(file_path): waveform, sr = torchaudio.load(file_path) return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy() def transcribe_whisper(file_path): result = whisper_model.transcribe(file_path) return result["text"].strip().lower() def transcribe_wav2vec(file_path): audio = load_audio(file_path) inputs = wav2vec_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): logits = wav2vec_model(**inputs.to(device)).logits predicted_ids = torch.argmax(logits, dim=-1) return wav2vec_processor.batch_decode(predicted_ids)[0].strip().lower() def transcribe_hubert(file_path): audio = load_audio(file_path) inputs = hubert_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): logits = hubert_model(**inputs.to(device)).logits predicted_ids = torch.argmax(logits, dim=-1) return hubert_processor.batch_decode(predicted_ids)[0].strip().lower() def highlight_differences(ref, hyp): sm = SequenceMatcher(None, ref.split(), hyp.split()) result = [] for opcode, i1, i2, j1, j2 in sm.get_opcodes(): if opcode == 'equal': result.extend(hyp.split()[j1:j2]) elif opcode in ('replace', 'insert', 'delete'): wrong = hyp.split()[j1:j2] result.extend([f"{w}" for w in wrong]) return " ".join(result) def run_demo(age, gender, accent): filtered = [ entry for entry in data if entry["age"] == age and entry["gender"] == gender and entry["accent"] == accent ] if not filtered: return "No matching sample.", None, "", "", "", "", "", "" sample = random.choice(filtered) file_path = os.path.join("common_voice_en_validated_249", sample["path"]) gold = sample["sentence"].strip().lower() whisper_text = transcribe_whisper(file_path) wav2vec_text = transcribe_wav2vec(file_path) hubert_text = transcribe_hubert(file_path) table = f"""
Model | Transcription | WER |
---|---|---|
Gold | {gold} | 0.00 |
Whisper | {highlight_differences(gold, whisper_text)} | {wer(gold, whisper_text):.2f} |
Wav2Vec2 | {highlight_differences(gold, wav2vec_text)} | {wer(gold, wav2vec_text):.2f} |
HuBERT | {highlight_differences(gold, hubert_text)} | {wer(gold, hubert_text):.2f} |