# Segmentation function using wtpsplit SaT 3l full fine-tuned model from wtpsplit import SaT from typing import List import torch # Global SaT model instance (lazy loading) _sat_model = None def get_sat_model(model_name: str = "sat-3l", device: str = "cuda") -> SaT: """ Get or create global SaT model instance Args: model_name: Model name from segment-any-text device: Device to run model on Returns: SaT model instance """ global _sat_model if _sat_model is None: print(f"Loading SaT 3l full fine-tuned model: {model_name}") # Load model with full fine-tuned weights # First load the base model, then load the fine-tuned weights _sat_model = SaT("sat-3l") # Load the fine-tuned weights import torch model_path = "models/SaT_cunit_with_maze/model_finetuned/sat-3l_full_ENNI/pytorch_model.bin" state_dict = torch.load(model_path, map_location="cpu") # Remove "backbone." prefix from keys to match the expected model structure new_state_dict = {} for key, value in state_dict.items(): if key.startswith("backbone."): new_key = key[9:] # Remove "backbone." prefix new_state_dict[new_key] = value else: new_state_dict[key] = value # Adjust model sizes to match the fine-tuned model # Check word embeddings size if "roberta.embeddings.word_embeddings.weight" in new_state_dict: fine_tuned_vocab_size = new_state_dict["roberta.embeddings.word_embeddings.weight"].shape[0] current_vocab_size = _sat_model.model.roberta.embeddings.word_embeddings.weight.shape[0] if fine_tuned_vocab_size != current_vocab_size: print(f"Resizing word embeddings from {current_vocab_size} to {fine_tuned_vocab_size}") _sat_model.model.resize_token_embeddings(fine_tuned_vocab_size) # Check classifier size if "classifier.weight" in new_state_dict: fine_tuned_num_labels = new_state_dict["classifier.weight"].shape[0] current_num_labels = _sat_model.model.classifier.weight.shape[0] if fine_tuned_num_labels != current_num_labels: print(f"Resizing classifier from {current_num_labels} to {fine_tuned_num_labels}") # Resize classifier import torch.nn as nn _sat_model.model.classifier = nn.Linear( _sat_model.model.classifier.in_features, fine_tuned_num_labels ) _sat_model.model.num_labels = fine_tuned_num_labels _sat_model.model.load_state_dict(new_state_dict, strict=False) # Move to GPU if available and requested if device == "cuda" and torch.cuda.is_available(): _sat_model.half().to("cuda") print(f"SaT 3l full model loaded on GPU") else: print(f"SaT 3l full model loaded on CPU") return _sat_model # input is the list of words, no punctuation, all lower case, # output is the list of label: 0 represent the correspounding word is not the last word of c-unit, # 1 represent the correspounding word is the last word of c-unit def segment_SaT(text: str) -> List[int]: """ Segment text using wtpsplit SaT 3l full fine-tuned model Args: text: Input text to segment Returns: List of labels: 0 = word is not the last word of c-unit, 1 = word is the last word of c-unit """ if not text.strip(): return [] # Clean text (consistent with segment_batchalign) cleaned_text = text.lower().replace(".", "").replace(",", "") words = cleaned_text.strip().split() if not words: return [] # Get SaT model sat_model = get_sat_model() # Use SaT to split the text into sentences try: sentences = sat_model.split(cleaned_text) # Convert sentence boundaries to word-level labels word_labels = [0] * len(words) # Track position in original text word_idx = 0 for sentence in sentences: sentence_words = sentence.strip().split() # Mark the last word of each sentence as segment boundary if sentence_words: # Find the last word of this sentence in the original word list sentence_end_idx = word_idx + len(sentence_words) - 1 # Ensure we don't go out of bounds if sentence_end_idx < len(words): word_labels[sentence_end_idx] = 1 word_idx += len(sentence_words) return word_labels except Exception as e: print(f"Error in SaT 3l full segmentation: {e}") return [0] * len(words) # read ASR transcription file, segment to c-unit, save to new json file def reorganize_transcription_c_unit(session_id, base_dir="session_data"): return if __name__ == "__main__": # Test the segmentation test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing" print(f"Input text: {test_text}") print(f"Words: {test_text.split()}") labels = segment_SaT(test_text) print(f"Segment labels: {labels}") # Show segmented text words = test_text.split() segments = [] current_segment = [] for word, label in zip(words, labels): current_segment.append(word) if label == 1: segments.append(" ".join(current_segment)) current_segment = [] # Add remaining words if any if current_segment: segments.append(" ".join(current_segment)) print("\nSegmented text:") for i, segment in enumerate(segments, 1): print(f"Segment {i}: {segment}")