import json import os import re from typing import List, Dict, Any, Callable, Tuple def map_special_tokens_to_word_positions(text: str, word_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: special_token_map: List[Dict[str, Any]] = [] for m in re.finditer(r'<[^>]*?>', text): special_token_map.append({ "token": m.group(), "char_start": m.start(), # index in original text }) if not special_token_map: return [] visible_offset_map = {} visible_idx = 0 i = 0 while i < len(text): if text[i] == '<': j = text.find('>', i) + 1 i = j continue visible_offset_map[i] = visible_idx visible_idx += 1 i += 1 clean_text = re.sub(r'<[^>]*?>', '', text) # locate each word in clean_text word_positions = [] cur = 0 for w in word_list: pos = clean_text.find(w["word"], cur) if pos != -1: word_positions.append({ "word": w["word"], "start": pos, "end": pos + len(w["word"]) }) cur = pos + len(w["word"]) # map each token for sp in special_token_map: # how many visible chars are before this token? raw_idx = sp["char_start"] visible_before = 0 # find largest key <= raw_idx in visible_offset_map keys = [k for k in visible_offset_map.keys() if k < raw_idx] if keys: visible_before = visible_offset_map[max(keys)] + 1 # +1 because map stores idx of char at k insert_after = -1 for i, wp in enumerate(word_positions): if visible_before >= wp["end"]: insert_after = i else: break sp["insert_after_word_idx"] = insert_after return special_token_map def reorganize_transcription_c_unit( session_id: str, segment_func: Callable[[str], List[int]], base_dir: str = "session_data", device: str = "cuda" ) -> Tuple[int, int]: """Segment utterances into C-units with rules: 1. Boundaries inside or are ignored. 2. Trailing moves to next C-unit prefix. Returns (total_cunit_count, ignored_boundary_count). """ session_dir = os.path.join(base_dir, session_id) input_file = os.path.join(session_dir, "transcription.json") output_file = os.path.join(session_dir, "transcription_cunit.json") if not os.path.exists(input_file): raise FileNotFoundError(input_file) with open(input_file, "r", encoding="utf-8") as f: data = json.load(f) # Handle both old and new format if "segments" in data: transcription_data = data["segments"] else: transcription_data = data cunit_data: List[Dict[str, Any]] = [] ignored_boundary_count = 0 for utt in transcription_data: original_text = utt["text"] words_meta = utt.get("words", []) clean_text = re.sub(r'<[^>]*?>', '', original_text).strip() if not clean_text: continue # build word list if words_meta: word_data = [w for w in words_meta if w["word"] not in {"?", ",", ".", "!"}] word_texts = [w["word"] for w in word_data] else: word_texts = re.sub(r'[\?\.,!]', '', clean_text).split() word_data = [{"word": w, "start": utt["start"], "end": utt["end"]} for w in word_texts] if not word_texts: continue # token positions & special ranges special_token_map = map_special_tokens_to_word_positions(original_text, word_data) rep_ranges, rev_ranges = _build_special_ranges(special_token_map) def inside_special(idx: int) -> bool: return any(s <= idx <= e for s, e in rep_ranges) or any(s <= idx <= e for s, e in rev_ranges) # segmentation labels labels = segment_func(' '.join(word_texts)) if len(labels) != len(word_texts): raise ValueError( f"Segmentation length mismatch: {len(word_texts)} words vs {len(labels)} labels" ) current_words: List[str] = [] current_meta: List[Dict[str, Any]] = [] cunit_start_idx = 0 # global word idx of first word in current c‑unit cunit_start_time = word_data[0]["start"] carry_over_tokens: List[str] = [] for i, (word, label) in enumerate(zip(word_texts, labels)): current_words.append(word) current_meta.append(word_data[i]) is_last_word = i == len(word_texts) - 1 boundary_from_model = label == 1 and not inside_special(i) if label == 1 and inside_special(i): ignored_boundary_count += 1 make_boundary = boundary_from_model or is_last_word if not make_boundary: continue # -------- assemble C‑unit -------- text_parts: List[str] = [] # 2a. prefix: carried‑over if carry_over_tokens: text_parts.extend(carry_over_tokens) carry_over_tokens = [] for j, w in enumerate(current_words): global_word_idx = cunit_start_idx + j # sentence‑initial tokens & ‑1 insertion if global_word_idx == 0: text_parts.extend( [sp["token"] for sp in special_token_map if sp["insert_after_word_idx"] == -1] ) text_parts.append(w) # tokens that follow this word text_parts.extend( [sp["token"] for sp in special_token_map if sp["insert_after_word_idx"] == global_word_idx] ) # 2b. move trailing to next c‑unit while text_parts and text_parts[-1].upper() == '': carry_over_tokens.insert(0, text_parts.pop()) # 2c. move trailing or to next c‑unit while text_parts and text_parts[-1].upper() in {'', ''}: carry_over_tokens.insert(0, text_parts.pop()) # Create text_token (with special tokens) and text (only words) text_token = ' '.join(text_parts) text_words_only = ' '.join(current_words) cunit_data.append({ "start": cunit_start_time, "end": current_meta[-1]["end"], "speaker": "", # Initialize as empty "text_token": text_token, "text": text_words_only, "words": [ { "word": word["word"], "start": word["start"], "end": word["end"] } for word in current_meta ] }) # reset for next C‑unit cunit_start_idx = i + 1 current_words, current_meta = [], [] if cunit_start_idx < len(word_data): cunit_start_time = word_data[cunit_start_idx]["start"] # Wrap in segments structure to match original format output_data = { "segments": cunit_data } with open(output_file, "w", encoding="utf-8") as f: json.dump(output_data, f, indent=2, ensure_ascii=False) print(f"C-unit segmentation done → {output_file}") return len(cunit_data), ignored_boundary_count def _build_special_ranges(special_token_map: List[Dict[str, Any]]): rep_ranges, rev_ranges = [], [] rep_start, rev_start = None, None for sp in special_token_map: tok = sp["token"].upper() idx = sp["insert_after_word_idx"] if tok == '': rep_start = idx + 1 elif tok == '' and rep_start is not None: rep_ranges.append((rep_start, idx)) rep_start = None elif tok == '': rev_start = idx + 1 elif tok == '' and rev_start is not None: rev_ranges.append((rev_start, idx)) rev_start = None return rep_ranges, rev_ranges