File size: 3,088 Bytes
6f5f35c
70722dd
6f5f35c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70722dd
6f5f35c
 
 
 
 
 
23a6dd1
c9914a1
6f5f35c
 
 
 
70722dd
6f5f35c
 
 
 
 
 
 
9a80b9d
 
 
6f5f35c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129b9a6
6f5f35c
 
 
 
 
 
 
 
 
576f70a
6f5f35c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74d019c
6f5f35c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import spaces
import sys
from pathlib import Path

import gradio as gr
import torch
import torchaudio

# Add src to path to import sfi_utmos
project_root = Path(__file__).resolve().parent
sys.path.insert(0, str(project_root / "src"))

from sfi_utmos.model.ssl_mos import SSLMOSLightningModule

# Global variable for the model
model: SSLMOSLightningModule | None = None
device = "cuda"


def load_model(checkpoint_path: str):
    """Loads the model from the given checkpoint path."""
    global model
    model = SSLMOSLightningModule.load_from_checkpoint(
        checkpoint_path, map_location=device,
        pretrained_model_path=None,
    )
    model.eval()
    print(f"Model loaded from {checkpoint_path}")

@spaces.GPU
def predict_mos(audio_path: str):
    """Predicts the MOS score for the given audio file."""
    if model is None:
        return "Error: Model not loaded. Please provide a valid checkpoint path."
    ratings = []
    for listner in range(1, 11):
        wav, sr = torchaudio.load(audio_path)
        if model.condition_sr:
            if sr not in model.sr2id.keys():
                return f"Error: Sample rate {sr} not supported by the model. Supported rates: {list(model.sr2id.keys())}"
        waves = [wav.view(-1).to(model.device)]
        srs = torch.tensor(sr).view(1, -1).to(model.device)
        if model.condition_sr:
            srs = torch.stack(
                [torch.tensor(model.sr2id[sr.detach().cpu().item()]) for sr in srs]
            ).to(model.device)
        listner_tensor = torch.tensor(listner).view(-1).to(model.device)
        if hasattr(model, "is_sfi") and model.is_sfi:
            model.ssl_model.set_sample_rate(srs[0].item())
            waves = torch.nn.utils.rnn.pad_sequence(
                [w.view(-1) for w in waves], batch_first=True
            ).to(device)
        else:
            waves = [torchaudio.functional.resample(w, sr, 16_000) for w in waves]
        output = model.forward(
            waves,
            listner_tensor,
            srs,
        )
        ratings.append(output.cpu().item())
    mos_score = 2*(sum(ratings) / len(ratings)) + 3

    return f"{mos_score:.3f}"


def main():
    parser = argparse.ArgumentParser(description="Run MOS prediction demo with Gradio.")
    parser.add_argument(
        "--checkpoint_path",
        type=str,
        default="https://huggingface.co/sarulab-speech/MSR-UTMOS_w2v2_fold0/resolve/main/model.ckpt",
        help="Path to the model checkpoint (.ckpt file).",
    )
    args = parser.parse_args()

    load_model(args.checkpoint_path)

    if model is None:
        print("Failed to load model. Exiting.")
        sys.exit(1)

    # Gradio interface
    iface = gr.Interface(
        fn=predict_mos,
        inputs=gr.Audio(type="filepath", label="Upload Audio File"),
        outputs="text",
        title="MSR-UTMOS: MOS Prediction Demo",
        description=(
            "Upload an audio file (WAV, MP3, etc.) to get its predicted Mean Opinion Score (MOS). "
        ),
    )
    iface.launch()


if __name__ == "__main__":
    main()