|
|
|
|
|
from wtpsplit import SaT |
|
from typing import List |
|
import torch |
|
|
|
|
|
|
|
_sat_model = None |
|
|
|
|
|
def get_sat_model(model_name: str = "sat-3l", device: str = "cuda") -> SaT: |
|
""" |
|
Get or create global SaT model instance |
|
|
|
Args: |
|
model_name: Model name from segment-any-text |
|
device: Device to run model on |
|
|
|
Returns: |
|
SaT model instance |
|
""" |
|
global _sat_model |
|
|
|
if _sat_model is None: |
|
print(f"Loading SaT model: {model_name}") |
|
|
|
|
|
_sat_model = SaT("sat-3l", lora_path = "models/SaT_cunit_with_maze/model_finetuned/sat-3l_r64a128_lora_ENNI/enni-salt/en") |
|
|
|
|
|
if device == "cuda" and torch.cuda.is_available(): |
|
_sat_model.half().to("cuda") |
|
print(f"SaT model loaded on GPU") |
|
else: |
|
print(f"SaT model loaded on CPU") |
|
|
|
return _sat_model |
|
|
|
|
|
|
|
|
|
|
|
def segment_SaT(text: str) -> List[int]: |
|
""" |
|
Segment text using wtpsplit SaT model |
|
|
|
Args: |
|
text: Input text to segment |
|
|
|
Returns: |
|
List of labels: 0 = word is not the last word of c-unit, |
|
1 = word is the last word of c-unit |
|
""" |
|
if not text.strip(): |
|
return [] |
|
|
|
|
|
cleaned_text = text.lower().replace(".", "").replace(",", "") |
|
words = cleaned_text.strip().split() |
|
if not words: |
|
return [] |
|
|
|
|
|
sat_model = get_sat_model() |
|
|
|
|
|
try: |
|
sentences = sat_model.split(cleaned_text) |
|
|
|
|
|
word_labels = [0] * len(words) |
|
|
|
|
|
word_idx = 0 |
|
|
|
for sentence in sentences: |
|
sentence_words = sentence.strip().split() |
|
|
|
|
|
if sentence_words: |
|
|
|
sentence_end_idx = word_idx + len(sentence_words) - 1 |
|
|
|
|
|
if sentence_end_idx < len(words): |
|
word_labels[sentence_end_idx] = 1 |
|
|
|
word_idx += len(sentence_words) |
|
|
|
return word_labels |
|
|
|
except Exception as e: |
|
print(f"Error in SaT segmentation: {e}") |
|
return [0] * len(words) |
|
|
|
|
|
|
|
|
|
def reorganize_transcription_c_unit(session_id, base_dir="session_data"): |
|
return |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing" |
|
|
|
print(f"Input text: {test_text}") |
|
print(f"Words: {test_text.split()}") |
|
|
|
labels = segment_SaT(test_text) |
|
print(f"Segment labels: {labels}") |
|
|
|
|
|
words = test_text.split() |
|
segments = [] |
|
current_segment = [] |
|
|
|
for word, label in zip(words, labels): |
|
current_segment.append(word) |
|
if label == 1: |
|
segments.append(" ".join(current_segment)) |
|
current_segment = [] |
|
|
|
|
|
if current_segment: |
|
segments.append(" ".join(current_segment)) |
|
|
|
print("\nSegmented text:") |
|
for i, segment in enumerate(segments, 1): |
|
print(f"Segment {i}: {segment}") |