QuaranTajweedV1 / app.py
OsamaO's picture
Update app.py
44cd27f verified
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import soundfile as sf
from transformers import WhisperProcessor, WhisperModel, WhisperForConditionalGeneration
import gradio as gr
# Load full model for transcription
whisper_model_gen = WhisperForConditionalGeneration.from_pretrained(
"tarteel-ai/whisper-base-ar-quran"
)
whisper_model_gen.eval()
#Model definition
class TajweedClassifier(nn.Module):
def __init__(self, base_model_name="tarteel-ai/whisper-base-ar-quran", tajweed_dim=3):
super().__init__()
self.encoder = WhisperModel.from_pretrained(base_model_name).encoder
hidden_size = self.encoder.config.d_model
self.tajweed_head = nn.Linear(hidden_size, tajweed_dim)
def forward(self, input_features):
encoder_out = self.encoder(input_features).last_hidden_state
pooled = encoder_out.mean(dim=1)
tajweed_logits = self.tajweed_head(pooled)
return torch.sigmoid(tajweed_logits)
# Load model & processor
processor = WhisperProcessor.from_pretrained("tarteel-ai/whisper-base-ar-quran")
model = TajweedClassifier()
checkpoint = torch.load("stat_tajweed_model_v3.pth", map_location="cpu")
# Remove module. prefix if using DataParallel
new_state_dict = {}
for k, v in checkpoint.items():
new_key = k[len("module."):] if k.startswith("module.") else k
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
model.eval()
# Prediction function
def predict_from_filepath(audio_path):
if audio_path is None or not isinstance(audio_path, str):
return {"error": "Invalid audio input."}
try:
waveform, sr = sf.read(audio_path)
except Exception as e:
return {"error": f"Could not read audio file: {str(e)}"}
# Handle stereo input by converting to mono
if len(waveform.shape) > 1:
waveform = waveform.mean(axis=1)
waveform = waveform.astype("float32")
# normalize
max_val = np.max(np.abs(waveform))
if max_val > 1.0:
waveform /= max_val
# Resample (whisper need sr = 16000)
if sr != 16000:
waveform = torchaudio.functional.resample(torch.from_numpy(waveform), orig_freq=sr, new_freq=16000).numpy()
# Prepare input
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features
with torch.no_grad():
output = model(input_features)
# tajweed predictions
labels = ["Separate tide (المد المنفصل): ", "Tight noon (النون المشددة): ", "Concealment (الإخفاء): "]
values = output.squeeze().tolist()
tajweed_dict = {label: value for label, value in zip(labels, values)}
# Transcription using Whisper model
transcription_ids = whisper_model_gen.generate(inputs.input_features)
transcription = processor.batch_decode(
transcription_ids, skip_special_tokens=True
)[0]
return {
"tajweed": tajweed_dict,
"transcription": transcription
}
# Gradio interface
demo = gr.Interface(
fn=predict_from_filepath,
inputs=gr.Audio(type="filepath", label="Upload Quran recitation (.wav)"),
outputs=gr.JSON(label="Prediction Output"),
title="Tajweed Prediction API",
description="Upload a .wav file, get 3 raw float values between 0 and 1."
)
if __name__ == "__main__":
demo.launch()