File size: 4,907 Bytes
460ec88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
GAIA Answer Formatter

This module handles reformatting of agent responses to meet the GAIA benchmark requirements.
It removes prefixes, citations, and other metadata to provide direct, concise answers.
"""

import re
import logging
from typing import Optional, Union, Dict, Any, List

# Set up logging
logger = logging.getLogger("gaia_agent.answer_formatter")

def format_answer(answer: str) -> str:
    """
    Format an answer according to GAIA benchmark requirements.
    
    Removes:
    - "Based on my search..." prefixes
    - Question repetition
    - Citation information
    - Other metadata
    
    Args:
        answer: The original answer text
        
    Returns:
        str: A clean, direct answer
    """
    if not answer:
        return ""
    
    # Log original answer for debugging
    logger.debug(f"Original answer: {answer}")
    
    # Remove "Based on my search" prefixes
    answer = re.sub(r'^Based on my search[,:]?\s*', '', answer, flags=re.IGNORECASE)
    
    # Remove "Here's what I found about..." prefixes
    answer = re.sub(r'^Here\'s what I found about [\'"].*?[\'"]:?\s*', '', answer, flags=re.IGNORECASE)
    
    # Remove question repetition patterns
    answer = re.sub(r'^(Regarding|About|Concerning|On) [\'"].*?[\'"]:?\s*', '', answer, flags=re.IGNORECASE)
    answer = re.sub(r'^You asked about [\'"].*?[\'"]:?\s*', '', answer, flags=re.IGNORECASE)
    
    # Remove citation information
    answer = re.sub(r'\n\nThis information comes from.*$', '', answer, flags=re.DOTALL)
    answer = re.sub(r'\n\nThis information is compiled from multiple sources.*$', '', answer, flags=re.DOTALL)
    answer = re.sub(r'\n\nSource:.*$', '', answer, flags=re.DOTALL)
    
    # Remove additional metadata sections
    answer = re.sub(r'\n\nAdditionally:.*$', '', answer, flags=re.DOTALL)
    
    # Clean up any remaining citation markers
    answer = re.sub(r'\[\d+\]', '', answer)
    
    # Trim whitespace
    answer = answer.strip()
    
    # Log the formatted answer
    logger.debug(f"Formatted answer: {answer}")
    
    return answer

def format_numerical_answer(answer: str) -> str:
    """
    Format a numerical answer to extract just the number.
    
    Args:
        answer: The original answer text
        
    Returns:
        str: Just the numerical value if one can be extracted, otherwise the formatted answer
    """
    # First apply general formatting
    cleaned_answer = format_answer(answer)
    
    # Extract numerical values
    numerical_match = re.search(r'(\d+(?:,\d+)*(?:\.\d+)?)', cleaned_answer)
    if numerical_match:
        return numerical_match.group(1)
    
    return cleaned_answer

def format_list_answer(answer: str) -> str:
    """
    Format a list-type answer to maintain the list structure but remove unnecessary text.
    
    Args:
        answer: The original answer text
        
    Returns:
        str: A cleaned list answer
    """
    # First apply general formatting
    cleaned_answer = format_answer(answer)
    
    # If the answer contains numbered or bulleted items, preserve the list structure
    if re.search(r'(\d+\.\s+|\*\s+|•\s+|-\s+)', cleaned_answer):
        # Extract the list items but remove any preamble
        list_items = re.findall(r'(?:\d+\.\s+|\*\s+|•\s+|-\s+)(.+?)(?=\n\n|\n(?:\d+\.\s+|\*\s+|•\s+|-\s+)|$)', cleaned_answer, re.DOTALL)
        if list_items:
            return '\n'.join([f"- {item.strip()}" for item in list_items])
    
    return cleaned_answer

def detect_answer_type(question: str) -> str:
    """
    Detect the type of answer expected based on the question.
    
    Args:
        question: The question text
        
    Returns:
        str: The detected answer type ('numerical', 'list', or 'text')
    """
    question_lower = question.lower()
    
    # Check for numerical questions
    if re.search(r'how many|how much|count|number of|total of|population|percentage|age|height|weight|distance|length|width|depth|area|volume', question_lower):
        return 'numerical'
    
    # Check for list questions
    if re.search(r'list|name|enumerate|what are the|examples of', question_lower):
        return 'list'
    
    # Default to text
    return 'text'

def format_answer_by_type(answer: str, question: Optional[str] = None) -> str:
    """
    Format an answer according to the detected answer type from the question.
    
    Args:
        answer: The original answer text
        question: The original question (optional)
        
    Returns:
        str: A formatted answer appropriate for the question type
    """
    if not question:
        return format_answer(answer)
    
    answer_type = detect_answer_type(question)
    
    if answer_type == 'numerical':
        return format_numerical_answer(answer)
    elif answer_type == 'list':
        return format_list_answer(answer)
    else:
        return format_answer(answer)