qqwjq1981 commited on
Commit
f99e269
·
verified ·
1 Parent(s): 66933cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -262
app.py CHANGED
@@ -275,70 +275,70 @@ def transcribe_video_with_speakers(video_path):
275
  return transcript_with_speakers, detected_language
276
 
277
  # Function to get the appropriate translation model based on target language
278
- def get_translation_model(source_language, target_language):
279
- """
280
- Get the translation model based on the source and target language.
281
 
282
- Parameters:
283
- - target_language (str): The language to translate the content into (e.g., 'es', 'fr').
284
- - source_language (str): The language of the input content (default is 'en' for English).
285
 
286
- Returns:
287
- - str: The translation model identifier.
288
- """
289
- # List of allowable languages
290
- allowable_languages = ["en", "es", "fr", "zh", "de", "it", "pt", "ja", "ko", "ru", "hi", "tr"]
291
-
292
- # Validate source and target languages
293
- if source_language not in allowable_languages:
294
- logger.debug(f"Invalid source language '{source_language}'. Supported languages are: {', '.join(allowable_languages)}")
295
- # Return a default model if source language is invalid
296
- source_language = "en" # Default to 'en'
297
-
298
- if target_language not in allowable_languages:
299
- logger.debug(f"Invalid target language '{target_language}'. Supported languages are: {', '.join(allowable_languages)}")
300
- # Return a default model if target language is invalid
301
- target_language = "zh" # Default to 'zh'
302
-
303
- if source_language == target_language:
304
- source_language = "en" # Default to 'en'
305
- target_language = "zh" # Default to 'zh'
306
-
307
- # Return the model using string concatenation
308
- return f"Helsinki-NLP/opus-mt-{source_language}-{target_language}"
309
-
310
- def translate_single_entry(entry, translator):
311
- original_text = entry["text"]
312
- translated_text = translator(original_text)[0]['translation_text']
313
- return {
314
- "start": entry["start"],
315
- "original": original_text,
316
- "translated": translated_text,
317
- "end": entry["end"],
318
- "speaker": entry["speaker"]
319
- }
320
-
321
- def translate_text(transcription_json, source_language, target_language):
322
- # Load the translation model for the specified target language
323
- translation_model_id = get_translation_model(source_language, target_language)
324
- logger.debug(f"Translation model: {translation_model_id}")
325
- translator = pipeline("translation", model=translation_model_id)
326
-
327
- # Use ThreadPoolExecutor to parallelize translations
328
- with concurrent.futures.ThreadPoolExecutor() as executor:
329
- # Submit all translation tasks and collect results
330
- translate_func = lambda entry: translate_single_entry(entry, translator)
331
- translated_json = list(executor.map(translate_func, transcription_json))
332
-
333
- # Sort the translated_json by start time
334
- translated_json.sort(key=lambda x: x["start"])
335
-
336
- # Log the components being added to translated_json
337
- for entry in translated_json:
338
- logger.debug("Added to translated_json: start=%s, original=%s, translated=%s, end=%s, speaker=%s",
339
- entry["start"], entry["original"], entry["translated"], entry["end"], entry["speaker"])
340
-
341
- return translated_json
342
 
343
  def update_translations(file, edited_table, process_mode):
344
  """
@@ -518,220 +518,265 @@ def solve_optimal_alignment(original_segments, generated_durations, total_durati
518
 
519
  return original_segments
520
 
521
- ocr_model = None
522
- ocr_lock = threading.Lock()
523
-
524
- def init_ocr_model():
525
- global ocr_model
526
- with ocr_lock:
527
- if ocr_model is None:
528
- ocr_model = PaddleOCR(use_angle_cls=True, lang="ch")
529
-
530
- def find_best_subtitle_region(frame, ocr_model, region_height_ratio=0.35, num_strips=5, min_conf=0.5):
531
- """
532
- Automatically identifies the best subtitle region in a video frame using OCR confidence.
533
-
534
- Parameters:
535
- - frame: full video frame (BGR np.ndarray)
536
- - ocr_model: a loaded PaddleOCR model
537
- - region_height_ratio: portion of image height to scan (from bottom up)
538
- - num_strips: how many horizontal strips to evaluate
539
- - min_conf: minimum average confidence to consider a region valid
540
-
541
- Returns:
542
- - crop_region: the cropped image region with highest OCR confidence
543
- - region_box: (y_start, y_end) of the region in the original frame
544
- """
545
- height, width, _ = frame.shape
546
- region_height = int(height * region_height_ratio)
547
- base_y_start = height - region_height
548
- strip_height = region_height // num_strips
549
-
550
- best_score = -1
551
- best_crop = None
552
- best_bounds = (0, height)
553
-
554
- for i in range(num_strips):
555
- y_start = base_y_start + i * strip_height
556
- y_end = y_start + strip_height
557
- strip = frame[y_start:y_end, :]
558
-
559
- try:
560
- result = ocr_model.ocr(strip, cls=True)
561
- if not result or not result[0]:
562
- continue
563
 
564
- total_score = sum(line[1][1] for line in result[0])
565
- avg_score = total_score / len(result[0])
566
 
567
- if avg_score > best_score:
568
- best_score = avg_score
569
- best_crop = strip
570
- best_bounds = (y_start, y_end)
571
 
572
- except Exception as e:
573
- continue # Fail silently on OCR issues
574
 
575
- if best_score >= min_conf and best_crop is not None:
576
- return best_crop, best_bounds
577
- else:
578
- # Fallback to center-bottom strip
579
- fallback_y = height - int(height * 0.2)
580
- return frame[fallback_y:, :], (fallback_y, height)
581
 
582
- def ocr_frame_worker(args, min_confidence=0.7):
583
- frame_idx, frame_time, frame = args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
- init_ocr_model() # Load model in thread-safe way
 
 
 
586
 
587
- if frame is None or frame.size == 0 or not isinstance(frame, np.ndarray):
588
- return {"time": frame_time, "text": ""}
 
589
 
590
- if frame.dtype != np.uint8:
591
- frame = frame.astype(np.uint8)
 
592
 
593
- try:
594
- result = ocr_model.ocr(frame, cls=True)
595
- lines = result[0] if result else []
596
- texts = [line[1][0] for line in lines if line[1][1] >= min_confidence]
597
- combined_text = " ".join(texts).strip()
598
- return {"time": frame_time, "text": combined_text}
599
- except Exception as e:
600
- print(f"⚠️ OCR failed at {frame_time:.2f}s: {e}")
601
- return {"time": frame_time, "text": ""}
602
-
603
- def frame_is_in_audio_segments(frame_time, audio_segments, tolerance=0.2):
604
- for segment in audio_segments:
605
- start, end = segment["start"], segment["end"]
606
- if (start - tolerance) <= frame_time <= (end + tolerance):
607
- return True
608
- return False
609
-
610
- def extract_ocr_subtitles_parallel(video_path, transcription_json, interval_sec=0.5, num_workers=4):
611
- cap = cv2.VideoCapture(video_path)
612
- fps = cap.get(cv2.CAP_PROP_FPS)
613
- frames = []
614
- frame_idx = 0
615
- success, frame = cap.read()
616
-
617
- while success:
618
- if frame_idx % int(fps * interval_sec) == 0:
619
- frame_time = frame_idx / fps
620
- if frame_is_in_audio_segments(frame_time, transcription_json):
621
- frames.append((frame_idx, frame_time, frame.copy()))
622
- success, frame = cap.read()
623
- frame_idx += 1
624
- cap.release()
625
-
626
- ocr_results = []
627
- ocr_failures = 0 # Count OCR failures
628
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
629
- futures = [executor.submit(ocr_frame_worker, frame) for frame in frames]
630
-
631
- for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
632
- try:
633
- result = f.result()
634
- if result["text"]:
635
- ocr_results.append(result)
636
- except Exception as e:
637
- ocr_failures += 1
638
 
639
- logger.info(f"✅ OCR extraction completed: {len(ocr_results)} frames successful, {ocr_failures} frames failed.")
640
- return ocr_results
 
 
 
641
 
642
- def collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90):
643
- collapsed = []
644
- current = None
645
- for entry in ocr_json:
646
- time = entry["time"]
647
- text = entry["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
 
649
- if not current:
650
- current = {"start": time, "end": time, "text": text}
651
- continue
652
 
653
- sim = fuzz.ratio(current["text"], text)
654
- if sim >= text_similarity_threshold:
655
- current["end"] = time
656
- logger.debug(f"MERGED: Current end extended to {time:.2f}s for text: '{current['text'][:50]}...' (Similarity: {sim})")
657
- else:
658
- logger.debug(f"NOT MERGING (Similarity: {sim} < Threshold: {text_similarity_threshold}):")
659
- logger.debug(f" Previous segment: {current['start']:.2f}s - {current['end']:.2f}s: '{current['text'][:50]}...'")
660
- logger.debug(f" New segment: {time:.2f}s: '{text[:50]}...'")
661
- collapsed.append(current)
662
- current = {"start": time, "end": time, "text": text}
663
- if current:
664
- collapsed.append(current)
665
-
666
- logger.info(f"✅ OCR subtitles collapsed into {len(collapsed)} segments.")
667
- for idx, seg in enumerate(collapsed):
668
- logger.debug(f"[OCR Collapsed {idx}] {seg['start']:.2f}s - {seg['end']:.2f}s: {seg['text'][:50]}...")
669
- return collapsed
670
-
671
- def merge_speaker_and_time_from_whisperx(
672
- ocr_json,
673
- whisperx_json,
674
- replace_threshold=90,
675
- time_tolerance=1.0
676
- ):
677
- merged = []
678
- used_whisperx = set()
679
- whisperx_used_flags = [False] * len(whisperx_json)
680
-
681
- # Step 1: Attempt to match each OCR entry to a WhisperX entry
682
- for ocr in ocr_json:
683
- ocr_start, ocr_end = ocr["start"], ocr["end"]
684
- ocr_text = ocr["text"]
685
-
686
- best_match = None
687
- best_score = -1
688
- best_idx = None
689
-
690
- for idx, wx in enumerate(whisperx_json):
691
- wx_start, wx_end = wx["start"], wx["end"]
692
- wx_text = wx["text"]
693
-
694
- # Check for time overlap
695
- overlap = not (ocr_end < wx_start - time_tolerance or ocr_start > wx_end + time_tolerance)
696
- if not overlap:
697
- continue
698
-
699
- sim = fuzz.ratio(ocr_text, wx_text)
700
- if sim > best_score:
701
- best_score = sim
702
- best_match = wx
703
- best_idx = idx
704
-
705
- if best_match and best_score >= replace_threshold:
706
- # Replace WhisperX segment with higher quality OCR text
707
- new_segment = copy.deepcopy(best_match)
708
- new_segment["text"] = ocr_text
709
- new_segment["ocr_replaced"] = True
710
- new_segment["ocr_similarity"] = best_score
711
- whisperx_used_flags[best_idx] = True
712
- merged.append(new_segment)
713
- else:
714
- # No replacement, check if this OCR is outside WhisperX time coverage
715
- covered = any(
716
- abs((ocr_start + ocr_end)/2 - (wx["start"] + wx["end"])/2) < time_tolerance
717
- for wx in whisperx_json
718
- )
719
- if not covered:
720
- new_segment = copy.deepcopy(ocr)
721
- new_segment["ocr_added"] = True
722
- new_segment["speaker"] = "UNKNOWN"
723
- merged.append(new_segment)
724
 
725
- # Step 2: Add untouched WhisperX segments
726
- for idx, wx in enumerate(whisperx_json):
727
- if not whisperx_used_flags[idx]:
728
- merged.append(wx)
 
 
 
 
729
 
730
- # Step 3: Sort all merged segments
731
- merged = sorted(merged, key=lambda x: x["start"])
 
 
 
 
 
732
 
733
- return merged
 
 
 
 
 
 
 
 
734
 
 
 
 
 
 
 
 
 
735
  # def merge_speaker_and_time_from_whisperx(ocr_json, whisperx_json, text_sim_threshold=80, replace_threshold=90):
736
  # merged = []
737
  # used_whisperx = set()
@@ -1132,10 +1177,10 @@ def upload_and_manage(file, target_language, process_mode):
1132
  transcription_json, source_language = transcribe_video_with_speakers(file.name)
1133
  logger.info(f"Transcription completed. Detected source language: {source_language}")
1134
 
1135
- transcription_json_merged = post_edit_transcribed_segments(transcription_json, file.name)
1136
  # Step 2: Translate the transcription
1137
- logger.info(f"Translating transcription from {source_language} to {target_language}...")
1138
- translated_json_raw = translate_text(transcription_json_merged, source_language, target_language)
1139
  logger.info(f"Translation completed. Number of translated segments: {len(translated_json_raw)}")
1140
 
1141
  translated_json = apply_adaptive_speed(translated_json_raw, source_language, target_language)
 
275
  return transcript_with_speakers, detected_language
276
 
277
  # Function to get the appropriate translation model based on target language
278
+ # def get_translation_model(source_language, target_language):
279
+ # """
280
+ # Get the translation model based on the source and target language.
281
 
282
+ # Parameters:
283
+ # - target_language (str): The language to translate the content into (e.g., 'es', 'fr').
284
+ # - source_language (str): The language of the input content (default is 'en' for English).
285
 
286
+ # Returns:
287
+ # - str: The translation model identifier.
288
+ # """
289
+ # # List of allowable languages
290
+ # allowable_languages = ["en", "es", "fr", "zh", "de", "it", "pt", "ja", "ko", "ru", "hi", "tr"]
291
+
292
+ # # Validate source and target languages
293
+ # if source_language not in allowable_languages:
294
+ # logger.debug(f"Invalid source language '{source_language}'. Supported languages are: {', '.join(allowable_languages)}")
295
+ # # Return a default model if source language is invalid
296
+ # source_language = "en" # Default to 'en'
297
+
298
+ # if target_language not in allowable_languages:
299
+ # logger.debug(f"Invalid target language '{target_language}'. Supported languages are: {', '.join(allowable_languages)}")
300
+ # # Return a default model if target language is invalid
301
+ # target_language = "zh" # Default to 'zh'
302
+
303
+ # if source_language == target_language:
304
+ # source_language = "en" # Default to 'en'
305
+ # target_language = "zh" # Default to 'zh'
306
+
307
+ # # Return the model using string concatenation
308
+ # return f"Helsinki-NLP/opus-mt-{source_language}-{target_language}"
309
+
310
+ # def translate_single_entry(entry, translator):
311
+ # original_text = entry["text"]
312
+ # translated_text = translator(original_text)[0]['translation_text']
313
+ # return {
314
+ # "start": entry["start"],
315
+ # "original": original_text,
316
+ # "translated": translated_text,
317
+ # "end": entry["end"],
318
+ # "speaker": entry["speaker"]
319
+ # }
320
+
321
+ # def translate_text(transcription_json, source_language, target_language):
322
+ # # Load the translation model for the specified target language
323
+ # translation_model_id = get_translation_model(source_language, target_language)
324
+ # logger.debug(f"Translation model: {translation_model_id}")
325
+ # translator = pipeline("translation", model=translation_model_id)
326
+
327
+ # # Use ThreadPoolExecutor to parallelize translations
328
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
329
+ # # Submit all translation tasks and collect results
330
+ # translate_func = lambda entry: translate_single_entry(entry, translator)
331
+ # translated_json = list(executor.map(translate_func, transcription_json))
332
+
333
+ # # Sort the translated_json by start time
334
+ # translated_json.sort(key=lambda x: x["start"])
335
+
336
+ # # Log the components being added to translated_json
337
+ # for entry in translated_json:
338
+ # logger.debug("Added to translated_json: start=%s, original=%s, translated=%s, end=%s, speaker=%s",
339
+ # entry["start"], entry["original"], entry["translated"], entry["end"], entry["speaker"])
340
+
341
+ # return translated_json
342
 
343
  def update_translations(file, edited_table, process_mode):
344
  """
 
518
 
519
  return original_segments
520
 
521
+ # ocr_model = None
522
+ # ocr_lock = threading.Lock()
523
+
524
+ # def init_ocr_model():
525
+ # global ocr_model
526
+ # with ocr_lock:
527
+ # if ocr_model is None:
528
+ # ocr_model = PaddleOCR(use_angle_cls=True, lang="ch")
529
+
530
+ # def find_best_subtitle_region(frame, ocr_model, region_height_ratio=0.35, num_strips=5, min_conf=0.5):
531
+ # """
532
+ # Automatically identifies the best subtitle region in a video frame using OCR confidence.
533
+
534
+ # Parameters:
535
+ # - frame: full video frame (BGR np.ndarray)
536
+ # - ocr_model: a loaded PaddleOCR model
537
+ # - region_height_ratio: portion of image height to scan (from bottom up)
538
+ # - num_strips: how many horizontal strips to evaluate
539
+ # - min_conf: minimum average confidence to consider a region valid
540
+
541
+ # Returns:
542
+ # - crop_region: the cropped image region with highest OCR confidence
543
+ # - region_box: (y_start, y_end) of the region in the original frame
544
+ # """
545
+ # height, width, _ = frame.shape
546
+ # region_height = int(height * region_height_ratio)
547
+ # base_y_start = height - region_height
548
+ # strip_height = region_height // num_strips
549
+
550
+ # best_score = -1
551
+ # best_crop = None
552
+ # best_bounds = (0, height)
553
+
554
+ # for i in range(num_strips):
555
+ # y_start = base_y_start + i * strip_height
556
+ # y_end = y_start + strip_height
557
+ # strip = frame[y_start:y_end, :]
558
+
559
+ # try:
560
+ # result = ocr_model.ocr(strip, cls=True)
561
+ # if not result or not result[0]:
562
+ # continue
563
 
564
+ # total_score = sum(line[1][1] for line in result[0])
565
+ # avg_score = total_score / len(result[0])
566
 
567
+ # if avg_score > best_score:
568
+ # best_score = avg_score
569
+ # best_crop = strip
570
+ # best_bounds = (y_start, y_end)
571
 
572
+ # except Exception as e:
573
+ # continue # Fail silently on OCR issues
574
 
575
+ # if best_score >= min_conf and best_crop is not None:
576
+ # return best_crop, best_bounds
577
+ # else:
578
+ # # Fallback to center-bottom strip
579
+ # fallback_y = height - int(height * 0.2)
580
+ # return frame[fallback_y:, :], (fallback_y, height)
581
 
582
+ # def ocr_frame_worker(args, min_confidence=0.7):
583
+ # frame_idx, frame_time, frame = args
584
+
585
+ # init_ocr_model() # Load model in thread-safe way
586
+
587
+ # if frame is None or frame.size == 0 or not isinstance(frame, np.ndarray):
588
+ # return {"time": frame_time, "text": ""}
589
+
590
+ # if frame.dtype != np.uint8:
591
+ # frame = frame.astype(np.uint8)
592
+
593
+ # try:
594
+ # result = ocr_model.ocr(frame, cls=True)
595
+ # lines = result[0] if result else []
596
+ # texts = [line[1][0] for line in lines if line[1][1] >= min_confidence]
597
+ # combined_text = " ".join(texts).strip()
598
+ # return {"time": frame_time, "text": combined_text}
599
+ # except Exception as e:
600
+ # print(f"⚠️ OCR failed at {frame_time:.2f}s: {e}")
601
+ # return {"time": frame_time, "text": ""}
602
+
603
+ # def frame_is_in_audio_segments(frame_time, audio_segments, tolerance=0.2):
604
+ # for segment in audio_segments:
605
+ # start, end = segment["start"], segment["end"]
606
+ # if (start - tolerance) <= frame_time <= (end + tolerance):
607
+ # return True
608
+ # return False
609
+
610
+ # def extract_ocr_subtitles_parallel(video_path, transcription_json, interval_sec=0.5, num_workers=4):
611
+ # cap = cv2.VideoCapture(video_path)
612
+ # fps = cap.get(cv2.CAP_PROP_FPS)
613
+ # frames = []
614
+ # frame_idx = 0
615
+ # success, frame = cap.read()
616
+
617
+ # while success:
618
+ # if frame_idx % int(fps * interval_sec) == 0:
619
+ # frame_time = frame_idx / fps
620
+ # if frame_is_in_audio_segments(frame_time, transcription_json):
621
+ # frames.append((frame_idx, frame_time, frame.copy()))
622
+ # success, frame = cap.read()
623
+ # frame_idx += 1
624
+ # cap.release()
625
+
626
+ # ocr_results = []
627
+ # ocr_failures = 0 # Count OCR failures
628
+ # with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
629
+ # futures = [executor.submit(ocr_frame_worker, frame) for frame in frames]
630
+
631
+ # for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
632
+ # try:
633
+ # result = f.result()
634
+ # if result["text"]:
635
+ # ocr_results.append(result)
636
+ # except Exception as e:
637
+ # ocr_failures += 1
638
+
639
+ # logger.info(f"✅ OCR extraction completed: {len(ocr_results)} frames successful, {ocr_failures} frames failed.")
640
+ # return ocr_results
641
+
642
+ # def collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90):
643
+ # collapsed = []
644
+ # current = None
645
+ # for entry in ocr_json:
646
+ # time = entry["time"]
647
+ # text = entry["text"]
648
+
649
+ # if not current:
650
+ # current = {"start": time, "end": time, "text": text}
651
+ # continue
652
+
653
+ # sim = fuzz.ratio(current["text"], text)
654
+ # if sim >= text_similarity_threshold:
655
+ # current["end"] = time
656
+ # logger.debug(f"MERGED: Current end extended to {time:.2f}s for text: '{current['text'][:50]}...' (Similarity: {sim})")
657
+ # else:
658
+ # logger.debug(f"NOT MERGING (Similarity: {sim} < Threshold: {text_similarity_threshold}):")
659
+ # logger.debug(f" Previous segment: {current['start']:.2f}s - {current['end']:.2f}s: '{current['text'][:50]}...'")
660
+ # logger.debug(f" New segment: {time:.2f}s: '{text[:50]}...'")
661
+ # collapsed.append(current)
662
+ # current = {"start": time, "end": time, "text": text}
663
+ # if current:
664
+ # collapsed.append(current)
665
+
666
+ # logger.info(f"✅ OCR subtitles collapsed into {len(collapsed)} segments.")
667
+ # for idx, seg in enumerate(collapsed):
668
+ # logger.debug(f"[OCR Collapsed {idx}] {seg['start']:.2f}s - {seg['end']:.2f}s: {seg['text'][:50]}...")
669
+ # return collapsed
670
+
671
+ # def merge_speaker_and_time_from_whisperx(
672
+ # ocr_json,
673
+ # whisperx_json,
674
+ # replace_threshold=90,
675
+ # time_tolerance=1.0
676
+ # ):
677
+ # merged = []
678
+ # used_whisperx = set()
679
+ # whisperx_used_flags = [False] * len(whisperx_json)
680
 
681
+ # # Step 1: Attempt to match each OCR entry to a WhisperX entry
682
+ # for ocr in ocr_json:
683
+ # ocr_start, ocr_end = ocr["start"], ocr["end"]
684
+ # ocr_text = ocr["text"]
685
 
686
+ # best_match = None
687
+ # best_score = -1
688
+ # best_idx = None
689
 
690
+ # for idx, wx in enumerate(whisperx_json):
691
+ # wx_start, wx_end = wx["start"], wx["end"]
692
+ # wx_text = wx["text"]
693
 
694
+ # # Check for time overlap
695
+ # overlap = not (ocr_end < wx_start - time_tolerance or ocr_start > wx_end + time_tolerance)
696
+ # if not overlap:
697
+ # continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
 
699
+ # sim = fuzz.ratio(ocr_text, wx_text)
700
+ # if sim > best_score:
701
+ # best_score = sim
702
+ # best_match = wx
703
+ # best_idx = idx
704
 
705
+ # if best_match and best_score >= replace_threshold:
706
+ # # Replace WhisperX segment with higher quality OCR text
707
+ # new_segment = copy.deepcopy(best_match)
708
+ # new_segment["text"] = ocr_text
709
+ # new_segment["ocr_replaced"] = True
710
+ # new_segment["ocr_similarity"] = best_score
711
+ # whisperx_used_flags[best_idx] = True
712
+ # merged.append(new_segment)
713
+ # else:
714
+ # # No replacement, check if this OCR is outside WhisperX time coverage
715
+ # covered = any(
716
+ # abs((ocr_start + ocr_end)/2 - (wx["start"] + wx["end"])/2) < time_tolerance
717
+ # for wx in whisperx_json
718
+ # )
719
+ # if not covered:
720
+ # new_segment = copy.deepcopy(ocr)
721
+ # new_segment["ocr_added"] = True
722
+ # new_segment["speaker"] = "UNKNOWN"
723
+ # merged.append(new_segment)
724
+
725
+ # # Step 2: Add untouched WhisperX segments
726
+ # for idx, wx in enumerate(whisperx_json):
727
+ # if not whisperx_used_flags[idx]:
728
+ # merged.append(wx)
729
+
730
+ # # Step 3: Sort all merged segments
731
+ # merged = sorted(merged, key=lambda x: x["start"])
732
 
733
+ # return merged
 
 
734
 
735
+ def process_segment_with_gpt(segment, source_lang, target_lang, model="gpt-4"):
736
+ original_text = segment["text"]
737
+ prompt = (
738
+ f"You are a multilingual assistant. Given the following text in {source_lang}, "
739
+ f"1) restore punctuation, and 2) translate it into {target_lang}.\n\n"
740
+ f"Text:\n{original_text}\n\n"
741
+ f"Return in JSON format:\n"
742
+ f'{{"punctuated": "...", "translated": "..."}}'
743
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
 
745
+ try:
746
+ response = openai.ChatCompletion.create(
747
+ model=model,
748
+ messages=[{"role": "user", "content": prompt}],
749
+ temperature=0.3
750
+ )
751
+ content = response.choices[0].message.content.strip()
752
+ result_json = eval(content) if content.startswith("{") else {}
753
 
754
+ return {
755
+ "start": segment["start"],
756
+ "end": segment["end"],
757
+ "speaker": segment.get("speaker", "SPEAKER_00"),
758
+ "original": result_json.get("punctuated", original_text),
759
+ "translated": result_json.get("translated", "")
760
+ }
761
 
762
+ except Exception as e:
763
+ print(f"❌ Error for segment {segment['start']}-{segment['end']}: {e}")
764
+ return {
765
+ "start": segment["start"],
766
+ "end": segment["end"],
767
+ "speaker": segment.get("speaker", "SPEAKER_00"),
768
+ "original": original_text,
769
+ "translated": ""
770
+ }
771
 
772
+ def punctuate_and_translate_parallel(transcription_json, source_lang="zh", target_lang="en", model="gpt-4o", max_workers=5):
773
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
774
+ futures = [
775
+ executor.submit(process_segment_with_gpt, seg, source_lang, target_lang, model)
776
+ for seg in transcription_json
777
+ ]
778
+ return [f.result() for f in concurrent.futures.as_completed(futures)]
779
+
780
  # def merge_speaker_and_time_from_whisperx(ocr_json, whisperx_json, text_sim_threshold=80, replace_threshold=90):
781
  # merged = []
782
  # used_whisperx = set()
 
1177
  transcription_json, source_language = transcribe_video_with_speakers(file.name)
1178
  logger.info(f"Transcription completed. Detected source language: {source_language}")
1179
 
1180
+ translated_json_raw = punctuate_and_translate_parallel(transcription_json, source_language, target_language)
1181
  # Step 2: Translate the transcription
1182
+ # logger.info(f"Translating transcription from {source_language} to {target_language}...")
1183
+ # translated_json_raw = translate_text(transcription_json_merged, )
1184
  logger.info(f"Translation completed. Number of translated segments: {len(translated_json_raw)}")
1185
 
1186
  translated_json = apply_adaptive_speed(translated_json_raw, source_language, target_language)