fffiloni commited on
Commit
2abb303
·
verified ·
1 Parent(s): b3908ae

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +8 -22
gradio_app.py CHANGED
@@ -104,13 +104,9 @@ def separate_dnr_video(video_path):
104
 
105
  @spaces.GPU()
106
  def separate_speakers_video(video_path):
107
- # Extract audio
108
- video = VideoFileClip(video_path)
109
- audio_path = f"/tmp/{uuid.uuid4().hex}_audio.wav"
110
- video.audio.write_audiofile(audio_path, fps=TARGET_SR, verbose=False, logger=None)
111
-
112
- # Load and resample
113
  waveform, original_sr = torchaudio.load(audio_path)
 
114
  if original_sr != TARGET_SR:
115
  waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
116
 
@@ -118,34 +114,23 @@ def separate_speakers_video(video_path):
118
  waveform = waveform.unsqueeze(0)
119
  audio_input = waveform.unsqueeze(0).to(device)
120
 
121
- # Inference
122
  with torch.no_grad():
123
  ests_speech = sep_model(audio_input).squeeze(0)
124
 
125
- # Output directory
126
  session_id = uuid.uuid4().hex[:8]
127
  output_dir = os.path.join("output_sep_video", session_id)
128
  os.makedirs(output_dir, exist_ok=True)
129
 
130
  output_videos = []
131
  for i in range(ests_speech.shape[0]):
132
- audio_np = ests_speech[i].cpu().numpy()
133
- if audio_np.ndim == 1:
134
- audio_np = audio_np[:, None] # Ensure shape [samples, 1]
135
-
136
- # Save separated audio
137
  separated_audio_path = os.path.join(output_dir, f"speaker_{i+1}.wav")
138
- sf.write(separated_audio_path, audio_np, TARGET_SR)
139
-
140
- # Combine with original video (no original audio)
141
- output_video_path = os.path.join(output_dir, f"speaker_{i+1}_video.mp4")
142
- new_audio = AudioFileClip(separated_audio_path)
143
- new_video = video.set_audio(new_audio)
144
- new_video.write_videofile(output_video_path, audio_codec="aac", verbose=False, logger=None)
145
 
146
- output_videos.append(output_video_path)
 
 
147
 
148
- # Pad with empty videos if less than MAX_SPEAKERS
149
  updates = []
150
  for i in range(MAX_SPEAKERS):
151
  if i < len(output_videos):
@@ -155,6 +140,7 @@ def separate_speakers_video(video_path):
155
  return updates
156
 
157
 
 
158
  # --- Gradio UI ---
159
  with gr.Blocks() as demo:
160
  gr.Markdown("# TIGER: Time-frequency Interleaved Gain Extraction and Reconstruction for Efficient Speech Separation")
 
104
 
105
  @spaces.GPU()
106
  def separate_speakers_video(video_path):
107
+ audio_path, video = extract_audio_from_video(video_path)
 
 
 
 
 
108
  waveform, original_sr = torchaudio.load(audio_path)
109
+
110
  if original_sr != TARGET_SR:
111
  waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
112
 
 
114
  waveform = waveform.unsqueeze(0)
115
  audio_input = waveform.unsqueeze(0).to(device)
116
 
 
117
  with torch.no_grad():
118
  ests_speech = sep_model(audio_input).squeeze(0)
119
 
 
120
  session_id = uuid.uuid4().hex[:8]
121
  output_dir = os.path.join("output_sep_video", session_id)
122
  os.makedirs(output_dir, exist_ok=True)
123
 
124
  output_videos = []
125
  for i in range(ests_speech.shape[0]):
 
 
 
 
 
126
  separated_audio_path = os.path.join(output_dir, f"speaker_{i+1}.wav")
127
+ audio_np = ests_speech[i].cpu().numpy()
128
+ sf.write(separated_audio_path, audio_np, TARGET_SR, format='WAV', subtype='PCM_16')
 
 
 
 
 
129
 
130
+ speaker_video_path = os.path.join(output_dir, f"speaker_{i+1}_video.mp4")
131
+ final_video = attach_audio_to_video(video, separated_audio_path, speaker_video_path)
132
+ output_videos.append(final_video)
133
 
 
134
  updates = []
135
  for i in range(MAX_SPEAKERS):
136
  if i < len(output_videos):
 
140
  return updates
141
 
142
 
143
+
144
  # --- Gradio UI ---
145
  with gr.Blocks() as demo:
146
  gr.Markdown("# TIGER: Time-frequency Interleaved Gain Extraction and Reconstruction for Efficient Speech Separation")