Spaces:
Running
Running
Update gradio_app.py
Browse files- gradio_app.py +12 -6
gradio_app.py
CHANGED
@@ -40,11 +40,11 @@ def separate_speakers_core(audio_path):
|
|
40 |
waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
|
41 |
|
42 |
if waveform.dim() == 1:
|
43 |
-
waveform = waveform.unsqueeze(0)
|
44 |
-
audio_input = waveform.unsqueeze(0).to(device)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
-
ests_speech = sep_model(audio_input).squeeze(0)
|
48 |
|
49 |
session_id = uuid.uuid4().hex[:8]
|
50 |
output_dir = os.path.join("output_sep", session_id)
|
@@ -53,15 +53,21 @@ def separate_speakers_core(audio_path):
|
|
53 |
output_files = []
|
54 |
for i in range(ests_speech.shape[0]):
|
55 |
path = os.path.join(output_dir, f"speaker_{i+1}.wav")
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
|
|
|
|
|
|
|
|
|
60 |
|
61 |
return output_files
|
62 |
|
63 |
|
64 |
|
|
|
65 |
@spaces.GPU()
|
66 |
def separate_dnr(audio_file):
|
67 |
audio, sr = torchaudio.load(audio_file)
|
|
|
40 |
waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
|
41 |
|
42 |
if waveform.dim() == 1:
|
43 |
+
waveform = waveform.unsqueeze(0) # Ensure shape is (1, samples)
|
44 |
+
audio_input = waveform.unsqueeze(0).to(device) # Shape: (1, 1, samples)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
+
ests_speech = sep_model(audio_input).squeeze(0) # Shape: (num_speakers, samples)
|
48 |
|
49 |
session_id = uuid.uuid4().hex[:8]
|
50 |
output_dir = os.path.join("output_sep", session_id)
|
|
|
53 |
output_files = []
|
54 |
for i in range(ests_speech.shape[0]):
|
55 |
path = os.path.join(output_dir, f"speaker_{i+1}.wav")
|
56 |
+
speaker_waveform = ests_speech[i].cpu()
|
57 |
+
|
58 |
+
if speaker_waveform.dim() == 1:
|
59 |
+
speaker_waveform = speaker_waveform.unsqueeze(0) # (1, samples)
|
60 |
|
61 |
+
# Ensure correct dtype and save in a widely compatible format
|
62 |
+
speaker_waveform = speaker_waveform.to(torch.float32)
|
63 |
+
torchaudio.save(path, speaker_waveform, TARGET_SR, format="wav", encoding="PCM_S", bits_per_sample=16)
|
64 |
+
output_files.append(path)
|
65 |
|
66 |
return output_files
|
67 |
|
68 |
|
69 |
|
70 |
+
|
71 |
@spaces.GPU()
|
72 |
def separate_dnr(audio_file):
|
73 |
audio, sr = torchaudio.load(audio_file)
|