Spaces:
Running
Running
File size: 3,441 Bytes
1b981e9 941791b 486b863 1b981e9 44cd27f 486b863 2267452 1b981e9 2267452 1b981e9 2267452 1b981e9 2267452 1b981e9 2267452 1b981e9 53f31ee 675e209 53f31ee 2267452 53f31ee 2267452 53f31ee 941791b 53f31ee 1b981e9 2267452 941791b 2267452 941791b 20bdddf 1b981e9 2267452 e638ed6 1b981e9 2267452 20bdddf d456f92 20bdddf d456f92 20bdddf e638ed6 1b981e9 e638ed6 1b981e9 cee37e2 486b863 cee37e2 1b981e9 2267452 1b981e9 941791b 486b863 1b981e9 e638ed6 1b981e9 941791b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import soundfile as sf
from transformers import WhisperProcessor, WhisperModel, WhisperForConditionalGeneration
import gradio as gr
# Load full model for transcription
whisper_model_gen = WhisperForConditionalGeneration.from_pretrained(
"tarteel-ai/whisper-base-ar-quran"
)
whisper_model_gen.eval()
#Model definition
class TajweedClassifier(nn.Module):
def __init__(self, base_model_name="tarteel-ai/whisper-base-ar-quran", tajweed_dim=3):
super().__init__()
self.encoder = WhisperModel.from_pretrained(base_model_name).encoder
hidden_size = self.encoder.config.d_model
self.tajweed_head = nn.Linear(hidden_size, tajweed_dim)
def forward(self, input_features):
encoder_out = self.encoder(input_features).last_hidden_state
pooled = encoder_out.mean(dim=1)
tajweed_logits = self.tajweed_head(pooled)
return torch.sigmoid(tajweed_logits)
# Load model & processor
processor = WhisperProcessor.from_pretrained("tarteel-ai/whisper-base-ar-quran")
model = TajweedClassifier()
checkpoint = torch.load("stat_tajweed_model_v3.pth", map_location="cpu")
# Remove module. prefix if using DataParallel
new_state_dict = {}
for k, v in checkpoint.items():
new_key = k[len("module."):] if k.startswith("module.") else k
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
model.eval()
# Prediction function
def predict_from_filepath(audio_path):
if audio_path is None or not isinstance(audio_path, str):
return {"error": "Invalid audio input."}
try:
waveform, sr = sf.read(audio_path)
except Exception as e:
return {"error": f"Could not read audio file: {str(e)}"}
# Handle stereo input by converting to mono
if len(waveform.shape) > 1:
waveform = waveform.mean(axis=1)
waveform = waveform.astype("float32")
# normalize
max_val = np.max(np.abs(waveform))
if max_val > 1.0:
waveform /= max_val
# Resample (whisper need sr = 16000)
if sr != 16000:
waveform = torchaudio.functional.resample(torch.from_numpy(waveform), orig_freq=sr, new_freq=16000).numpy()
# Prepare input
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features
with torch.no_grad():
output = model(input_features)
# tajweed predictions
labels = ["Separate tide (المد المنفصل): ", "Tight noon (النون المشددة): ", "Concealment (الإخفاء): "]
values = output.squeeze().tolist()
tajweed_dict = {label: value for label, value in zip(labels, values)}
# Transcription using Whisper model
transcription_ids = whisper_model_gen.generate(inputs.input_features)
transcription = processor.batch_decode(
transcription_ids, skip_special_tokens=True
)[0]
return {
"tajweed": tajweed_dict,
"transcription": transcription
}
# Gradio interface
demo = gr.Interface(
fn=predict_from_filepath,
inputs=gr.Audio(type="filepath", label="Upload Quran recitation (.wav)"),
outputs=gr.JSON(label="Prediction Output"),
title="Tajweed Prediction API",
description="Upload a .wav file, get 3 raw float values between 0 and 1."
)
if __name__ == "__main__":
demo.launch()
|