Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,386 Bytes
56efbc8 8bb81da 56efbc8 ab0cc1f c7fba2d db2559d 56efbc8 0ff1354 56efbc8 18c8531 776c247 0ff1354 56efbc8 3af0ebe 658a305 56efbc8 9d66cc0 56efbc8 3c23ad1 3af0ebe 3c23ad1 77a842a 484532b 3c23ad1 3af0ebe 56efbc8 9d66cc0 3af0ebe 2881f71 d49889f 7bf14b8 d49889f 2881f71 56efbc8 c7fba2d 56efbc8 3af0ebe 42fbee6 3af0ebe caa9ec7 3af0ebe caa9ec7 3af0ebe 411fe00 3af0ebe caa9ec7 3af0ebe 56efbc8 3af0ebe 56efbc8 3af0ebe ecbb90e 0ff1354 17efe4f 56efbc8 3af0ebe 0ff1354 eca5e51 0ff1354 56efbc8 3af0ebe 56efbc8 0ff1354 6e13067 56efbc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import shlex
import subprocess
import spaces
import torch
import gradio as gr
# install packages for mamba
def install_mamba():
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
#subprocess.run(shlex.split("pip install numpy==1.26.4"))
install_mamba()
ABOUT = """
# SEMamba: Speech Enhancement
A Mamba-based model that denoises real-world audio.
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
"""
import torch
import yaml
import librosa
import librosa.display
import matplotlib
import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400 import cal_pcs
ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"
# load config
with open(cfg_f, 'r') as f:
cfg = yaml.safe_load(f)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
model = SEMamba(cfg).to(device)
#sdict = torch.load(ckpt, map_location=device)
#model.load_state_dict(sdict["generator"])
#model.eval()
@spaces.GPU
def enhance(filepath, model_name):
# Load model based on selection
ckpt_path = {
"VCTK-Demand": "ckpts/SEMamba_advanced.pth",
"VCTK+DNS": "ckpts/vd.pth"
}[model_name]
print("Loading:", ckpt_path)
model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
model.eval()
with torch.no_grad():
# load & resample
wav, orig_sr = librosa.load(filepath, sr=None)
noisy_wav = wav.copy()
if orig_sr != 16000:
wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
x = torch.from_numpy(wav).float().to(device)
norm = torch.sqrt(len(x)/torch.sum(x**2))
#x = (x * norm).unsqueeze(0)
x = (x * norm)
# split into 4s segments (64000 samples)
segment_len = 4 * 16000
chunks = x.split(segment_len)
enhanced_chunks = []
for chunk in chunks:
if len(chunk) < segment_len:
#pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4)
chunk = torch.cat([chunk, pad])
chunk = chunk.unsqueeze(0)
amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
amp2, pha2, _ = model(amp, pha)
out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
out = (out / norm).squeeze(0)
enhanced_chunks.append(out)
out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
# back to original rate
if orig_sr != 16000:
out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
# Normalize
peak = np.max(np.abs(out))
if peak > 0.05:
out = out / peak * 0.85
# write file
sf.write("enhanced.wav", out, orig_sr)
# spectrograms
fig, axs = plt.subplots(1, 2, figsize=(16, 4))
# noisy
D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256)
S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
axs[0].set_title("Noisy Spectrogram")
# enhanced
D_clean = librosa.stft(out, n_fft=512, hop_length=256)
S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
#librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
axs[1].set_title("Enhanced Spectrogram")
plt.tight_layout()
return "enhanced.wav", fig
#with gr.Blocks() as demo:
# gr.Markdown(ABOUT)
# input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
# enhance_btn = gr.Button("Enhance")
# output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
# plot_output = gr.Plot(label="Spectrograms")
#
# enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
#
#demo.queue().launch()
with gr.Blocks() as demo:
gr.Markdown(ABOUT)
input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
model_choice = gr.Radio(
label="Choose Model (The use of VCTK+DNS is recommended)",
choices=["VCTK-Demand", "VCTK+DNS"],
value="VCTK-Demand"
)
enhance_btn = gr.Button("Enhance")
output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
plot_output = gr.Plot(label="Spectrograms")
enhance_btn.click(
fn=enhance,
inputs=[input_audio, model_choice],
outputs=[output_audio, plot_output]
)
gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.")
demo.queue().launch()
|