File size: 4,757 Bytes
8474f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from .base_benchmark import BaseBenchmark
from typing import Dict, Any, Optional, Tuple
from datasets import load_dataset
import re
import random
from .evaluation_utils import extract_answer_mmlu

class GPQABenchmark(BaseBenchmark):
    """GPQA (Graduate-Level Google-Proof Q&A) benchmark"""
    
    def __init__(self):
        super().__init__(name="GPQA", dataset_name="Idavidrein/gpqa")
        
    async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
        """Load GPQA dataset"""
        # GPQA has different subsets: gpqa_main, gpqa_diamond, gpqa_extended
        subset = kwargs.get('subset', 'gpqa_main')
        
        try:
            # Set HF token if available
            import os
            hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGING_FACE_HUB_TOKEN')
            if hf_token:
                dataset = load_dataset(self.dataset_name, subset, split='train', token=hf_token)
            else:
                dataset = load_dataset(self.dataset_name, subset, split='train')
        except Exception as e:
            if "gated dataset" in str(e) or "authentication" in str(e).lower():
                raise Exception(
                    "GPQA dataset requires authentication. Please:\n"
                    "1. Set HF_TOKEN environment variable\n"
                    "2. Request access at https://huggingface.co/datasets/Idavidrein/gpqa\n"
                    f"Original error: {e}"
                )
            # Fallback to main if subset not found
            try:
                dataset = load_dataset(self.dataset_name, 'gpqa_main', split='train')
            except:
                raise e
        
        self.dataset = []
        for sample in dataset:
            # GPQA has these fields: Question, Correct Answer, Incorrect Answer 1-3
            choices = [
                sample.get('Correct Answer', ''),
                sample.get('Incorrect Answer 1', ''),
                sample.get('Incorrect Answer 2', ''),
                sample.get('Incorrect Answer 3', '')
            ]
            
            # Shuffle choices and track correct index
            import random
            indices = list(range(4))
            random.shuffle(indices)
            shuffled_choices = [choices[i] for i in indices]
            correct_index = indices.index(0)  # 0 was the correct answer position
            
            self.dataset.append({
                'question': sample['Question'],
                'choices': shuffled_choices,
                'correct_index': correct_index,
                'subject': sample.get('Subdomain', 'Unknown'),
                'raw_sample': sample
            })
        
        # Shuffle dataset
        random.shuffle(self.dataset)
        
        if sample_size and len(self.dataset) > sample_size:
            self.dataset = self.dataset[:sample_size]
    
    def format_prompt(self, sample: Dict[str, Any]) -> str:
        """Format GPQA question as prompt matching official format"""
        question = sample['question']
        choices = sample['choices']
        
        # GPQA uses a simpler format in lm-eval
        prompt = f"""What is the correct answer to this question: {question}

Choices:
(A) {choices[0]}
(B) {choices[1]}
(C) {choices[2]}
(D) {choices[3]}

Answer:"""
        return prompt
    
    async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
        """Evaluate a single GPQA sample"""
        prompt = self.format_prompt(sample)
        
        try:
            response = await api.generate_with_retry(prompt, **kwargs)
            
            # Extract answer from response using standard extraction
            predicted_letter = extract_answer_mmlu(response)
            
            if predicted_letter:
                predicted_index = ord(predicted_letter) - ord('A')
            else:
                # If no clear answer, mark as incorrect
                predicted_index = -1
            
            correct_index = sample['correct_index']
            is_correct = predicted_index == correct_index
            
            result = {
                'question': sample['question'],
                'choices': sample['choices'],
                'correct_answer': correct_index,
                'predicted_answer': predicted_index,
                'model_response': response,
                'is_correct': is_correct,
                'subject': sample['subject']
            }
            
            return is_correct, result
            
        except Exception as e:
            result = {
                'question': sample['question'],
                'error': str(e),
                'is_correct': False
            }
            return False, result