File size: 3,441 Bytes
1b981e9
 
 
 
941791b
486b863
1b981e9
 
44cd27f
486b863
 
 
 
 
2267452
1b981e9
 
 
2267452
1b981e9
 
 
 
 
2267452
1b981e9
 
 
2267452
1b981e9
 
2267452
1b981e9
 
53f31ee
675e209
53f31ee
2267452
53f31ee
2267452
53f31ee
941791b
53f31ee
 
 
1b981e9
 
2267452
 
941791b
 
 
 
 
 
2267452
941791b
 
 
 
 
 
 
20bdddf
1b981e9
2267452
e638ed6
 
 
1b981e9
2267452
20bdddf
d456f92
20bdddf
d456f92
20bdddf
e638ed6
1b981e9
e638ed6
 
1b981e9
cee37e2
 
 
 
 
 
486b863
 
 
 
 
cee37e2
 
 
 
1b981e9
2267452
 
1b981e9
941791b
 
486b863
1b981e9
e638ed6
1b981e9
 
 
941791b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import soundfile as sf
from transformers import WhisperProcessor, WhisperModel, WhisperForConditionalGeneration
import gradio as gr

# Load full model for transcription
whisper_model_gen = WhisperForConditionalGeneration.from_pretrained(
    "tarteel-ai/whisper-base-ar-quran"
)
whisper_model_gen.eval()

#Model definition
class TajweedClassifier(nn.Module):
    def __init__(self, base_model_name="tarteel-ai/whisper-base-ar-quran", tajweed_dim=3):
        super().__init__()
        
        self.encoder = WhisperModel.from_pretrained(base_model_name).encoder
        hidden_size = self.encoder.config.d_model
        self.tajweed_head = nn.Linear(hidden_size, tajweed_dim)

    def forward(self, input_features):
        
        encoder_out = self.encoder(input_features).last_hidden_state
        pooled = encoder_out.mean(dim=1)
        tajweed_logits = self.tajweed_head(pooled)
        
        return torch.sigmoid(tajweed_logits)

# Load model & processor
processor = WhisperProcessor.from_pretrained("tarteel-ai/whisper-base-ar-quran")
model = TajweedClassifier()

checkpoint = torch.load("stat_tajweed_model_v3.pth", map_location="cpu")

# Remove module. prefix if using DataParallel
new_state_dict = {}

for k, v in checkpoint.items():
    new_key = k[len("module."):] if k.startswith("module.") else k
    new_state_dict[new_key] = v

model.load_state_dict(new_state_dict)
model.eval()


# Prediction function
def predict_from_filepath(audio_path):
    if audio_path is None or not isinstance(audio_path, str):
        return {"error": "Invalid audio input."}

    try:
        waveform, sr = sf.read(audio_path)

    except Exception as e:
        return {"error": f"Could not read audio file: {str(e)}"}

    # Handle stereo input by converting to mono
    if len(waveform.shape) > 1:
        waveform = waveform.mean(axis=1)

    waveform = waveform.astype("float32")

    # normalize
    max_val = np.max(np.abs(waveform))
    if max_val > 1.0:
        waveform /= max_val

    # Resample (whisper need sr = 16000)
    if sr != 16000:
        waveform = torchaudio.functional.resample(torch.from_numpy(waveform), orig_freq=sr, new_freq=16000).numpy()

    # Prepare input
    inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
    input_features = inputs.input_features

    with torch.no_grad():
        output = model(input_features)

        # tajweed predictions
        labels = ["Separate tide (المد المنفصل): ", "Tight noon (النون المشددة): ", "Concealment (الإخفاء): "]
        values = output.squeeze().tolist() 
        tajweed_dict = {label: value for label, value in zip(labels, values)}

        # Transcription using Whisper model
        transcription_ids = whisper_model_gen.generate(inputs.input_features)
        transcription = processor.batch_decode(
            transcription_ids, skip_special_tokens=True
        )[0]
        
    return {
        "tajweed": tajweed_dict,
        "transcription": transcription
    }


# Gradio interface
demo = gr.Interface(
    fn=predict_from_filepath,
    inputs=gr.Audio(type="filepath", label="Upload Quran recitation (.wav)"),
    outputs=gr.JSON(label="Prediction Output"),
    title="Tajweed Prediction API",
    description="Upload a .wav file, get 3 raw float values between 0 and 1."
)

if __name__ == "__main__":
    demo.launch()