Spaces:
Running
on
Zero
Running
on
Zero
import shlex | |
import subprocess | |
import spaces | |
import torch | |
import gradio as gr | |
# install packages for mamba | |
def install_mamba(): | |
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118")) | |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) | |
#subprocess.run(shlex.split("pip install numpy==1.26.4")) | |
install_mamba() | |
ABOUT = """ | |
# SEMamba: Speech Enhancement | |
A Mamba-based model that denoises real-world audio. | |
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram. | |
""" | |
import torch | |
import yaml | |
import librosa | |
import librosa.display | |
import matplotlib | |
import numpy as np | |
import soundfile as sf | |
import matplotlib.pyplot as plt | |
from models.stfts import mag_phase_stft, mag_phase_istft | |
from models.generator import SEMamba | |
from models.pcs400 import cal_pcs | |
ckpt = "ckpts/SEMamba_advanced.pth" | |
cfg_f = "recipes/SEMamba_advanced.yaml" | |
# load config | |
with open(cfg_f, 'r') as f: | |
cfg = yaml.safe_load(f) | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = "cuda" | |
model = SEMamba(cfg).to(device) | |
#sdict = torch.load(ckpt, map_location=device) | |
#model.load_state_dict(sdict["generator"]) | |
#model.eval() | |
def enhance(filepath, model_name): | |
# Load model based on selection | |
ckpt_path = { | |
"VCTK-Demand": "ckpts/SEMamba_advanced.pth", | |
"VCTK+DNS": "ckpts/vd.pth" | |
}[model_name] | |
print("Loading:", ckpt_path) | |
model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"]) | |
model.eval() | |
with torch.no_grad(): | |
# load & resample | |
wav, orig_sr = librosa.load(filepath, sr=None) | |
noisy_wav = wav.copy() | |
if orig_sr != 16000: | |
wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000) | |
x = torch.from_numpy(wav).float().to(device) | |
norm = torch.sqrt(len(x)/torch.sum(x**2)) | |
#x = (x * norm).unsqueeze(0) | |
x = (x * norm) | |
# split into 4s segments (64000 samples) | |
segment_len = 4 * 16000 | |
chunks = x.split(segment_len) | |
enhanced_chunks = [] | |
for chunk in chunks: | |
if len(chunk) < segment_len: | |
#pad = torch.zeros(segment_len - len(chunk), device=chunk.device) | |
pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4) | |
chunk = torch.cat([chunk, pad]) | |
chunk = chunk.unsqueeze(0) | |
amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3) | |
amp2, pha2, _ = model(amp, pha) | |
out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3) | |
out = (out / norm).squeeze(0) | |
enhanced_chunks.append(out) | |
out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding | |
# back to original rate | |
if orig_sr != 16000: | |
out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr) | |
# Normalize | |
peak = np.max(np.abs(out)) | |
if peak > 0.05: | |
out = out / peak * 0.85 | |
# write file | |
sf.write("enhanced.wav", out, orig_sr) | |
# spectrograms | |
fig, axs = plt.subplots(1, 2, figsize=(16, 4)) | |
# noisy | |
D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256) | |
S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max) | |
librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0) | |
axs[0].set_title("Noisy Spectrogram") | |
# enhanced | |
D_clean = librosa.stft(out, n_fft=512, hop_length=256) | |
S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max) | |
librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0) | |
#librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0) | |
axs[1].set_title("Enhanced Spectrogram") | |
plt.tight_layout() | |
return "enhanced.wav", fig | |
#with gr.Blocks() as demo: | |
# gr.Markdown(ABOUT) | |
# input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True) | |
# enhance_btn = gr.Button("Enhance") | |
# output_audio = gr.Audio(label="Enhanced Audio", type="filepath") | |
# plot_output = gr.Plot(label="Spectrograms") | |
# | |
# enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output]) | |
# | |
#demo.queue().launch() | |
with gr.Blocks() as demo: | |
gr.Markdown(ABOUT) | |
input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True) | |
model_choice = gr.Radio( | |
label="Choose Model (The use of VCTK+DNS is recommended)", | |
choices=["VCTK-Demand", "VCTK+DNS"], | |
value="VCTK-Demand" | |
) | |
enhance_btn = gr.Button("Enhance") | |
output_audio = gr.Audio(label="Enhanced Audio", type="filepath") | |
plot_output = gr.Plot(label="Spectrograms") | |
enhance_btn.click( | |
fn=enhance, | |
inputs=[input_audio, model_choice], | |
outputs=[output_audio, plot_output] | |
) | |
gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.") | |
demo.queue().launch() | |