|
""" |
|
GAIA Agent Enhanced Implementation |
|
|
|
This module provides an enhanced implementation of the GAIA agent |
|
that uses specialized components instead of hardcoded responses. |
|
""" |
|
|
|
import os |
|
import re |
|
import logging |
|
import time |
|
from typing import Dict, Any, List, Optional, Union, Callable |
|
import traceback |
|
import sys |
|
import json |
|
|
|
|
|
from src.gaia.agent.answer_formatter import format_answer_by_type |
|
|
|
|
|
try: |
|
from langgraph.graph import END, StateGraph |
|
from langgraph.prebuilt import ToolNode |
|
LANGGRAPH_AVAILABLE = True |
|
except ImportError: |
|
LANGGRAPH_AVAILABLE = False |
|
|
|
|
|
logger = logging.getLogger("gaia_agent") |
|
|
|
|
|
from src.gaia.agent.components import TextAnalyzer, VideoAnalyzer, SearchManager, MemoryManager |
|
from src.gaia.agent.tool_registry import get_tools, create_tools_registry |
|
from src.gaia.agent.config import VERBOSE, DEFAULT_CHECKPOINT_PATH |
|
|
|
class GAIAAgent: |
|
""" |
|
Enhanced GAIA Agent implementation. |
|
|
|
This agent uses specialized components to handle different types of questions |
|
without hardcoded responses. |
|
""" |
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None): |
|
""" |
|
Initialize the GAIA Agent with configuration. |
|
|
|
Args: |
|
config: Configuration dictionary |
|
""" |
|
self.config = config or {} |
|
self.verbose = self.config.get("verbose", VERBOSE) |
|
|
|
|
|
self._initialize_components() |
|
|
|
|
|
self.state = { |
|
"initialized": True, |
|
"last_question": None, |
|
"last_answer": None, |
|
"last_execution_time": None |
|
} |
|
|
|
|
|
if LANGGRAPH_AVAILABLE: |
|
self.graph = self._build_langgraph() |
|
logger.info("LangGraph workflow initialized") |
|
else: |
|
self.graph = None |
|
logger.warning("LangGraph not available, using fallback processing") |
|
|
|
|
|
self.tools_registry = create_tools_registry() |
|
|
|
logger.info("GAIA Agent initialized successfully") |
|
|
|
def _initialize_components(self): |
|
"""Initialize specialized components.""" |
|
logger.info("Initializing components") |
|
|
|
try: |
|
|
|
self.text_analyzer = TextAnalyzer() |
|
|
|
|
|
self.video_analyzer = VideoAnalyzer() |
|
|
|
|
|
self.search_manager = SearchManager(self.config.get("search", {})) |
|
|
|
|
|
self.memory_manager = MemoryManager(self.config.get("memory", { |
|
"use_supabase": bool(os.getenv("SUPABASE_URL", "")), |
|
"cache_enabled": True |
|
})) |
|
|
|
logger.info("All components initialized successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error initializing components: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
raise RuntimeError(f"Failed to initialize GAIA Agent components: {str(e)}") |
|
|
|
def _build_langgraph(self) -> Optional[StateGraph]: |
|
""" |
|
Build and return the LangGraph workflow. |
|
|
|
Returns: |
|
StateGraph or None if LangGraph is unavailable |
|
""" |
|
if not LANGGRAPH_AVAILABLE: |
|
return None |
|
|
|
try: |
|
from src.gaia.agent.graph import build_agent_graph |
|
return build_agent_graph() |
|
except Exception as e: |
|
logger.error(f"Error building LangGraph: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
return None |
|
|
|
def _detect_question_type(self, question: str) -> str: |
|
""" |
|
Detect the type of question to determine appropriate handling. |
|
|
|
Args: |
|
question: The question to analyze |
|
|
|
Returns: |
|
str: Question type identifier |
|
""" |
|
question_lower = question.lower() |
|
|
|
|
|
if self.text_analyzer.is_reversed_text(question): |
|
return "reversed_text" |
|
|
|
|
|
if re.search(r'\b[A-Z]{4,}\b', question): |
|
return "unscramble_word" |
|
|
|
|
|
if "youtube.com/watch" in question_lower or "youtu.be/" in question_lower: |
|
return "youtube_video" |
|
|
|
|
|
if "image" in question_lower and ("analyze" in question_lower or "what" in question_lower or "describe" in question_lower): |
|
return "image_analysis" |
|
|
|
|
|
if ".mp3" in question_lower or "audio" in question_lower or "recording" in question_lower: |
|
return "audio_analysis" |
|
|
|
|
|
if "chess" in question_lower and "position" in question_lower: |
|
return "chess_analysis" |
|
|
|
|
|
if re.search(r'(\d+\s*[\+\-\*\/]\s*\d+)', question_lower) or "calculate" in question_lower: |
|
return "math_question" |
|
|
|
|
|
return "general_knowledge" |
|
|
|
def process_question(self, question: str) -> str: |
|
""" |
|
Process a question using appropriate components based on question type. |
|
|
|
Args: |
|
question: The question to process |
|
|
|
Returns: |
|
str: The generated answer |
|
""" |
|
start_time = time.time() |
|
logger.info(f"Processing question: {question}") |
|
|
|
try: |
|
|
|
cached_answer = self.memory_manager.get_cached_answer(question) |
|
if cached_answer: |
|
logger.info("Retrieved answer from cache") |
|
|
|
|
|
self.state["last_question"] = question |
|
self.state["last_answer"] = cached_answer |
|
self.state["last_execution_time"] = time.time() - start_time |
|
|
|
return cached_answer |
|
|
|
|
|
question_type = self._detect_question_type(question) |
|
logger.info(f"Detected question type: {question_type}") |
|
|
|
|
|
answer = None |
|
|
|
|
|
if question_type == "reversed_text": |
|
logger.info("Processing reversed text question") |
|
text_analysis = self.text_analyzer.process_text_question(question) |
|
|
|
if text_analysis.get("answer"): |
|
answer = text_analysis["answer"] |
|
else: |
|
|
|
logger.info("Specialized handling failed, trying general processing") |
|
answer = self._process_with_langgraph(question) |
|
|
|
|
|
elif question_type == "unscramble_word": |
|
logger.info("Processing word unscrambling question") |
|
text_analysis = self.text_analyzer.process_text_question(question) |
|
|
|
if text_analysis.get("answer"): |
|
answer = text_analysis["answer"] |
|
else: |
|
|
|
logger.info("Specialized handling failed, trying general processing") |
|
answer = self._process_with_langgraph(question) |
|
|
|
|
|
elif question_type == "youtube_video": |
|
logger.info("Processing YouTube video question") |
|
|
|
|
|
video_url_match = re.search(r'((?:https?:\/\/)?(?:www\.)?(?:youtube\.com\/watch\?v=|youtu\.be\/)[a-zA-Z0-9_-]+)', question) |
|
|
|
if video_url_match: |
|
video_url = video_url_match.group(1) |
|
video_analysis = self.video_analyzer.analyze_video_content(video_url, question) |
|
|
|
if video_analysis.get("answer"): |
|
answer = video_analysis["answer"] |
|
else: |
|
|
|
logger.info("Video analysis failed, trying general processing") |
|
answer = self._process_with_langgraph(question) |
|
else: |
|
|
|
logger.warning("No YouTube URL found in question") |
|
answer = "I couldn't find a YouTube video URL in your question. Please provide a valid YouTube link for analysis." |
|
|
|
|
|
elif question_type == "audio_analysis": |
|
logger.info("Processing audio analysis question") |
|
|
|
|
|
|
|
answer = self._process_with_langgraph(question) |
|
|
|
|
|
elif question_type in ["image_analysis", "chess_analysis"]: |
|
logger.info(f"Processing {question_type} question") |
|
|
|
|
|
answer = self._process_with_langgraph(question) |
|
|
|
|
|
elif question_type == "math_question": |
|
logger.info("Processing math question") |
|
|
|
|
|
|
|
expression_match = re.search(r'(\d+)\s*([\+\-\*\/])\s*(\d+)', question) |
|
if expression_match: |
|
try: |
|
num1 = int(expression_match.group(1)) |
|
operator = expression_match.group(2) |
|
num2 = int(expression_match.group(3)) |
|
|
|
result = None |
|
if operator == '+': |
|
result = num1 + num2 |
|
elif operator == '-': |
|
result = num1 - num2 |
|
elif operator == '*': |
|
result = num1 * num2 |
|
elif operator == '/' and num2 != 0: |
|
result = num1 / num2 |
|
|
|
if result is not None: |
|
answer = f"The result of {num1} {operator} {num2} is {result}." |
|
else: |
|
|
|
answer = self._process_with_langgraph(question) |
|
except Exception: |
|
|
|
answer = self._process_with_langgraph(question) |
|
else: |
|
|
|
answer = self._process_with_langgraph(question) |
|
|
|
|
|
else: |
|
logger.info("Processing general knowledge question") |
|
answer = self._process_with_langgraph(question) |
|
|
|
|
|
if not answer: |
|
logger.warning("LangGraph processing failed, using search fallback") |
|
search_result = self.search_manager.search(question) |
|
answer = search_result.get("answer", "I couldn't find a specific answer to your question.") |
|
|
|
|
|
self.memory_manager.cache_question_answer(question, answer) |
|
|
|
|
|
formatted_answer = format_answer_by_type(answer, question) |
|
|
|
|
|
self.state["last_question"] = question |
|
self.state["last_answer"] = formatted_answer |
|
self.state["last_execution_time"] = time.time() - start_time |
|
|
|
logger.info(f"Question processed in {time.time() - start_time:.2f} seconds") |
|
logger.debug(f"Original answer: {answer}") |
|
logger.debug(f"Formatted answer: {formatted_answer}") |
|
return formatted_answer |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing question: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
|
|
|
|
error_msg = f"Error processing the question. Please try rephrasing it." |
|
|
|
|
|
if self.verbose: |
|
error_msg = f"Error: {str(e)}" |
|
|
|
return error_msg |
|
|
|
def _process_with_langgraph(self, question: str) -> Optional[str]: |
|
""" |
|
Process a question using the LangGraph workflow. |
|
|
|
Args: |
|
question: The question to process |
|
|
|
Returns: |
|
str or None: Generated answer or None if processing failed |
|
""" |
|
if not self.graph: |
|
logger.warning("LangGraph not available, using search fallback") |
|
search_result = self.search_manager.search(question) |
|
return search_result.get("answer") |
|
|
|
try: |
|
logger.info("Processing with LangGraph workflow") |
|
|
|
|
|
input_state = { |
|
"question": question, |
|
"tools": get_tools(), |
|
"thoughts": [], |
|
"messages": [], |
|
"answer": None, |
|
"tool_results": {} |
|
} |
|
|
|
|
|
result = self.graph.invoke(input_state) |
|
|
|
if result and "answer" in result: |
|
answer = result["answer"] |
|
|
|
formatted_answer = format_answer_by_type(answer, question) |
|
logger.info("Successfully processed with LangGraph") |
|
logger.debug(f"Original LangGraph answer: {answer}") |
|
logger.debug(f"Formatted LangGraph answer: {formatted_answer}") |
|
return formatted_answer |
|
else: |
|
logger.warning("LangGraph processing did not produce an answer") |
|
return None |
|
|
|
except Exception as e: |
|
logger.error(f"Error in LangGraph processing: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
return None |
|
|
|
def run(self, input_data: Union[Dict[str, Any], str]) -> str: |
|
""" |
|
Run the agent on the provided input data. |
|
|
|
Args: |
|
input_data: Either a dictionary containing the question or the question string directly |
|
|
|
Returns: |
|
str: Generated answer |
|
""" |
|
|
|
if isinstance(input_data, str): |
|
question = input_data |
|
else: |
|
|
|
question = input_data.get("question", "") |
|
|
|
if not question: |
|
return "No question provided. Please provide a question to get a response." |
|
|
|
return self.process_question(question) |
|
|
|
def get_state(self) -> Dict[str, Any]: |
|
""" |
|
Get the current state of the agent. |
|
|
|
Returns: |
|
dict: Current agent state |
|
""" |
|
return self.state.copy() |
|
|
|
def reset(self) -> None: |
|
"""Reset the agent state.""" |
|
logger.info("Resetting agent state") |
|
|
|
|
|
self.state = { |
|
"initialized": True, |
|
"last_question": None, |
|
"last_answer": None, |
|
"last_execution_time": None |
|
} |
|
|
|
|
|
if self.config.get("clear_cache_on_reset", False): |
|
self.memory_manager.clear_cache() |