|
""" |
|
Enhanced GAIA-Ready AI Agent with integrated memory and reasoning systems |
|
|
|
This is the main integration file that combines the agent, memory system, |
|
and reasoning system into a complete solution for the Hugging Face Agents Course. |
|
""" |
|
|
|
import os |
|
import sys |
|
import json |
|
import traceback |
|
from typing import List, Dict, Any, Optional, Union |
|
from datetime import datetime |
|
|
|
|
|
try: |
|
from memory_system import EnhancedMemoryManager |
|
from reasoning_system import ReasoningSystem |
|
except ImportError: |
|
print("Error: Could not import memory_system or reasoning_system modules.") |
|
print("Make sure memory_system.py and reasoning_system.py are in the same directory.") |
|
sys.exit(1) |
|
|
|
|
|
try: |
|
from smolagents import Agent, InferenceClientModel, Tool, LiteLLMModel |
|
except ImportError: |
|
import subprocess |
|
subprocess.check_call(["pip", "install", "smolagents"]) |
|
from smolagents import Agent, InferenceClientModel, Tool |
|
try: |
|
from smolagents import LiteLLMModel |
|
except ImportError: |
|
print("Warning: LiteLLMModel not available, will use InferenceClientModel only.") |
|
|
|
|
|
from agent import ( |
|
web_search_function, |
|
web_page_content_function, |
|
calculator_function, |
|
python_executor_function, |
|
image_analyzer_function, |
|
text_processor_function, |
|
file_manager_function |
|
) |
|
|
|
|
|
class EnhancedGAIAAgent: |
|
""" |
|
Enhanced AI Agent designed to perform well on the GAIA benchmark |
|
Integrates memory and reasoning systems with the Think-Act-Observe workflow |
|
""" |
|
def __init__(self, api_key=None, use_local_model=False, use_semantic_memory=True): |
|
""" |
|
Initialize the enhanced GAIA agent |
|
|
|
Args: |
|
api_key: API key for Hugging Face Inference API |
|
use_local_model: Whether to use a local model via Ollama |
|
use_semantic_memory: Whether to use semantic search for memory retrieval |
|
""" |
|
|
|
self.memory_manager = EnhancedMemoryManager(use_semantic_search=use_semantic_memory) |
|
|
|
|
|
if use_local_model: |
|
|
|
try: |
|
self.model = LiteLLMModel( |
|
model_id="ollama_chat/qwen2:7b", |
|
api_base="http://127.0.0.1:11434", |
|
num_ctx=8192, |
|
) |
|
print("Using local Ollama model: qwen2:7b") |
|
except Exception as e: |
|
print(f"Error initializing local model: {str(e)}") |
|
print("Falling back to Hugging Face Inference API") |
|
self.model = InferenceClientModel( |
|
model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
api_key=api_key or os.environ.get("HF_API_KEY", "") |
|
) |
|
print("Using Hugging Face Inference API model: Mixtral-8x7B") |
|
else: |
|
|
|
self.model = InferenceClientModel( |
|
model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
api_key=api_key or os.environ.get("HF_API_KEY", "") |
|
) |
|
print("Using Hugging Face Inference API model: Mixtral-8x7B") |
|
|
|
|
|
self.tools = [ |
|
Tool( |
|
name="web_search", |
|
description="Search the web for information", |
|
function=web_search_function |
|
), |
|
Tool( |
|
name="web_page_content", |
|
description="Fetch and extract content from a web page", |
|
function=web_page_content_function |
|
), |
|
Tool( |
|
name="calculator", |
|
description="Perform mathematical calculations", |
|
function=calculator_function |
|
), |
|
Tool( |
|
name="image_analyzer", |
|
description="Analyze image content", |
|
function=image_analyzer_function |
|
), |
|
Tool( |
|
name="python_executor", |
|
description="Execute Python code", |
|
function=python_executor_function |
|
), |
|
Tool( |
|
name="text_processor", |
|
description="Process and analyze text", |
|
function=text_processor_function |
|
), |
|
Tool( |
|
name="file_manager", |
|
description="Save and load data from files", |
|
function=file_manager_function |
|
) |
|
] |
|
|
|
|
|
self.system_prompt = """ |
|
You are an advanced AI assistant designed to solve complex tasks from the GAIA benchmark. |
|
You have access to various tools that can help you solve these tasks. |
|
|
|
Always follow the Think-Act-Observe workflow: |
|
1. Think: Carefully analyze the task and plan your approach |
|
- Break down complex tasks into smaller steps |
|
- Consider what information you need and how to get it |
|
- Plan your approach before taking action |
|
|
|
2. Act: Use appropriate tools to gather information or perform actions |
|
- web_search: Search the web for information |
|
- web_page_content: Extract content from specific web pages |
|
- calculator: Perform mathematical calculations |
|
- image_analyzer: Analyze image content |
|
- python_executor: Run Python code for complex operations |
|
- text_processor: Process and analyze text (summarize, analyze_sentiment, extract_keywords) |
|
- file_manager: Save and load data from files (save, load) |
|
|
|
3. Observe: Analyze the results of your actions and adjust your approach |
|
- Verify if the information answers the original question |
|
- Identify any gaps or inconsistencies |
|
- Determine if additional actions are needed |
|
|
|
For complex tasks: |
|
- Break them down into smaller, manageable steps |
|
- Keep track of your progress and intermediate results |
|
- Verify each step before moving to the next |
|
- Always double-check your final answer |
|
|
|
When reasoning: |
|
- Be thorough and methodical |
|
- Consider multiple perspectives |
|
- Explain your thought process clearly |
|
- Cite sources when providing factual information |
|
|
|
Remember that the GAIA benchmark tests your ability to: |
|
- Reason effectively about complex problems |
|
- Understand and process multimodal information |
|
- Navigate the web to find information |
|
- Use tools appropriately to solve tasks |
|
|
|
Always verify your answers before submitting them. |
|
""" |
|
|
|
|
|
self.base_agent = Agent( |
|
model=self.model, |
|
tools=self.tools, |
|
system_prompt=self.system_prompt |
|
) |
|
|
|
|
|
self.reasoning_system = ReasoningSystem(self.base_agent, self.memory_manager) |
|
|
|
|
|
self.max_retries = 3 |
|
self.error_log = [] |
|
|
|
def solve(self, query: str, max_iterations: int = 5, verbose: bool = True) -> Dict[str, Any]: |
|
""" |
|
Solve a task using the enhanced Think-Act-Observe workflow |
|
|
|
Args: |
|
query: The user's query or task |
|
max_iterations: Maximum number of iterations |
|
verbose: Whether to print detailed progress |
|
|
|
Returns: |
|
Dictionary containing the final answer and metadata |
|
""" |
|
start_time = datetime.now() |
|
|
|
if verbose: |
|
print(f"\n{'='*50}") |
|
print(f"Starting to solve: {query}") |
|
print(f"{'='*50}\n") |
|
|
|
try: |
|
|
|
final_answer = self.reasoning_system.execute_reasoning_cycle(query, max_iterations) |
|
|
|
|
|
execution_time = (datetime.now() - start_time).total_seconds() |
|
|
|
if verbose: |
|
print(f"\n{'='*50}") |
|
print(f"Task completed in {execution_time:.2f} seconds") |
|
print(f"{'='*50}\n") |
|
|
|
|
|
memory_summary = self.memory_manager.get_memory_summary() |
|
|
|
return { |
|
"query": query, |
|
"answer": final_answer, |
|
"execution_time": execution_time, |
|
"iterations": max_iterations, |
|
"memory_summary": memory_summary, |
|
"success": True, |
|
"error": None |
|
} |
|
except Exception as e: |
|
error_msg = f"Error solving task: {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
|
|
|
|
self.error_log.append({ |
|
"timestamp": datetime.now().isoformat(), |
|
"query": query, |
|
"error": str(e), |
|
"traceback": traceback.format_exc() |
|
}) |
|
|
|
|
|
try: |
|
recovery_prompt = f""" |
|
I encountered an error while trying to solve this task: {query} |
|
|
|
The error was: {str(e)} |
|
|
|
Based on what I know so far, please provide the best possible answer or explanation. |
|
If you can't provide a complete answer, explain what you do know and what information is missing. |
|
""" |
|
recovery_answer = self.base_agent.chat(recovery_prompt) |
|
|
|
execution_time = (datetime.now() - start_time).total_seconds() |
|
|
|
if verbose: |
|
print(f"\n{'='*50}") |
|
print(f"Task completed with recovery in {execution_time:.2f} seconds") |
|
print(f"{'='*50}\n") |
|
|
|
return { |
|
"query": query, |
|
"answer": recovery_answer, |
|
"execution_time": execution_time, |
|
"iterations": 0, |
|
"success": False, |
|
"error": str(e), |
|
"recovery": True |
|
} |
|
except Exception as recovery_error: |
|
|
|
return { |
|
"query": query, |
|
"answer": f"I'm sorry, I encountered an error while solving this task and couldn't recover: {str(e)}", |
|
"execution_time": (datetime.now() - start_time).total_seconds(), |
|
"iterations": 0, |
|
"success": False, |
|
"error": str(e), |
|
"recovery_error": str(recovery_error), |
|
"recovery": False |
|
} |
|
|
|
def batch_solve(self, queries: List[str], max_iterations: int = 5, verbose: bool = True) -> List[Dict[str, Any]]: |
|
""" |
|
Solve multiple tasks in batch |
|
|
|
Args: |
|
queries: List of user queries or tasks |
|
max_iterations: Maximum number of iterations per query |
|
verbose: Whether to print detailed progress |
|
|
|
Returns: |
|
List of results for each query |
|
""" |
|
results = [] |
|
|
|
for i, query in enumerate(queries): |
|
if verbose: |
|
print(f"\n{'='*50}") |
|
print(f"Processing task {i+1}/{len(queries)}: {query}") |
|
print(f"{'='*50}\n") |
|
|
|
result = self.solve(query, max_iterations, verbose) |
|
results.append(result) |
|
|
|
|
|
self.memory_manager.clear_working_memory() |
|
|
|
return results |
|
|
|
def save_results(self, results: Union[Dict[str, Any], List[Dict[str, Any]]], filename: str = "gaia_results.json") -> None: |
|
""" |
|
Save results to a file |
|
|
|
Args: |
|
results: Results from solve() or batch_solve() |
|
filename: Name of the file to save results to |
|
""" |
|
try: |
|
with open(filename, 'w') as f: |
|
json.dump(results, f, indent=2) |
|
|
|
print(f"Results saved to {filename}") |
|
except Exception as e: |
|
print(f"Error saving results: {str(e)}") |
|
|
|
def load_results(self, filename: str = "gaia_results.json") -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
|
""" |
|
Load results from a file |
|
|
|
Args: |
|
filename: Name of the file to load results from |
|
|
|
Returns: |
|
Loaded results |
|
""" |
|
try: |
|
with open(filename, 'r') as f: |
|
results = json.load(f) |
|
|
|
print(f"Results loaded from {filename}") |
|
return results |
|
except Exception as e: |
|
print(f"Error loading results: {str(e)}") |
|
return [] |
|
|
|
def evaluate_performance(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
""" |
|
Evaluate performance metrics from batch results |
|
|
|
Args: |
|
results: Results from batch_solve() |
|
|
|
Returns: |
|
Dictionary of performance metrics |
|
""" |
|
if not results: |
|
return {"error": "No results to evaluate"} |
|
|
|
total_queries = len(results) |
|
successful_queries = sum(1 for r in results if r.get("success", False)) |
|
recovery_queries = sum(1 for r in results if not r.get("success", False) and r.get("recovery", False)) |
|
failed_queries = total_queries - successful_queries - recovery_queries |
|
|
|
avg_execution_time = sum(r.get("execution_time", 0) for r in results) / total_queries |
|
|
|
return { |
|
"total_queries": total_queries, |
|
"successful_queries": successful_queries, |
|
"recovery_queries": recovery_queries, |
|
"failed_queries": failed_queries, |
|
"success_rate": successful_queries / total_queries if total_queries > 0 else 0, |
|
"recovery_rate": recovery_queries / total_queries if total_queries > 0 else 0, |
|
"failure_rate": failed_queries / total_queries if total_queries > 0 else 0, |
|
"avg_execution_time": avg_execution_time |
|
} |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
agent = EnhancedGAIAAgent(use_local_model=False, use_semantic_memory=True) |
|
|
|
|
|
sample_queries = [ |
|
"What is the capital of France and what is its population? Also, calculate 15% of this population.", |
|
"Who was the first person to walk on the moon? What year did this happen?", |
|
"Explain the concept of photosynthesis in simple terms." |
|
] |
|
|
|
|
|
print("\nSolving single query...") |
|
result = agent.solve(sample_queries[0]) |
|
print("\nFinal Answer:") |
|
print(result["answer"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|