|
""" |
|
Main agent implementation for the GAIA benchmark. |
|
|
|
This module contains the GAIAAgent class which is responsible for: |
|
- Processing questions from the GAIA benchmark |
|
- Selecting and executing appropriate tools |
|
- Formulating precise answers |
|
- Handling errors and logging |
|
- Addressing special cases like reversed text and word unscrambling |
|
|
|
The agent uses LangGraph for workflow management and Supabase for memory, |
|
with configuration from the config module. It can be extended with |
|
additional capabilities as needed. |
|
""" |
|
|
|
import logging |
|
import time |
|
import traceback |
|
import re |
|
import hashlib |
|
from typing import Dict, Any, Optional, List, Union, Tuple |
|
|
|
from src.gaia.agent.config import ( |
|
get_logging_config, |
|
get_model_config, |
|
get_tool_config, |
|
get_memory_config, |
|
get_agent_config, |
|
VERBOSE |
|
) |
|
|
|
|
|
from src.gaia.agent.graph import run_agent_graph |
|
|
|
from src.gaia.memory import SupabaseMemory |
|
from src.gaia.memory.supabase_memory import ConversationMemory, ResultCache, WorkingMemory |
|
|
|
logging_config = get_logging_config() |
|
logging.basicConfig( |
|
level=logging_config["level"], |
|
format=logging_config["format"], |
|
filename=logging_config["filename"] |
|
) |
|
logger = logging.getLogger("gaia_agent") |
|
|
|
class GAIAAgent: |
|
""" |
|
Agent for answering questions from the GAIA benchmark. |
|
|
|
This agent processes questions, selects appropriate tools, |
|
executes a reasoning process, and formulates precise answers. |
|
It includes improved handling for special cases like reversed text, |
|
direct text manipulation, and word unscrambling. |
|
""" |
|
|
|
def __init__(self, config: Optional[Any] = None): |
|
""" |
|
Initialize the GAIA agent. |
|
|
|
Args: |
|
config: Optional configuration (dictionary, string, Config object, etc.) |
|
""" |
|
|
|
default_config = { |
|
"model": get_model_config(), |
|
"tools": get_tool_config(), |
|
"memory": get_memory_config(), |
|
"agent": get_agent_config() |
|
} |
|
|
|
|
|
self._original_config = config |
|
|
|
|
|
self.supabase_memory = None |
|
self.conversation_memory = None |
|
self.result_cache = None |
|
self.working_memory = None |
|
|
|
|
|
if config is None: |
|
self.config = default_config |
|
|
|
elif isinstance(config, str): |
|
self.config = default_config |
|
|
|
elif isinstance(config, dict): |
|
self.config = config |
|
|
|
if "model" not in self.config: |
|
self.config["model"] = get_model_config() |
|
if "tools" not in self.config: |
|
self.config["tools"] = get_tool_config() |
|
if "memory" not in self.config: |
|
self.config["memory"] = get_memory_config() |
|
if "agent" not in self.config: |
|
self.config["agent"] = get_agent_config() |
|
|
|
else: |
|
|
|
self.config = default_config |
|
|
|
|
|
if "memory" not in self.config: |
|
self.config["memory"] = get_memory_config() |
|
|
|
self.memory_config = self.config["memory"] |
|
|
|
|
|
|
|
self._config_cache = {} |
|
|
|
|
|
if self.memory_config.get("enabled", False): |
|
self._initialize_memory() |
|
|
|
|
|
self.verbose = VERBOSE |
|
|
|
logger.info("GAIA Agent initialized") |
|
|
|
|
|
def get(self, key, default=None): |
|
"""Get configuration value (for compatibility with tests).""" |
|
if hasattr(self._original_config, 'get'): |
|
|
|
return self._original_config.get(key, default) |
|
elif isinstance(self.config, dict) and key in self.config: |
|
return self.config[key] |
|
else: |
|
|
|
return self._config_cache.get(key, default) |
|
|
|
def set(self, key, value): |
|
"""Set configuration value (for compatibility with tests).""" |
|
if hasattr(self._original_config, 'set'): |
|
|
|
self._original_config.set(key, value) |
|
elif isinstance(self.config, dict): |
|
|
|
self.config[key] = value |
|
|
|
|
|
self._config_cache[key] = value |
|
|
|
def _initialize_memory(self): |
|
"""Initialize memory systems based on configuration.""" |
|
|
|
try: |
|
|
|
self.supabase_memory = SupabaseMemory({}) |
|
logger.info("Default memory initialized") |
|
|
|
|
|
self.conversation_memory = ConversationMemory(self.supabase_memory, "conversation") |
|
self.result_cache = ResultCache(self.supabase_memory) |
|
self.working_memory = WorkingMemory(self.supabase_memory, "working") |
|
|
|
logger.info("All memory systems initialized with defaults") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize memory: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
|
|
self.supabase_memory = None |
|
self.conversation_memory = None |
|
self.result_cache = None |
|
self.working_memory = None |
|
|
|
def process_question(self, question: str) -> str: |
|
""" |
|
Process a question and generate an answer. |
|
|
|
Args: |
|
question (str): The question to process |
|
|
|
Returns: |
|
str: The answer to the question |
|
""" |
|
start_time = time.time() |
|
|
|
try: |
|
logger.info(f"Processing question: {question[:100]}...") |
|
|
|
answer = self._process_question(question) |
|
|
|
end_time = time.time() |
|
processing_time = end_time - start_time |
|
|
|
logger.info(f"Question processed in {processing_time:.2f} seconds") |
|
|
|
return answer |
|
except Exception as e: |
|
logger.error(f"Error processing question: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
return f"Error processing your question: {str(e)}" |
|
|
|
def query(self, question: str) -> Dict[str, Any]: |
|
""" |
|
Query the agent with a question to get an answer. |
|
|
|
This is the main API method used by test harnesses and applications. |
|
It returns a dictionary with the answer, reasoning, and other metadata. |
|
|
|
Args: |
|
question (str): The question to answer |
|
|
|
Returns: |
|
Dict[str, Any]: Dictionary containing the answer and metadata |
|
""" |
|
try: |
|
start_time = time.time() |
|
answer = self.process_question(question) |
|
end_time = time.time() |
|
processing_time = end_time - start_time |
|
|
|
|
|
reasoning = "" |
|
tools_used = [] |
|
|
|
if self.working_memory: |
|
plan = self.working_memory.get_intermediate_result("plan") |
|
if plan: |
|
reasoning = str(plan) |
|
|
|
tool_results = self.working_memory.get_intermediate_result("tool_results") |
|
if tool_results: |
|
tools_used = [r.get("tool_name") for r in tool_results if r.get("tool_name")] |
|
|
|
return { |
|
"answer": answer, |
|
"reasoning": reasoning, |
|
"time_taken": processing_time, |
|
"tools_used": tools_used, |
|
"success": True |
|
} |
|
except Exception as e: |
|
logger.error(f"Error in query: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
|
|
return { |
|
"answer": f"Error: {str(e)}", |
|
"reasoning": f"An error occurred: {str(e)}", |
|
"time_taken": 0, |
|
"tools_used": [], |
|
"success": False, |
|
"error": str(e) |
|
} |
|
|
|
def _process_question(self, question: str) -> str: |
|
""" |
|
Internal method to process a question using the LangGraph workflow. |
|
|
|
Args: |
|
question (str): The question to process |
|
|
|
Returns: |
|
str: The answer to the question |
|
""" |
|
try: |
|
|
|
cache_key = hashlib.md5(question.encode()).hexdigest() |
|
|
|
if self.result_cache: |
|
cached_answer = self.result_cache.get_result(cache_key) |
|
if cached_answer: |
|
logger.info("Retrieved answer from cache") |
|
return cached_answer |
|
|
|
answer = "" |
|
|
|
|
|
if run_agent_graph: |
|
try: |
|
|
|
logger.info("Running LangGraph workflow") |
|
result = run_agent_graph( |
|
{"question": question}, |
|
self.config |
|
) |
|
|
|
if result and isinstance(result, dict): |
|
answer = result.get("answer", "") |
|
logger.info(f"Got answer from run_agent_graph: {answer[:100]}...") |
|
|
|
|
|
if self.working_memory: |
|
if result.get("plan"): |
|
self.working_memory.store_intermediate_result( |
|
"plan", result["plan"] |
|
) |
|
|
|
if result.get("tool_results"): |
|
self.working_memory.store_intermediate_result( |
|
"tool_results", result["tool_results"] |
|
) |
|
except Exception as e: |
|
logger.error(f"Error running LangGraph workflow: {str(e)}") |
|
answer = f"I encountered an error while processing your question. Please try rephrasing or asking a different question." |
|
|
|
|
|
if not answer: |
|
answer = "I don't have enough information to answer this question accurately." |
|
|
|
|
|
if self.result_cache: |
|
self.result_cache.cache_result(cache_key, answer) |
|
|
|
if self.conversation_memory: |
|
self.conversation_memory.add_message("user", question) |
|
self.conversation_memory.add_message("assistant", answer) |
|
|
|
return answer |
|
|
|
except Exception as e: |
|
logger.error(f"Error running LangGraph workflow: {str(e)}") |
|
raise |
|
|
|
def get_memory_snapshot(self) -> Dict[str, Any]: |
|
""" |
|
Get a snapshot of the agent's memory. |
|
|
|
Returns: |
|
Dict containing memory contents |
|
""" |
|
snapshot = {} |
|
|
|
if hasattr(self, 'conversation_memory') and self.conversation_memory: |
|
try: |
|
snapshot["conversation"] = self.conversation_memory.get_messages() |
|
except Exception as e: |
|
logger.warning(f"Error getting conversation memory: {str(e)}") |
|
snapshot["conversation"] = [] |
|
|
|
if hasattr(self, 'working_memory') and self.working_memory: |
|
try: |
|
snapshot["working"] = self.working_memory.get_all_results() |
|
except Exception as e: |
|
logger.warning(f"Error getting working memory: {str(e)}") |
|
snapshot["working"] = {} |
|
|
|
return snapshot |
|
|
|
def clear_memory(self): |
|
"""Clear all agent memory.""" |
|
|
|
if hasattr(self, 'conversation_memory') and self.conversation_memory: |
|
self.conversation_memory.clear() |
|
|
|
if hasattr(self, 'working_memory') and self.working_memory: |
|
self.working_memory.clear() |
|
|
|
if hasattr(self, 'result_cache') and self.result_cache: |
|
self.result_cache.clear() |
|
|
|
logger.info("Agent memory cleared") |
|
|
|
def run(self, question: str) -> str: |
|
""" |
|
Run the agent on a question and return the answer. |
|
|
|
This method is required by the app.py interface. |
|
|
|
Args: |
|
question (str): The question to process |
|
|
|
Returns: |
|
str: The answer to the question |
|
""" |
|
return self.process_question(question) |