import shlex import subprocess import spaces import torch import gradio as gr # install packages for mamba def install_mamba(): #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118")) subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) #subprocess.run(shlex.split("pip install numpy==1.26.4")) install_mamba() ABOUT = """ # SEMamba: Speech Enhancement A Mamba-based model that denoises real-world audio. Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram. """ import torch import yaml import librosa import librosa.display import matplotlib import numpy as np import soundfile as sf import matplotlib.pyplot as plt from models.stfts import mag_phase_stft, mag_phase_istft from models.generator import SEMamba from models.pcs400 import cal_pcs ckpt = "ckpts/SEMamba_advanced.pth" cfg_f = "recipes/SEMamba_advanced.yaml" # load config with open(cfg_f, 'r') as f: cfg = yaml.safe_load(f) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cuda" model = SEMamba(cfg).to(device) #sdict = torch.load(ckpt, map_location=device) #model.load_state_dict(sdict["generator"]) #model.eval() @spaces.GPU def enhance(filepath, model_name): # Load model based on selection ckpt_path = { "VCTK-Demand": "ckpts/SEMamba_advanced.pth", "VCTK+DNS": "ckpts/vd.pth" }[model_name] print("Loading:", ckpt_path) model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"]) model.eval() with torch.no_grad(): # load & resample wav, orig_sr = librosa.load(filepath, sr=None) noisy_wav = wav.copy() if orig_sr != 16000: wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000) x = torch.from_numpy(wav).float().to(device) norm = torch.sqrt(len(x)/torch.sum(x**2)) #x = (x * norm).unsqueeze(0) x = (x * norm) # split into 4s segments (64000 samples) segment_len = 4 * 16000 chunks = x.split(segment_len) enhanced_chunks = [] for chunk in chunks: if len(chunk) < segment_len: #pad = torch.zeros(segment_len - len(chunk), device=chunk.device) pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4) chunk = torch.cat([chunk, pad]) chunk = chunk.unsqueeze(0) amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3) amp2, pha2, _ = model(amp, pha) out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3) out = (out / norm).squeeze(0) enhanced_chunks.append(out) out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding # back to original rate if orig_sr != 16000: out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr) # Normalize peak = np.max(np.abs(out)) if peak > 0.05: out = out / peak * 0.85 # write file sf.write("enhanced.wav", out, orig_sr) # spectrograms fig, axs = plt.subplots(1, 2, figsize=(16, 4)) # noisy D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256) S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max) librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0) axs[0].set_title("Noisy Spectrogram") # enhanced D_clean = librosa.stft(out, n_fft=512, hop_length=256) S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max) librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0) #librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0) axs[1].set_title("Enhanced Spectrogram") plt.tight_layout() return "enhanced.wav", fig #with gr.Blocks() as demo: # gr.Markdown(ABOUT) # input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True) # enhance_btn = gr.Button("Enhance") # output_audio = gr.Audio(label="Enhanced Audio", type="filepath") # plot_output = gr.Plot(label="Spectrograms") # # enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output]) # #demo.queue().launch() with gr.Blocks() as demo: gr.Markdown(ABOUT) input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True) model_choice = gr.Radio( label="Choose Model (The use of VCTK+DNS is recommended)", choices=["VCTK-Demand", "VCTK+DNS"], value="VCTK-Demand" ) enhance_btn = gr.Button("Enhance") output_audio = gr.Audio(label="Enhanced Audio", type="filepath") plot_output = gr.Plot(label="Spectrograms") enhance_btn.click( fn=enhance, inputs=[input_audio, model_choice], outputs=[output_audio, plot_output] ) gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.") demo.queue().launch()