Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,088 Bytes
6f5f35c 70722dd 6f5f35c 70722dd 6f5f35c 23a6dd1 c9914a1 6f5f35c 70722dd 6f5f35c 9a80b9d 6f5f35c 129b9a6 6f5f35c 576f70a 6f5f35c 74d019c 6f5f35c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import argparse
import spaces
import sys
from pathlib import Path
import gradio as gr
import torch
import torchaudio
# Add src to path to import sfi_utmos
project_root = Path(__file__).resolve().parent
sys.path.insert(0, str(project_root / "src"))
from sfi_utmos.model.ssl_mos import SSLMOSLightningModule
# Global variable for the model
model: SSLMOSLightningModule | None = None
device = "cuda"
def load_model(checkpoint_path: str):
"""Loads the model from the given checkpoint path."""
global model
model = SSLMOSLightningModule.load_from_checkpoint(
checkpoint_path, map_location=device,
pretrained_model_path=None,
)
model.eval()
print(f"Model loaded from {checkpoint_path}")
@spaces.GPU
def predict_mos(audio_path: str):
"""Predicts the MOS score for the given audio file."""
if model is None:
return "Error: Model not loaded. Please provide a valid checkpoint path."
ratings = []
for listner in range(1, 11):
wav, sr = torchaudio.load(audio_path)
if model.condition_sr:
if sr not in model.sr2id.keys():
return f"Error: Sample rate {sr} not supported by the model. Supported rates: {list(model.sr2id.keys())}"
waves = [wav.view(-1).to(model.device)]
srs = torch.tensor(sr).view(1, -1).to(model.device)
if model.condition_sr:
srs = torch.stack(
[torch.tensor(model.sr2id[sr.detach().cpu().item()]) for sr in srs]
).to(model.device)
listner_tensor = torch.tensor(listner).view(-1).to(model.device)
if hasattr(model, "is_sfi") and model.is_sfi:
model.ssl_model.set_sample_rate(srs[0].item())
waves = torch.nn.utils.rnn.pad_sequence(
[w.view(-1) for w in waves], batch_first=True
).to(device)
else:
waves = [torchaudio.functional.resample(w, sr, 16_000) for w in waves]
output = model.forward(
waves,
listner_tensor,
srs,
)
ratings.append(output.cpu().item())
mos_score = 2*(sum(ratings) / len(ratings)) + 3
return f"{mos_score:.3f}"
def main():
parser = argparse.ArgumentParser(description="Run MOS prediction demo with Gradio.")
parser.add_argument(
"--checkpoint_path",
type=str,
default="https://huggingface.co/sarulab-speech/MSR-UTMOS_w2v2_fold0/resolve/main/model.ckpt",
help="Path to the model checkpoint (.ckpt file).",
)
args = parser.parse_args()
load_model(args.checkpoint_path)
if model is None:
print("Failed to load model. Exiting.")
sys.exit(1)
# Gradio interface
iface = gr.Interface(
fn=predict_mos,
inputs=gr.Audio(type="filepath", label="Upload Audio File"),
outputs="text",
title="MSR-UTMOS: MOS Prediction Demo",
description=(
"Upload an audio file (WAV, MP3, etc.) to get its predicted Mean Opinion Score (MOS). "
),
)
iface.launch()
if __name__ == "__main__":
main()
|