qqwjq1981 commited on
Commit
7add5f8
·
verified ·
1 Parent(s): c6f940f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -91
app.py CHANGED
@@ -133,7 +133,7 @@ def handle_feedback(feedback):
133
  conn.commit()
134
  return "Thank you for your feedback!", None
135
 
136
- def segment_background_audio(audio_path, background_audio_path="background_segments.wav"):
137
  """
138
  Uses Demucs to separate audio and extract background (non-vocal) parts.
139
  Merges drums, bass, and other stems into a single background track.
@@ -150,6 +150,7 @@ def segment_background_audio(audio_path, background_audio_path="background_segme
150
  stem_dir = os.path.join("separated", "htdemucs", filename)
151
 
152
  # Step 3: Load and merge background stems
 
153
  drums = AudioSegment.from_wav(os.path.join(stem_dir, "drums.wav"))
154
  bass = AudioSegment.from_wav(os.path.join(stem_dir, "bass.wav"))
155
  other = AudioSegment.from_wav(os.path.join(stem_dir, "other.wav"))
@@ -158,7 +159,8 @@ def segment_background_audio(audio_path, background_audio_path="background_segme
158
 
159
  # Step 4: Export the merged background
160
  background.export(background_audio_path, format="wav")
161
- return background_audio_path
 
162
 
163
  def transcribe_video_with_speakers(video_path):
164
  # Extract audio from video
@@ -167,7 +169,7 @@ def transcribe_video_with_speakers(video_path):
167
  video.audio.write_audiofile(audio_path)
168
  logger.info(f"Audio extracted from video: {audio_path}")
169
 
170
- segment_result = segment_background_audio(audio_path)
171
  print(f"Saved non-speech (background) audio to local")
172
 
173
  # Set up device
@@ -180,7 +182,7 @@ def transcribe_video_with_speakers(video_path):
180
  logger.info("WhisperX model loaded")
181
 
182
  # Transcribe
183
- result = model.transcribe(audio_path, chunk_size=6, print_progress = True)
184
  logger.info("Audio transcription completed")
185
 
186
  # Get the detected language
@@ -188,12 +190,12 @@ def transcribe_video_with_speakers(video_path):
188
  logger.debug(f"Detected language: {detected_language}")
189
  # Alignment
190
  # model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
191
- # result = whisperx.align(result["segments"], model_a, metadata, audio_path, device)
192
  # logger.info("Transcription alignment completed")
193
 
194
  # Diarization (works independently of Whisper model size)
195
  diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_api_key, device=device)
196
- diarize_segments = diarize_model(audio_path)
197
  logger.info("Speaker diarization completed")
198
 
199
  # Assign speakers
@@ -213,31 +215,62 @@ def transcribe_video_with_speakers(video_path):
213
  }
214
  for segment in result["segments"]
215
  ]
216
-
217
  # Collect audio for each speaker
218
  speaker_audio = {}
219
- for segment in result["segments"]:
220
- speaker = segment["speaker"]
221
- if speaker not in speaker_audio:
222
- speaker_audio[speaker] = []
223
- speaker_audio[speaker].append((segment["start"], segment["end"]))
224
-
 
 
 
 
 
 
 
 
 
 
 
225
  # Collapse and truncate speaker audio
226
  speaker_sample_paths = {}
227
- audio_clip = AudioFileClip(audio_path)
 
 
 
228
  for speaker, segments in speaker_audio.items():
 
 
229
  speaker_clips = [audio_clip.subclip(start, end) for start, end in segments]
230
- combined_clip = concatenate_audioclips(speaker_clips)
 
 
 
 
 
 
 
 
 
 
231
  truncated_clip = combined_clip.subclip(0, min(30, combined_clip.duration))
 
 
 
232
  sample_path = f"speaker_{speaker}_sample.wav"
233
  truncated_clip.write_audiofile(sample_path)
234
  speaker_sample_paths[speaker] = sample_path
235
- logger.info(f"Created sample for {speaker}: {sample_path}")
236
 
237
- # Clean up
 
238
  video.close()
239
  audio_clip.close()
240
- os.remove(audio_path)
 
241
 
242
  return transcript_with_speakers, detected_language
243
 
@@ -637,6 +670,74 @@ def collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90):
637
 
638
  return collapsed
639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
  def post_edit_transcribed_segments(transcription_json, video_path,
641
  interval_sec=0.5,
642
  text_similarity_threshold=80,
@@ -658,80 +759,12 @@ def post_edit_transcribed_segments(transcription_json, video_path,
658
  # Step 2: Collapse repetitive OCR
659
  collapsed_ocr = collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90)
660
 
661
- # Step 3: Refine existing WhisperX segments (Phase 1)
662
- merged_segments = []
663
- used_ocr_indices = set()
664
-
665
- for entry_idx, entry in enumerate(transcription_json):
666
- start = entry.get("start", 0)
667
- end = entry.get("end", 0)
668
- base_text = entry.get("text", "")
669
-
670
- best_match_idx = None
671
- best_score = -1
672
-
673
- for ocr_idx, ocr in enumerate(collapsed_ocr):
674
- time_overlap = not (ocr["end"] < start - time_tolerance or ocr["start"] > end + time_tolerance)
675
- if not time_overlap:
676
- continue
677
-
678
- sim = fuzz.ratio(ocr["text"], base_text)
679
- if sim > best_score:
680
- best_score = sim
681
- best_match_idx = ocr_idx
682
-
683
- updated_entry = entry.copy()
684
- if best_match_idx is not None and best_score >= text_similarity_threshold:
685
- updated_entry["text"] = collapsed_ocr[best_match_idx]["text"]
686
- updated_entry["ocr_matched"] = True
687
- updated_entry["ocr_similarity"] = best_score
688
- used_ocr_indices.add(best_match_idx)
689
- else:
690
- updated_entry["ocr_matched"] = False
691
- updated_entry["ocr_similarity"] = best_score if best_score >= 0 else None
692
-
693
- merged_segments.append(updated_entry)
694
-
695
- # Step 4: Insert unused OCR segments (Phase 2)
696
- inserted_segments = []
697
- for ocr_idx, ocr in enumerate(collapsed_ocr):
698
- if ocr_idx in used_ocr_indices:
699
- continue
700
-
701
- # Check for fuzzy duplicates in WhisperX
702
- duplicate = False
703
- for whisper_seg in transcription_json:
704
- if abs(ocr["start"] - whisper_seg["start"]) < time_tolerance or abs(ocr["end"] - whisper_seg["end"]) < time_tolerance:
705
- sim = fuzz.ratio(ocr["text"], whisper_seg["text"])
706
- if sim >= text_similarity_threshold:
707
- duplicate = True
708
- break
709
-
710
- if duplicate:
711
- logger.debug(f"🟡 Skipping near-duplicate OCR: '{ocr['text']}'")
712
- continue
713
-
714
- # Infer speaker from nearest WhisperX entry
715
- nearby = sorted(transcription_json, key=lambda x: abs(x["start"] - ocr["start"]))
716
- speaker_guess = nearby[0].get("speaker", "unknown") if nearby else "unknown"
717
-
718
- inserted_segment = {
719
- "start": ocr["start"],
720
- "end": ocr["end"],
721
- "text": ocr["text"],
722
- "speaker": speaker_guess
723
- }
724
- inserted_segments.append(inserted_segment)
725
-
726
- # Step 5: Combine and sort
727
- final_segments = merged_segments + inserted_segments
728
- final_segments = sorted(final_segments, key=lambda x: x["start"])
729
-
730
- print(f"✅ Post-editing completed: {len(final_segments)} total segments "
731
- f"({len(inserted_segments)} OCR-inserted segments)")
732
-
733
- return final_segments
734
 
 
 
735
 
736
  def process_entry(entry, i, tts_model, video_width, video_height, process_mode, target_language, font_path, speaker_sample_paths=None):
737
  logger.debug(f"Processing entry {i}: {entry}")
 
133
  conn.commit()
134
  return "Thank you for your feedback!", None
135
 
136
+ def segment_background_audio(audio_path, background_audio_path="background_segments.wav", speech_audio_path="speech_segment.wav"):
137
  """
138
  Uses Demucs to separate audio and extract background (non-vocal) parts.
139
  Merges drums, bass, and other stems into a single background track.
 
150
  stem_dir = os.path.join("separated", "htdemucs", filename)
151
 
152
  # Step 3: Load and merge background stems
153
+ vocals = AudioSegment.from_wav(os.path.join(stem_dir, "vocals.wav"))
154
  drums = AudioSegment.from_wav(os.path.join(stem_dir, "drums.wav"))
155
  bass = AudioSegment.from_wav(os.path.join(stem_dir, "bass.wav"))
156
  other = AudioSegment.from_wav(os.path.join(stem_dir, "other.wav"))
 
159
 
160
  # Step 4: Export the merged background
161
  background.export(background_audio_path, format="wav")
162
+ vocals.export(speech_audio_path, format="wav")
163
+ return background_audio_path, speech_audio_path
164
 
165
  def transcribe_video_with_speakers(video_path):
166
  # Extract audio from video
 
169
  video.audio.write_audiofile(audio_path)
170
  logger.info(f"Audio extracted from video: {audio_path}")
171
 
172
+ segment_result, speech_audio_path = segment_background_audio(audio_path)
173
  print(f"Saved non-speech (background) audio to local")
174
 
175
  # Set up device
 
182
  logger.info("WhisperX model loaded")
183
 
184
  # Transcribe
185
+ result = model.transcribe(speech_audio_path, chunk_size=6, print_progress = True)
186
  logger.info("Audio transcription completed")
187
 
188
  # Get the detected language
 
190
  logger.debug(f"Detected language: {detected_language}")
191
  # Alignment
192
  # model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
193
+ # result = whisperx.align(result["segments"], model_a, metadata, speech_audio_path, device)
194
  # logger.info("Transcription alignment completed")
195
 
196
  # Diarization (works independently of Whisper model size)
197
  diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_api_key, device=device)
198
+ diarize_segments = diarize_model(speech_audio_path)
199
  logger.info("Speaker diarization completed")
200
 
201
  # Assign speakers
 
215
  }
216
  for segment in result["segments"]
217
  ]
218
+
219
  # Collect audio for each speaker
220
  speaker_audio = {}
221
+ logger.info("🔎 Start collecting valid audio segments per speaker...")
222
+
223
+ for idx, segment in enumerate(result["segments"]):
224
+ speaker = segment.get("speaker", "SPEAKER_00")
225
+ start = segment["start"]
226
+ end = segment["end"]
227
+
228
+ if end > start and (end - start) > 0.05: # Require >50ms duration
229
+ if speaker not in speaker_audio:
230
+ speaker_audio[speaker] = [(start, end)]
231
+ else:
232
+ speaker_audio[speaker].append((start, end))
233
+
234
+ logger.debug(f"Segment {idx}: Added to speaker {speaker} [{start:.2f}s → {end:.2f}s]")
235
+ else:
236
+ logger.warning(f"⚠️ Segment {idx} discarded: invalid duration ({start:.2f}s → {end:.2f}s)")
237
+
238
  # Collapse and truncate speaker audio
239
  speaker_sample_paths = {}
240
+ audio_clip = AudioFileClip(speech_audio_path)
241
+
242
+ logger.info(f"🔎 Found {len(speaker_audio)} speakers with valid segments. Start creating speaker samples...")
243
+
244
  for speaker, segments in speaker_audio.items():
245
+ logger.info(f"🔹 Speaker {speaker}: {len(segments)} valid segments")
246
+
247
  speaker_clips = [audio_clip.subclip(start, end) for start, end in segments]
248
+ if not speaker_clips:
249
+ logger.warning(f"⚠️ No valid audio clips for speaker {speaker}. Skipping sample creation.")
250
+ continue
251
+
252
+ if len(speaker_clips) == 1:
253
+ logger.debug(f"Speaker {speaker}: Only one clip, skipping concatenation.")
254
+ combined_clip = speaker_clips[0]
255
+ else:
256
+ logger.debug(f"Speaker {speaker}: Concatenating {len(speaker_clips)} clips.")
257
+ combined_clip = concatenate_audioclips(speaker_clips)
258
+
259
  truncated_clip = combined_clip.subclip(0, min(30, combined_clip.duration))
260
+ logger.debug(f"Speaker {speaker}: Truncated to {truncated_clip.duration:.2f} seconds.")
261
+
262
+ # Step 4: Save the final result
263
  sample_path = f"speaker_{speaker}_sample.wav"
264
  truncated_clip.write_audiofile(sample_path)
265
  speaker_sample_paths[speaker] = sample_path
266
+ logger.info(f"Created and saved sample for {speaker}: {sample_path}")
267
 
268
+ # Cleanup
269
+ logger.info("🧹 Closing audio clip and removing temporary files...")
270
  video.close()
271
  audio_clip.close()
272
+ os.remove(speech_audio_path)
273
+ logger.info("✅ Finished processing all speaker samples.")
274
 
275
  return transcript_with_speakers, detected_language
276
 
 
670
 
671
  return collapsed
672
 
673
+ def merge_speaker_and_time_from_whisperx(ocr_json, whisperx_json, text_sim_threshold=80, replace_threshold=90):
674
+ """
675
+ Given OCR and WhisperX segments, merge speaker ID and optionally replace time.
676
+ """
677
+ merged = []
678
+
679
+ for ocr in ocr_json:
680
+ ocr_start = ocr["start"]
681
+ ocr_end = ocr["end"]
682
+ ocr_text = ocr["text"]
683
+
684
+ best_match = None
685
+ best_score = -1
686
+
687
+ for wx in whisperx_json:
688
+ wx_start, wx_end = wx["start"], wx["end"]
689
+ wx_text = wx["text"]
690
+
691
+ # Time overlap (soft constraint)
692
+ time_center_diff = abs((ocr_start + ocr_end)/2 - (wx_start + wx_end)/2)
693
+ if time_center_diff > 3: # skip if too far
694
+ continue
695
+
696
+ # Text similarity
697
+ sim = fuzz.ratio(ocr_text, wx_text)
698
+ if sim > best_score:
699
+ best_score = sim
700
+ best_match = wx
701
+
702
+ new_entry = copy.deepcopy(ocr)
703
+ if best_match:
704
+ new_entry["speaker"] = best_match.get("speaker", "UNKNOWN")
705
+ new_entry["ocr_similarity"] = best_score
706
+
707
+ if best_score >= replace_threshold:
708
+ new_entry["start"] = best_match["start"]
709
+ new_entry["end"] = best_match["end"]
710
+
711
+ else:
712
+ new_entry["speaker"] = "UNKNOWN"
713
+ new_entry["ocr_similarity"] = None
714
+
715
+ merged.append(new_entry)
716
+
717
+ return merged
718
+
719
+
720
+ def realign_ocr_segments(merged_ocr_json, min_gap=0.2):
721
+ """
722
+ Realign OCR segments to avoid overlaps using midpoint-based adjustment.
723
+ """
724
+ merged_ocr_json = sorted(merged_ocr_json, key=lambda x: x["start"])
725
+
726
+ for i in range(1, len(merged_ocr_json)):
727
+ prev = merged_ocr_json[i - 1]
728
+ curr = merged_ocr_json[i]
729
+
730
+ # If current overlaps with previous, adjust
731
+ if curr["start"] < prev["end"] + min_gap:
732
+ midpoint = (prev["end"] + curr["start"]) / 2
733
+ prev["end"] = round(midpoint - min_gap / 2, 3)
734
+ curr["start"] = round(midpoint + min_gap / 2, 3)
735
+
736
+ # Prevent negative durations
737
+ if curr["start"] >= curr["end"]:
738
+ curr["end"] = round(curr["start"] + 0.3, 3)
739
+
740
+ return merged_ocr_json
741
  def post_edit_transcribed_segments(transcription_json, video_path,
742
  interval_sec=0.5,
743
  text_similarity_threshold=80,
 
759
  # Step 2: Collapse repetitive OCR
760
  collapsed_ocr = collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90)
761
 
762
+ # Step 3: Merge and realign OCR segments.
763
+ ocr_merged = merge_speaker_and_time_from_whisperx(ocr_json, whisperx_json)
764
+ ocr_realigned = realign_ocr_segments(ocr_merged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
 
766
+ logger.info(f"✅ Final merged and realigned OCR: {len(ocr_realigned)} segments")
767
+ return ocr_realigned
768
 
769
  def process_entry(entry, i, tts_model, video_width, video_height, process_mode, target_language, font_path, speaker_sample_paths=None):
770
  logger.debug(f"Processing entry {i}: {entry}")