import os import json import numpy as np import random import torch import re import torch from pynvml import * import time class NpEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() else: return super(NpEncoder, self).default(obj) def load_results(file_name): with open(os.path.join('results', file_name)) as file: results = json.load(file) return results def save_json(results, file_path="debug.json"): json_dict = json.dumps(results, cls=NpEncoder) dict_from_str = json.loads(json_dict) with open(file_path, 'w', encoding='utf-8') as f: json.dump(dict_from_str, f) def load_json(file_path): with open(file_path) as file: results = json.load(file) return results def save_results(results, dir, file_name="debug"): json_dict = json.dumps(results, cls=NpEncoder) dict_from_str = json.loads(json_dict) if not os.path.exists(f'results/{dir}'): os.makedirs(f'results/{dir}', exist_ok=True) with open(os.path.join(f'results/{dir}', f'{file_name}.json'), 'w', encoding='utf-8') as f: json.dump(dict_from_str, f) def read_results(dir, file_name="debug"): file_path = os.path.join(f'results/{dir}', f'{file_name}.json') if not os.path.exists(file_path): raise FileNotFoundError(f"No such file: '{file_path}'") with open(file_path, 'r', encoding='utf-8') as f: results = json.load(f) return results def _save_results(args,attr_results, pred_results_path): if args.dataset_name in ['musique', 'narrativeqa', 'qmsum']: name = f"{args.prompt_injection_attack}" elif args.dataset_name in ['nq-poison','hotpotqa-poison','msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip','nq-poison-safety']: name = "PoisonedRag" elif args.dataset_name in ['srt','mrt']: name = "needle_in_haystack" else: raise ValueError("Unsupported dataset_name.") if args.attr_type in ["vanilla_perturb","tracllm"]: save_results(attr_results, pred_results_path, name+f"_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{'_'.join(args.score_funcs)}_{args.avg_k}_{args.K}") elif args.attr_type == "attntrace": save_results(attr_results, pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.avg_k}_{args.q}_{args.B}_{args.K}') elif args.attr_type == "self_citation" or args.attr_type == "context_cite" or "attention" in args.attr_type: save_results(attr_results, pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.K}') else: raise ValueError("Unsupported attr_type.") def _read_results(args, pred_results_path): if args.dataset_name in ['musique', 'narrativeqa', 'qmsum']: name = f"{args.prompt_injection_attack}" elif args.dataset_name in ['nq-poison','hotpotqa-poison','msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip', 'nq-poison-safety']: name = "PoisonedRag" elif args.dataset_name in ['srt','mrt']: name = "needle_in_haystack" else: raise ValueError("Unsupported dataset_name.") if args.attr_type in ["vanilla_perturb","tracllm"]: return read_results( pred_results_path, name+f"_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{'_'.join(args.score_funcs)}_{args.avg_k}_{args.K}") elif args.attr_type == "attntrace": return read_results( pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.avg_k}_{args.q}_{args.B}_{args.K}') elif args.attr_type == "self_citation" or "attention" in args.attr_type: return read_results( pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.K}') else: raise ValueError("Unsupported attr_type.") def setup_seeds(seed): # seed = config.run_cfg.seed + get_rank() random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def clean_str(s): try: s=str(s) except: print('Error: the output cannot be converted to a string') s=s.strip() if len(s)>1 and s[-1] == ".": s=s[:-1] return s.lower() def newline_pad_contexts(contexts): return [contexts[0]] + ['\n\n'+context for context in contexts[1:]] def f1_score(precision, recall): """ Calculate the F1 score given precision and recall arrays. Args: precision (np.array): A 2D array of precision values. recall (np.array): A 2D array of recall values. Returns: np.array: A 2D array of F1 scores. """ f1_scores = np.divide(2 * precision * recall, precision + recall, where=(precision + recall) != 0) return f1_scores def remove_citations(sent): return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "") def find_indices(list1: list, list2: list): # 存储结果的列表 indices = [] # 遍历list1中的每个元素 for element in list1: # 尝试找到element在list2中的索引 try: index = list2.index(element) # 如果找到,将索引添加到结果列表中 indices.append(index) except ValueError: # 如果元素不在list2中,跳过 continue return indices def contexts_to_paragraphs(contexts): paragraphs = contexts[0].split('\n\n') paragraphs = [paragraph if i == 0 else '\n\n' + paragraph for i, paragraph in enumerate(paragraphs)] return paragraphs def contexts_to_segments(contexts): segment_size = 100 context = contexts[0] words = context.split(' ') # Create a list to hold segments segments = [] # Iterate over the words and group them into segments for i in range(0, len(words), segment_size): # Join a segment of 100 words and add to segments list segment = ' '.join(words[i:i + segment_size])+' ' segments.append(segment) return segments def paragraphs_to_sentences(paragraphs): all_sentences = [] # Split the merged string into sentences #sentences = sent_tokenize(merged_string) for i,paragraph in enumerate(paragraphs): sentences = split_into_sentences(paragraph) all_sentences.extend(sentences) return all_sentences def contexts_to_sentences(contexts): paragraphs = contexts_to_paragraphs(contexts) all_sentences = paragraphs_to_sentences(paragraphs) return all_sentences import re alphabets= "([A-Za-z])" prefixes = "(Mr|St|Mrs|Ms|Dr)[.]" suffixes = "(Inc|Ltd|Jr|Sr|Co)" starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" websites = "[.](com|net|org|io|gov|edu|me)" digits = "([0-9])" multiple_dots = r'\.{2,}' def split_into_phrases(text: str) -> list[str]: sentences = split_into_sentences(text) phrases = [] for sent in sentences: phrases+=sent.split(',') return phrases def split_into_sentences(text: str) -> list[str]: """ Split the text into sentences. If the text contains substrings "" or "", they would lead to incorrect splitting because they are used as markers for splitting. :param text: text to be split into sentences :type text: str :return: list of sentences :rtype: list[str] """ text = text.replace("。", ".") text = " " + text + " " text = text.replace("\n","") text = re.sub(prefixes,"\\1",text) text = re.sub(websites,"\\1",text) text = re.sub(digits + "[.]" + digits,"\\1\\2",text) text = re.sub(multiple_dots, lambda match: "" * len(match.group(0)) + "", text) if "Ph.D" in text: text = text.replace("Ph.D.","PhD") text = re.sub("\s" + alphabets + "[.] "," \\1 ",text) text = re.sub(acronyms+" "+starters,"\\1 \\2",text) text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1\\2\\3",text) text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1\\2",text) text = re.sub(" "+suffixes+"[.] "+starters," \\1 \\2",text) text = re.sub(" "+suffixes+"[.]"," \\1",text) text = re.sub(" " + alphabets + "[.]"," \\1",text) if "”" in text: text = text.replace(".”","”.") if "\"" in text: text = text.replace(".\"","\".") if "!" in text: text = text.replace("!\"","\"!") if "?" in text: text = text.replace("?\"","\"?") text = text.replace(".",".") text = text.replace("?","?") text = text.replace("!","!") text = text.replace("",".") sentences = text.split("") sentences = [s.strip() for s in sentences] if sentences and not sentences[-1]: sentences = sentences[:-1] sentences = [s.replace("", "\n") for s in sentences] return sentences def get_previous_answer(answer, explained_answer): previous_answer = answer.split(explained_answer)[0] return previous_answer def plot_sentence_importance(question, sentences_list, important_ids, importance_values, answer, explained_answer = "", width = 200): from rich.console import Console from rich.text import Text assert len(important_ids) == len(importance_values), "Mismatch between number of words and importance values." all_importance_values =np.zeros(len(sentences_list)) all_importance_values[important_ids] = importance_values #print("sentences list: ", sentences_list) console = Console(width =width) text = Text() #print("MIN:",np.min(all_importance_values)) #print(all_importance_values) #all_importance_values = (all_importance_values - np.min(all_importance_values)) / (np.max(all_importance_values) - np.min(all_importance_values)+0.0001) all_importance_values = (all_importance_values ) / (np.max(all_importance_values) +0.0001) text.append("Context:\n", style=f"black bold") for i,(sentence, imp) in enumerate(zip(sentences_list, all_importance_values)): #sentence = sentence.capitalize() red_intensity = 255 blue_intensity=0 #print(imp) if imp < 0 or imp ==0: green_intensity=255 blue_intensity=255 else: green_intensity = int(255* (1 - imp)) bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}" text.append(sentence, style=f"on #{bg_color} black") text.append("\nQuery: \n", style=f"black bold") red_intensity = 255 green_intensity=255 blue_intensity=255 bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}" text.append(question, style=f"on #{bg_color} black") text.append("\nLLM_response:\n", style=f"black bold") answer = answer.capitalize() red_intensity = 255 green_intensity=255 blue_intensity=255 bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}" text.append(answer, style=f"on #{bg_color} black") if explained_answer!="": text.append("\nExplained part:", style=f"black bold") red_intensity = 255 green_intensity=255 blue_intensity=255 bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}" text.append(explained_answer, style=f"on #{bg_color} black") console.print(text) def unzip_tuples(tuple_list): list1 = [t[0] for t in tuple_list] list2 = [t[1] for t in tuple_list] return list1, list2 def manual_zip(list1, list2): # Ensure both lists have the same length if len(list1) != len(list2): raise ValueError("Both lists must have the same length") combined_list = [] for i in range(len(list1)): combined_list.append((list1[i], list2[i])) return combined_list def check_cannot_answer(answer): prefixes = ["I don't know"] do_not_know = any([prefix in answer for prefix in prefixes]) print("DO NOT KNOW: ", do_not_know) return do_not_know def top_k_indexes(lst, k): # Check if k is greater than the length of the list if k > len(lst): k = len(lst) # Get the indexes of the list sorted by their values in descending order sorted_indexes = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True) # Return the first k indexes from the sorted list return sorted_indexes[:k] def get_top_k(important_ids, importance_scores, k): top_k=top_k_indexes(importance_scores, k) topk_ids = [important_ids[j] for j in top_k] topk_scores = [importance_scores[j] for j in top_k] return topk_ids,topk_scores def add_specific_indexes(lst, indexes_to_add): indexes_to_add = sorted(indexes_to_add) return [item for idx, item in enumerate(lst) if idx in indexes_to_add] def remove_specific_indexes(lst, indexes_to_remove): return [item for idx, item in enumerate(lst) if idx not in indexes_to_remove] def clean_str(s): try: s=str(s) except: print('Error: the output cannot be converted to a string') s=s.strip() if len(s)>1 and s[-1] == ".": s=s[:-1] return s.lower() def split_context(level, contexts): assert isinstance(contexts, list) if len(contexts)>1: #the context is already segmented return contexts else: if level =="sentence": all_texts = contexts_to_sentences(contexts) elif level =="segment": all_texts = contexts_to_segments(contexts) elif level =="paragraph": all_texts = contexts_to_paragraphs(contexts) else: raise ValueError("Invalid explanation level.") return all_texts def check_overlap(str1, str2, n): len1 = len(str1) len2 = len(str2) if str1 in str2 or str2 in str1: return True # Check overlap by comparing suffix of str1 with prefix of str2 for i in range(1, min(len1, len2) + 1): if i > n and str1[-i:] == str2[:i]: return True # Check overlap by comparing prefix of str1 with suffix of str2 for i in range(1, min(len1, len2) + 1): if i > n and str1[:i] == str2[-i:]: return True return False def get_gt_ids(all_texts, injected_adv): gt_ids =[] gt_texts = [] for j, segment in enumerate(all_texts): for malicious_text in injected_adv: if check_overlap(segment,malicious_text,10): gt_ids.append(j) gt_texts.append(all_texts[j]) return gt_ids,gt_texts def min_subset_to_contain(gt_text, texts): candidates =[] for i in range(len(texts)): for j in range(i+1,len(texts)): #print("candidate:",''.join(texts[i:j])) if gt_text in ''.join(texts[i:j]).replace(' ',' '): candidates.append(texts[i:j]) #print(candidates) if len(candidates) >0: return min(candidates, key=len) else: return [] def mean_of_percent(values,percent = 1): # Step 1: Sort the list in descending order sorted_values = sorted(values, reverse=True) # Step 2: Determine the number of elements in the top 20% top_percent_count = max(1, int(len(sorted_values) * percent)) print("top_percent_count: ", top_percent_count) # Step 3: Extract the top 20% values top_values = sorted_values[:top_percent_count] # Step 4: Calculate and return the mean of the top 20% values if len(top_values) ==0: return 0 mean_top = sum(top_values) / len(top_values) return mean_top def is_value_in_dicts(dictionary, value_to_check): for value in dictionary.values(): if isinstance(value, (np.ndarray, list)): # If value is an array or list, check if any/all elements match if np.array_equal(value, value_to_check): # For numpy arrays return True else: if value == value_to_check: return True return False def wait_for_available_gpu_memory(required_memory_gb, device=0, check_interval=5): """ Waits until the required amount of GPU memory is available. Args: required_memory_gb (float): Required GPU memory in gigabytes. device (int): GPU device index (default is 0) check_interval (int): Time interval in seconds between memory checks. Returns: None """ required_memory_bytes = required_memory_gb * 1e9 # Convert GB to bytes while True: try: nvmlInit() handle = nvmlDeviceGetHandleByIndex(device) info = nvmlDeviceGetMemoryInfo(handle) available_memory = info.free if available_memory >= required_memory_bytes: print(f"Sufficient GPU memory available: {available_memory / 1e9:.2f} GB") nvmlShutdown() return else: print(f"Waiting for GPU memory. Available: {available_memory / 1e9:.2f} GB, Required: {required_memory_gb:.2f} GB") nvmlShutdown() except NVMLError as error: print(f"Error getting GPU memory: {error}") # Fallback to PyTorch method if torch.cuda.is_available(): device = torch.cuda.current_device() total_memory = torch.cuda.get_device_properties(device).total_memory allocated_memory = torch.cuda.memory_allocated(device) available_memory = total_memory - allocated_memory if available_memory >= required_memory_bytes: print(f"Sufficient GPU memory available (PyTorch): {available_memory / 1e9:.2f} GB") return 1 else: print(f"Waiting for GPU memory (PyTorch). Available: {available_memory / 1e9:.2f} GB, Required: {required_memory_gb:.2f} GB") else: print("CUDA is not available") time.sleep(check_interval)