Spaces:
Running
on
Zero
Running
on
Zero
''' | |
Easily process & load LongBench, PoisonedRAG and NeedleInHaystack datasets. | |
''' | |
from src.utils import load_json | |
from datasets import load_dataset | |
import random | |
import json | |
from src.utils import contexts_to_sentences | |
def load_poison(dataset_name='nq-poison',retriever = 'contriever',top_k =5, num_poison = 5): | |
result_path = f"datasets/PoisonedRAG/{dataset_name}-{retriever}-{num_poison}.json" | |
results_list = load_json(result_path) | |
processed_results = [] | |
for iter,iteration_result in enumerate(results_list): | |
processed_results.extend(iteration_result[f'iter_{iter}']) | |
for result in processed_results: | |
result['topk_contents']=result['topk_contents'][:top_k] | |
result['topk_results']=result['topk_results'][:top_k] | |
print("Processed result size: ",len(processed_results)) | |
return processed_results | |
def insert_needle(dataset_name,haystack, needles,context_length,inject_times=3): | |
haystack ='\n'.join(haystack) | |
haystack = ' '.join(haystack.split(' ')[:context_length]) | |
haystack_sentences = contexts_to_sentences([haystack]) | |
num_sentences = len(haystack_sentences) | |
for needle in needles: | |
if dataset_name == "srt": | |
inject_times =inject_times | |
elif dataset_name == "mrt": | |
inject_times =1 | |
for iter in range(inject_times): | |
# Generate a random position | |
random_position = random.randint(int(num_sentences*0), num_sentences) | |
# Insert the string at the random position | |
haystack_sentences = haystack_sentences[:random_position] + [needle] + haystack_sentences[random_position:] | |
return ''.join(haystack_sentences) | |
def load_needle(dataset_name,context_length,inject_times=3): | |
haystack_path = "datasets/NeedleInHaystack/PaulGrahamEssays.jsonl" | |
# Initialize an empty list to store the JSON objects | |
haystack = [] | |
# Open the JSONL file and read line by line | |
with open(haystack_path, 'r') as file: | |
for line in file: | |
# Load each line as a JSON object and append to the list | |
haystack.append(json.loads(line)) | |
haystack = [haystack[i]['text'] for i in range(20)] | |
dataset = load_json(f"datasets/NeedleInHaystack/subjective_{dataset_name}.json") | |
for data in dataset: | |
data['needle_in_haystack'] = insert_needle(dataset_name,haystack, data['needles'],context_length,inject_times=inject_times) | |
return dataset | |
def _load_dataset(dataset_name='nq-poison', retriever='contriever', retrieval_k=5, **kwargs): | |
num_poison = kwargs.get('num_poison', 5) | |
print("Load dataset: ",dataset_name) | |
if dataset_name in ["narrativeqa","musique","qmsum"]: | |
print("datset_name: ",dataset_name) | |
dataset = load_dataset('THUDM/LongBench', dataset_name, split='test') | |
elif dataset_name in ['nq-poison', 'hotpotqa-poison', 'msmarco-poison','nq-poison-combinatorial','nq-poison-insufficient','nq-poison-correctness','nq-poison-hotflip','nq-poison-safety']: | |
dataset = load_poison(dataset_name, retriever, retrieval_k,num_poison = num_poison) | |
elif dataset_name in ['srt','mrt']: | |
context_length = kwargs.get('context_length', 10000) | |
dataset = load_needle(dataset_name,context_length,inject_times=num_poison) | |
else: | |
raise NotImplementedError | |
return dataset | |