Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torchaudio | |
import numpy as np | |
import soundfile as sf | |
from transformers import WhisperProcessor, WhisperModel, WhisperForConditionalGeneration | |
import gradio as gr | |
# Load full model for transcription | |
whisper_model_gen = WhisperForConditionalGeneration.from_pretrained( | |
"tarteel-ai/whisper-base-ar-quran" | |
) | |
whisper_model_gen.eval() | |
#Model definition | |
class TajweedClassifier(nn.Module): | |
def __init__(self, base_model_name="tarteel-ai/whisper-base-ar-quran", tajweed_dim=3): | |
super().__init__() | |
self.encoder = WhisperModel.from_pretrained(base_model_name).encoder | |
hidden_size = self.encoder.config.d_model | |
self.tajweed_head = nn.Linear(hidden_size, tajweed_dim) | |
def forward(self, input_features): | |
encoder_out = self.encoder(input_features).last_hidden_state | |
pooled = encoder_out.mean(dim=1) | |
tajweed_logits = self.tajweed_head(pooled) | |
return torch.sigmoid(tajweed_logits) | |
# Load model & processor | |
processor = WhisperProcessor.from_pretrained("tarteel-ai/whisper-base-ar-quran") | |
model = TajweedClassifier() | |
checkpoint = torch.load("stat_tajweed_model_v3.pth", map_location="cpu") | |
# Remove module. prefix if using DataParallel | |
new_state_dict = {} | |
for k, v in checkpoint.items(): | |
new_key = k[len("module."):] if k.startswith("module.") else k | |
new_state_dict[new_key] = v | |
model.load_state_dict(new_state_dict) | |
model.eval() | |
# Prediction function | |
def predict_from_filepath(audio_path): | |
if audio_path is None or not isinstance(audio_path, str): | |
return {"error": "Invalid audio input."} | |
try: | |
waveform, sr = sf.read(audio_path) | |
except Exception as e: | |
return {"error": f"Could not read audio file: {str(e)}"} | |
# Handle stereo input by converting to mono | |
if len(waveform.shape) > 1: | |
waveform = waveform.mean(axis=1) | |
waveform = waveform.astype("float32") | |
# normalize | |
max_val = np.max(np.abs(waveform)) | |
if max_val > 1.0: | |
waveform /= max_val | |
# Resample (whisper need sr = 16000) | |
if sr != 16000: | |
waveform = torchaudio.functional.resample(torch.from_numpy(waveform), orig_freq=sr, new_freq=16000).numpy() | |
# Prepare input | |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") | |
input_features = inputs.input_features | |
with torch.no_grad(): | |
output = model(input_features) | |
# tajweed predictions | |
labels = ["Separate tide (المد المنفصل): ", "Tight noon (النون المشددة): ", "Concealment (الإخفاء): "] | |
values = output.squeeze().tolist() | |
tajweed_dict = {label: value for label, value in zip(labels, values)} | |
# Transcription using Whisper model | |
transcription_ids = whisper_model_gen.generate(inputs.input_features) | |
transcription = processor.batch_decode( | |
transcription_ids, skip_special_tokens=True | |
)[0] | |
return { | |
"tajweed": tajweed_dict, | |
"transcription": transcription | |
} | |
# Gradio interface | |
demo = gr.Interface( | |
fn=predict_from_filepath, | |
inputs=gr.Audio(type="filepath", label="Upload Quran recitation (.wav)"), | |
outputs=gr.JSON(label="Prediction Output"), | |
title="Tajweed Prediction API", | |
description="Upload a .wav file, get 3 raw float values between 0 and 1." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |