import random import torch from transformers import BertTokenizer, BertForMaskedLM from nltk.corpus import stopwords import nltk # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # Ensure stopwords are downloaded try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') class MaskingProcessor: def __init__(self): self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") self.stop_words = set(stopwords.words('english')) def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords=False): """ Mask one word before the first common n-gram, one between two n-grams, and one after the last common n-gram (random selection). Args: original_sentence (str): Original sentence common_ngrams (dict): Common n-grams and their indices Returns: str: Masked sentence """ if remove_stopwords: words = original_sentence.split() words = [word for word in words if word not in self.stop_words] else: words = original_sentence.split() mask_indices = [] # Handle before the first common n-gram if common_ngrams: first_ngram_start = list(common_ngrams.values())[0][0][0] if first_ngram_start > 0: mask_indices.append(random.randint(0, first_ngram_start - 1)) # Handle between common n-grams ngram_positions = list(common_ngrams.values()) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][-1][1] start_next = ngram_positions[i + 1][0][0] if start_next > end_prev + 1: mask_indices.append(random.randint(end_prev + 1, start_next - 1)) # Handle after the last common n-gram last_ngram_end = ngram_positions[-1][-1][1] if last_ngram_end < len(words) - 1: mask_indices.append(random.randint(last_ngram_end + 1, len(words) - 1)) # Mask the chosen indices for idx in mask_indices: if idx not in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: words[idx] = self.tokenizer.mask_token return " ".join(words) def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords=False): """ Mask one word before the first common n-gram, one between two n-grams, and one after the last common n-gram (highest entropy selection). Args: original_sentence (str): Original sentence common_ngrams (dict): Common n-grams and their indices Returns: str: Masked sentence """ if remove_stopwords: words = original_sentence.split() words = [word for word in words if word not in self.stop_words] else: words = original_sentence.split() entropy_scores = {} for idx, word in enumerate(words): if idx in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: continue # Skip words in common n-grams masked_sentence = words[:idx] + [self.tokenizer.mask_token] + words[idx + 1:] masked_sentence = " ".join(masked_sentence) input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits filtered_logits = logits[0, mask_token_index, :] probs = torch.softmax(filtered_logits, dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0) entropy_scores[idx] = entropy mask_indices = [] # Handle before the first common n-gram if common_ngrams: first_ngram_start = list(common_ngrams.values())[0][0][0] candidates = [i for i in range(first_ngram_start) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Handle between common n-grams ngram_positions = list(common_ngrams.values()) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][-1][1] start_next = ngram_positions[i + 1][0][0] candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Handle after the last common n-gram last_ngram_end = ngram_positions[-1][-1][1] candidates = [i for i in range(last_ngram_end + 1, len(words)) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Mask the chosen indices for idx in mask_indices: words[idx] = self.tokenizer.mask_token return " ".join(words) def calculate_mask_logits(self, masked_sentence): """ Calculate logits for masked tokens in the sentence using BERT. Args: masked_sentence (str): Sentence with [MASK] tokens Returns: dict: Masked token indices and their logits """ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index} return mask_logits def process_sentences(self, original_sentences, result_dict, remove_stopwords=False, method="random"): """ Process a list of sentences and calculate logits for masked tokens using the specified method. Args: original_sentences (list): List of original sentences result_dict (dict): Common n-grams and their indices for each sentence method (str): Masking method ("random" or "entropy") Returns: dict: Masked sentences and their logits for each sentence """ results = {} for sentence, ngrams in result_dict.items(): if method == "random": masked_sentence = self.mask_sentence_random(sentence, ngrams) elif method == "entropy": masked_sentence = self.mask_sentence_entropy(sentence, ngrams) else: raise ValueError("Invalid method. Choose 'random' or 'entropy'.") logits = self.calculate_mask_logits(masked_sentence) results[sentence] = { "masked_sentence": masked_sentence, "mask_logits": logits } return results # Example usage if __name__ == "__main__": # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS sentences = [ "The quick brown fox jumps over the lazy dog.", "A quick brown dog outpaces a lazy fox.", "Quick brown animals leap over lazy obstacles." ] result_dict = { "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]}, "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]}, "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]} } # result_dict = { # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]} # } processor = MaskingProcessor() results_random = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="random") results_entropy = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="entropy") for sentence, output in results_random.items(): print(f"Original Sentence (Random): {sentence}") print(f"Masked Sentence (Random): {output['masked_sentence']}") # print(f"Mask Logits (Random): {output['mask_logits']}") for sentence, output in results_entropy.items(): print(f"Original Sentence (Entropy): {sentence}") print(f"Masked Sentence (Entropy): {output['masked_sentence']}") # print(f"Mask Logits (Entropy): {output['mask_logits']}") ''' result_dict = { "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]} } '''