qqwjq1981 commited on
Commit
f57819a
·
verified ·
1 Parent(s): 9f5dde4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -13
app.py CHANGED
@@ -731,9 +731,32 @@ def solve_optimal_alignment(original_segments, generated_durations, total_durati
731
  # merged = sorted(merged, key=lambda x: x["start"])
732
 
733
  # return merged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
 
735
- def process_segment_with_gpt(segment, source_lang, target_lang, model="gpt-4"):
736
  original_text = segment["text"]
 
 
 
 
 
 
 
737
  prompt = (
738
  f"You are a multilingual assistant. Given the following text in {source_lang}, "
739
  f"1) restore punctuation, and 2) translate it into {target_lang}.\n\n"
@@ -743,39 +766,111 @@ def process_segment_with_gpt(segment, source_lang, target_lang, model="gpt-4"):
743
  )
744
 
745
  try:
746
- response = client.chat.completions.create(
 
747
  model=model,
748
  messages=[{"role": "user", "content": prompt}],
749
  temperature=0.3
750
  )
751
  content = response.choices[0].message.content.strip()
752
- result_json = eval(content) if content.startswith("{") else {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
 
 
 
 
 
 
754
  return {
755
  "start": segment["start"],
756
  "end": segment["end"],
757
  "speaker": segment.get("speaker", "SPEAKER_00"),
758
- "original": result_json.get("punctuated", original_text),
759
- "translated": result_json.get("translated", "")
760
  }
761
-
762
  except Exception as e:
763
- print(f"❌ Error for segment {segment['start']}-{segment['end']}: {e}")
 
 
 
 
 
764
  return {
765
  "start": segment["start"],
766
  "end": segment["end"],
767
  "speaker": segment.get("speaker", "SPEAKER_00"),
768
  "original": original_text,
769
- "translated": ""
770
  }
771
 
772
- def punctuate_and_translate_parallel(transcription_json, source_lang="zh", target_lang="en", model="gpt-4o", max_workers=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773
  with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
774
- futures = [
775
- executor.submit(process_segment_with_gpt, seg, source_lang, target_lang, model)
 
776
  for seg in transcription_json
777
- ]
778
- return [f.result() for f in concurrent.futures.as_completed(futures)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
 
780
  # def merge_speaker_and_time_from_whisperx(ocr_json, whisperx_json, text_sim_threshold=80, replace_threshold=90):
781
  # merged = []
 
731
  # merged = sorted(merged, key=lambda x: x["start"])
732
 
733
  # return merged
734
+ # --- Function Definitions ---
735
+
736
+ def process_segment_with_gpt(segment, source_lang, target_lang, model="gpt-4", openai_client=None):
737
+ """
738
+ Processes a single text segment: restores punctuation and translates using an OpenAI GPT model.
739
+ """
740
+ # Essential check: Ensure the OpenAI client is provided
741
+ if openai_client is None:
742
+ segment_identifier = f"{segment.get('start', 'N/A')}-{segment.get('end', 'N/A')}"
743
+ logger.error(f"❌ OpenAI client was not provided for segment {segment_identifier}. Cannot process.")
744
+ return {
745
+ "start": segment.get("start"),
746
+ "end": segment.get("end"),
747
+ "speaker": segment.get("speaker", "SPEAKER_00"),
748
+ "original": segment["text"],
749
+ "translated": "[ERROR: OpenAI client not provided]"
750
+ }
751
 
 
752
  original_text = segment["text"]
753
+ segment_id = f"{segment['start']}-{segment['end']}" # Create a unique ID for this segment for easier log tracking
754
+
755
+ logger.debug(
756
+ f"Starting processing for segment {segment_id}. "
757
+ f"Original text preview: '{original_text[:100]}{'...' if len(original_text) > 100 else ''}'"
758
+ )
759
+
760
  prompt = (
761
  f"You are a multilingual assistant. Given the following text in {source_lang}, "
762
  f"1) restore punctuation, and 2) translate it into {target_lang}.\n\n"
 
766
  )
767
 
768
  try:
769
+ logger.debug(f"Sending request to OpenAI model '{model}' for segment {segment_id}...")
770
+ response = openai_client.chat.completions.create( # Using the passed 'openai_client'
771
  model=model,
772
  messages=[{"role": "user", "content": prompt}],
773
  temperature=0.3
774
  )
775
  content = response.choices[0].message.content.strip()
776
+ logger.debug(
777
+ f"Received raw response from model for segment {segment_id}: "
778
+ f"'{content[:200]}{'...' if len(content) > 200 else ''}'" # Truncate for log readability
779
+ )
780
+
781
+ result_json = {}
782
+ try:
783
+ # Use json.loads for safer and standard JSON parsing compared to eval()
784
+ result_json = json.loads(content)
785
+ except json.JSONDecodeError as e:
786
+ logger.warning(
787
+ f"⚠️ Failed to parse JSON response for segment {segment_id}. Error: {e}. "
788
+ f"Raw content received: '{content}'"
789
+ )
790
+ # Fallback behavior if JSON parsing fails: use original text, empty translation
791
+ punctuated_text = original_text
792
+ translated_text = ""
793
+ else:
794
+ # If JSON parsing was successful
795
+ punctuated_text = result_json.get("punctuated", original_text)
796
+ translated_text = result_json.get("translated", "")
797
 
798
+ logger.info(
799
+ f"✅ Successfully processed segment {segment_id}. "
800
+ f"Punctuated preview: '{punctuated_text[:50]}{'...' if len(punctuated_text) > 50 else ''}', "
801
+ f"Translated preview: '{translated_text[:50]}{'...' if len(translated_text) > 50 else ''}'"
802
+ )
803
  return {
804
  "start": segment["start"],
805
  "end": segment["end"],
806
  "speaker": segment.get("speaker", "SPEAKER_00"),
807
+ "original": punctuated_text,
808
+ "translated": translated_text
809
  }
 
810
  except Exception as e:
811
+ # Log the full traceback using exc_info=True for better debugging
812
+ logger.error(
813
+ f"❌ An unexpected error occurred while processing segment {segment_id}: {e}",
814
+ exc_info=True
815
+ )
816
+ # Return the original segment with an empty translated text on error
817
  return {
818
  "start": segment["start"],
819
  "end": segment["end"],
820
  "speaker": segment.get("speaker", "SPEAKER_00"),
821
  "original": original_text,
822
+ "translated": "[ERROR: Processing failed]"
823
  }
824
 
825
+ def punctuate_and_translate_parallel(transcription_json, source_lang="zh", target_lang="en", model="gpt-4o", max_workers=5, openai_client=None):
826
+ """
827
+ Orchestrates parallel punctuation restoration and translation of multiple segments
828
+ using a ThreadPoolExecutor.
829
+ """
830
+ if not transcription_json:
831
+ logger.warning("No segments provided in transcription_json for parallel processing. Returning an empty list.")
832
+ return []
833
+
834
+ logger.info(f"Starting parallel punctuation and translation for {len(transcription_json)} segments.")
835
+ logger.info(
836
+ f"Configuration: Model='{model}', Source Language='{source_lang}', "
837
+ f"Target Language='{target_lang}', Max Workers={max_workers}."
838
+ )
839
+
840
+ results = []
841
  with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
842
+ # Submit each segment for processing, ensuring the openai_client is passed to each worker
843
+ futures = {
844
+ executor.submit(process_segment_with_gpt, seg, source_lang, target_lang, model, openai_client): seg
845
  for seg in transcription_json
846
+ }
847
+ logger.info(f"All {len(futures)} segments have been submitted to the thread pool.")
848
+
849
+ # Asynchronously collect results as they complete
850
+ for i, future in enumerate(concurrent.futures.as_completed(futures)):
851
+ segment = futures[future] # Retrieve the original segment data for logging context
852
+ segment_id = f"{segment['start']}-{segment['end']}"
853
+ try:
854
+ result = future.result() # This will re-raise any exception from the worker thread
855
+ results.append(result)
856
+ logger.debug(f"Collected result for segment {segment_id}. ({i + 1}/{len(futures)} completed)")
857
+ except Exception as exc:
858
+ # This catch block is for rare cases where the future itself fails to yield a result,
859
+ # or an exception was not caught within `process_segment_with_gpt`.
860
+ logger.error(
861
+ f"Unhandled exception encountered while retrieving result for segment {segment_id}: {exc}",
862
+ exc_info=True
863
+ )
864
+ # Ensure a placeholder result is added even if future retrieval fails
865
+ results.append({
866
+ "start": segment.get("start"),
867
+ "end": segment.get("end"),
868
+ "speaker": segment.get("speaker", "SPEAKER_00"),
869
+ "original": segment["text"],
870
+ "translated": "[ERROR: Unhandled exception in parallel processing]"
871
+ })
872
+ logger.info("🎉 Parallel processing complete. All results collected.")
873
+ return results
874
 
875
  # def merge_speaker_and_time_from_whisperx(ocr_json, whisperx_json, text_sim_threshold=80, replace_threshold=90):
876
  # merged = []