fffiloni commited on
Commit
6d25e94
·
verified ·
1 Parent(s): b090534

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +12 -6
gradio_app.py CHANGED
@@ -40,11 +40,11 @@ def separate_speakers_core(audio_path):
40
  waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
41
 
42
  if waveform.dim() == 1:
43
- waveform = waveform.unsqueeze(0)
44
- audio_input = waveform.unsqueeze(0).to(device)
45
 
46
  with torch.no_grad():
47
- ests_speech = sep_model(audio_input).squeeze(0)
48
 
49
  session_id = uuid.uuid4().hex[:8]
50
  output_dir = os.path.join("output_sep", session_id)
@@ -53,15 +53,21 @@ def separate_speakers_core(audio_path):
53
  output_files = []
54
  for i in range(ests_speech.shape[0]):
55
  path = os.path.join(output_dir, f"speaker_{i+1}.wav")
56
- waveform = ests_speech[i].cpu().unsqueeze(0) # (1, samples)
57
- torchaudio.save(path, waveform, TARGET_SR)
58
- output_files.append(path)
 
59
 
 
 
 
 
60
 
61
  return output_files
62
 
63
 
64
 
 
65
  @spaces.GPU()
66
  def separate_dnr(audio_file):
67
  audio, sr = torchaudio.load(audio_file)
 
40
  waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
41
 
42
  if waveform.dim() == 1:
43
+ waveform = waveform.unsqueeze(0) # Ensure shape is (1, samples)
44
+ audio_input = waveform.unsqueeze(0).to(device) # Shape: (1, 1, samples)
45
 
46
  with torch.no_grad():
47
+ ests_speech = sep_model(audio_input).squeeze(0) # Shape: (num_speakers, samples)
48
 
49
  session_id = uuid.uuid4().hex[:8]
50
  output_dir = os.path.join("output_sep", session_id)
 
53
  output_files = []
54
  for i in range(ests_speech.shape[0]):
55
  path = os.path.join(output_dir, f"speaker_{i+1}.wav")
56
+ speaker_waveform = ests_speech[i].cpu()
57
+
58
+ if speaker_waveform.dim() == 1:
59
+ speaker_waveform = speaker_waveform.unsqueeze(0) # (1, samples)
60
 
61
+ # Ensure correct dtype and save in a widely compatible format
62
+ speaker_waveform = speaker_waveform.to(torch.float32)
63
+ torchaudio.save(path, speaker_waveform, TARGET_SR, format="wav", encoding="PCM_S", bits_per_sample=16)
64
+ output_files.append(path)
65
 
66
  return output_files
67
 
68
 
69
 
70
+
71
  @spaces.GPU()
72
  def separate_dnr(audio_file):
73
  audio, sr = torchaudio.load(audio_file)