Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import numpy as np | |
import random | |
import torch | |
import re | |
import torch | |
from pynvml import * | |
import time | |
class NpEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, np.integer): | |
return int(obj) | |
elif isinstance(obj, np.floating): | |
return float(obj) | |
elif isinstance(obj, np.ndarray): | |
return obj.tolist() | |
else: | |
return super(NpEncoder, self).default(obj) | |
def load_results(file_name): | |
with open(os.path.join('results', file_name)) as file: | |
results = json.load(file) | |
return results | |
def save_json(results, file_path="debug.json"): | |
json_dict = json.dumps(results, cls=NpEncoder) | |
dict_from_str = json.loads(json_dict) | |
with open(file_path, 'w', encoding='utf-8') as f: | |
json.dump(dict_from_str, f) | |
def load_json(file_path): | |
with open(file_path) as file: | |
results = json.load(file) | |
return results | |
def save_results(results, dir, file_name="debug"): | |
json_dict = json.dumps(results, cls=NpEncoder) | |
dict_from_str = json.loads(json_dict) | |
if not os.path.exists(f'results/{dir}'): | |
os.makedirs(f'results/{dir}', exist_ok=True) | |
with open(os.path.join(f'results/{dir}', f'{file_name}.json'), 'w', encoding='utf-8') as f: | |
json.dump(dict_from_str, f) | |
def read_results(dir, file_name="debug"): | |
file_path = os.path.join(f'results/{dir}', f'{file_name}.json') | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"No such file: '{file_path}'") | |
with open(file_path, 'r', encoding='utf-8') as f: | |
results = json.load(f) | |
return results | |
def _save_results(args,attr_results, pred_results_path): | |
if args.dataset_name in ['musique', 'narrativeqa', 'qmsum']: | |
name = f"{args.prompt_injection_attack}" | |
elif args.dataset_name in ['nq-poison','hotpotqa-poison','msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip','nq-poison-safety']: | |
name = "PoisonedRag" | |
elif args.dataset_name in ['srt','mrt']: | |
name = "needle_in_haystack" | |
else: | |
raise ValueError("Unsupported dataset_name.") | |
if args.attr_type in ["vanilla_perturb","tracllm"]: | |
save_results(attr_results, pred_results_path, name+f"_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{'_'.join(args.score_funcs)}_{args.avg_k}_{args.K}") | |
elif args.attr_type == "attntrace": | |
save_results(attr_results, pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.avg_k}_{args.q}_{args.B}_{args.K}') | |
elif args.attr_type == "self_citation" or args.attr_type == "context_cite" or "attention" in args.attr_type: | |
save_results(attr_results, pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.K}') | |
else: | |
raise ValueError("Unsupported attr_type.") | |
def _read_results(args, pred_results_path): | |
if args.dataset_name in ['musique', 'narrativeqa', 'qmsum']: | |
name = f"{args.prompt_injection_attack}" | |
elif args.dataset_name in ['nq-poison','hotpotqa-poison','msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip', 'nq-poison-safety']: | |
name = "PoisonedRag" | |
elif args.dataset_name in ['srt','mrt']: | |
name = "needle_in_haystack" | |
else: | |
raise ValueError("Unsupported dataset_name.") | |
if args.attr_type in ["vanilla_perturb","tracllm"]: | |
return read_results( pred_results_path, name+f"_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{'_'.join(args.score_funcs)}_{args.avg_k}_{args.K}") | |
elif args.attr_type == "attntrace": | |
return read_results( pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.avg_k}_{args.q}_{args.B}_{args.K}') | |
elif args.attr_type == "self_citation" or "attention" in args.attr_type: | |
return read_results( pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.K}') | |
else: | |
raise ValueError("Unsupported attr_type.") | |
def setup_seeds(seed): | |
# seed = config.run_cfg.seed + get_rank() | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
def clean_str(s): | |
try: | |
s=str(s) | |
except: | |
print('Error: the output cannot be converted to a string') | |
s=s.strip() | |
if len(s)>1 and s[-1] == ".": | |
s=s[:-1] | |
return s.lower() | |
def newline_pad_contexts(contexts): | |
return [contexts[0]] + ['\n\n'+context for context in contexts[1:]] | |
def f1_score(precision, recall): | |
""" | |
Calculate the F1 score given precision and recall arrays. | |
Args: | |
precision (np.array): A 2D array of precision values. | |
recall (np.array): A 2D array of recall values. | |
Returns: | |
np.array: A 2D array of F1 scores. | |
""" | |
f1_scores = np.divide(2 * precision * recall, precision + recall, where=(precision + recall) != 0) | |
return f1_scores | |
def remove_citations(sent): | |
return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "") | |
def find_indices(list1: list, list2: list): | |
# 存储结果的列表 | |
indices = [] | |
# 遍历list1中的每个元素 | |
for element in list1: | |
# 尝试找到element在list2中的索引 | |
try: | |
index = list2.index(element) | |
# 如果找到,将索引添加到结果列表中 | |
indices.append(index) | |
except ValueError: | |
# 如果元素不在list2中,跳过 | |
continue | |
return indices | |
def contexts_to_paragraphs(contexts): | |
paragraphs = contexts[0].split('\n\n') | |
paragraphs = [paragraph if i == 0 else '\n\n' + paragraph for i, paragraph in enumerate(paragraphs)] | |
return paragraphs | |
def contexts_to_segments(contexts): | |
segment_size = 100 | |
context = contexts[0] | |
words = context.split(' ') | |
# Create a list to hold segments | |
segments = [] | |
# Iterate over the words and group them into segments | |
for i in range(0, len(words), segment_size): | |
# Join a segment of 100 words and add to segments list | |
segment = ' '.join(words[i:i + segment_size])+' ' | |
segments.append(segment) | |
return segments | |
def paragraphs_to_sentences(paragraphs): | |
all_sentences = [] | |
# Split the merged string into sentences | |
#sentences = sent_tokenize(merged_string) | |
for i,paragraph in enumerate(paragraphs): | |
sentences = split_into_sentences(paragraph) | |
all_sentences.extend(sentences) | |
return all_sentences | |
def contexts_to_sentences(contexts): | |
paragraphs = contexts_to_paragraphs(contexts) | |
all_sentences = paragraphs_to_sentences(paragraphs) | |
return all_sentences | |
import re | |
alphabets= "([A-Za-z])" | |
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]" | |
suffixes = "(Inc|Ltd|Jr|Sr|Co)" | |
starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" | |
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" | |
websites = "[.](com|net|org|io|gov|edu|me)" | |
digits = "([0-9])" | |
multiple_dots = r'\.{2,}' | |
def split_into_phrases(text: str) -> list[str]: | |
sentences = split_into_sentences(text) | |
phrases = [] | |
for sent in sentences: | |
phrases+=sent.split(',') | |
return phrases | |
def split_into_sentences(text: str) -> list[str]: | |
""" | |
Split the text into sentences. | |
If the text contains substrings "<prd>" or "<stop>", they would lead | |
to incorrect splitting because they are used as markers for splitting. | |
:param text: text to be split into sentences | |
:type text: str | |
:return: list of sentences | |
:rtype: list[str] | |
""" | |
text = text.replace("。", ".") | |
text = " " + text + " " | |
text = text.replace("\n","<newline>") | |
text = re.sub(prefixes,"\\1<prd>",text) | |
text = re.sub(websites,"<prd>\\1",text) | |
text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text) | |
text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text) | |
if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>") | |
text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text) | |
text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text) | |
text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text) | |
text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text) | |
text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text) | |
text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text) | |
text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text) | |
if "”" in text: text = text.replace(".”","”.") | |
if "\"" in text: text = text.replace(".\"","\".") | |
if "!" in text: text = text.replace("!\"","\"!") | |
if "?" in text: text = text.replace("?\"","\"?") | |
text = text.replace(".",".<stop>") | |
text = text.replace("?","?<stop>") | |
text = text.replace("!","!<stop>") | |
text = text.replace("<prd>",".") | |
sentences = text.split("<stop>") | |
sentences = [s.strip() for s in sentences] | |
if sentences and not sentences[-1]: sentences = sentences[:-1] | |
sentences = [s.replace("<newline>", "\n") for s in sentences] | |
return sentences | |
def get_previous_answer(answer, explained_answer): | |
previous_answer = answer.split(explained_answer)[0] | |
return previous_answer | |
def plot_sentence_importance(question, sentences_list, important_ids, importance_values, answer, explained_answer = "", width = 200): | |
from rich.console import Console | |
from rich.text import Text | |
assert len(important_ids) == len(importance_values), "Mismatch between number of words and importance values." | |
all_importance_values =np.zeros(len(sentences_list)) | |
all_importance_values[important_ids] = importance_values | |
#print("sentences list: ", sentences_list) | |
console = Console(width =width) | |
text = Text() | |
#print("MIN:",np.min(all_importance_values)) | |
#print(all_importance_values) | |
#all_importance_values = (all_importance_values - np.min(all_importance_values)) / (np.max(all_importance_values) - np.min(all_importance_values)+0.0001) | |
all_importance_values = (all_importance_values ) / (np.max(all_importance_values) +0.0001) | |
text.append("Context:\n", style=f"black bold") | |
for i,(sentence, imp) in enumerate(zip(sentences_list, all_importance_values)): | |
#sentence = sentence.capitalize() | |
red_intensity = 255 | |
blue_intensity=0 | |
#print(imp) | |
if imp < 0 or imp ==0: | |
green_intensity=255 | |
blue_intensity=255 | |
else: | |
green_intensity = int(255* (1 - imp)) | |
bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}" | |
text.append(sentence, style=f"on #{bg_color} black") | |
text.append("\nQuery: \n", style=f"black bold") | |
red_intensity = 255 | |
green_intensity=255 | |
blue_intensity=255 | |
bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}" | |
text.append(question, style=f"on #{bg_color} black") | |
text.append("\nLLM_response:\n", style=f"black bold") | |
answer = answer.capitalize() | |
red_intensity = 255 | |
green_intensity=255 | |
blue_intensity=255 | |
bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}" | |
text.append(answer, style=f"on #{bg_color} black") | |
if explained_answer!="": | |
text.append("\nExplained part:", style=f"black bold") | |
red_intensity = 255 | |
green_intensity=255 | |
blue_intensity=255 | |
bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}" | |
text.append(explained_answer, style=f"on #{bg_color} black") | |
console.print(text) | |
def unzip_tuples(tuple_list): | |
list1 = [t[0] for t in tuple_list] | |
list2 = [t[1] for t in tuple_list] | |
return list1, list2 | |
def manual_zip(list1, list2): | |
# Ensure both lists have the same length | |
if len(list1) != len(list2): | |
raise ValueError("Both lists must have the same length") | |
combined_list = [] | |
for i in range(len(list1)): | |
combined_list.append((list1[i], list2[i])) | |
return combined_list | |
def check_cannot_answer(answer): | |
prefixes = ["I don't know"] | |
do_not_know = any([prefix in answer for prefix in prefixes]) | |
print("DO NOT KNOW: ", do_not_know) | |
return do_not_know | |
def top_k_indexes(lst, k): | |
# Check if k is greater than the length of the list | |
if k > len(lst): | |
k = len(lst) | |
# Get the indexes of the list sorted by their values in descending order | |
sorted_indexes = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True) | |
# Return the first k indexes from the sorted list | |
return sorted_indexes[:k] | |
def get_top_k(important_ids, importance_scores, k): | |
top_k=top_k_indexes(importance_scores, k) | |
topk_ids = [important_ids[j] for j in top_k] | |
topk_scores = [importance_scores[j] for j in top_k] | |
return topk_ids,topk_scores | |
def add_specific_indexes(lst, indexes_to_add): | |
indexes_to_add = sorted(indexes_to_add) | |
return [item for idx, item in enumerate(lst) if idx in indexes_to_add] | |
def remove_specific_indexes(lst, indexes_to_remove): | |
return [item for idx, item in enumerate(lst) if idx not in indexes_to_remove] | |
def clean_str(s): | |
try: | |
s=str(s) | |
except: | |
print('Error: the output cannot be converted to a string') | |
s=s.strip() | |
if len(s)>1 and s[-1] == ".": | |
s=s[:-1] | |
return s.lower() | |
def split_context(level, contexts): | |
assert isinstance(contexts, list) | |
if len(contexts)>1: #the context is already segmented | |
return contexts | |
else: | |
if level =="sentence": | |
all_texts = contexts_to_sentences(contexts) | |
elif level =="segment": | |
all_texts = contexts_to_segments(contexts) | |
elif level =="paragraph": | |
all_texts = contexts_to_paragraphs(contexts) | |
else: | |
raise ValueError("Invalid explanation level.") | |
return all_texts | |
def check_overlap(str1, str2, n): | |
len1 = len(str1) | |
len2 = len(str2) | |
if str1 in str2 or str2 in str1: | |
return True | |
# Check overlap by comparing suffix of str1 with prefix of str2 | |
for i in range(1, min(len1, len2) + 1): | |
if i > n and str1[-i:] == str2[:i]: | |
return True | |
# Check overlap by comparing prefix of str1 with suffix of str2 | |
for i in range(1, min(len1, len2) + 1): | |
if i > n and str1[:i] == str2[-i:]: | |
return True | |
return False | |
def get_gt_ids(all_texts, injected_adv): | |
gt_ids =[] | |
gt_texts = [] | |
for j, segment in enumerate(all_texts): | |
for malicious_text in injected_adv: | |
if check_overlap(segment,malicious_text,10): | |
gt_ids.append(j) | |
gt_texts.append(all_texts[j]) | |
return gt_ids,gt_texts | |
def min_subset_to_contain(gt_text, texts): | |
candidates =[] | |
for i in range(len(texts)): | |
for j in range(i+1,len(texts)): | |
#print("candidate:",''.join(texts[i:j])) | |
if gt_text in ''.join(texts[i:j]).replace(' ',' '): | |
candidates.append(texts[i:j]) | |
#print(candidates) | |
if len(candidates) >0: | |
return min(candidates, key=len) | |
else: | |
return [] | |
def mean_of_percent(values,percent = 1): | |
# Step 1: Sort the list in descending order | |
sorted_values = sorted(values, reverse=True) | |
# Step 2: Determine the number of elements in the top 20% | |
top_percent_count = max(1, int(len(sorted_values) * percent)) | |
print("top_percent_count: ", top_percent_count) | |
# Step 3: Extract the top 20% values | |
top_values = sorted_values[:top_percent_count] | |
# Step 4: Calculate and return the mean of the top 20% values | |
if len(top_values) ==0: | |
return 0 | |
mean_top = sum(top_values) / len(top_values) | |
return mean_top | |
def is_value_in_dicts(dictionary, value_to_check): | |
for value in dictionary.values(): | |
if isinstance(value, (np.ndarray, list)): | |
# If value is an array or list, check if any/all elements match | |
if np.array_equal(value, value_to_check): # For numpy arrays | |
return True | |
else: | |
if value == value_to_check: | |
return True | |
return False | |
def wait_for_available_gpu_memory(required_memory_gb, device=0, check_interval=5): | |
""" | |
Waits until the required amount of GPU memory is available. | |
Args: | |
required_memory_gb (float): Required GPU memory in gigabytes. | |
device (int): GPU device index (default is 0) | |
check_interval (int): Time interval in seconds between memory checks. | |
Returns: | |
None | |
""" | |
required_memory_bytes = required_memory_gb * 1e9 # Convert GB to bytes | |
while True: | |
try: | |
nvmlInit() | |
handle = nvmlDeviceGetHandleByIndex(device) | |
info = nvmlDeviceGetMemoryInfo(handle) | |
available_memory = info.free | |
if available_memory >= required_memory_bytes: | |
print(f"Sufficient GPU memory available: {available_memory / 1e9:.2f} GB") | |
nvmlShutdown() | |
return | |
else: | |
print(f"Waiting for GPU memory. Available: {available_memory / 1e9:.2f} GB, Required: {required_memory_gb:.2f} GB") | |
nvmlShutdown() | |
except NVMLError as error: | |
print(f"Error getting GPU memory: {error}") | |
# Fallback to PyTorch method | |
if torch.cuda.is_available(): | |
device = torch.cuda.current_device() | |
total_memory = torch.cuda.get_device_properties(device).total_memory | |
allocated_memory = torch.cuda.memory_allocated(device) | |
available_memory = total_memory - allocated_memory | |
if available_memory >= required_memory_bytes: | |
print(f"Sufficient GPU memory available (PyTorch): {available_memory / 1e9:.2f} GB") | |
return 1 | |
else: | |
print(f"Waiting for GPU memory (PyTorch). Available: {available_memory / 1e9:.2f} GB, Required: {required_memory_gb:.2f} GB") | |
else: | |
print("CUDA is not available") | |
time.sleep(check_interval) |