File size: 18,341 Bytes
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383cea5
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
import os
import json
import numpy as np
import random
import torch
import re
import torch
from pynvml import *
import time
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)
        
def load_results(file_name):
    with open(os.path.join('results', file_name)) as file:
        results = json.load(file)
    return results
def save_json(results, file_path="debug.json"):
    json_dict = json.dumps(results, cls=NpEncoder)
    dict_from_str = json.loads(json_dict)
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(dict_from_str, f)

def load_json(file_path):
    with open(file_path) as file:
        results = json.load(file)
    return results


def save_results(results, dir, file_name="debug"):
    json_dict = json.dumps(results, cls=NpEncoder)
    dict_from_str = json.loads(json_dict)
    if not os.path.exists(f'results/{dir}'):
        os.makedirs(f'results/{dir}', exist_ok=True)
    with open(os.path.join(f'results/{dir}', f'{file_name}.json'), 'w', encoding='utf-8') as f:
        json.dump(dict_from_str, f)
def read_results(dir, file_name="debug"):
    file_path = os.path.join(f'results/{dir}', f'{file_name}.json')
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"No such file: '{file_path}'")
    with open(file_path, 'r', encoding='utf-8') as f:
        results = json.load(f)
    return results
def _save_results(args,attr_results, pred_results_path):
    if args.dataset_name in ['musique', 'narrativeqa', 'qmsum']:
        name = f"{args.prompt_injection_attack}"
    elif args.dataset_name in ['nq-poison','hotpotqa-poison','msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip','nq-poison-safety']:
        name = "PoisonedRag"
    elif args.dataset_name in ['srt','mrt']:
        name = "needle_in_haystack"
    else:
        raise ValueError("Unsupported dataset_name.")
    if args.attr_type in ["vanilla_perturb","tracllm"]:
        save_results(attr_results, pred_results_path, name+f"_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{'_'.join(args.score_funcs)}_{args.avg_k}_{args.K}")
    elif args.attr_type == "attntrace":
        save_results(attr_results, pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.avg_k}_{args.q}_{args.B}_{args.K}')
    elif args.attr_type == "self_citation" or args.attr_type == "context_cite" or "attention" in args.attr_type:
        save_results(attr_results, pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.K}')
    else:
        raise ValueError("Unsupported attr_type.")

def _read_results(args, pred_results_path):
    if args.dataset_name in ['musique', 'narrativeqa', 'qmsum']:
        name = f"{args.prompt_injection_attack}"
    elif args.dataset_name in ['nq-poison','hotpotqa-poison','msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip', 'nq-poison-safety']:
        name = "PoisonedRag"
    elif args.dataset_name in ['srt','mrt']:
        name = "needle_in_haystack"
    else:
        raise ValueError("Unsupported dataset_name.")
    if args.attr_type in ["vanilla_perturb","tracllm"]:
        return read_results( pred_results_path, name+f"_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{'_'.join(args.score_funcs)}_{args.avg_k}_{args.K}")
    elif args.attr_type == "attntrace":
        return read_results( pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.avg_k}_{args.q}_{args.B}_{args.K}')
    elif args.attr_type == "self_citation" or "attention" in args.attr_type:
        return read_results( pred_results_path, name+f'_{args.dataset_name}_{args.inject_times}_{args.model_name}_{args.attr_type}_{args.K}')
    else:
        raise ValueError("Unsupported attr_type.")


def setup_seeds(seed):
    # seed = config.run_cfg.seed + get_rank()
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def clean_str(s):
    try:
        s=str(s)
    except:
        print('Error: the output cannot be converted to a string')
    s=s.strip()
    if len(s)>1 and s[-1] == ".":
        s=s[:-1]
    return s.lower()
def newline_pad_contexts(contexts):
    return [contexts[0]] + ['\n\n'+context for context in contexts[1:]]
def f1_score(precision, recall):
    """
    Calculate the F1 score given precision and recall arrays.
    
    Args:
    precision (np.array): A 2D array of precision values.
    recall (np.array): A 2D array of recall values.
    
    Returns:
    np.array: A 2D array of F1 scores.
    """
    f1_scores = np.divide(2 * precision * recall, precision + recall, where=(precision + recall) != 0)
    
    return f1_scores

def remove_citations(sent):
    return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")


def find_indices(list1: list, list2: list):
    # 存储结果的列表
    indices = []
    # 遍历list1中的每个元素
    for element in list1:
        # 尝试找到element在list2中的索引
        try:
            index = list2.index(element)
            # 如果找到,将索引添加到结果列表中
            indices.append(index)
        except ValueError:
            # 如果元素不在list2中,跳过
            continue
    return indices
def contexts_to_paragraphs(contexts):
    paragraphs = contexts[0].split('\n\n')
    paragraphs = [paragraph if i == 0 else '\n\n' + paragraph for i, paragraph in enumerate(paragraphs)]

    return paragraphs
def contexts_to_segments(contexts):
    segment_size = 100
    context = contexts[0]
    words = context.split(' ')

    # Create a list to hold segments
    segments = []
    
    # Iterate over the words and group them into segments
    for i in range(0, len(words), segment_size):
        # Join a segment of 100 words and add to segments list
        segment = ' '.join(words[i:i + segment_size])+' '
        segments.append(segment)
    
    return segments

    

def paragraphs_to_sentences(paragraphs):
    all_sentences = []

    # Split the merged string into sentences
    #sentences = sent_tokenize(merged_string)
    for i,paragraph in enumerate(paragraphs):
        sentences = split_into_sentences(paragraph)
        all_sentences.extend(sentences)
    return all_sentences
def contexts_to_sentences(contexts):
    paragraphs = contexts_to_paragraphs(contexts)
    all_sentences = paragraphs_to_sentences(paragraphs)
    return all_sentences

import re
alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov|edu|me)"
digits = "([0-9])"
multiple_dots = r'\.{2,}'
def split_into_phrases(text: str) -> list[str]:
    sentences = split_into_sentences(text)
    phrases = []
    for sent in sentences:
        phrases+=sent.split(',')
    return phrases
def split_into_sentences(text: str) -> list[str]:
    """
    Split the text into sentences.

    If the text contains substrings "<prd>" or "<stop>", they would lead 
    to incorrect splitting because they are used as markers for splitting.

    :param text: text to be split into sentences
    :type text: str

    :return: list of sentences
    :rtype: list[str]
    """
    text = text.replace("。", ".")
    text = " " + text + "  "
    text = text.replace("\n","<newline>")
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
    text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text)
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(".",".<stop>")
    text = text.replace("?","?<stop>")
    text = text.replace("!","!<stop>")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = [s.strip() for s in sentences]
    if sentences and not sentences[-1]: sentences = sentences[:-1]
    sentences = [s.replace("<newline>", "\n") for s in sentences]
    return sentences
def get_previous_answer(answer, explained_answer):
    previous_answer = answer.split(explained_answer)[0]
    return previous_answer
def plot_sentence_importance(question, sentences_list, important_ids, importance_values, answer, explained_answer = "", width = 200):
    from rich.console import Console
    from rich.text import Text

    assert len(important_ids) == len(importance_values), "Mismatch between number of words and importance values."
    all_importance_values =np.zeros(len(sentences_list))
    all_importance_values[important_ids] = importance_values
    #print("sentences list: ", sentences_list)
    console = Console(width =width)
    text = Text()
    #print("MIN:",np.min(all_importance_values))
    #print(all_importance_values)
    #all_importance_values = (all_importance_values - np.min(all_importance_values)) / (np.max(all_importance_values) - np.min(all_importance_values)+0.0001)
    all_importance_values = (all_importance_values ) / (np.max(all_importance_values) +0.0001)
    
    text.append("Context:\n", style=f"black bold")
    for i,(sentence, imp) in enumerate(zip(sentences_list, all_importance_values)):

        #sentence = sentence.capitalize()
        red_intensity = 255
        blue_intensity=0
        #print(imp)
        if imp < 0 or imp ==0:
            green_intensity=255
            blue_intensity=255
        else:
            green_intensity = int(255* (1 - imp))

        bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}"

        text.append(sentence, style=f"on #{bg_color} black")
    text.append("\nQuery: \n", style=f"black bold")
    red_intensity = 255
    green_intensity=255
    blue_intensity=255

    bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}"
    text.append(question, style=f"on #{bg_color} black")
    text.append("\nLLM_response:\n", style=f"black bold")

    answer = answer.capitalize()
    red_intensity = 255
    green_intensity=255
    blue_intensity=255

    bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}"
    text.append(answer, style=f"on #{bg_color} black")
    if explained_answer!="":
        text.append("\nExplained part:", style=f"black bold")

        red_intensity = 255
        green_intensity=255
        blue_intensity=255

        bg_color = f"{red_intensity:02x}{green_intensity:02x}{blue_intensity:02x}"
        text.append(explained_answer, style=f"on #{bg_color} black")
    console.print(text)

def unzip_tuples(tuple_list):
    list1 = [t[0] for t in tuple_list]
    list2 = [t[1] for t in tuple_list]
    return list1, list2
def manual_zip(list1, list2):
    # Ensure both lists have the same length
    if len(list1) != len(list2):
        raise ValueError("Both lists must have the same length")

    combined_list = []
    for i in range(len(list1)):
        combined_list.append((list1[i], list2[i]))
    
    return combined_list
def check_cannot_answer(answer):
    prefixes = ["I don't know"]
    do_not_know = any([prefix in answer for prefix in prefixes])
    print("DO NOT KNOW: ", do_not_know)
    return do_not_know

def top_k_indexes(lst, k):
    # Check if k is greater than the length of the list
    if k > len(lst):
        k = len(lst)
    # Get the indexes of the list sorted by their values in descending order
    sorted_indexes = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True)
    # Return the first k indexes from the sorted list
    return sorted_indexes[:k]

def get_top_k(important_ids, importance_scores, k):
    top_k=top_k_indexes(importance_scores, k)
    topk_ids = [important_ids[j] for j in top_k]
    topk_scores = [importance_scores[j] for j in top_k]
    return topk_ids,topk_scores
def add_specific_indexes(lst, indexes_to_add):
    indexes_to_add = sorted(indexes_to_add)
    return [item for idx, item in enumerate(lst) if idx in indexes_to_add]
def remove_specific_indexes(lst, indexes_to_remove):
    return [item for idx, item in enumerate(lst) if idx not in indexes_to_remove]
def clean_str(s):
    try:
        s=str(s)
    except:
        print('Error: the output cannot be converted to a string')
    s=s.strip()
    if len(s)>1 and s[-1] == ".":
        s=s[:-1]
    return s.lower()
def split_context(level, contexts):
    assert isinstance(contexts, list)
    if len(contexts)>1: #the context is already segmented
        return contexts
    else:
        if level =="sentence":
            all_texts = contexts_to_sentences(contexts)
        elif level =="segment":
            all_texts = contexts_to_segments(contexts)
        elif level =="paragraph":
            all_texts = contexts_to_paragraphs(contexts)
        else:
            raise ValueError("Invalid explanation level.")
    return all_texts

def check_overlap(str1, str2, n):
    len1 = len(str1)
    len2 = len(str2)
    
    if str1 in str2 or str2 in str1:
        return True
    # Check overlap by comparing suffix of str1 with prefix of str2
    for i in range(1, min(len1, len2) + 1):
        if i > n and str1[-i:] == str2[:i]:
            return True
    
    # Check overlap by comparing prefix of str1 with suffix of str2
    for i in range(1, min(len1, len2) + 1):
        if i > n and str1[:i] == str2[-i:]:
            return True
    
    return False

def get_gt_ids(all_texts, injected_adv):
    gt_ids =[]
    gt_texts = []
    for j, segment in enumerate(all_texts):
        for malicious_text in injected_adv:
            if check_overlap(segment,malicious_text,10):
                gt_ids.append(j)
                gt_texts.append(all_texts[j])
    return gt_ids,gt_texts

def min_subset_to_contain(gt_text, texts):
    candidates =[]
    for i in range(len(texts)):
        for j in range(i+1,len(texts)):
            #print("candidate:",''.join(texts[i:j]))
            if gt_text in ''.join(texts[i:j]).replace('  ',' '):
                candidates.append(texts[i:j])
    #print(candidates)
    if len(candidates) >0:
        return min(candidates, key=len)
    else:
        return []

def mean_of_percent(values,percent = 1):
    # Step 1: Sort the list in descending order
    sorted_values = sorted(values, reverse=True)
    
    # Step 2: Determine the number of elements in the top 20%
    top_percent_count = max(1, int(len(sorted_values) * percent))
    print("top_percent_count: ", top_percent_count)
    # Step 3: Extract the top 20% values
    top_values = sorted_values[:top_percent_count]
    # Step 4: Calculate and return the mean of the top 20% values
    if len(top_values) ==0:
        return 0 

    mean_top = sum(top_values) / len(top_values)
    return mean_top

def is_value_in_dicts(dictionary, value_to_check):
    for value in dictionary.values():
        if isinstance(value, (np.ndarray, list)):
            # If value is an array or list, check if any/all elements match
            if np.array_equal(value, value_to_check):  # For numpy arrays
                return True
        else:
            if value == value_to_check:
                return True
    return False


def wait_for_available_gpu_memory(required_memory_gb, device=0, check_interval=5):
    """
    Waits until the required amount of GPU memory is available.
    Args:
    required_memory_gb (float): Required GPU memory in gigabytes.
    device (int): GPU device index (default is 0)
    check_interval (int): Time interval in seconds between memory checks.
    Returns:
    None
    """
    required_memory_bytes = required_memory_gb * 1e9  # Convert GB to bytes
    while True:
        try:
            nvmlInit()
            handle = nvmlDeviceGetHandleByIndex(device)
            info = nvmlDeviceGetMemoryInfo(handle)
            available_memory = info.free
            if available_memory >= required_memory_bytes:
                print(f"Sufficient GPU memory available: {available_memory / 1e9:.2f} GB")
                nvmlShutdown()
                return
            else:
                print(f"Waiting for GPU memory. Available: {available_memory / 1e9:.2f} GB, Required: {required_memory_gb:.2f} GB")
            nvmlShutdown()
        except NVMLError as error:
            print(f"Error getting GPU memory: {error}")
            # Fallback to PyTorch method
            if torch.cuda.is_available():
                device = torch.cuda.current_device()
                total_memory = torch.cuda.get_device_properties(device).total_memory
                allocated_memory = torch.cuda.memory_allocated(device)
                available_memory = total_memory - allocated_memory
                if available_memory >= required_memory_bytes:
                    print(f"Sufficient GPU memory available (PyTorch): {available_memory / 1e9:.2f} GB")
                    return 1
                else:
                    print(f"Waiting for GPU memory (PyTorch). Available: {available_memory / 1e9:.2f} GB, Required: {required_memory_gb:.2f} GB")
            else:
                print("CUDA is not available")
        time.sleep(check_interval)