# app.py import gradio as gr import torch from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, Wav2Vec2Processor, Wav2Vec2ForCTC ) import librosa import numpy as np import warnings warnings.filterwarnings("ignore") class NigerianWhisperTranscriber: def __init__(self): self.models = {} self.processors = {} self.device = "cuda" if torch.cuda.is_available() else "cpu" # Model configurations with their architectures self.model_configs = { "Yoruba": { "model_name": "DereAbdulhameed/Whisper-Yoruba", "architecture": "whisper" }, "Hausa": { "model_name": "Baghdad99/saad-speech-recognition-hausa-audio-to-text", "architecture": "whisper" }, "Igbo": { "model_name": "AstralZander/igbo_ASR", "architecture": "wav2vec2" } } print(f"Using device: {self.device}") def load_model(self, language): """Load model and processor for specific language""" if language not in self.models: try: print(f"Loading {language} model...") config = self.model_configs[language] model_name = config["model_name"] architecture = config["architecture"] if architecture == "whisper": # Load Whisper model processor = WhisperProcessor.from_pretrained(model_name) model = WhisperForConditionalGeneration.from_pretrained(model_name) model = model.to(self.device) elif architecture == "wav2vec2": # Load Wav2Vec2 model processor = Wav2Vec2Processor.from_pretrained(model_name) model = Wav2Vec2ForCTC.from_pretrained(model_name) model = model.to(self.device) self.processors[language] = processor self.models[language] = model print(f"{language} model loaded successfully!") return True except Exception as e: print(f"Error loading {language} model: {str(e)}") return False return True def preprocess_audio(self, audio_path): """Preprocess audio file for Whisper""" try: # Load audio file audio, sr = librosa.load(audio_path, sr=16000) # Ensure audio is not empty if len(audio) == 0: raise ValueError("Audio file is empty") # Normalize audio audio = audio.astype(np.float32) return audio except Exception as e: raise ValueError(f"Error processing audio: {str(e)}") def chunk_audio(self, audio, chunk_length=25): """Split audio into chunks for processing longer recordings""" sample_rate = 16000 chunk_samples = chunk_length * sample_rate chunks = [] for i in range(0, len(audio), chunk_samples): chunk = audio[i:i + chunk_samples] if len(chunk) > sample_rate: # Only process chunks longer than 1 second chunks.append(chunk) return chunks def transcribe_chunk(self, audio_chunk, language): """Transcribe a single audio chunk""" processor = self.processors[language] model = self.models[language] config = self.model_configs[language] if config["architecture"] == "whisper": # Whisper processing inputs = processor( audio_chunk, sampling_rate=16000, return_tensors="pt" ) input_features = inputs.input_features.to(self.device) # Create attention mask if available attention_mask = None if hasattr(inputs, 'attention_mask') and inputs.attention_mask is not None: attention_mask = inputs.attention_mask.to(self.device) # Generate transcription with torch.no_grad(): if attention_mask is not None: predicted_ids = model.generate( input_features, attention_mask=attention_mask, max_new_tokens=400, num_beams=5, temperature=0.0, do_sample=False, use_cache=True, pad_token_id=processor.tokenizer.eos_token_id ) else: predicted_ids = model.generate( input_features, max_new_tokens=400, num_beams=5, temperature=0.0, do_sample=False, use_cache=True, pad_token_id=processor.tokenizer.eos_token_id ) # Decode transcription transcription = processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] return transcription.strip() elif config["architecture"] == "wav2vec2": # Wav2Vec2 processing inputs = processor( audio_chunk, sampling_rate=16000, return_tensors="pt", padding=True ) input_values = inputs.input_values.to(self.device) # Generate transcription with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) # Decode transcription for Wav2Vec2 # The key is to use `skip_special_tokens=True` here as well, # and potentially handle any remaining [PAD] explicitly if the tokenizer # doesn't completely remove them with that flag. transcription = processor.batch_decode( predicted_ids, skip_special_tokens=True # Ensure special tokens are skipped )[0] # Additional clean-up for Wav2Vec2 specific models if skip_special_tokens isn't enough # Some Wav2Vec2 tokenizers might represent padding characters differently or # not fully remove them with skip_special_tokens=True depending on how they were trained. # We can perform an explicit string replacement as a fallback. transcription = transcription.replace("[PAD]", "").strip() transcription = " ".join(transcription.split()) # To remove extra spaces return transcription.strip() def transcribe(self, audio_path, language): """Transcribe audio file in specified language""" try: # Load model if not already loaded if not self.load_model(language): return f"Error: Could not load {language} model" # Preprocess audio audio = self.preprocess_audio(audio_path) # Check audio length (25 seconds = 400,000 samples at 16kHz) if len(audio) > 400000: # If longer than 25 seconds # Process in chunks chunks = self.chunk_audio(audio, chunk_length=25) transcriptions = [] for i, chunk in enumerate(chunks): print(f"Processing chunk {i+1}/{len(chunks)}") # Transcribe chunk chunk_transcription = self.transcribe_chunk(chunk, language) transcriptions.append(chunk_transcription) # Combine all transcriptions full_transcription = " ".join(transcriptions) return full_transcription else: # Process short audio normally return self.transcribe_chunk(audio, language) except Exception as e: return f"Error during transcription: {str(e)}" # Initialize transcriber transcriber = NigerianWhisperTranscriber() def transcribe_audio_unified(audio_file, audio_mic, language): """Gradio function for transcription from either file or microphone""" # Determine which audio source to use audio_source = audio_file if audio_file is not None else audio_mic if audio_source is None: return "Please upload an audio file or record from microphone" try: result = transcriber.transcribe(audio_source, language) return result except Exception as e: return f"Transcription failed: {str(e)}" def get_model_info(language): """Get information about the selected model""" model_info = { "Yoruba": "DereAbdulhameed/Whisper-Yoruba - Whisper model specialized for Yoruba language", "Hausa": "Baghdad99/saad-speech-recognition-hausa-audio-to-text - Fine-tuned Whisper model for Hausa (WER: 44.4%)", "Igbo": "AstralZander/igbo_ASR - Wav2Vec2-XLS-R model fine-tuned for Igbo language (WER: 51%)" } return model_info.get(language, "Model information not available") # Create Gradio interface with gr.Blocks( title="Nigerian Languages Speech Transcription", theme=gr.themes.Soft(), css=""" .main-header { text-align: center; color: #2E7D32; margin-bottom: 20px; } .language-info { background-color: #f5f5f5; padding: 10px; border-radius: 5px; margin: 10px 0; } """ ) as demo: gr.HTML("""
Transcribe audio in Yoruba, Hausa, and Igbo using specialized Whisper models
""") with gr.Row(): with gr.Column(scale=1): # Language selection language_dropdown = gr.Dropdown( choices=["Yoruba", "Hausa", "Igbo"], value="Yoruba", label="Select Language", info="Choose the language of your audio file" ) # Audio input options gr.HTML("Note: First-time model loading may take a few minutes.
Recording Tip: Speak clearly and ensure good audio quality for better transcription accuracy.
Long Audio: Audio longer than 25 seconds will be automatically processed in chunks.