File size: 2,861 Bytes
133fe7e
def985a
64a3746
133fe7e
def985a
 
64a3746
def985a
 
 
64a3746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
def985a
 
64a3746
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Any, Optional
from smolagents.tools import Tool
import re

class FinalAnswerTool(Tool):
    name = "final_answer"
    description = "Provides a final answer to the given problem in GAIA benchmark format."
    inputs = {'answer': {'type': 'any', 'description': 'The final answer to the problem'}}
    output_type = "any"

    def forward(self, answer: Any) -> str:
        """
        Process the answer to ensure it follows GAIA benchmark formatting rules.
        Returns a clean string that matches expected format.
        """
        # Convert complex objects to simple strings
        if isinstance(answer, dict):
            # Try to extract meaningful value from dictionary
            if len(answer) == 1:
                answer = list(answer.values())[0]
            elif 'answer' in answer:
                answer = answer['answer']
            elif 'result' in answer:
                answer = answer['result']
            elif 'value' in answer:
                answer = answer['value']
            else:
                # Join values as comma-separated list
                values = [str(v) for v in answer.values() if v is not None]
                answer = ", ".join(values)
        
        elif isinstance(answer, list):
            # Convert list to comma-separated string
            answer = ", ".join(str(item) for item in answer if item is not None)
        
        # Convert to string and apply GAIA formatting rules
        answer_str = str(answer).strip()
        
        # Remove common formatting issues
        answer_str = self._clean_gaia_format(answer_str)
        
        return answer_str
    
    def _clean_gaia_format(self, text: str) -> str:
        """Apply GAIA benchmark formatting rules."""
        # Remove "FINAL ANSWER:" prefix if present
        text = re.sub(r'^(FINAL\s*ANSWER\s*:\s*)', '', text, flags=re.IGNORECASE).strip()
        
        # Remove quotes if they wrap the entire answer
        if (text.startswith('"') and text.endswith('"')) or (text.startswith("'") and text.endswith("'")):
            text = text[1:-1]
        
        # Remove articles for strings (a, an, the) at the beginning
        text = re.sub(r'^(a|an|the)\s+', '', text, flags=re.IGNORECASE)
        
        # Remove units symbols unless they might be part of the answer
        # Be conservative - only remove obvious currency and percent
        if not any(char.isalpha() for char in text.replace('$', '').replace('%', '')):
            text = text.replace('$', '').replace('%', '')
        
        # Remove commas from numbers (but not from lists)
        if re.match(r'^\d{1,3}(,\d{3})+(\.\d+)?$', text):
            text = text.replace(',', '')
        
        return text.strip()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_initialized = True