import gradio as gr import torch import yaml import librosa from huggingface_hub import hf_hub_download from models.stfts import mag_phase_stft, mag_phase_istft from models.generator import SEMamba from models.pcs400 import cal_pcs # download model files from your HF repo ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba", "ckpts/SEMamba_advanced.pth") cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba", "recipes/SEMamba_advanced.yaml") # load config with open(cfg_f) as f: cfg = yaml.safe_load(f) stft_cfg = cfg["stft_cfg"] model_cfg = cfg["model_cfg"] sr = stft_cfg["sampling_rate"] n_fft = stft_cfg["n_fft"] hop_size = stft_cfg["hop_size"] win_size = stft_cfg["win_size"] compress_ff = model_cfg["compress_factor"] # init model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SEMamba(cfg).to(device) sdict = torch.load(ckpt, map_location=device) model.load_state_dict(sdict["generator"]) model.eval() def enhance(audio, do_pcs): orig_sr, wav_np = audio # 1) resample to 16 kHz if needed if orig_sr != sr: wav_np = librosa.resample(wav_np, orig_sr, sr) wav = torch.from_numpy(wav_np).float().to(device) # normalize norm = torch.sqrt(len(wav) / torch.sum(wav**2)) wav = (wav * norm).unsqueeze(0) # STFT → model → ISTFT amp, pha, _ = mag_phase_stft(wav, n_fft, hop_size, win_size, compress_ff) amp_g, pha_g = model(amp, pha) out = mag_phase_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_ff) out = (out / norm).squeeze().cpu().numpy() # optional PCS filter if do_pcs: out = cal_pcs(out) # 2) resample back to original rate if orig_sr != sr: out = librosa.resample(out, sr, orig_sr) return orig_sr, out demo = gr.Interface( fn=enhance, inputs=[ gr.Audio(source="upload", type="numpy", label="Noisy wav"), gr.Checkbox(label="Apply PCS post-processing", value=False), ], outputs=gr.Audio(type="numpy", label="Enhanced wav"), title="SEMamba Speech Enhancement", description="Upload a noisy WAV; tick **Apply PCS** for the pcs400 filter.", ) demo.launch()