|
""" |
|
GAIA-Ready AI Agent using smolagents framework |
|
|
|
This agent is designed to meet the requirements of the Hugging Face Agents Course |
|
and perform well on the GAIA benchmark. It implements the Think-Act-Observe workflow |
|
and includes tools for web search, calculation, image analysis, and code execution. |
|
""" |
|
|
|
import os |
|
import json |
|
import base64 |
|
import requests |
|
from typing import List, Dict, Any, Optional, Union, Callable |
|
import re |
|
import time |
|
from datetime import datetime |
|
import traceback |
|
|
|
|
|
try: |
|
from smolagents import Agent, InferenceClientModel, Tool |
|
from smolagents.memory import Memory |
|
except ImportError: |
|
import subprocess |
|
subprocess.check_call(["pip", "install", "smolagents"]) |
|
from smolagents import Agent, InferenceClientModel, Tool |
|
from smolagents.memory import Memory |
|
|
|
try: |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import io |
|
except ImportError: |
|
import subprocess |
|
subprocess.check_call(["pip", "install", "numpy", "matplotlib", "pillow"]) |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import io |
|
|
|
try: |
|
import requests |
|
from bs4 import BeautifulSoup |
|
except ImportError: |
|
import subprocess |
|
subprocess.check_call(["pip", "install", "requests", "beautifulsoup4"]) |
|
import requests |
|
from bs4 import BeautifulSoup |
|
|
|
|
|
class MemoryManager: |
|
""" |
|
Custom memory manager for the agent that maintains short-term, long-term, |
|
and working memory. |
|
""" |
|
def __init__(self): |
|
self.short_term_memory = [] |
|
self.long_term_memory = [] |
|
self.working_memory = {} |
|
self.max_short_term_items = 10 |
|
self.max_long_term_items = 50 |
|
|
|
def add_to_short_term(self, item: Dict[str, Any]) -> None: |
|
"""Add an item to short-term memory, maintaining size limit""" |
|
self.short_term_memory.append(item) |
|
if len(self.short_term_memory) > self.max_short_term_items: |
|
self.short_term_memory.pop(0) |
|
|
|
def add_to_long_term(self, item: Dict[str, Any]) -> None: |
|
"""Add an important item to long-term memory, maintaining size limit""" |
|
self.long_term_memory.append(item) |
|
if len(self.long_term_memory) > self.max_long_term_items: |
|
self.long_term_memory.pop(0) |
|
|
|
def store_in_working_memory(self, key: str, value: Any) -> None: |
|
"""Store a value in working memory under the specified key""" |
|
self.working_memory[key] = value |
|
|
|
def get_from_working_memory(self, key: str) -> Optional[Any]: |
|
"""Retrieve a value from working memory by key""" |
|
return self.working_memory.get(key) |
|
|
|
def clear_working_memory(self) -> None: |
|
"""Clear the working memory""" |
|
self.working_memory = {} |
|
|
|
def get_relevant_memories(self, query: str) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve memories relevant to the current query |
|
Simple implementation using keyword matching |
|
""" |
|
relevant_memories = [] |
|
query_keywords = set(query.lower().split()) |
|
|
|
|
|
for memory in self.long_term_memory: |
|
memory_text = memory.get("content", "").lower() |
|
if any(keyword in memory_text for keyword in query_keywords): |
|
relevant_memories.append(memory) |
|
|
|
|
|
for memory in self.short_term_memory: |
|
memory_text = memory.get("content", "").lower() |
|
if any(keyword in memory_text for keyword in query_keywords): |
|
relevant_memories.append(memory) |
|
|
|
return relevant_memories |
|
|
|
def get_memory_summary(self) -> str: |
|
"""Get a summary of the current memory state for the agent""" |
|
short_term_summary = "\n".join([f"- {m.get('content', '')}" for m in self.short_term_memory[-5:]]) |
|
long_term_summary = "\n".join([f"- {m.get('content', '')}" for m in self.long_term_memory[-5:]]) |
|
working_memory_summary = "\n".join([f"- {k}: {v}" for k, v in self.working_memory.items()]) |
|
|
|
return f""" |
|
MEMORY SUMMARY: |
|
-------------- |
|
Recent Short-Term Memory: |
|
{short_term_summary} |
|
|
|
Important Long-Term Memory: |
|
{long_term_summary} |
|
|
|
Working Memory: |
|
{working_memory_summary} |
|
""" |
|
|
|
|
|
|
|
|
|
def web_search_function(query: str) -> str: |
|
""" |
|
Search the web for information using a search API |
|
|
|
Args: |
|
query: The search query |
|
|
|
Returns: |
|
Search results as a string |
|
""" |
|
try: |
|
|
|
url = f"https://ddg-api.herokuapp.com/search?query={query}" |
|
response = requests.get(url) |
|
|
|
if response.status_code == 200: |
|
results = response.json() |
|
formatted_results = [] |
|
|
|
for i, result in enumerate(results[:5]): |
|
title = result.get('title', 'No title') |
|
snippet = result.get('snippet', 'No snippet') |
|
link = result.get('link', 'No link') |
|
formatted_results.append(f"{i+1}. {title}\n {snippet}\n URL: {link}\n") |
|
|
|
return "Search Results:\n" + "\n".join(formatted_results) |
|
else: |
|
return f"Error: Search request failed with status code {response.status_code}" |
|
except Exception as e: |
|
return f"Error performing web search: {str(e)}" |
|
|
|
|
|
def web_page_content_function(url: str) -> str: |
|
""" |
|
Fetch and extract content from a web page |
|
|
|
Args: |
|
url: The URL of the web page to fetch |
|
|
|
Returns: |
|
Extracted content as a string |
|
""" |
|
try: |
|
headers = { |
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' |
|
} |
|
response = requests.get(url, headers=headers) |
|
|
|
if response.status_code == 200: |
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
|
|
|
|
for script in soup(["script", "style"]): |
|
script.extract() |
|
|
|
|
|
text = soup.get_text() |
|
|
|
|
|
lines = (line.strip() for line in text.splitlines()) |
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) |
|
text = '\n'.join(chunk for chunk in chunks if chunk) |
|
|
|
|
|
if len(text) > 4000: |
|
text = text[:4000] + "...\n[Content truncated due to length]" |
|
|
|
return f"Content from {url}:\n\n{text}" |
|
else: |
|
return f"Error: Failed to fetch web page with status code {response.status_code}" |
|
except Exception as e: |
|
return f"Error fetching web page content: {str(e)}" |
|
|
|
|
|
def calculator_function(expression: str) -> str: |
|
""" |
|
Evaluate a mathematical expression |
|
|
|
Args: |
|
expression: The mathematical expression to evaluate |
|
|
|
Returns: |
|
Result of the calculation as a string |
|
""" |
|
try: |
|
|
|
|
|
clean_expr = re.sub(r'[^0-9+\-*/().^ ]', '', expression) |
|
|
|
|
|
clean_expr = clean_expr.replace('^', '**') |
|
|
|
|
|
result = eval(clean_expr) |
|
|
|
return f"Expression: {expression}\nResult: {result}" |
|
except Exception as e: |
|
return f"Error calculating result: {str(e)}" |
|
|
|
|
|
def python_executor_function(code: str) -> str: |
|
""" |
|
Execute Python code and return the result |
|
|
|
Args: |
|
code: The Python code to execute |
|
|
|
Returns: |
|
Output of the code execution as a string |
|
""" |
|
try: |
|
|
|
from io import StringIO |
|
import sys |
|
|
|
old_stdout = sys.stdout |
|
redirected_output = StringIO() |
|
sys.stdout = redirected_output |
|
|
|
|
|
exec_globals = { |
|
"np": np, |
|
"plt": plt, |
|
"requests": requests, |
|
"BeautifulSoup": BeautifulSoup, |
|
"Image": Image, |
|
"io": io, |
|
"json": json, |
|
"base64": base64, |
|
"re": re, |
|
"time": time, |
|
"datetime": datetime |
|
} |
|
|
|
exec(code, exec_globals) |
|
|
|
|
|
sys.stdout = old_stdout |
|
output = redirected_output.getvalue() |
|
|
|
return f"Code executed successfully:\n\n{output}" |
|
except Exception as e: |
|
return f"Error executing Python code: {str(e)}\n{traceback.format_exc()}" |
|
|
|
|
|
def image_analyzer_function(image_url: str) -> str: |
|
""" |
|
Analyze an image and provide a description |
|
|
|
Args: |
|
image_url: URL of the image to analyze |
|
|
|
Returns: |
|
Description of the image as a string |
|
""" |
|
try: |
|
|
|
response = requests.get(image_url) |
|
|
|
if response.status_code == 200: |
|
|
|
image_data = base64.b64encode(response.content).decode('utf-8') |
|
|
|
|
|
|
|
return f""" |
|
Image Analysis: |
|
- Successfully retrieved image from {image_url} |
|
- Image size: {len(response.content)} bytes |
|
|
|
[Note: In a production environment, this would use a vision model to analyze the image content] |
|
|
|
To properly analyze this image, please describe what you see in the image. |
|
""" |
|
else: |
|
return f"Error: Failed to fetch image with status code {response.status_code}" |
|
except Exception as e: |
|
return f"Error analyzing image: {str(e)}" |
|
|
|
|
|
def text_processor_function(text: str, operation: str) -> str: |
|
""" |
|
Process and analyze text |
|
|
|
Args: |
|
text: The text to process |
|
operation: The operation to perform (summarize, analyze_sentiment, extract_keywords) |
|
|
|
Returns: |
|
Processed text as a string |
|
""" |
|
try: |
|
if operation == "summarize": |
|
|
|
sentences = text.split('. ') |
|
if len(sentences) <= 3: |
|
return f"Summary: {text}" |
|
|
|
|
|
summary = f"{sentences[0]}. {sentences[len(sentences)//2]}. {sentences[-1]}" |
|
return f"Summary: {summary}" |
|
|
|
elif operation == "analyze_sentiment": |
|
|
|
positive_words = ['good', 'great', 'excellent', 'positive', 'happy', 'love', 'like'] |
|
negative_words = ['bad', 'poor', 'negative', 'unhappy', 'hate', 'dislike'] |
|
|
|
text_lower = text.lower() |
|
positive_count = sum(1 for word in positive_words if word in text_lower) |
|
negative_count = sum(1 for word in negative_words if word in text_lower) |
|
|
|
if positive_count > negative_count: |
|
sentiment = "positive" |
|
elif negative_count > positive_count: |
|
sentiment = "negative" |
|
else: |
|
sentiment = "neutral" |
|
|
|
return f"Sentiment Analysis: {sentiment} (positive words: {positive_count}, negative words: {negative_count})" |
|
|
|
elif operation == "extract_keywords": |
|
|
|
import re |
|
from collections import Counter |
|
|
|
|
|
text_clean = re.sub(r'[^\w\s]', '', text.lower()) |
|
|
|
|
|
stop_words = ['the', 'a', 'an', 'and', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'] |
|
words = [word for word in text_clean.split() if word not in stop_words and len(word) > 2] |
|
|
|
|
|
word_counts = Counter(words) |
|
|
|
|
|
keywords = [word for word, count in word_counts.most_common(10)] |
|
|
|
return f"Keywords: {', '.join(keywords)}" |
|
else: |
|
return f"Error: Unknown operation '{operation}'. Supported operations: summarize, analyze_sentiment, extract_keywords" |
|
except Exception as e: |
|
return f"Error processing text: {str(e)}" |
|
|
|
|
|
def file_manager_function(operation: str, filename: str, content: str = None) -> str: |
|
""" |
|
Save and load data from files |
|
|
|
Args: |
|
operation: The operation to perform (save, load) |
|
filename: The name of the file |
|
content: The content to save (for save operation) |
|
|
|
Returns: |
|
Result of the operation as a string |
|
""" |
|
try: |
|
if operation == "save": |
|
if content is None: |
|
return "Error: Content is required for save operation" |
|
|
|
with open(filename, 'w') as f: |
|
f.write(content) |
|
|
|
return f"Successfully saved content to {filename}" |
|
|
|
elif operation == "load": |
|
if not os.path.exists(filename): |
|
return f"Error: File {filename} does not exist" |
|
|
|
with open(filename, 'r') as f: |
|
content = f.read() |
|
|
|
return f"Content of {filename}:\n\n{content}" |
|
else: |
|
return f"Error: Unknown operation '{operation}'. Supported operations: save, load" |
|
except Exception as e: |
|
return f"Error managing file: {str(e)}" |
|
|
|
|
|
class GAIAAgent: |
|
""" |
|
AI Agent designed to perform well on the GAIA benchmark |
|
Implements the Think-Act-Observe workflow |
|
""" |
|
def __init__(self, api_key=None, use_local_model=False): |
|
self.memory_manager = MemoryManager() |
|
|
|
|
|
if use_local_model: |
|
|
|
try: |
|
from smolagents import LiteLLMModel |
|
self.model = LiteLLMModel( |
|
model_id="ollama_chat/qwen2:7b", |
|
api_base="http://127.0.0.1:11434", |
|
num_ctx=8192, |
|
) |
|
except Exception as e: |
|
print(f"Error initializing local model: {str(e)}") |
|
print("Falling back to Hugging Face Inference API") |
|
self.model = InferenceClientModel( |
|
model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
api_key=api_key or os.environ.get("HF_API_KEY", "") |
|
) |
|
else: |
|
|
|
self.model = InferenceClientModel( |
|
model_id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
api_key=api_key or os.environ.get("HF_API_KEY", "") |
|
) |
|
|
|
|
|
self.tools = [ |
|
Tool( |
|
name="web_search", |
|
description="Search the web for information", |
|
function=web_search_function |
|
), |
|
Tool( |
|
name="web_page_content", |
|
description="Fetch and extract content from a web page", |
|
function=web_page_content_function |
|
), |
|
Tool( |
|
name="calculator", |
|
description="Perform mathematical calculations", |
|
function=calculator_function |
|
), |
|
Tool( |
|
name="image_analyzer", |
|
description="Analyze image content", |
|
function=image_analyzer_function |
|
), |
|
Tool( |
|
name="python_executor", |
|
description="Execute Python code", |
|
function=python_executor_function |
|
), |
|
Tool( |
|
name="text_processor", |
|
description="Process and analyze text", |
|
function=text_processor_function |
|
), |
|
Tool( |
|
name="file_manager", |
|
description="Save and load data from files", |
|
function=file_manager_function |
|
) |
|
] |
|
|
|
|
|
self.system_prompt = """ |
|
You are an advanced AI assistant designed to solve complex tasks from the GAIA benchmark. |
|
You have access to various tools that can help you solve these tasks. |
|
|
|
Always follow the Think-Act-Observe workflow: |
|
1. Think: Carefully analyze the task and plan your approach |
|
2. Act: Use appropriate tools to gather information or perform actions |
|
3. Observe: Analyze the results of your actions and adjust your approach if needed |
|
|
|
For complex tasks, break them down into smaller steps. |
|
Always verify your answers before submitting them. |
|
|
|
When using tools: |
|
- web_search: Use to find information online |
|
- web_page_content: Use to extract content from specific web pages |
|
- calculator: Use for mathematical calculations |
|
- image_analyzer: Use to analyze image content |
|
- python_executor: Use to run Python code for complex operations |
|
- text_processor: Use to process and analyze text (summarize, analyze_sentiment, extract_keywords) |
|
- file_manager: Use to save and load data from files (save, load) |
|
|
|
Be thorough, methodical, and precise in your reasoning. |
|
""" |
|
|
|
|
|
self.agent = Agent( |
|
model=self.model, |
|
tools=self.tools, |
|
system_prompt=self.system_prompt |
|
) |
|
|
|
def think(self, query): |
|
""" |
|
Analyze the task and plan an approach |
|
|
|
Args: |
|
query: The user's query or task |
|
|
|
Returns: |
|
Dictionary containing analysis and plan |
|
""" |
|
|
|
relevant_memories = self.memory_manager.get_relevant_memories(query) |
|
|
|
|
|
thinking_prompt = f""" |
|
TASK: {query} |
|
|
|
RELEVANT MEMORIES: |
|
{relevant_memories if relevant_memories else "No relevant memories found."} |
|
|
|
Please analyze this task and create a plan: |
|
1. What is this task asking for? |
|
2. What information do I need to solve it? |
|
3. What tools would be most helpful? |
|
4. What steps should I take to solve it? |
|
|
|
Provide your analysis and plan. |
|
""" |
|
|
|
|
|
response = self.agent.chat(thinking_prompt) |
|
|
|
|
|
self.memory_manager.add_to_short_term({ |
|
"type": "thinking", |
|
"content": response, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
|
|
return { |
|
"analysis": response, |
|
"plan": response |
|
} |
|
|
|
def act(self, plan, query): |
|
""" |
|
Execute actions based on the plan |
|
|
|
Args: |
|
plan: The plan generated by the think step |
|
query: The original query |
|
|
|
Returns: |
|
Results of the actions |
|
""" |
|
|
|
tool_selection_prompt = f""" |
|
TASK: {query} |
|
|
|
MY PLAN: |
|
{plan['plan']} |
|
|
|
Based on this plan, which tool should I use first and with what parameters? |
|
Respond in the following format: |
|
TOOL: [tool name] |
|
PARAMETERS: [parameters for the tool] |
|
REASONING: [why this tool is appropriate] |
|
""" |
|
|
|
tool_selection = self.agent.chat(tool_selection_prompt) |
|
|
|
|
|
self.memory_manager.add_to_short_term({ |
|
"type": "tool_selection", |
|
"content": tool_selection, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
|
|
|
|
action_prompt = f""" |
|
TASK: {query} |
|
|
|
MY PLAN: |
|
{plan['plan']} |
|
|
|
TOOL SELECTION: |
|
{tool_selection} |
|
|
|
Please execute the appropriate tool to help solve this task. |
|
""" |
|
|
|
action_result = self.agent.chat(action_prompt) |
|
|
|
|
|
self.memory_manager.add_to_short_term({ |
|
"type": "action_result", |
|
"content": action_result, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
return action_result |
|
|
|
def observe(self, action_result, plan, query): |
|
""" |
|
Analyze the results of actions and determine next steps |
|
|
|
Args: |
|
action_result: Results from the act step |
|
plan: The original plan |
|
query: The original query |
|
|
|
Returns: |
|
Observation and next steps |
|
""" |
|
observation_prompt = f""" |
|
TASK: {query} |
|
|
|
MY PLAN: |
|
{plan['plan']} |
|
|
|
ACTION RESULT: |
|
{action_result} |
|
|
|
Please analyze these results: |
|
1. What did I learn from this action? |
|
2. Does this fully answer the original task? |
|
3. If not, what should I do next? |
|
4. If yes, what is the final answer? |
|
|
|
Provide your analysis and next steps or final answer. |
|
""" |
|
|
|
observation = self.agent.chat(observation_prompt) |
|
|
|
|
|
self.memory_manager.add_to_short_term({ |
|
"type": "observation", |
|
"content": observation, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
|
|
if "next steps" in observation.lower() or "next tool" in observation.lower(): |
|
continue_execution = True |
|
else: |
|
|
|
self.memory_manager.add_to_long_term({ |
|
"type": "final_answer", |
|
"query": query, |
|
"content": observation, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
continue_execution = False |
|
|
|
return { |
|
"observation": observation, |
|
"continue": continue_execution |
|
} |
|
|
|
def solve(self, query, max_iterations=5): |
|
""" |
|
Solve a task using the Think-Act-Observe workflow |
|
|
|
Args: |
|
query: The user's query or task |
|
max_iterations: Maximum number of iterations to prevent infinite loops |
|
|
|
Returns: |
|
Final answer to the query |
|
""" |
|
|
|
self.memory_manager.add_to_short_term({ |
|
"type": "query", |
|
"content": query, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
|
|
iteration = 0 |
|
final_answer = None |
|
|
|
while iteration < max_iterations: |
|
print(f"Iteration {iteration + 1}/{max_iterations}") |
|
|
|
|
|
print("Thinking...") |
|
plan = self.think(query) |
|
|
|
|
|
print("Acting...") |
|
action_result = self.act(plan, query) |
|
|
|
|
|
print("Observing...") |
|
observation = self.observe(action_result, plan, query) |
|
|
|
|
|
if not observation["continue"]: |
|
final_answer = observation["observation"] |
|
break |
|
|
|
|
|
query = f""" |
|
Original task: {query} |
|
|
|
Progress so far: |
|
{observation["observation"]} |
|
|
|
Please continue solving this task. |
|
""" |
|
|
|
iteration += 1 |
|
|
|
|
|
if final_answer is None: |
|
final_answer = f""" |
|
I've spent {max_iterations} iterations trying to solve this task. |
|
Here's my best answer based on what I've learned: |
|
|
|
{observation["observation"]} |
|
|
|
Note: This answer may be incomplete as I reached the maximum number of iterations. |
|
""" |
|
|
|
return final_answer |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
agent = GAIAAgent(use_local_model=False) |
|
|
|
|
|
query = "What is the capital of France and what is its population? Also, calculate 15% of this population." |
|
|
|
|
|
answer = agent.solve(query) |
|
|
|
print("\nFinal Answer:") |
|
print(answer) |
|
|