|
"""
|
|
Simplified LangGraph-based GAIA Agent Implementation
|
|
|
|
This module provides a streamlined implementation of the GAIA agent using LangGraph
|
|
for workflow management. It has been designed to be robust, maintainable, and
|
|
directly usable in the Huggingface Space environment.
|
|
|
|
Key features:
|
|
- Direct tool integration
|
|
- Simplified prompt construction
|
|
- Clear execution flow
|
|
- Robust error handling
|
|
- Fallback mechanisms for critical components
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import os
|
|
import json
|
|
import re
|
|
import traceback
|
|
import hashlib
|
|
from typing import Dict, Any, List, Optional, Union, Tuple, Literal, TypedDict
|
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langgraph.graph import StateGraph, END
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger("gaia_agent")
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
|
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
|
|
PERPLEXITY_API_KEY = os.getenv("PERPLEXITY_API_KEY", "")
|
|
SERPER_API_URL = os.getenv("SERPER_API_URL", "https://google.serper.dev/search")
|
|
SUPABASE_URL = os.getenv("SUPABASE_URL", "")
|
|
SUPABASE_KEY = os.getenv("SUPABASE_KEY", "")
|
|
USER_AGENT = os.getenv("USER_AGENT", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36")
|
|
|
|
|
|
if not OPENAI_API_KEY:
|
|
logger.warning("OPENAI_API_KEY is not set. Agent will use fallback mode with limited capabilities.")
|
|
if not SERPER_API_KEY and not PERPLEXITY_API_KEY:
|
|
logger.warning("Neither SERPER_API_KEY nor PERPLEXITY_API_KEY is set. Web search capabilities will be limited.")
|
|
elif not SERPER_API_KEY:
|
|
logger.warning("SERPER_API_KEY is not set. Will attempt to use Perplexity for search if available.")
|
|
elif not PERPLEXITY_API_KEY:
|
|
logger.warning("PERPLEXITY_API_KEY is not set. Will use Serper for search capabilities.")
|
|
if not SUPABASE_URL or not SUPABASE_KEY:
|
|
logger.warning("SUPABASE_URL or SUPABASE_KEY is not set. Memory persistence will be limited to in-memory storage.")
|
|
|
|
|
|
try:
|
|
from duckduckgo_search import DDGS
|
|
DDGS_AVAILABLE = True
|
|
except ImportError:
|
|
DDGS_AVAILABLE = False
|
|
logger.warning("DuckDuckGo search package not available. Some search features will be limited.")
|
|
|
|
try:
|
|
import requests
|
|
from bs4 import BeautifulSoup
|
|
WEB_TOOLS_AVAILABLE = True
|
|
except ImportError:
|
|
WEB_TOOLS_AVAILABLE = False
|
|
logger.warning("Web tools dependencies not available. Web content extraction will be limited.")
|
|
|
|
|
|
class AgentState(TypedDict):
|
|
"""Type for agent state."""
|
|
question: str
|
|
analysis: Optional[Dict[str, Any]]
|
|
plan: Optional[List[Dict[str, Any]]]
|
|
current_step: Optional[int]
|
|
tool_results: List[Dict[str, Any]]
|
|
reasoning: Optional[str]
|
|
answer: Optional[str]
|
|
error: Optional[str]
|
|
|
|
|
|
class SimpleMemory:
|
|
"""Simple in-memory storage for conversation history and results"""
|
|
|
|
def __init__(self):
|
|
self.conversations = {}
|
|
self.result_cache = {}
|
|
|
|
def add_conversation(self, session_id: str, role: str, content: str):
|
|
"""Add a message to the conversation history"""
|
|
if session_id not in self.conversations:
|
|
self.conversations[session_id] = []
|
|
|
|
self.conversations[session_id].append({
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": time.time()
|
|
})
|
|
|
|
def get_conversation(self, session_id: str, max_messages: int = 10) -> List[Dict[str, Any]]:
|
|
"""Get the conversation history for a session"""
|
|
if session_id not in self.conversations:
|
|
return []
|
|
|
|
|
|
return self.conversations[session_id][-max_messages:]
|
|
|
|
def cache_result(self, key: str, value: Any):
|
|
"""Store a result in the cache"""
|
|
self.result_cache[key] = {
|
|
"value": value,
|
|
"timestamp": time.time()
|
|
}
|
|
|
|
def get_cached_result(self, key: str, max_age_seconds: int = 3600) -> Optional[Any]:
|
|
"""Get a result from the cache if it exists and is not too old"""
|
|
if key not in self.result_cache:
|
|
return None
|
|
|
|
cache_entry = self.result_cache[key]
|
|
age = time.time() - cache_entry["timestamp"]
|
|
|
|
if age > max_age_seconds:
|
|
|
|
return None
|
|
|
|
return cache_entry["value"]
|
|
|
|
def clear(self, session_id: Optional[str] = None):
|
|
"""Clear memory for a session or all sessions if not specified"""
|
|
if session_id:
|
|
if session_id in self.conversations:
|
|
del self.conversations[session_id]
|
|
else:
|
|
self.conversations = {}
|
|
self.result_cache = {}
|
|
|
|
|
|
class WebSearchTool:
|
|
"""Tool for searching the web using available search engines"""
|
|
|
|
def __init__(self):
|
|
self.result_count = 5
|
|
self.timeout = 10
|
|
|
|
def search(self, query: str) -> List[Dict[str, Any]]:
|
|
"""Search using the best available search method"""
|
|
|
|
if DDGS_AVAILABLE:
|
|
results = self._search_duckduckgo(query)
|
|
if results:
|
|
return results
|
|
|
|
|
|
if SERPER_API_KEY and WEB_TOOLS_AVAILABLE:
|
|
results = self._search_serper(query)
|
|
if results:
|
|
return results
|
|
|
|
|
|
logger.warning("All search methods failed or unavailable")
|
|
return [
|
|
{
|
|
"title": "Search Unavailable",
|
|
"link": "",
|
|
"snippet": "Search functionality is currently unavailable. Please ensure that either DuckDuckGo package is installed or SERPER_API_KEY is set."
|
|
}
|
|
]
|
|
|
|
def _search_duckduckgo(self, query: str) -> List[Dict[str, Any]]:
|
|
"""Search using DuckDuckGo"""
|
|
if not DDGS_AVAILABLE:
|
|
logger.warning("DuckDuckGo package not available")
|
|
return []
|
|
|
|
try:
|
|
results = []
|
|
with DDGS() as ddgs:
|
|
ddg_results = list(ddgs.text(
|
|
query,
|
|
max_results=self.result_count,
|
|
timelimit=self.timeout
|
|
))
|
|
|
|
for result in ddg_results:
|
|
results.append({
|
|
"title": result.get("title", ""),
|
|
"link": result.get("href", ""),
|
|
"snippet": result.get("body", "")
|
|
})
|
|
|
|
return results
|
|
except Exception as e:
|
|
logger.error(f"Error searching DuckDuckGo: {str(e)}")
|
|
return []
|
|
|
|
def _search_serper(self, query: str) -> List[Dict[str, Any]]:
|
|
"""Search using Serper API if available"""
|
|
if not SERPER_API_KEY:
|
|
logger.warning("Serper API key not set")
|
|
return []
|
|
|
|
if not WEB_TOOLS_AVAILABLE:
|
|
logger.warning("Web tools not available")
|
|
return []
|
|
|
|
try:
|
|
headers = {
|
|
"X-API-KEY": SERPER_API_KEY,
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
payload = {
|
|
"q": query,
|
|
"num": self.result_count
|
|
}
|
|
|
|
response = requests.post(
|
|
SERPER_API_URL,
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=self.timeout
|
|
)
|
|
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
results = []
|
|
for result in data.get("organic", []):
|
|
results.append({
|
|
"title": result.get("title", ""),
|
|
"link": result.get("link", ""),
|
|
"snippet": result.get("snippet", "")
|
|
})
|
|
|
|
return results
|
|
except Exception as e:
|
|
logger.error(f"Error searching with Serper: {str(e)}")
|
|
return []
|
|
|
|
|
|
class ContentExtractor:
|
|
"""Tool for extracting content from web pages"""
|
|
|
|
def __init__(self):
|
|
self.timeout = 10
|
|
self.max_content_length = 8000
|
|
|
|
def extract_content(self, url: str) -> Dict[str, Any]:
|
|
"""Extract content from a web page"""
|
|
if not WEB_TOOLS_AVAILABLE:
|
|
logger.warning("Web tools not available for content extraction")
|
|
return {
|
|
"url": url,
|
|
"title": "Content Extraction Unavailable",
|
|
"content": "Web content extraction is currently unavailable. Please ensure that requests and BeautifulSoup packages are installed.",
|
|
"success": False,
|
|
"error": "Web tools dependencies not available"
|
|
}
|
|
|
|
try:
|
|
headers = {"User-Agent": USER_AGENT}
|
|
response = requests.get(url, headers=headers, timeout=self.timeout)
|
|
response.raise_for_status()
|
|
|
|
soup = BeautifulSoup(response.text, "html.parser")
|
|
|
|
|
|
title = soup.title.string if soup.title else ""
|
|
|
|
|
|
for script in soup(["script", "style"]):
|
|
script.extract()
|
|
|
|
|
|
text = soup.get_text()
|
|
lines = (line.strip() for line in text.splitlines())
|
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
|
text = "\n".join(chunk for chunk in chunks if chunk)
|
|
|
|
|
|
if len(text) > self.max_content_length:
|
|
text = text[:self.max_content_length] + "..."
|
|
|
|
return {
|
|
"url": url,
|
|
"title": title,
|
|
"content": text,
|
|
"success": True
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error extracting content from {url}: {str(e)}")
|
|
return {
|
|
"url": url,
|
|
"error": str(e),
|
|
"success": False
|
|
}
|
|
|
|
|
|
def analyze_question(state: AgentState, llm) -> AgentState:
|
|
"""Analyze the question to determine type and needs"""
|
|
try:
|
|
template = """
|
|
You're an expert at analyzing questions. Examine this question and provide an analysis in JSON format:
|
|
|
|
Question: {question}
|
|
|
|
Your analysis should include:
|
|
1. question_type: The type of question (factual, how-to, analytical, etc.)
|
|
2. complexity: A rating from 1-5 of how complex the question is
|
|
3. required_tools: List of tools that would help answer this question (web_search, content_extraction, etc.)
|
|
4. information_sources: Likely sources of information for this answer (web, academic papers, etc.)
|
|
|
|
Format your response as valid JSON.
|
|
"""
|
|
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
("system", template),
|
|
("human", "{question}")
|
|
])
|
|
|
|
chain = prompt | llm | StrOutputParser()
|
|
|
|
analysis_text = chain.invoke({"question": state["question"]})
|
|
|
|
|
|
try:
|
|
analysis = json.loads(analysis_text)
|
|
except json.JSONDecodeError:
|
|
|
|
match = re.search(r'\{.*\}', analysis_text, re.DOTALL)
|
|
if match:
|
|
try:
|
|
analysis = json.loads(match.group(0))
|
|
except:
|
|
|
|
analysis = {
|
|
"question_type": "factual",
|
|
"complexity": 3,
|
|
"required_tools": ["web_search"],
|
|
"information_sources": ["web"]
|
|
}
|
|
else:
|
|
|
|
analysis = {
|
|
"question_type": "factual",
|
|
"complexity": 3,
|
|
"required_tools": ["web_search"],
|
|
"information_sources": ["web"]
|
|
}
|
|
|
|
return {
|
|
**state,
|
|
"analysis": analysis
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error analyzing question: {str(e)}")
|
|
return {
|
|
**state,
|
|
"analysis": {
|
|
"question_type": "factual",
|
|
"complexity": 3,
|
|
"required_tools": ["web_search"],
|
|
"information_sources": ["web"]
|
|
},
|
|
"error": f"Error during question analysis: {str(e)}"
|
|
}
|
|
|
|
def create_plan(state: AgentState, llm) -> AgentState:
|
|
"""Create a plan for answering the question"""
|
|
try:
|
|
template = """
|
|
You're an expert planner for answering questions. Based on the analysis, create a step-by-step plan for answering this question.
|
|
|
|
Question: {question}
|
|
Analysis: {analysis}
|
|
|
|
Format your response as a JSON list of steps, where each step has:
|
|
1. step_number: Sequential number of the step
|
|
2. description: What should be done
|
|
3. tool: Tool to use (web_search, content_extraction, or null if no tool needed)
|
|
4. tool_input: Parameters for the tool (e.g., search query or URL)
|
|
|
|
Ensure your response is valid JSON.
|
|
"""
|
|
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
("system", template),
|
|
("human", "Create a plan for answering this question.")
|
|
])
|
|
|
|
chain = prompt | llm | StrOutputParser()
|
|
|
|
plan_text = chain.invoke({
|
|
"question": state["question"],
|
|
"analysis": json.dumps(state["analysis"])
|
|
})
|
|
|
|
|
|
try:
|
|
plan = json.loads(plan_text)
|
|
if not isinstance(plan, list):
|
|
raise ValueError("Plan must be a list")
|
|
except (json.JSONDecodeError, ValueError):
|
|
|
|
match = re.search(r'\[.*\]', plan_text, re.DOTALL)
|
|
if match:
|
|
try:
|
|
plan = json.loads(match.group(0))
|
|
except:
|
|
|
|
plan = create_fallback_plan(state)
|
|
else:
|
|
plan = create_fallback_plan(state)
|
|
|
|
return {
|
|
**state,
|
|
"plan": plan,
|
|
"current_step": 0,
|
|
"tool_results": []
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error creating plan: {str(e)}")
|
|
return {
|
|
**state,
|
|
"plan": create_fallback_plan(state),
|
|
"current_step": 0,
|
|
"tool_results": [],
|
|
"error": f"Error during plan creation: {str(e)}"
|
|
}
|
|
|
|
def create_fallback_plan(state: AgentState) -> List[Dict[str, Any]]:
|
|
"""Create a simple fallback plan when the main planning fails"""
|
|
tools = state.get("analysis", {}).get("required_tools", ["web_search"])
|
|
|
|
plan = [
|
|
{
|
|
"step_number": 1,
|
|
"description": "Search for information about the question",
|
|
"tool": "web_search",
|
|
"tool_input": {"query": state["question"]}
|
|
},
|
|
{
|
|
"step_number": 2,
|
|
"description": "Formulate an answer based on search results",
|
|
"tool": None,
|
|
"tool_input": None
|
|
}
|
|
]
|
|
|
|
|
|
if "web_search" in tools:
|
|
plan.insert(1, {
|
|
"step_number": 2,
|
|
"description": "Extract content from the most relevant search result URL if available",
|
|
"tool": "content_extraction",
|
|
"tool_input": {"url_from_search_results": True}
|
|
})
|
|
|
|
plan[-1]["step_number"] = 3
|
|
|
|
return plan
|
|
|
|
def execute_tool(state: AgentState) -> AgentState:
|
|
"""Execute the current tool in the plan"""
|
|
|
|
current_step = state.get("current_step", 0)
|
|
plan = state.get("plan", [])
|
|
tool_name = "unknown"
|
|
tool_input = {}
|
|
|
|
try:
|
|
|
|
if not isinstance(state, dict):
|
|
raise ValueError(f"Invalid state type: {type(state)}. Expected dict.")
|
|
|
|
|
|
if current_step >= len(plan):
|
|
logger.info("Execute tool: reached end of plan")
|
|
return {
|
|
**state,
|
|
"current_step": current_step + 1
|
|
}
|
|
|
|
|
|
step = plan[current_step]
|
|
if not isinstance(step, dict):
|
|
logger.error(f"Invalid step format at position {current_step}: {type(step)}")
|
|
raise ValueError(f"Invalid step format at position {current_step}")
|
|
|
|
tool_name = step.get("tool")
|
|
tool_input = step.get("tool_input", {})
|
|
|
|
|
|
if not tool_name:
|
|
logger.info(f"No tool specified for step {current_step}, skipping")
|
|
return {
|
|
**state,
|
|
"current_step": current_step + 1
|
|
}
|
|
|
|
logger.info(f"Executing tool '{tool_name}' for step {current_step}")
|
|
|
|
|
|
result = {"tool_name": tool_name, "success": False, "error": None}
|
|
|
|
if tool_name == "web_search":
|
|
try:
|
|
|
|
query = tool_input.get("query", state["question"])
|
|
if not query or not isinstance(query, str):
|
|
raise ValueError("Invalid search query: must be a non-empty string")
|
|
|
|
|
|
if len(query) > 500:
|
|
logger.warning(f"Search query too long ({len(query)} chars), truncating to 500 chars")
|
|
query = query[:497] + "..."
|
|
|
|
search_tool = WebSearchTool()
|
|
search_results = search_tool.search(query)
|
|
|
|
|
|
if not isinstance(search_results, list):
|
|
logger.warning(f"Invalid search results type: {type(search_results)}")
|
|
search_results = []
|
|
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": len(search_results) > 0,
|
|
"query": query,
|
|
"results": search_results,
|
|
"error": None if search_results else "No search results found"
|
|
}
|
|
|
|
|
|
if any("API key" in result.get("title", "") or "API key" in result.get("snippet", "")
|
|
for result in search_results):
|
|
logger.error("Search results indicate API key issue")
|
|
result["error"] = "Search API key error detected in results"
|
|
result["success"] = False
|
|
|
|
except ConnectionError as conn_err:
|
|
logger.error(f"Connection error in web search: {str(conn_err)}")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"query": tool_input.get("query", state["question"]),
|
|
"results": [],
|
|
"error": f"Connection error: {str(conn_err)}"
|
|
}
|
|
except TimeoutError as timeout_err:
|
|
logger.error(f"Timeout error in web search: {str(timeout_err)}")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"query": tool_input.get("query", state["question"]),
|
|
"results": [],
|
|
"error": f"Search timed out: {str(timeout_err)}"
|
|
}
|
|
except Exception as search_err:
|
|
logger.error(f"Error in web search: {str(search_err)}")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"query": tool_input.get("query", state["question"]),
|
|
"results": [],
|
|
"error": f"Search error: {str(search_err)}"
|
|
}
|
|
|
|
elif tool_name == "content_extraction":
|
|
try:
|
|
|
|
url = tool_input.get("url")
|
|
|
|
|
|
if not url and tool_input.get("url_from_search_results", False):
|
|
|
|
for past_result in reversed(state.get("tool_results", [])):
|
|
if past_result.get("tool_name") == "web_search" and past_result.get("success"):
|
|
search_results = past_result.get("results", [])
|
|
if search_results:
|
|
url = search_results[0].get("link")
|
|
break
|
|
|
|
|
|
if not url or not isinstance(url, str):
|
|
logger.warning("No valid URL found for content extraction")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"error": "No valid URL provided or found in search results"
|
|
}
|
|
elif not url.startswith(("http://", "https://")):
|
|
logger.warning(f"Invalid URL format: {url}")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"url": url,
|
|
"error": "Invalid URL format: URL must start with http:// or https://"
|
|
}
|
|
else:
|
|
extractor = ContentExtractor()
|
|
content = extractor.extract_content(url)
|
|
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": content.get("success", False),
|
|
"url": url,
|
|
"content": content,
|
|
"error": content.get("error")
|
|
}
|
|
except ConnectionError as conn_err:
|
|
logger.error(f"Connection error in content extraction: {str(conn_err)}")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"url": tool_input.get("url", "unknown"),
|
|
"error": f"Connection error during content extraction: {str(conn_err)}"
|
|
}
|
|
except TimeoutError as timeout_err:
|
|
logger.error(f"Timeout error in content extraction: {str(timeout_err)}")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"url": tool_input.get("url", "unknown"),
|
|
"error": f"Content extraction timed out: {str(timeout_err)}"
|
|
}
|
|
except Exception as extract_err:
|
|
logger.error(f"Error in content extraction: {str(extract_err)}")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"url": tool_input.get("url", "unknown"),
|
|
"error": f"Content extraction error: {str(extract_err)}"
|
|
}
|
|
|
|
else:
|
|
|
|
logger.warning(f"Unknown tool requested: {tool_name}")
|
|
result = {
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"error": f"Unknown tool: {tool_name}"
|
|
}
|
|
|
|
|
|
tool_results = state.get("tool_results", []) or []
|
|
|
|
|
|
if result.get("success"):
|
|
logger.info(f"Tool '{tool_name}' executed successfully")
|
|
else:
|
|
logger.warning(f"Tool '{tool_name}' execution failed: {result.get('error')}")
|
|
|
|
return {
|
|
**state,
|
|
"tool_results": tool_results + [result],
|
|
"current_step": current_step + 1
|
|
}
|
|
|
|
except Exception as e:
|
|
error_type = type(e).__name__
|
|
logger.error(f"Error executing tool '{tool_name}': {error_type}: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
|
tool_results = state.get("tool_results", []) or []
|
|
|
|
|
|
error_message = str(e)
|
|
if "ConnectionError" in error_type or "requests.exceptions" in error_type:
|
|
error_message = f"Connection error during tool execution: {str(e)}. This might be due to network issues or the service being unavailable."
|
|
elif "TimeoutError" in error_type:
|
|
error_message = f"Tool execution timed out: {str(e)}. The operation took too long to complete."
|
|
elif "JSONDecodeError" in error_type:
|
|
error_message = f"Error parsing response data: {str(e)}. The service returned an unexpected format."
|
|
elif "KeyError" in error_type or "AttributeError" in error_type:
|
|
error_message = f"Missing or invalid data during tool execution: {str(e)}. This might be due to incomplete or malformed data."
|
|
elif "AuthenticationError" in error_type or "api key" in str(e).lower():
|
|
error_message = f"Authentication error during tool execution: {str(e)}. This might be due to invalid API credentials."
|
|
|
|
return {
|
|
**state,
|
|
"tool_results": tool_results + [{
|
|
"tool_name": tool_name,
|
|
"success": False,
|
|
"error": error_message,
|
|
"error_type": error_type
|
|
}],
|
|
"current_step": current_step + 1,
|
|
"error": f"Error during tool execution: {error_message}"
|
|
}
|
|
"""
|
|
GAIA (Grounded AI Assistant) agent with web search and content extraction capabilities.
|
|
This class provides a simplified interface for the app.py file to interact with.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the GAIA agent with simplified configuration"""
|
|
self.memory = SimpleMemory()
|
|
logger.info("GAIA Agent initialized")
|
|
|
|
def __call__(self, question: str) -> str:
|
|
"""
|
|
Process a question and generate an answer.
|
|
Compatible with the interface expected by app.py.
|
|
|
|
Args:
|
|
question (str): The question to process
|
|
|
|
Returns:
|
|
str: The answer to the question
|
|
"""
|
|
return self.process_question(question)
|
|
|
|
def process_question(self, question: str) -> str:
|
|
"""
|
|
Process a question and generate an answer.
|
|
|
|
Args:
|
|
question (str): The question to process
|
|
|
|
Returns:
|
|
str: The answer to the question
|
|
"""
|
|
|
|
cache_key = f"question_{hashlib.md5(question.encode()).hexdigest()}"
|
|
|
|
|
|
cached_answer = self.memory.get_cached_result(cache_key)
|
|
if cached_answer:
|
|
logger.info(f"Using cached answer for question: {question[:50]}...")
|
|
return cached_answer
|
|
|
|
try:
|
|
|
|
if OPENAI_API_KEY:
|
|
llm = ChatOpenAI(
|
|
temperature=0,
|
|
model="gpt-3.5-turbo",
|
|
api_key=OPENAI_API_KEY
|
|
)
|
|
|
|
|
|
answer = self._process_with_langgraph(question, llm)
|
|
else:
|
|
|
|
logger.warning("Using fallback mode (no OpenAI API key provided)")
|
|
answer = self._fallback_processing(question)
|
|
|
|
|
|
self.memory.cache_result(cache_key, answer)
|
|
|
|
return answer
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing question: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
return f"I apologize, but I encountered an error while processing your question: {str(e)}"
|
|
|
|
def _process_with_langgraph(self, question: str, llm) -> str:
|
|
"""Process question using LangGraph workflow"""
|
|
try:
|
|
|
|
def should_continue(state: AgentState) -> Literal["continue", "complete"]:
|
|
"""Determine if the agent should continue or is finished"""
|
|
current_step = state.get("current_step", 0)
|
|
plan = state.get("plan", [])
|
|
|
|
|
|
if current_step is None or plan is None:
|
|
return "complete"
|
|
|
|
if current_step < len(plan):
|
|
return "continue"
|
|
else:
|
|
return "complete"
|
|
|
|
|
|
workflow = StateGraph(AgentState)
|
|
|
|
|
|
workflow.add_node("analyze", analyze_question)
|
|
workflow.add_node("create_plan", create_plan)
|
|
workflow.add_node("execute_tool", execute_tool)
|
|
workflow.add_node("formulate_answer", formulate_answer)
|
|
|
|
|
|
workflow.add_edge("analyze", "create_plan")
|
|
workflow.add_edge("create_plan", "execute_tool")
|
|
workflow.add_edge("execute_tool", should_continue)
|
|
workflow.add_conditional_edges(
|
|
"execute_tool",
|
|
should_continue,
|
|
{
|
|
"continue": "execute_tool",
|
|
"complete": "formulate_answer"
|
|
}
|
|
)
|
|
workflow.add_edge("formulate_answer", END)
|
|
|
|
|
|
workflow.set_entry_point("analyze")
|
|
|
|
|
|
app = workflow.compile()
|
|
|
|
|
|
state = {
|
|
"question": question,
|
|
"tool_results": []
|
|
}
|
|
|
|
result = app.invoke({
|
|
**state,
|
|
"llm": llm
|
|
})
|
|
|
|
|
|
if "answer" in result and result["answer"]:
|
|
return result["answer"]
|
|
elif "error" in result and result["error"]:
|
|
return f"I encountered an error: {result['error']}"
|
|
else:
|
|
return "I was unable to generate an answer based on the available information."
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in LangGraph processing: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
return f"I encountered an error while processing your question: {str(e)}"
|
|
|
|
def _fallback_processing(self, question: str) -> str:
|
|
"""Simple fallback implementation when LLM is not available"""
|
|
try:
|
|
|
|
if "how" in question.lower():
|
|
answer = f"To address '{question.strip('?')}', I would recommend following these steps: 1) Understand the core concepts, 2) Apply a structured approach, 3) Evaluate results, and 4) Refine as needed. Without being able to access external knowledge at the moment, this is a general framework for addressing how-to questions."
|
|
elif "what" in question.lower():
|
|
answer = f"Regarding '{question.strip('?')}', this typically involves understanding several key factors. While I don't have access to external knowledge at the moment, this type of question usually requires defining terms, establishing context, and examining relevant concepts."
|
|
elif "why" in question.lower():
|
|
answer = f"The question '{question.strip('?')}' relates to causality and explanation. Such questions typically involve understanding underlying mechanisms, historical context, and logical relationships between factors."
|
|
else:
|
|
|
|
try:
|
|
search_tool = WebSearchTool()
|
|
search_results = search_tool.search(question)
|
|
|
|
if search_results and search_results[0].get("snippet"):
|
|
snippet = search_results[0]["snippet"]
|
|
answer = f"Based on available information: {snippet}\n\nPlease note that without access to a language model, I can only provide this basic search result."
|
|
else:
|
|
answer = f"I'm sorry, but I cannot provide a comprehensive answer to '{question}' at this moment due to limited access to external knowledge and language model capabilities."
|
|
except Exception as search_err:
|
|
logger.error(f"Error in fallback search: {str(search_err)}")
|
|
answer = f"I'm sorry, but I cannot provide a comprehensive answer to '{question}' at this moment due to limited access to external knowledge and language model capabilities."
|
|
|
|
return answer
|
|
except Exception as e:
|
|
logger.error(f"Error in fallback processing: {str(e)}")
|
|
return f"I apologize, but I'm currently unable to process your question due to system limitations."
|
|
|
|
def query(self, question: str) -> Dict[str, Any]:
|
|
"""
|
|
Query the agent with a question to get an answer with metadata.
|
|
|
|
Args:
|
|
question (str): The question to answer
|
|
|
|
Returns:
|
|
Dict[str, Any]: Dictionary containing the answer and metadata
|
|
"""
|
|
try:
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
answer = self.process_question(question)
|
|
|
|
|
|
processing_time = time.time() - start_time
|
|
|
|
|
|
return {
|
|
"question": question,
|
|
"answer": answer,
|
|
"processing_time": processing_time,
|
|
"timestamp": time.time(),
|
|
"status": "success"
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in query: {str(e)}")
|
|
|
|
return {
|
|
"question": question,
|
|
"answer": f"Error processing query: {str(e)}",
|
|
"processing_time": time.time() - start_time,
|
|
"timestamp": time.time(),
|
|
"status": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
def clear_memory(self):
|
|
"""Clear the agent's memory"""
|
|
self.memory.clear()
|
|
logger.info("Agent memory cleared") |