AttnTrace / src /load_dataset.py
SecureLLMSys's picture
init
f214f36
'''
Easily process & load LongBench, PoisonedRAG and NeedleInHaystack datasets.
'''
from src.utils import load_json
from datasets import load_dataset
import random
import json
from src.utils import contexts_to_sentences
def load_poison(dataset_name='nq-poison',retriever = 'contriever',top_k =5, num_poison = 5):
result_path = f"datasets/PoisonedRAG/{dataset_name}-{retriever}-{num_poison}.json"
results_list = load_json(result_path)
processed_results = []
for iter,iteration_result in enumerate(results_list):
processed_results.extend(iteration_result[f'iter_{iter}'])
for result in processed_results:
result['topk_contents']=result['topk_contents'][:top_k]
result['topk_results']=result['topk_results'][:top_k]
print("Processed result size: ",len(processed_results))
return processed_results
def insert_needle(dataset_name,haystack, needles,context_length,inject_times=3):
haystack ='\n'.join(haystack)
haystack = ' '.join(haystack.split(' ')[:context_length])
haystack_sentences = contexts_to_sentences([haystack])
num_sentences = len(haystack_sentences)
for needle in needles:
if dataset_name == "srt":
inject_times =inject_times
elif dataset_name == "mrt":
inject_times =1
for iter in range(inject_times):
# Generate a random position
random_position = random.randint(int(num_sentences*0), num_sentences)
# Insert the string at the random position
haystack_sentences = haystack_sentences[:random_position] + [needle] + haystack_sentences[random_position:]
return ''.join(haystack_sentences)
def load_needle(dataset_name,context_length,inject_times=3):
haystack_path = "datasets/NeedleInHaystack/PaulGrahamEssays.jsonl"
# Initialize an empty list to store the JSON objects
haystack = []
# Open the JSONL file and read line by line
with open(haystack_path, 'r') as file:
for line in file:
# Load each line as a JSON object and append to the list
haystack.append(json.loads(line))
haystack = [haystack[i]['text'] for i in range(20)]
dataset = load_json(f"datasets/NeedleInHaystack/subjective_{dataset_name}.json")
for data in dataset:
data['needle_in_haystack'] = insert_needle(dataset_name,haystack, data['needles'],context_length,inject_times=inject_times)
return dataset
def _load_dataset(dataset_name='nq-poison', retriever='contriever', retrieval_k=5, **kwargs):
num_poison = kwargs.get('num_poison', 5)
print("Load dataset: ",dataset_name)
if dataset_name in ["narrativeqa","musique","qmsum"]:
print("datset_name: ",dataset_name)
dataset = load_dataset('THUDM/LongBench', dataset_name, split='test')
elif dataset_name in ['nq-poison', 'hotpotqa-poison', 'msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip','nq-poison-safety']:
dataset = load_poison(dataset_name, retriever, retrieval_k,num_poison = num_poison)
elif dataset_name in ['srt','mrt']:
context_length = kwargs.get('context_length', 10000)
dataset = load_needle(dataset_name,context_length,inject_times=num_poison)
else:
raise NotImplementedError
return dataset