import torch import torch.nn as nn import torchaudio import numpy as np import soundfile as sf from transformers import WhisperProcessor, WhisperModel, WhisperForConditionalGeneration import gradio as gr # Load full model for transcription whisper_model_gen = WhisperForConditionalGeneration.from_pretrained( "tarteel-ai/whisper-base-ar-quran" ) whisper_model_gen.eval() #Model definition class TajweedClassifier(nn.Module): def __init__(self, base_model_name="tarteel-ai/whisper-base-ar-quran", tajweed_dim=3): super().__init__() self.encoder = WhisperModel.from_pretrained(base_model_name).encoder hidden_size = self.encoder.config.d_model self.tajweed_head = nn.Linear(hidden_size, tajweed_dim) def forward(self, input_features): encoder_out = self.encoder(input_features).last_hidden_state pooled = encoder_out.mean(dim=1) tajweed_logits = self.tajweed_head(pooled) return torch.sigmoid(tajweed_logits) # Load model & processor processor = WhisperProcessor.from_pretrained("tarteel-ai/whisper-base-ar-quran") model = TajweedClassifier() checkpoint = torch.load("stat_tajweed_model_v3.pth", map_location="cpu") # Remove module. prefix if using DataParallel new_state_dict = {} for k, v in checkpoint.items(): new_key = k[len("module."):] if k.startswith("module.") else k new_state_dict[new_key] = v model.load_state_dict(new_state_dict) model.eval() # Prediction function def predict_from_filepath(audio_path): if audio_path is None or not isinstance(audio_path, str): return {"error": "Invalid audio input."} try: waveform, sr = sf.read(audio_path) except Exception as e: return {"error": f"Could not read audio file: {str(e)}"} # Handle stereo input by converting to mono if len(waveform.shape) > 1: waveform = waveform.mean(axis=1) waveform = waveform.astype("float32") # normalize max_val = np.max(np.abs(waveform)) if max_val > 1.0: waveform /= max_val # Resample (whisper need sr = 16000) if sr != 16000: waveform = torchaudio.functional.resample(torch.from_numpy(waveform), orig_freq=sr, new_freq=16000).numpy() # Prepare input inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") input_features = inputs.input_features with torch.no_grad(): output = model(input_features) # tajweed predictions labels = ["Separate tide (المد المنفصل): ", "Tight noon (النون المشددة): ", "Concealment (الإخفاء): "] values = output.squeeze().tolist() tajweed_dict = {label: value for label, value in zip(labels, values)} # Transcription using Whisper model transcription_ids = whisper_model_gen.generate(inputs.input_features) transcription = processor.batch_decode( transcription_ids, skip_special_tokens=True )[0] return { "tajweed": tajweed_dict, "transcription": transcription } # Gradio interface demo = gr.Interface( fn=predict_from_filepath, inputs=gr.Audio(type="filepath", label="Upload Quran recitation (.wav)"), outputs=gr.JSON(label="Prediction Output"), title="Tajweed Prediction API", description="Upload a .wav file, get 3 raw float values between 0 and 1." ) if __name__ == "__main__": demo.launch()