File size: 2,861 Bytes
133fe7e def985a 64a3746 133fe7e def985a 64a3746 def985a 64a3746 def985a 64a3746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from typing import Any, Optional
from smolagents.tools import Tool
import re
class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem in GAIA benchmark format."
inputs = {'answer': {'type': 'any', 'description': 'The final answer to the problem'}}
output_type = "any"
def forward(self, answer: Any) -> str:
"""
Process the answer to ensure it follows GAIA benchmark formatting rules.
Returns a clean string that matches expected format.
"""
# Convert complex objects to simple strings
if isinstance(answer, dict):
# Try to extract meaningful value from dictionary
if len(answer) == 1:
answer = list(answer.values())[0]
elif 'answer' in answer:
answer = answer['answer']
elif 'result' in answer:
answer = answer['result']
elif 'value' in answer:
answer = answer['value']
else:
# Join values as comma-separated list
values = [str(v) for v in answer.values() if v is not None]
answer = ", ".join(values)
elif isinstance(answer, list):
# Convert list to comma-separated string
answer = ", ".join(str(item) for item in answer if item is not None)
# Convert to string and apply GAIA formatting rules
answer_str = str(answer).strip()
# Remove common formatting issues
answer_str = self._clean_gaia_format(answer_str)
return answer_str
def _clean_gaia_format(self, text: str) -> str:
"""Apply GAIA benchmark formatting rules."""
# Remove "FINAL ANSWER:" prefix if present
text = re.sub(r'^(FINAL\s*ANSWER\s*:\s*)', '', text, flags=re.IGNORECASE).strip()
# Remove quotes if they wrap the entire answer
if (text.startswith('"') and text.endswith('"')) or (text.startswith("'") and text.endswith("'")):
text = text[1:-1]
# Remove articles for strings (a, an, the) at the beginning
text = re.sub(r'^(a|an|the)\s+', '', text, flags=re.IGNORECASE)
# Remove units symbols unless they might be part of the answer
# Be conservative - only remove obvious currency and percent
if not any(char.isalpha() for char in text.replace('$', '').replace('%', '')):
text = text.replace('$', '').replace('%', '')
# Remove commas from numbers (but not from lists)
if re.match(r'^\d{1,3}(,\d{3})+(\.\d+)?$', text):
text = text.replace(',', '')
return text.strip()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_initialized = True
|