roychao19477 commited on
Commit
56efbc8
Β·
1 Parent(s): 8bb81da
Files changed (1) hide show
  1. app.py +86 -15
app.py CHANGED
@@ -1,20 +1,91 @@
1
- import gradio as gr
 
2
  import spaces
3
- import numpy as np
4
- import soundfile as sf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  @spaces.GPU
7
- def dummy_enhance(audio_path):
8
- print("Audio received:", audio_path)
9
- # Return the same file as a dummy operation
10
- return audio_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  with gr.Blocks() as demo:
13
- gr.Markdown("# SEMamba: ZeroGPU Upload Test")
14
- audio_input = gr.Audio(type="filepath", label="Upload Audio", interactive=True)
15
- submit = gr.Button("Run Enhancement")
16
- audio_output = gr.Audio(type="filepath", label="Enhanced Output")
17
-
18
- submit.click(dummy_enhance, inputs=[audio_input], outputs=[audio_output])
19
-
20
- demo.queue().launch() # No ssr_mode=False
 
1
+ import shlex
2
+ import subprocess
3
  import spaces
4
+ import torch
5
+ import gradio as gr
6
+
7
+ # install packages for mamba
8
+ def install_mamba():
9
+ #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
10
+ #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
11
+ subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
12
+ #subprocess.run(shlex.split("pip install numpy==1.26.4"))
13
+
14
+ install_mamba()
15
+
16
+ ABOUT = """
17
+ # SEMamba: Speech Enhancement
18
+ A Mamba-based model that denoises real-world audio.
19
+ Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
20
+ """
21
+
22
+
23
+ import torch
24
+ import yaml
25
+ import librosa
26
+ import librosa.display
27
+ import matplotlib
28
+ from models.stfts import mag_phase_stft, mag_phase_istft
29
+ from models.generator import SEMamba
30
+ from models.pcs400 import cal_pcs
31
+
32
+ ckpt = "ckpts/SEMamba_advanced.pth"
33
+ cfg_f = "recipes/SEMamba_advanced.yaml"
34
+
35
+ # load config
36
+ with open(cfg_f, 'r') as f:
37
+ cfg = yaml.safe_load(f)
38
+
39
+
40
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ device = "cuda"
42
+ model = SEMamba(cfg).to(device)
43
+ sdict = torch.load(ckpt, map_location=device)
44
+ model.load_state_dict(sdict["generator"])
45
+ model.eval()
46
+
47
 
48
  @spaces.GPU
49
+ def enhance(filepath):
50
+ with torch.no_grad():
51
+ # load & (if needed) resample to model SR
52
+ wav, orig_sr = librosa.load(filepath, sr=None)
53
+
54
+ if orig_sr != 16000:
55
+ wav = librosa.resample(wav, orig_sr, 16000)
56
+ # normalize β†’ tensor
57
+ x = torch.from_numpy(wav).float().to(device)
58
+ norm = torch.sqrt(len(x)/torch.sum(x**2))
59
+ x = (x*norm).unsqueeze(0)
60
+ # STFT β†’ model β†’ ISTFT
61
+ amp ,pha , _ = mag_phase_stft(x, 400, 100, 400, 0.3)
62
+ with torch.no_grad():
63
+ amp2, pha2, comp = model(amp, pha)
64
+ out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
65
+ out = (out/norm).squeeze().cpu().numpy()
66
+ # back to original rate
67
+ if orig_sr != 16000:
68
+ out = librosa.resample(out, 16000, orig_sr, 'PCM_16')
69
+ # write file
70
+ sf.write("enhanced.wav", out, orig_sr)
71
+ # build spectrogram
72
+
73
+ D = librosa.stft(out, n_fft=1024, hop_length=512)
74
+ S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
75
+ fig, ax = plt.subplots(figsize=(6,3))
76
+ librosa.display.specshow(S, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=ax)
77
+ ax.set_title("Enhanced Spectrogram")
78
+ plt.colorbar(format="%+2.0f dB", ax=ax)
79
+
80
+ return "enhanced.wav"#, fig
81
+
82
 
83
  with gr.Blocks() as demo:
84
+ gr.Markdown(ABOUT)
85
+ input_audio = gr.Audio(label="Input Audio", type="filepath")
86
+ enhance_btn = gr.Button("Enhance")
87
+ output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
88
+
89
+ enhance_btn.click(fn=enhance, inputs=input_audio, outputs=output_audio)
90
+
91
+ demo.queue().launch()