import argparse import spaces import sys from pathlib import Path import gradio as gr import torch import torchaudio # Add src to path to import sfi_utmos project_root = Path(__file__).resolve().parent sys.path.insert(0, str(project_root / "src")) from sfi_utmos.model.ssl_mos import SSLMOSLightningModule # Global variable for the model model: SSLMOSLightningModule | None = None device = "cuda" def load_model(checkpoint_path: str): """Loads the model from the given checkpoint path.""" global model model = SSLMOSLightningModule.load_from_checkpoint( checkpoint_path, map_location=device, pretrained_model_path=None, ) model.eval() print(f"Model loaded from {checkpoint_path}") @spaces.GPU def predict_mos(audio_path: str): """Predicts the MOS score for the given audio file.""" if model is None: return "Error: Model not loaded. Please provide a valid checkpoint path." ratings = [] for listner in range(1, 11): wav, sr = torchaudio.load(audio_path) if model.condition_sr: if sr not in model.sr2id.keys(): return f"Error: Sample rate {sr} not supported by the model. Supported rates: {list(model.sr2id.keys())}" waves = [wav.view(-1).to(model.device)] srs = torch.tensor(sr).view(1, -1).to(model.device) if model.condition_sr: srs = torch.stack( [torch.tensor(model.sr2id[sr.detach().cpu().item()]) for sr in srs] ).to(model.device) listner_tensor = torch.tensor(listner).view(-1).to(model.device) if hasattr(model, "is_sfi") and model.is_sfi: model.ssl_model.set_sample_rate(srs[0].item()) waves = torch.nn.utils.rnn.pad_sequence( [w.view(-1) for w in waves], batch_first=True ).to(device) else: waves = [torchaudio.functional.resample(w, sr, 16_000) for w in waves] output = model.forward( waves, listner_tensor, srs, ) ratings.append(output.cpu().item()) mos_score = 2*(sum(ratings) / len(ratings)) + 3 return f"{mos_score:.3f}" def main(): parser = argparse.ArgumentParser(description="Run MOS prediction demo with Gradio.") parser.add_argument( "--checkpoint_path", type=str, default="https://huggingface.co/sarulab-speech/MSR-UTMOS_w2v2_fold0/resolve/main/model.ckpt", help="Path to the model checkpoint (.ckpt file).", ) args = parser.parse_args() load_model(args.checkpoint_path) if model is None: print("Failed to load model. Exiting.") sys.exit(1) # Gradio interface iface = gr.Interface( fn=predict_mos, inputs=gr.Audio(type="filepath", label="Upload Audio File"), outputs="text", title="MSR-UTMOS: MOS Prediction Demo", description=( "Upload an audio file (WAV, MP3, etc.) to get its predicted Mean Opinion Score (MOS). " ), ) iface.launch() if __name__ == "__main__": main()