# Segmentation function from Batchalign import json import os import re import torch from transformers import AutoTokenizer, AutoModelForTokenClassification from nltk.tokenize import sent_tokenize import nltk nltk.download('punkt_tab') nltk.download('punkt') # input is the list of words, no punctuation, all lower case, # output is the list of label: 0 represent the correspounding word is not the last word of c-unit, # 1 represent the correspounding word is the last word of c-unit def segment_batchalign(text: str) -> list[int]: DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load tokenizer and model locally model_path = "talkbank/CHATUtterance-en" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForTokenClassification.from_pretrained(model_path) model.to(DEVICE) model.eval() text = text.lower().replace(".", "").replace(",", "") words = text.split() # Tokenize tokd = tokenizer([words], return_tensors="pt", is_split_into_words=True).to(DEVICE) with torch.no_grad(): logits = model(**tokd).logits predictions = torch.argmax(logits, dim=2).squeeze(0).cpu().tolist() # Align predictions with words word_ids = tokd.word_ids(0) result_words = [] seen = set() for i, word_idx in enumerate(word_ids): if word_idx is None or word_idx in seen: continue seen.add(word_idx) pred = predictions[i] word = words[word_idx] if pred == 1: word = word[0].upper() + word[1:] elif pred == 2: word += "." elif pred == 3: word += "?" elif pred == 4: word += "!" elif pred == 5: word += "," result_words.append(word) # Convert tokens back to string and split into sentences sentence = tokenizer.convert_tokens_to_string(result_words) try: sentences = sent_tokenize(sentence) except LookupError: import nltk nltk.download('punkt') sentences = sent_tokenize(sentence) # Convert sentences to boundary labels boundaries = [] for sent in sentences: sent_word_count = len(sent.split()) boundaries += [0] * (sent_word_count - 1) + [1] for i in range(1, len(boundaries)): if boundaries[i - 1] == 1 and boundaries[i] == 1: boundaries[i - 1] = 0 return boundaries if __name__ == "__main__": # Test the segmentation # test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing" test_text = "sir can I have balloon and the sir say yes you can and he said five dollars that xxx and and he is like where is that they his tether is right there and and he said and the bunny said oopsies I do not have money and the doc and the and the and the bunny runned for the doctor an and he says doctor doctor I want a balloon here is the money and you can have the balloons both of them now they are happy the end" print(f"Input text: {test_text}") print(f"Words: {test_text.split()}") labels = segment_batchalign(test_text) print(f"Segment labels: {labels}") # Show segmented text words = test_text.split() segments = [] current_segment = [] for word, label in zip(words, labels): current_segment.append(word) if label == 1: segments.append(" ".join(current_segment)) current_segment = [] # Add remaining words if any if current_segment: segments.append(" ".join(current_segment)) print("\nSegmented text:") for i, segment in enumerate(segments, 1): print(f"Segment {i}: {segment}")