File size: 6,147 Bytes
5806e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Segmentation function using wtpsplit SaT 3l full fine-tuned model

from wtpsplit import SaT
from typing import List
import torch


# Global SaT model instance (lazy loading)
_sat_model = None


def get_sat_model(model_name: str = "sat-3l", device: str = "cuda") -> SaT:
    """
    Get or create global SaT model instance
    
    Args:
        model_name: Model name from segment-any-text
        device: Device to run model on
        
    Returns:
        SaT model instance
    """
    global _sat_model
    
    if _sat_model is None:
        print(f"Loading SaT 3l full fine-tuned model: {model_name}")
        # Load model with full fine-tuned weights
        # First load the base model, then load the fine-tuned weights
        _sat_model = SaT("sat-3l")
        # Load the fine-tuned weights
        import torch
        model_path = "models/SaT_cunit_with_maze/model_finetuned/sat-3l_full_ENNI/pytorch_model.bin"
        state_dict = torch.load(model_path, map_location="cpu")
        
        # Remove "backbone." prefix from keys to match the expected model structure
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("backbone."):
                new_key = key[9:]  # Remove "backbone." prefix
                new_state_dict[new_key] = value
            else:
                new_state_dict[key] = value
        
        # Adjust model sizes to match the fine-tuned model
        # Check word embeddings size
        if "roberta.embeddings.word_embeddings.weight" in new_state_dict:
            fine_tuned_vocab_size = new_state_dict["roberta.embeddings.word_embeddings.weight"].shape[0]
            current_vocab_size = _sat_model.model.roberta.embeddings.word_embeddings.weight.shape[0]
            if fine_tuned_vocab_size != current_vocab_size:
                print(f"Resizing word embeddings from {current_vocab_size} to {fine_tuned_vocab_size}")
                _sat_model.model.resize_token_embeddings(fine_tuned_vocab_size)
        
        # Check classifier size
        if "classifier.weight" in new_state_dict:
            fine_tuned_num_labels = new_state_dict["classifier.weight"].shape[0]
            current_num_labels = _sat_model.model.classifier.weight.shape[0]
            if fine_tuned_num_labels != current_num_labels:
                print(f"Resizing classifier from {current_num_labels} to {fine_tuned_num_labels}")
                # Resize classifier
                import torch.nn as nn
                _sat_model.model.classifier = nn.Linear(
                    _sat_model.model.classifier.in_features, 
                    fine_tuned_num_labels
                )
                _sat_model.model.num_labels = fine_tuned_num_labels
        
        _sat_model.model.load_state_dict(new_state_dict, strict=False)
        
        # Move to GPU if available and requested
        if device == "cuda" and torch.cuda.is_available():
            _sat_model.half().to("cuda")
            print(f"SaT 3l full model loaded on GPU")
        else:
            print(f"SaT 3l full model loaded on CPU")
    
    return _sat_model


# input is the list of words, no punctuation, all lower case, 
# output is the list of label: 0 represent the correspounding word is not the last word of c-unit,
# 1 represent the correspounding word is the last word of c-unit
def segment_SaT(text: str) -> List[int]:
    """
    Segment text using wtpsplit SaT 3l full fine-tuned model
    
    Args:
        text: Input text to segment
        
    Returns:
        List of labels: 0 = word is not the last word of c-unit,
                        1 = word is the last word of c-unit
    """
    if not text.strip():
        return []
    
    # Clean text (consistent with segment_batchalign)
    cleaned_text = text.lower().replace(".", "").replace(",", "")
    words = cleaned_text.strip().split()
    if not words:
        return []
    
    # Get SaT model
    sat_model = get_sat_model()
    
    # Use SaT to split the text into sentences
    try:
        sentences = sat_model.split(cleaned_text)
        
        # Convert sentence boundaries to word-level labels
        word_labels = [0] * len(words)
        
        # Track position in original text
        word_idx = 0
        
        for sentence in sentences:
            sentence_words = sentence.strip().split()
            
            # Mark the last word of each sentence as segment boundary
            if sentence_words:
                # Find the last word of this sentence in the original word list
                sentence_end_idx = word_idx + len(sentence_words) - 1
                
                # Ensure we don't go out of bounds
                if sentence_end_idx < len(words):
                    word_labels[sentence_end_idx] = 1
                
                word_idx += len(sentence_words)
        
        return word_labels
        
    except Exception as e:
        print(f"Error in SaT 3l full segmentation: {e}")
        return [0] * len(words)



# read ASR transcription file, segment to c-unit, save to new json file
def reorganize_transcription_c_unit(session_id, base_dir="session_data"):
    return



if __name__ == "__main__":
    # Test the segmentation
    test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing"
    
    print(f"Input text: {test_text}")
    print(f"Words: {test_text.split()}")
    
    labels = segment_SaT(test_text)
    print(f"Segment labels: {labels}")
    
    # Show segmented text
    words = test_text.split()
    segments = []
    current_segment = []
    
    for word, label in zip(words, labels):
        current_segment.append(word)
        if label == 1:
            segments.append(" ".join(current_segment))
            current_segment = []
    
    # Add remaining words if any
    if current_segment:
        segments.append(" ".join(current_segment))
    
    print("\nSegmented text:")
    for i, segment in enumerate(segments, 1):
        print(f"Segment {i}: {segment}")