roychao19477 commited on
Commit
3af0ebe
Β·
1 Parent(s): 56efbc8
Files changed (1) hide show
  1. app.py +28 -20
app.py CHANGED
@@ -44,48 +44,56 @@ sdict = torch.load(ckpt, map_location=device)
44
  model.load_state_dict(sdict["generator"])
45
  model.eval()
46
 
47
-
48
  @spaces.GPU
49
  def enhance(filepath):
50
  with torch.no_grad():
51
- # load & (if needed) resample to model SR
52
  wav, orig_sr = librosa.load(filepath, sr=None)
53
-
54
  if orig_sr != 16000:
55
  wav = librosa.resample(wav, orig_sr, 16000)
56
- # normalize β†’ tensor
57
  x = torch.from_numpy(wav).float().to(device)
58
  norm = torch.sqrt(len(x)/torch.sum(x**2))
59
- x = (x*norm).unsqueeze(0)
 
60
  # STFT β†’ model β†’ ISTFT
61
- amp ,pha , _ = mag_phase_stft(x, 400, 100, 400, 0.3)
62
- with torch.no_grad():
63
- amp2, pha2, comp = model(amp, pha)
64
  out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
65
- out = (out/norm).squeeze().cpu().numpy()
 
66
  # back to original rate
67
  if orig_sr != 16000:
68
- out = librosa.resample(out, 16000, orig_sr, 'PCM_16')
 
69
  # write file
70
  sf.write("enhanced.wav", out, orig_sr)
71
- # build spectrogram
72
 
73
- D = librosa.stft(out, n_fft=1024, hop_length=512)
74
- S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
75
- fig, ax = plt.subplots(figsize=(6,3))
76
- librosa.display.specshow(S, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=ax)
77
- ax.set_title("Enhanced Spectrogram")
78
- plt.colorbar(format="%+2.0f dB", ax=ax)
 
 
 
 
 
 
 
 
79
 
80
- return "enhanced.wav"#, fig
81
 
 
82
 
83
  with gr.Blocks() as demo:
84
  gr.Markdown(ABOUT)
85
- input_audio = gr.Audio(label="Input Audio", type="filepath")
86
  enhance_btn = gr.Button("Enhance")
87
  output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
 
88
 
89
- enhance_btn.click(fn=enhance, inputs=input_audio, outputs=output_audio)
90
 
91
  demo.queue().launch()
 
44
  model.load_state_dict(sdict["generator"])
45
  model.eval()
46
 
 
47
  @spaces.GPU
48
  def enhance(filepath):
49
  with torch.no_grad():
50
+ # load & resample
51
  wav, orig_sr = librosa.load(filepath, sr=None)
 
52
  if orig_sr != 16000:
53
  wav = librosa.resample(wav, orig_sr, 16000)
 
54
  x = torch.from_numpy(wav).float().to(device)
55
  norm = torch.sqrt(len(x)/torch.sum(x**2))
56
+ x = (x * norm).unsqueeze(0)
57
+
58
  # STFT β†’ model β†’ ISTFT
59
+ amp, pha, _ = mag_phase_stft(x, 400, 100, 400, 0.3)
60
+ amp2, pha2, _ = model(amp, pha)
 
61
  out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
62
+ out = (out / norm).squeeze().cpu().numpy()
63
+
64
  # back to original rate
65
  if orig_sr != 16000:
66
+ out = librosa.resample(out, 16000, orig_sr)
67
+
68
  # write file
69
  sf.write("enhanced.wav", out, orig_sr)
 
70
 
71
+ # spectrograms
72
+ fig, axs = plt.subplots(1, 2, figsize=(10, 4))
73
+
74
+ # noisy
75
+ D_noisy = librosa.stft(wav, n_fft=1024, hop_length=512)
76
+ S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
77
+ librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=axs[0])
78
+ axs[0].set_title("Noisy Spectrogram")
79
+
80
+ # enhanced
81
+ D_clean = librosa.stft(out, n_fft=1024, hop_length=512)
82
+ S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
83
+ librosa.display.specshow(S_clean, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1])
84
+ axs[1].set_title("Enhanced Spectrogram")
85
 
86
+ plt.tight_layout()
87
 
88
+ return "enhanced.wav", fig
89
 
90
  with gr.Blocks() as demo:
91
  gr.Markdown(ABOUT)
92
+ input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
93
  enhance_btn = gr.Button("Enhance")
94
  output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
95
+ plot_output = gr.Plot(label="Spectrograms")
96
 
97
+ enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
98
 
99
  demo.queue().launch()