File size: 3,359 Bytes
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
'''
    Easily process & load LongBench, PoisonedRAG and NeedleInHaystack datasets.
'''
from src.utils import load_json
from datasets import load_dataset
import random
import json
from src.utils import contexts_to_sentences
def load_poison(dataset_name='nq-poison',retriever = 'contriever',top_k =5, num_poison = 5):
    result_path = f"datasets/PoisonedRAG/{dataset_name}-{retriever}-{num_poison}.json"
    results_list = load_json(result_path)
    processed_results = []
    for iter,iteration_result in enumerate(results_list):
        processed_results.extend(iteration_result[f'iter_{iter}'])
    for result in processed_results:
        result['topk_contents']=result['topk_contents'][:top_k]
        result['topk_results']=result['topk_results'][:top_k]
    print("Processed result size: ",len(processed_results))

    return processed_results

    
def insert_needle(dataset_name,haystack, needles,context_length,inject_times=3):
    haystack ='\n'.join(haystack)
    haystack =  ' '.join(haystack.split(' ')[:context_length])
    haystack_sentences = contexts_to_sentences([haystack])
    num_sentences = len(haystack_sentences)
    
    for needle in needles:
        if dataset_name == "srt":
            inject_times =inject_times
        elif dataset_name == "mrt":
            inject_times =1
        for iter in range(inject_times):
            # Generate a random position
            random_position = random.randint(int(num_sentences*0), num_sentences)
            
            # Insert the string at the random position
            haystack_sentences = haystack_sentences[:random_position] + [needle] + haystack_sentences[random_position:]

    return ''.join(haystack_sentences)

def load_needle(dataset_name,context_length,inject_times=3):
    haystack_path = "datasets/NeedleInHaystack/PaulGrahamEssays.jsonl"
    # Initialize an empty list to store the JSON objects
    haystack = []

    # Open the JSONL file and read line by line
    with open(haystack_path, 'r') as file:
        for line in file:
            # Load each line as a JSON object and append to the list
            haystack.append(json.loads(line))

    haystack = [haystack[i]['text'] for i in range(20)]
    dataset = load_json(f"datasets/NeedleInHaystack/subjective_{dataset_name}.json")
    for data in dataset:
        data['needle_in_haystack'] = insert_needle(dataset_name,haystack, data['needles'],context_length,inject_times=inject_times)
    return dataset

def _load_dataset(dataset_name='nq-poison', retriever='contriever', retrieval_k=5, **kwargs):
    num_poison = kwargs.get('num_poison', 5)
    print("Load dataset: ",dataset_name)
    if dataset_name in ["narrativeqa","musique","qmsum"]:
        print("datset_name: ",dataset_name)
        dataset = load_dataset('THUDM/LongBench', dataset_name, split='test')
    elif dataset_name in ['nq-poison', 'hotpotqa-poison', 'msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip','nq-poison-safety']:
        dataset = load_poison(dataset_name, retriever, retrieval_k,num_poison = num_poison)
    elif dataset_name in ['srt','mrt']:
        context_length = kwargs.get('context_length', 10000)
        dataset = load_needle(dataset_name,context_length,inject_times=num_poison)
    else: 
        raise NotImplementedError
    return dataset