|
|
|
|
|
from wtpsplit import SaT |
|
from typing import List |
|
import torch |
|
|
|
|
|
|
|
_sat_model = None |
|
|
|
|
|
def get_sat_model(model_name: str = "sat-3l", device: str = "cuda") -> SaT: |
|
""" |
|
Get or create global SaT model instance |
|
|
|
Args: |
|
model_name: Model name from segment-any-text |
|
device: Device to run model on |
|
|
|
Returns: |
|
SaT model instance |
|
""" |
|
global _sat_model |
|
|
|
if _sat_model is None: |
|
print(f"Loading SaT 3l full fine-tuned model: {model_name}") |
|
|
|
|
|
_sat_model = SaT("sat-3l") |
|
|
|
import torch |
|
model_path = "models/SaT_cunit_with_maze/model_finetuned/sat-3l_full_ENNI/pytorch_model.bin" |
|
state_dict = torch.load(model_path, map_location="cpu") |
|
|
|
|
|
new_state_dict = {} |
|
for key, value in state_dict.items(): |
|
if key.startswith("backbone."): |
|
new_key = key[9:] |
|
new_state_dict[new_key] = value |
|
else: |
|
new_state_dict[key] = value |
|
|
|
|
|
|
|
if "roberta.embeddings.word_embeddings.weight" in new_state_dict: |
|
fine_tuned_vocab_size = new_state_dict["roberta.embeddings.word_embeddings.weight"].shape[0] |
|
current_vocab_size = _sat_model.model.roberta.embeddings.word_embeddings.weight.shape[0] |
|
if fine_tuned_vocab_size != current_vocab_size: |
|
print(f"Resizing word embeddings from {current_vocab_size} to {fine_tuned_vocab_size}") |
|
_sat_model.model.resize_token_embeddings(fine_tuned_vocab_size) |
|
|
|
|
|
if "classifier.weight" in new_state_dict: |
|
fine_tuned_num_labels = new_state_dict["classifier.weight"].shape[0] |
|
current_num_labels = _sat_model.model.classifier.weight.shape[0] |
|
if fine_tuned_num_labels != current_num_labels: |
|
print(f"Resizing classifier from {current_num_labels} to {fine_tuned_num_labels}") |
|
|
|
import torch.nn as nn |
|
_sat_model.model.classifier = nn.Linear( |
|
_sat_model.model.classifier.in_features, |
|
fine_tuned_num_labels |
|
) |
|
_sat_model.model.num_labels = fine_tuned_num_labels |
|
|
|
_sat_model.model.load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
if device == "cuda" and torch.cuda.is_available(): |
|
_sat_model.half().to("cuda") |
|
print(f"SaT 3l full model loaded on GPU") |
|
else: |
|
print(f"SaT 3l full model loaded on CPU") |
|
|
|
return _sat_model |
|
|
|
|
|
|
|
|
|
|
|
def segment_SaT(text: str) -> List[int]: |
|
""" |
|
Segment text using wtpsplit SaT 3l full fine-tuned model |
|
|
|
Args: |
|
text: Input text to segment |
|
|
|
Returns: |
|
List of labels: 0 = word is not the last word of c-unit, |
|
1 = word is the last word of c-unit |
|
""" |
|
if not text.strip(): |
|
return [] |
|
|
|
|
|
cleaned_text = text.lower().replace(".", "").replace(",", "") |
|
words = cleaned_text.strip().split() |
|
if not words: |
|
return [] |
|
|
|
|
|
sat_model = get_sat_model() |
|
|
|
|
|
try: |
|
sentences = sat_model.split(cleaned_text) |
|
|
|
|
|
word_labels = [0] * len(words) |
|
|
|
|
|
word_idx = 0 |
|
|
|
for sentence in sentences: |
|
sentence_words = sentence.strip().split() |
|
|
|
|
|
if sentence_words: |
|
|
|
sentence_end_idx = word_idx + len(sentence_words) - 1 |
|
|
|
|
|
if sentence_end_idx < len(words): |
|
word_labels[sentence_end_idx] = 1 |
|
|
|
word_idx += len(sentence_words) |
|
|
|
return word_labels |
|
|
|
except Exception as e: |
|
print(f"Error in SaT 3l full segmentation: {e}") |
|
return [0] * len(words) |
|
|
|
|
|
|
|
|
|
def reorganize_transcription_c_unit(session_id, base_dir="session_data"): |
|
return |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing" |
|
|
|
print(f"Input text: {test_text}") |
|
print(f"Words: {test_text.split()}") |
|
|
|
labels = segment_SaT(test_text) |
|
print(f"Segment labels: {labels}") |
|
|
|
|
|
words = test_text.split() |
|
segments = [] |
|
current_segment = [] |
|
|
|
for word, label in zip(words, labels): |
|
current_segment.append(word) |
|
if label == 1: |
|
segments.append(" ".join(current_segment)) |
|
current_segment = [] |
|
|
|
|
|
if current_segment: |
|
segments.append(" ".join(current_segment)) |
|
|
|
print("\nSegmented text:") |
|
for i, segment in enumerate(segments, 1): |
|
print(f"Segment {i}: {segment}") |