File size: 4,079 Bytes
5806e12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Segmentation function using wtpsplit SaT model
from wtpsplit import SaT
from typing import List
import torch
# Global SaT model instance (lazy loading)
_sat_model = None
def get_sat_model(model_name: str = "sat-12l", device: str = "cuda") -> SaT:
"""
Get or create global SaT model instance
Args:
model_name: Model name from segment-any-text
device: Device to run model on
Returns:
SaT model instance
"""
global _sat_model
if _sat_model is None:
print(f"Loading SaT model: {model_name}")
# _sat_model = SaT(model_name)
# load model with finetuned lora adapter
_sat_model = SaT("sat-12l", lora_path = "models/SaT_cunit_with_maze/model_finetuned/sat-12l_lora_ENNI/enni-salt/en")
# Move to GPU if available and requested
if device == "cuda" and torch.cuda.is_available():
_sat_model.half().to("cuda")
print(f"SaT model loaded on GPU")
else:
print(f"SaT model loaded on CPU")
return _sat_model
# input is the list of words, no punctuation, all lower case,
# output is the list of label: 0 represent the correspounding word is not the last word of c-unit,
# 1 represent the correspounding word is the last word of c-unit
def segment_SaT(text: str) -> List[int]:
"""
Segment text using wtpsplit SaT model
Args:
text: Input text to segment
Returns:
List of labels: 0 = word is not the last word of c-unit,
1 = word is the last word of c-unit
"""
if not text.strip():
return []
# Clean text (consistent with segment_batchalign)
cleaned_text = text.lower().replace(".", "").replace(",", "")
words = cleaned_text.strip().split()
if not words:
return []
# Get SaT model
sat_model = get_sat_model()
# Use SaT to split the text into sentences
try:
sentences = sat_model.split(cleaned_text)
# Convert sentence boundaries to word-level labels
word_labels = [0] * len(words)
# Track position in original text
word_idx = 0
for sentence in sentences:
sentence_words = sentence.strip().split()
# Mark the last word of each sentence as segment boundary
if sentence_words:
# Find the last word of this sentence in the original word list
sentence_end_idx = word_idx + len(sentence_words) - 1
# Ensure we don't go out of bounds
if sentence_end_idx < len(words):
word_labels[sentence_end_idx] = 1
word_idx += len(sentence_words)
return word_labels
except Exception as e:
print(f"Error in SaT segmentation: {e}")
return [0] * len(words)
# read ASR transcription file, segment to c-unit, save to new json file
def reorganize_transcription_c_unit(session_id, base_dir="session_data"):
return
if __name__ == "__main__":
# Test the segmentation
test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing"
print(f"Input text: {test_text}")
print(f"Words: {test_text.split()}")
labels = segment_SaT(test_text)
print(f"Segment labels: {labels}")
# Show segmented text
words = test_text.split()
segments = []
current_segment = []
for word, label in zip(words, labels):
current_segment.append(word)
if label == 1:
segments.append(" ".join(current_segment))
current_segment = []
# Add remaining words if any
if current_segment:
segments.append(" ".join(current_segment))
print("\nSegmented text:")
for i, segment in enumerate(segments, 1):
print(f"Segment {i}: {segment}") |