roychao19477 commited on
Commit
3c23ad1
·
1 Parent(s): 9d66cc0
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -53,20 +53,34 @@ def enhance(filepath):
53
  wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
54
  x = torch.from_numpy(wav).float().to(device)
55
  norm = torch.sqrt(len(x)/torch.sum(x**2))
56
- x = (x * norm).unsqueeze(0)
 
57
 
58
- # STFT model ISTFT
59
- amp, pha, _ = mag_phase_stft(x, 400, 100, 400, 0.3)
60
- amp2, pha2, _ = model(amp, pha)
61
- out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
62
- out = (out / norm).squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # back to original rate
65
  if orig_sr != 16000:
66
  out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
67
 
68
  # write file
69
- sf.write("enhanced.wav", out, orig_sr)
70
 
71
  # spectrograms
72
  fig, axs = plt.subplots(1, 2, figsize=(10, 4))
 
53
  wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
54
  x = torch.from_numpy(wav).float().to(device)
55
  norm = torch.sqrt(len(x)/torch.sum(x**2))
56
+ #x = (x * norm).unsqueeze(0)
57
+ x = (x * norm)
58
 
59
+ # split into 4s segments (64000 samples)
60
+ segment_len = 4 * 16000
61
+ chunks = x.split(segment_len)
62
+ enhanced_chunks = []
63
+
64
+ for chunk in chunks:
65
+ if len(chunk) < segment_len:
66
+ pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
67
+ chunk = torch.cat([chunk, pad])
68
+ chunk = chunk.unsqueeze(0)
69
+
70
+ amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
71
+ amp2, pha2, _ = model(amp, pha)
72
+ out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
73
+ out = (out / norm).squeeze(0)
74
+ enhanced_chunks.append(out)
75
+
76
+ out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
77
 
78
  # back to original rate
79
  if orig_sr != 16000:
80
  out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
81
 
82
  # write file
83
+ sf.write("enhanced.wav", out, sr=orig_sr)
84
 
85
  # spectrograms
86
  fig, axs = plt.subplots(1, 2, figsize=(10, 4))