from typing import Any, Optional from smolagents.tools import Tool import re class FinalAnswerTool(Tool): name = "final_answer" description = "Provides a final answer to the given problem in GAIA benchmark format." inputs = {'answer': {'type': 'any', 'description': 'The final answer to the problem'}} output_type = "any" def forward(self, answer: Any) -> str: """ Process the answer to ensure it follows GAIA benchmark formatting rules. Returns a clean string that matches expected format. """ # Convert complex objects to simple strings if isinstance(answer, dict): # Try to extract meaningful value from dictionary if len(answer) == 1: answer = list(answer.values())[0] elif 'answer' in answer: answer = answer['answer'] elif 'result' in answer: answer = answer['result'] elif 'value' in answer: answer = answer['value'] else: # Join values as comma-separated list values = [str(v) for v in answer.values() if v is not None] answer = ", ".join(values) elif isinstance(answer, list): # Convert list to comma-separated string answer = ", ".join(str(item) for item in answer if item is not None) # Convert to string and apply GAIA formatting rules answer_str = str(answer).strip() # Remove common formatting issues answer_str = self._clean_gaia_format(answer_str) return answer_str def _clean_gaia_format(self, text: str) -> str: """Apply GAIA benchmark formatting rules.""" # Remove "FINAL ANSWER:" prefix if present text = re.sub(r'^(FINAL\s*ANSWER\s*:\s*)', '', text, flags=re.IGNORECASE).strip() # Remove quotes if they wrap the entire answer if (text.startswith('"') and text.endswith('"')) or (text.startswith("'") and text.endswith("'")): text = text[1:-1] # Remove articles for strings (a, an, the) at the beginning text = re.sub(r'^(a|an|the)\s+', '', text, flags=re.IGNORECASE) # Remove units symbols unless they might be part of the answer # Be conservative - only remove obvious currency and percent if not any(char.isalpha() for char in text.replace('$', '').replace('%', '')): text = text.replace('$', '').replace('%', '') # Remove commas from numbers (but not from lists) if re.match(r'^\d{1,3}(,\d{3})+(\.\d+)?$', text): text = text.replace(',', '') return text.strip() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_initialized = True