|
""" |
|
Utilities module for GAIA implementation. |
|
Contains helper functions for question parsing, answer formatting, logging, and CLI utilities. |
|
""" |
|
|
|
|
|
import logging as python_logging |
|
from typing import Dict, Any, List, Optional, Union |
|
|
|
|
|
from src.gaia.utils.parsing.parsing import ( |
|
extract_key_information, |
|
identify_required_tools, |
|
determine_question_complexity, |
|
parse_question, |
|
COMPLEXITY_LEVELS |
|
) |
|
|
|
from src.gaia.utils.formatting.formatting import ( |
|
extract_answer, |
|
format_answer, |
|
validate_answer_format, |
|
process_answer, |
|
FORMAT_TYPES |
|
) |
|
|
|
from src.gaia.utils.logging.logging_framework import ( |
|
initialize_logging, |
|
log_info, |
|
log_warning, |
|
log_error, |
|
log_api_request, |
|
log_api_response, |
|
log_tool_selection, |
|
log_tool_execution, |
|
log_workflow_step, |
|
log_memory_operation, |
|
TimingContext, |
|
get_trace_id, |
|
set_trace_id, |
|
generate_trace_id |
|
) |
|
|
|
|
|
logger = python_logging.getLogger("gaia_agent.utils") |
|
|
|
def clean_text(text: str) -> str: |
|
""" |
|
Clean text by removing extra whitespace and normalizing. |
|
|
|
Args: |
|
text: The text to clean |
|
|
|
Returns: |
|
Cleaned text |
|
""" |
|
if not text: |
|
return "" |
|
|
|
cleaned = " ".join(text.split()) |
|
return cleaned.strip() |
|
|
|
def truncate_text(text: str, max_length: int = 1000, add_ellipsis: bool = True) -> str: |
|
""" |
|
Truncate text to a maximum length. |
|
|
|
Args: |
|
text: The text to truncate |
|
max_length: Maximum length in characters |
|
add_ellipsis: Whether to add ellipsis if truncated |
|
|
|
Returns: |
|
Truncated text |
|
""" |
|
if not text or len(text) <= max_length: |
|
return text |
|
|
|
truncated = text[:max_length] |
|
if add_ellipsis: |
|
truncated += "..." |
|
|
|
return truncated |
|
|
|
def merge_dicts(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Merge two dictionaries, with dict2 values taking precedence. |
|
|
|
Args: |
|
dict1: First dictionary |
|
dict2: Second dictionary (takes precedence) |
|
|
|
Returns: |
|
Merged dictionary |
|
""" |
|
result = dict1.copy() |
|
result.update(dict2) |
|
return result |
|
|
|
def get_answer_format(question: str) -> str: |
|
""" |
|
Determine the expected answer format based on the question. |
|
|
|
Args: |
|
question: The question to analyze |
|
|
|
Returns: |
|
Expected answer format type |
|
""" |
|
question_lower = question.lower() |
|
|
|
if any(term in question_lower for term in ["how many", "how much", "count", "sum", "total", "calculate"]): |
|
return FORMAT_TYPES["NUMBER"] |
|
|
|
if any(term in question_lower for term in ["when", "what date", "what time", "what year"]): |
|
return FORMAT_TYPES["DATE"] |
|
|
|
if any(term in question_lower for term in ["is it", "are there", "does it", "can it", "will it", "should it"]): |
|
if question_lower.startswith(("is ", "are ", "does ", "do ", "can ", "will ", "should ")): |
|
return FORMAT_TYPES["BOOLEAN"] |
|
|
|
if any(term in question_lower for term in ["list", "enumerate", "what are", "examples of", "types of"]): |
|
return FORMAT_TYPES["LIST"] |
|
|
|
if any(term in question_lower for term in ["who", "person", "which company", "which organization"]): |
|
return FORMAT_TYPES["ENTITY"] |
|
|
|
return FORMAT_TYPES["TEXT"] |
|
|
|
def analyze_question(question: str) -> Dict[str, Any]: |
|
""" |
|
Comprehensive analysis of a question. |
|
|
|
Args: |
|
question: The question to analyze |
|
|
|
Returns: |
|
Dictionary with analysis results |
|
""" |
|
parsed = parse_question(question) |
|
|
|
expected_format = get_answer_format(question) |
|
|
|
return { |
|
"question": question, |
|
"parsed_info": parsed, |
|
"expected_format": expected_format, |
|
"required_tools": parsed["required_tools"], |
|
"complexity": parsed["complexity"]["complexity_level"] |
|
} |
|
|
|
def format_response(response: str, expected_format: Optional[str] = None) -> Dict[str, Any]: |
|
""" |
|
Format a response according to the expected format. |
|
|
|
Args: |
|
response: The response to format |
|
expected_format: Optional expected format type |
|
|
|
Returns: |
|
Dictionary with formatted response |
|
""" |
|
return process_answer(response, expected_format) |
|
|
|
|
|
__all__ = [ |
|
|
|
'extract_key_information', |
|
'identify_required_tools', |
|
'determine_question_complexity', |
|
'parse_question', |
|
'COMPLEXITY_LEVELS', |
|
|
|
|
|
'extract_answer', |
|
'format_answer', |
|
'validate_answer_format', |
|
'process_answer', |
|
'FORMAT_TYPES', |
|
|
|
|
|
'initialize_logging', |
|
'log_info', |
|
'log_warning', |
|
'log_error', |
|
'log_api_request', |
|
'log_api_response', |
|
'log_tool_selection', |
|
'log_tool_execution', |
|
'log_workflow_step', |
|
'log_memory_operation', |
|
'TimingContext', |
|
'get_trace_id', |
|
'set_trace_id', |
|
'generate_trace_id', |
|
|
|
|
|
'clean_text', |
|
'truncate_text', |
|
'merge_dicts', |
|
'get_answer_format', |
|
'analyze_question', |
|
'format_response' |
|
] |