web-search-improved
#1
by
ddavydov
- opened
- agent.py +51 -44
- app.py +3 -13
- requirements.txt +2 -5
- test_local.py +99 -200
- tools.py +38 -267
- utils.py +65 -32
agent.py
CHANGED
@@ -1,45 +1,52 @@
|
|
1 |
from typing import TypedDict, Annotated
|
2 |
import os
|
3 |
-
from dotenv import load_dotenv
|
4 |
from langgraph.graph.message import add_messages
|
5 |
-
|
6 |
-
# Load environment variables from .env file
|
7 |
-
load_dotenv()
|
8 |
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
|
9 |
from langgraph.prebuilt import ToolNode
|
10 |
from langgraph.graph import START, StateGraph
|
11 |
from langgraph.checkpoint.memory import MemorySaver
|
12 |
from langgraph.prebuilt import tools_condition
|
13 |
-
from
|
14 |
from tools import agent_tools
|
15 |
-
from utils import format_gaia_answer, log_agent_step
|
16 |
|
17 |
-
# Initialize
|
18 |
-
|
19 |
-
|
|
|
20 |
temperature=0.1,
|
21 |
-
|
22 |
-
api_key=os.environ.get("OPENAI_API_KEY")
|
23 |
)
|
24 |
|
|
|
25 |
chat_with_tools = chat.bind_tools(agent_tools)
|
26 |
|
27 |
-
# System prompt for
|
28 |
-
SYSTEM_PROMPT = """You are a
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
|
39 |
# Generate the AgentState
|
40 |
class AgentState(TypedDict):
|
41 |
messages: Annotated[list[AnyMessage], add_messages]
|
42 |
task_id: str
|
|
|
43 |
|
44 |
def assistant(state: AgentState):
|
45 |
"""Main assistant function that processes messages and calls tools."""
|
@@ -73,7 +80,9 @@ def create_smart_agent():
|
|
73 |
)
|
74 |
builder.add_edge("tools", "assistant")
|
75 |
|
76 |
-
|
|
|
|
|
77 |
|
78 |
return agent
|
79 |
|
@@ -82,41 +91,44 @@ class SmartAgent:
|
|
82 |
|
83 |
def __init__(self):
|
84 |
self.agent = create_smart_agent()
|
85 |
-
print("🤖 Smart Agent initialized with
|
86 |
|
87 |
-
def __call__(self, question: str, task_id: str = None) ->
|
88 |
-
"""Process a question and return the formatted answer
|
89 |
try:
|
90 |
print(f"\n🎯 Processing question: {question[:100]}...")
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
enhanced_question = question
|
93 |
if task_id:
|
94 |
-
enhanced_question = f"Task ID: {task_id}\n\nQuestion: {question}"
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
}
|
99 |
|
100 |
initial_state = {
|
101 |
"messages": [HumanMessage(content=enhanced_question)],
|
102 |
-
"task_id": task_id or ""
|
|
|
103 |
}
|
104 |
|
105 |
result = self.agent.invoke(initial_state, config=config)
|
106 |
|
|
|
107 |
if result and 'messages' in result and result['messages']:
|
108 |
final_message = result['messages'][-1]
|
109 |
raw_answer = final_message.content
|
110 |
-
|
111 |
-
reasoning_trace = []
|
112 |
-
for msg in result['messages']:
|
113 |
-
if hasattr(msg, 'content') and msg.content:
|
114 |
-
reasoning_trace.append(msg.content)
|
115 |
-
|
116 |
-
reasoning_text = "\n---\n".join(reasoning_trace)
|
117 |
else:
|
118 |
raw_answer = "No response generated"
|
119 |
-
reasoning_text = "No reasoning trace available"
|
120 |
|
121 |
# Format the answer for submission
|
122 |
formatted_answer = format_gaia_answer(raw_answer)
|
@@ -124,16 +136,11 @@ class SmartAgent:
|
|
124 |
print(f"✅ Raw answer: {raw_answer}")
|
125 |
print(f"🎯 Formatted answer: {formatted_answer}")
|
126 |
|
127 |
-
|
128 |
-
if not formatted_answer or formatted_answer.strip() == "":
|
129 |
-
print("⚠️ WARNING: Empty formatted answer!")
|
130 |
-
formatted_answer = "ERROR: No valid answer extracted"
|
131 |
-
|
132 |
-
return formatted_answer, reasoning_text
|
133 |
|
134 |
except Exception as e:
|
135 |
error_msg = f"Error processing question: {str(e)}"
|
136 |
print(f"❌ {error_msg}")
|
137 |
-
return error_msg
|
138 |
|
139 |
smart_agent = SmartAgent()
|
|
|
1 |
from typing import TypedDict, Annotated
|
2 |
import os
|
|
|
3 |
from langgraph.graph.message import add_messages
|
|
|
|
|
|
|
4 |
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
|
5 |
from langgraph.prebuilt import ToolNode
|
6 |
from langgraph.graph import START, StateGraph
|
7 |
from langgraph.checkpoint.memory import MemorySaver
|
8 |
from langgraph.prebuilt import tools_condition
|
9 |
+
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
10 |
from tools import agent_tools
|
11 |
+
from utils import format_gaia_answer, analyze_question_type, create_execution_plan, log_agent_step
|
12 |
|
13 |
+
# Initialize LLM (same as unit3)
|
14 |
+
llm = HuggingFaceEndpoint(
|
15 |
+
repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
|
16 |
+
huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
|
17 |
temperature=0.1,
|
18 |
+
max_new_tokens=1024,
|
|
|
19 |
)
|
20 |
|
21 |
+
chat = ChatHuggingFace(llm=llm, verbose=True)
|
22 |
chat_with_tools = chat.bind_tools(agent_tools)
|
23 |
|
24 |
+
# System prompt for intelligent question answering
|
25 |
+
SYSTEM_PROMPT = """You are a highly capable AI assistant designed to answer questions accurately and helpfully.
|
26 |
+
|
27 |
+
Your approach should include:
|
28 |
+
- Multi-step reasoning and planning for complex questions
|
29 |
+
- Intelligent tool usage when needed for web search, file processing, calculations, and analysis
|
30 |
+
- Precise, factual answers based on reliable information
|
31 |
+
- Breaking down complex questions into manageable steps
|
32 |
|
33 |
+
IMPORTANT GUIDELINES:
|
34 |
+
1. Think step-by-step and use available tools when they can help provide better answers
|
35 |
+
2. For current information: Search the web for up-to-date facts
|
36 |
+
3. For files: Process associated files when task_id is provided
|
37 |
+
4. For visual content: Analyze images carefully when present
|
38 |
+
5. For calculations: Use computational tools for accuracy
|
39 |
+
6. Provide concise, direct answers without unnecessary prefixes
|
40 |
+
7. Focus on accuracy and helpfulness
|
41 |
+
8. Be factual and avoid speculation
|
42 |
|
43 |
+
Your goal is to be as helpful and accurate as possible while using the right tools for each task."""
|
44 |
|
45 |
# Generate the AgentState
|
46 |
class AgentState(TypedDict):
|
47 |
messages: Annotated[list[AnyMessage], add_messages]
|
48 |
task_id: str
|
49 |
+
question_analysis: dict
|
50 |
|
51 |
def assistant(state: AgentState):
|
52 |
"""Main assistant function that processes messages and calls tools."""
|
|
|
80 |
)
|
81 |
builder.add_edge("tools", "assistant")
|
82 |
|
83 |
+
# Add memory
|
84 |
+
memory = MemorySaver()
|
85 |
+
agent = builder.compile(checkpointer=memory)
|
86 |
|
87 |
return agent
|
88 |
|
|
|
91 |
|
92 |
def __init__(self):
|
93 |
self.agent = create_smart_agent()
|
94 |
+
print("🤖 Smart Agent initialized with LangGraph and tools")
|
95 |
|
96 |
+
def __call__(self, question: str, task_id: str = None) -> str:
|
97 |
+
"""Process a question and return the formatted answer."""
|
98 |
try:
|
99 |
print(f"\n🎯 Processing question: {question[:100]}...")
|
100 |
|
101 |
+
# Analyze the question
|
102 |
+
analysis = analyze_question_type(question)
|
103 |
+
print(f"📊 Question analysis: {analysis}")
|
104 |
+
|
105 |
+
# Create execution plan
|
106 |
+
plan = create_execution_plan(question, task_id)
|
107 |
+
print(f"📋 Execution plan: {plan}")
|
108 |
+
|
109 |
+
# Prepare the question with task_id context if available
|
110 |
enhanced_question = question
|
111 |
if task_id:
|
112 |
+
enhanced_question = f"Task ID: {task_id}\n\nQuestion: {question}\n\nNote: If this question involves files, use the file_download tool with task_id '{task_id}' to access associated files."
|
113 |
|
114 |
+
# Invoke the agent
|
115 |
+
thread_id = f"task-{task_id}" if task_id else "general"
|
116 |
+
config = {"configurable": {"thread_id": thread_id}}
|
117 |
|
118 |
initial_state = {
|
119 |
"messages": [HumanMessage(content=enhanced_question)],
|
120 |
+
"task_id": task_id or "",
|
121 |
+
"question_analysis": analysis
|
122 |
}
|
123 |
|
124 |
result = self.agent.invoke(initial_state, config=config)
|
125 |
|
126 |
+
# Extract the final answer
|
127 |
if result and 'messages' in result and result['messages']:
|
128 |
final_message = result['messages'][-1]
|
129 |
raw_answer = final_message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
else:
|
131 |
raw_answer = "No response generated"
|
|
|
132 |
|
133 |
# Format the answer for submission
|
134 |
formatted_answer = format_gaia_answer(raw_answer)
|
|
|
136 |
print(f"✅ Raw answer: {raw_answer}")
|
137 |
print(f"🎯 Formatted answer: {formatted_answer}")
|
138 |
|
139 |
+
return formatted_answer
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
except Exception as e:
|
142 |
error_msg = f"Error processing question: {str(e)}"
|
143 |
print(f"❌ {error_msg}")
|
144 |
+
return error_msg
|
145 |
|
146 |
smart_agent = SmartAgent()
|
app.py
CHANGED
@@ -71,22 +71,12 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
71 |
print(f"Skipping item with missing task_id or question: {item}")
|
72 |
continue
|
73 |
try:
|
74 |
-
submitted_answer
|
75 |
-
answers_payload.append({
|
76 |
-
"task_id": task_id,
|
77 |
-
"submitted_answer": submitted_answer,
|
78 |
-
"reasoning_trace": reasoning_trace
|
79 |
-
})
|
80 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
81 |
except Exception as e:
|
82 |
print(f"Error running agent on task {task_id}: {e}")
|
83 |
-
|
84 |
-
answers_payload.append({
|
85 |
-
"task_id": task_id,
|
86 |
-
"submitted_answer": error_answer,
|
87 |
-
"reasoning_trace": f"Error occurred: {str(e)}"
|
88 |
-
})
|
89 |
-
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": error_answer})
|
90 |
|
91 |
if not answers_payload:
|
92 |
print("Agent did not produce any answers to submit.")
|
|
|
71 |
print(f"Skipping item with missing task_id or question: {item}")
|
72 |
continue
|
73 |
try:
|
74 |
+
submitted_answer = agent(question_text, task_id)
|
75 |
+
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
|
|
|
|
|
|
|
|
76 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
77 |
except Exception as e:
|
78 |
print(f"Error running agent on task {task_id}: {e}")
|
79 |
+
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
if not answers_payload:
|
82 |
print("Agent did not produce any answers to submit.")
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# Core dependencies from unit3
|
2 |
langchain
|
3 |
langchain-community
|
4 |
-
langchain-
|
5 |
langgraph
|
6 |
huggingface_hub
|
7 |
|
@@ -10,11 +10,8 @@ gradio
|
|
10 |
requests
|
11 |
pillow
|
12 |
PyPDF2
|
13 |
-
|
14 |
python-dotenv
|
15 |
-
beautifulsoup4
|
16 |
-
faiss-cpu
|
17 |
-
langchain-text-splitters
|
18 |
|
19 |
# For image processing and multimodal capabilities
|
20 |
transformers
|
|
|
1 |
# Core dependencies from unit3
|
2 |
langchain
|
3 |
langchain-community
|
4 |
+
langchain-huggingface
|
5 |
langgraph
|
6 |
huggingface_hub
|
7 |
|
|
|
10 |
requests
|
11 |
pillow
|
12 |
PyPDF2
|
13 |
+
duckduckgo-search
|
14 |
python-dotenv
|
|
|
|
|
|
|
15 |
|
16 |
# For image processing and multimodal capabilities
|
17 |
transformers
|
test_local.py
CHANGED
@@ -1,238 +1,137 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
"""
|
3 |
-
|
4 |
-
|
5 |
"""
|
6 |
|
7 |
-
import
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from agent import smart_agent
|
10 |
|
11 |
-
def
|
12 |
-
"""Test the
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
|
|
|
20 |
|
21 |
-
|
22 |
-
print()
|
23 |
-
|
24 |
-
# Run the agent
|
25 |
-
print("🤖 Running smart agent on the predefined question...")
|
26 |
try:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
processing_time = end_time - start_time
|
32 |
-
print(f"✅ Agent completed in {processing_time:.2f} seconds")
|
33 |
-
print()
|
34 |
-
|
35 |
except Exception as e:
|
36 |
-
print(f"
|
37 |
-
return False
|
38 |
|
39 |
-
# Display results
|
40 |
-
print("📊 AGENT RESULTS")
|
41 |
-
print("-" * 40)
|
42 |
-
print(f"🎯 Formatted Answer: '{answer}'")
|
43 |
-
print(f"📝 Reasoning Length: {len(reasoning_trace)} characters")
|
44 |
-
print(f"⏱️ Processing Time: {processing_time:.2f}s")
|
45 |
print()
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
print("✅ GAIA FORMAT VALIDATION")
|
56 |
-
print("-" * 40)
|
57 |
-
|
58 |
-
# Check if answer is not empty
|
59 |
-
if answer and answer.strip():
|
60 |
-
print("✅ Answer is not empty")
|
61 |
-
else:
|
62 |
-
print("❌ Answer is empty or None")
|
63 |
-
return False
|
64 |
-
|
65 |
-
# Check if answer looks like IOC country code (2-3 uppercase letters)
|
66 |
-
import re
|
67 |
-
if re.match(r'^[A-Z]{2,3}$', answer.strip()):
|
68 |
-
print(f"✅ Answer '{answer}' matches IOC country code format")
|
69 |
-
else:
|
70 |
-
print(f"⚠️ Answer '{answer}' may not be in correct IOC format (should be 2-3 uppercase letters)")
|
71 |
-
|
72 |
-
# Check if web search was used (look for web_search in reasoning)
|
73 |
-
if "web_search" in reasoning_trace.lower() or "search" in reasoning_trace.lower():
|
74 |
-
print("✅ Agent appears to have used web search")
|
75 |
-
else:
|
76 |
-
print("⚠️ No clear evidence of web search usage")
|
77 |
-
|
78 |
-
# Check answer length (should be short for country code)
|
79 |
-
if len(answer.strip()) <= 5:
|
80 |
-
print("✅ Answer length is appropriate for country code")
|
81 |
-
else:
|
82 |
-
print("⚠️ Answer seems too long for a country code")
|
83 |
|
84 |
print()
|
85 |
-
|
86 |
-
# Final validation
|
87 |
-
print("🏁 FINAL VALIDATION")
|
88 |
-
print("-" * 40)
|
89 |
-
|
90 |
-
if answer and answer.strip() and len(answer.strip()) <= 5:
|
91 |
-
print("✅ PREDEFINED TEST PASSED - Answer format suitable for GAIA")
|
92 |
-
print(f"🎯 Agent produced: '{answer}' for 1928 Olympics question")
|
93 |
-
return True
|
94 |
-
else:
|
95 |
-
print("❌ PREDEFINED TEST FAILED - Answer format needs improvement")
|
96 |
-
return False
|
97 |
|
98 |
-
def
|
99 |
-
"""Test the agent with a
|
100 |
-
|
101 |
-
print("🔧 GAIA Random Question Test")
|
102 |
-
print("="*60)
|
103 |
|
104 |
-
# Step 1: Fetch a random question
|
105 |
-
print("📡 Fetching random question from GAIA API...")
|
106 |
try:
|
107 |
question_data = fetch_random_question()
|
108 |
if not question_data:
|
109 |
-
print("❌ Failed to fetch
|
110 |
-
return
|
111 |
|
112 |
-
task_id = question_data.get("task_id"
|
113 |
-
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
return False
|
118 |
-
|
119 |
-
print(f"✅ Successfully fetched question")
|
120 |
-
print(f"📋 Task ID: {task_id}")
|
121 |
-
print(f"❓ Question: {question_text}")
|
122 |
-
print()
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
# Step 2: Run the agent
|
129 |
-
print("🤖 Running smart agent on the question...")
|
130 |
-
try:
|
131 |
-
start_time = time.time()
|
132 |
-
answer, reasoning_trace = smart_agent(question_text, task_id)
|
133 |
-
end_time = time.time()
|
134 |
-
|
135 |
-
processing_time = end_time - start_time
|
136 |
-
print(f"✅ Agent completed in {processing_time:.2f} seconds")
|
137 |
-
print()
|
138 |
|
139 |
except Exception as e:
|
140 |
-
print(f"❌
|
141 |
-
return False
|
142 |
|
143 |
-
# Step 3: Display results
|
144 |
-
print("📊 AGENT RESULTS")
|
145 |
-
print("-" * 40)
|
146 |
-
print(f"🎯 Formatted Answer: '{answer}'")
|
147 |
-
print(f"📝 Reasoning Length: {len(reasoning_trace)} characters")
|
148 |
-
print(f"⏱️ Processing Time: {processing_time:.2f}s")
|
149 |
print()
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
print("-" * 40)
|
154 |
-
reasoning_preview = reasoning_trace[:300] + "..." if len(reasoning_trace) > 300 else reasoning_trace
|
155 |
-
print(reasoning_preview)
|
156 |
-
print()
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
print("
|
165 |
-
else:
|
166 |
-
print("❌ Answer is empty or None")
|
167 |
return False
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
if len(answer) > 1000:
|
177 |
-
print("⚠️ Answer is very long (>1000 chars) - might need review")
|
178 |
-
else:
|
179 |
-
print("✅ Answer length is reasonable")
|
180 |
-
|
181 |
-
print()
|
182 |
-
|
183 |
-
# Step 6: Show submission format
|
184 |
-
print("📡 SUBMISSION FORMAT PREVIEW")
|
185 |
-
print("-" * 40)
|
186 |
-
|
187 |
-
submission_entry = {
|
188 |
-
"task_id": task_id,
|
189 |
-
"model_answer": answer,
|
190 |
-
"reasoning_trace": reasoning_trace
|
191 |
-
}
|
192 |
-
|
193 |
-
# Validate required fields
|
194 |
-
required_fields = ["task_id", "model_answer"]
|
195 |
-
all_valid = True
|
196 |
-
|
197 |
-
for field in required_fields:
|
198 |
-
if field in submission_entry and submission_entry[field]:
|
199 |
-
print(f"✅ {field}: '{submission_entry[field][:50]}{'...' if len(str(submission_entry[field])) > 50 else ''}'")
|
200 |
-
else:
|
201 |
-
print(f"❌ Missing or empty {field}")
|
202 |
-
all_valid = False
|
203 |
|
204 |
-
# Check
|
205 |
-
if
|
206 |
-
print(
|
207 |
-
|
208 |
-
print("ℹ️ reasoning_trace: Not present (optional)")
|
209 |
|
210 |
print()
|
211 |
|
212 |
-
#
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
print("✅ ALL CHECKS PASSED - Agent is ready for submission!")
|
218 |
-
print("🚀 You can now run the full evaluation with confidence.")
|
219 |
-
return True
|
220 |
-
else:
|
221 |
-
print("❌ SOME CHECKS FAILED - Please review the issues above.")
|
222 |
-
return False
|
223 |
|
224 |
if __name__ == "__main__":
|
225 |
-
|
226 |
-
print("This test validates web search functionality and answer formatting.")
|
227 |
-
print()
|
228 |
-
|
229 |
-
# Test the predefined 1928 Olympics question
|
230 |
-
success = test_predefined_gaia_question()
|
231 |
-
|
232 |
-
print("\n" + "="*60)
|
233 |
-
if success:
|
234 |
-
print("🎉 Predefined test completed successfully! Agent produces well-defined answers.")
|
235 |
-
print("💡 You can also run test_random_gaia_question() for additional testing.")
|
236 |
-
else:
|
237 |
-
print("⚠️ Predefined test revealed issues that need to be addressed.")
|
238 |
-
print("="*60)
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
"""
|
3 |
+
Local testing script for the GAIA agent.
|
4 |
+
Run this to test the agent before deploying to HF Spaces.
|
5 |
"""
|
6 |
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
# Load environment variables
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
# Add current directory to path for imports
|
15 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
16 |
+
|
17 |
+
from utils import fetch_random_question, analyze_question_type
|
18 |
from agent import smart_agent
|
19 |
|
20 |
+
def test_question_analysis():
|
21 |
+
"""Test the question analysis functionality."""
|
22 |
+
print("🧪 Testing question analysis...")
|
23 |
+
|
24 |
+
test_questions = [
|
25 |
+
"What is the current population of Tokyo?",
|
26 |
+
"Calculate 15 * 23 + 45",
|
27 |
+
"Analyze the image shown in the document",
|
28 |
+
"Extract all dates from the provided text file"
|
29 |
+
]
|
30 |
+
|
31 |
+
for question in test_questions:
|
32 |
+
analysis = analyze_question_type(question)
|
33 |
+
print(f"Question: {question}")
|
34 |
+
print(f"Analysis: {analysis}")
|
35 |
+
print()
|
36 |
+
|
37 |
+
def test_tools():
|
38 |
+
"""Test individual tools."""
|
39 |
+
print("🔧 Testing individual tools...")
|
40 |
|
41 |
+
# Test calculator
|
42 |
+
from tools import calculator_tool
|
43 |
+
calc_result = calculator_tool.func("15 + 27")
|
44 |
+
print(f"Calculator test: {calc_result}")
|
45 |
|
46 |
+
# Test web search (if available)
|
|
|
|
|
|
|
|
|
47 |
try:
|
48 |
+
from tools import web_search_tool
|
49 |
+
search_result = web_search_tool.func("Python programming language")
|
50 |
+
print(f"Web search test: {search_result[:100]}...")
|
|
|
|
|
|
|
|
|
|
|
51 |
except Exception as e:
|
52 |
+
print(f"Web search test failed: {e}")
|
|
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
print()
|
55 |
+
|
56 |
+
def test_agent_simple():
|
57 |
+
"""Test the agent with a simple question."""
|
58 |
+
print("🤖 Testing Smart agent with simple question...")
|
59 |
|
60 |
+
test_question = "What is 25 + 17?"
|
61 |
+
try:
|
62 |
+
result = smart_agent(test_question)
|
63 |
+
print(f"Question: {test_question}")
|
64 |
+
print(f"Answer: {result}")
|
65 |
+
print("✅ Simple test passed!")
|
66 |
+
except Exception as e:
|
67 |
+
print(f"❌ Simple test failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
def test_agent_with_api():
|
72 |
+
"""Test the agent with a real GAIA question from the API."""
|
73 |
+
print("🌐 Testing with real GAIA question from API...")
|
|
|
|
|
74 |
|
|
|
|
|
75 |
try:
|
76 |
question_data = fetch_random_question()
|
77 |
if not question_data:
|
78 |
+
print("❌ Failed to fetch question from API")
|
79 |
+
return
|
80 |
|
81 |
+
task_id = question_data.get("task_id")
|
82 |
+
question = question_data.get("question")
|
83 |
|
84 |
+
print(f"Task ID: {task_id}")
|
85 |
+
print(f"Question: {question}")
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
+
# Run the agent
|
88 |
+
answer = smart_agent(question, task_id)
|
89 |
+
print(f"Agent Answer: {answer}")
|
90 |
+
print("✅ API test completed!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
except Exception as e:
|
93 |
+
print(f"❌ API test failed: {e}")
|
|
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
print()
|
96 |
+
|
97 |
+
def check_environment():
|
98 |
+
"""Check if all required environment variables are set."""
|
99 |
+
print("🔍 Checking environment...")
|
100 |
|
101 |
+
required_vars = ["HUGGINGFACE_API_TOKEN"]
|
102 |
+
missing_vars = []
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
for var in required_vars:
|
105 |
+
if not os.getenv(var):
|
106 |
+
missing_vars.append(var)
|
107 |
+
else:
|
108 |
+
print(f"✅ {var} is set")
|
109 |
|
110 |
+
if missing_vars:
|
111 |
+
print(f"❌ Missing environment variables: {missing_vars}")
|
112 |
+
print("Please set these in your .env file or environment")
|
|
|
|
|
113 |
return False
|
114 |
|
115 |
+
print("✅ All required environment variables are set")
|
116 |
+
return True
|
117 |
+
|
118 |
+
def main():
|
119 |
+
"""Run all tests."""
|
120 |
+
print("🚀 Starting GAIA Agent Local Tests")
|
121 |
+
print("=" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
# Check environment first
|
124 |
+
if not check_environment():
|
125 |
+
print("❌ Environment check failed. Please fix and try again.")
|
126 |
+
return
|
|
|
127 |
|
128 |
print()
|
129 |
|
130 |
+
# Run tests
|
131 |
+
# test_question_analysis()
|
132 |
+
# test_tools()
|
133 |
+
# test_agent_simple()
|
134 |
+
test_agent_with_api()
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools.py
CHANGED
@@ -4,21 +4,28 @@ import os
|
|
4 |
from PIL import Image
|
5 |
import io
|
6 |
import base64
|
7 |
-
from
|
8 |
from typing import Optional
|
9 |
import json
|
10 |
import PyPDF2
|
11 |
import tempfile
|
12 |
-
import requests
|
13 |
-
from bs4 import BeautifulSoup
|
14 |
-
from langchain_community.vectorstores import FAISS
|
15 |
-
from langchain_openai import OpenAIEmbeddings
|
16 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
17 |
-
from langchain.schema import Document
|
18 |
-
from dotenv import load_dotenv
|
19 |
|
20 |
-
#
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def file_download_tool_func(task_id: str) -> str:
|
24 |
"""Downloads a file associated with a GAIA task ID."""
|
@@ -101,6 +108,26 @@ image_analysis_tool = Tool(
|
|
101 |
description="Analyzes images to extract information. Use this for questions involving visual content."
|
102 |
)
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
def text_processor_tool_func(text: str, operation: str = "summarize") -> str:
|
105 |
"""Processes text for various operations like summarization, extraction, etc."""
|
106 |
try:
|
@@ -143,267 +170,11 @@ text_processor_tool = Tool(
|
|
143 |
description="Processes text for various operations like summarization, number extraction, date extraction. Specify operation as second parameter."
|
144 |
)
|
145 |
|
146 |
-
def enhanced_web_retrieval_tool_func(query: str, backend: str = "bing") -> str:
|
147 |
-
"""Enhanced web search with cascading fallback: Wikipedia first, then general web search."""
|
148 |
-
try:
|
149 |
-
print(f"🔍 Enhanced web retrieval for: {query}")
|
150 |
-
|
151 |
-
# Step 1: Try Wikipedia search first
|
152 |
-
print("📚 Searching Wikipedia...")
|
153 |
-
wikipedia_results = get_wikipedia_search_urls(query, backend)
|
154 |
-
|
155 |
-
if has_sufficient_results(wikipedia_results):
|
156 |
-
print(f"✅ Found {len(wikipedia_results)} Wikipedia results")
|
157 |
-
documents = fetch_and_process_results(wikipedia_results, "Wikipedia")
|
158 |
-
if documents:
|
159 |
-
return search_documents_with_vector_store(documents, query, "Wikipedia")
|
160 |
-
|
161 |
-
# Step 2: Fallback to general web search
|
162 |
-
print("🌐 Wikipedia results insufficient, searching general web...")
|
163 |
-
web_results = get_general_web_search_urls(query, backend)
|
164 |
-
|
165 |
-
if web_results:
|
166 |
-
print(f"✅ Found {len(web_results)} general web results")
|
167 |
-
documents = fetch_and_process_results(web_results, "General Web")
|
168 |
-
if documents:
|
169 |
-
return search_documents_with_vector_store(documents, query, "General Web")
|
170 |
-
|
171 |
-
return "No sufficient results found in Wikipedia or general web search."
|
172 |
-
|
173 |
-
except Exception as e:
|
174 |
-
return f"Enhanced web retrieval failed: {str(e)}"
|
175 |
-
|
176 |
-
def get_wikipedia_search_urls(query: str, backend: str = "auto") -> list:
|
177 |
-
"""Get search results from English Wikipedia using DDGS."""
|
178 |
-
try:
|
179 |
-
with DDGS() as ddgs:
|
180 |
-
# Create Wikipedia-specific search queries
|
181 |
-
wikipedia_queries = [
|
182 |
-
f"{query} site:en.wikipedia.org"
|
183 |
-
]
|
184 |
-
|
185 |
-
search_results = []
|
186 |
-
seen_urls = set()
|
187 |
-
|
188 |
-
for wiki_query in wikipedia_queries:
|
189 |
-
try:
|
190 |
-
results = list(ddgs.text(
|
191 |
-
wiki_query,
|
192 |
-
max_results=8,
|
193 |
-
region="us-en",
|
194 |
-
backend=backend,
|
195 |
-
safesearch="moderate"
|
196 |
-
))
|
197 |
-
|
198 |
-
for result in results:
|
199 |
-
url = result.get('href', '')
|
200 |
-
|
201 |
-
# Only include Wikipedia URLs and avoid duplicates
|
202 |
-
if 'en.wikipedia.org' in url and url not in seen_urls:
|
203 |
-
search_results.append({
|
204 |
-
'url': url,
|
205 |
-
'title': result.get('title', 'No title'),
|
206 |
-
'snippet': result.get('body', 'No content')
|
207 |
-
})
|
208 |
-
seen_urls.add(url)
|
209 |
-
|
210 |
-
# Limit to 6 unique Wikipedia pages
|
211 |
-
if len(search_results) >= 6:
|
212 |
-
break
|
213 |
-
|
214 |
-
if len(search_results) >= 6:
|
215 |
-
break
|
216 |
-
|
217 |
-
except Exception as e:
|
218 |
-
print(f"Wikipedia search attempt failed: {e}")
|
219 |
-
continue
|
220 |
-
|
221 |
-
return search_results
|
222 |
-
|
223 |
-
except Exception as e:
|
224 |
-
print(f"Wikipedia search URL retrieval failed: {e}")
|
225 |
-
return []
|
226 |
-
|
227 |
-
def get_general_web_search_urls(query: str, backend: str = "auto") -> list:
|
228 |
-
"""Get search results from general web using DDGS."""
|
229 |
-
try:
|
230 |
-
with DDGS() as ddgs:
|
231 |
-
search_results = []
|
232 |
-
seen_urls = set()
|
233 |
-
|
234 |
-
try:
|
235 |
-
# General web search without site restriction
|
236 |
-
results = list(ddgs.text(
|
237 |
-
query,
|
238 |
-
max_results=8,
|
239 |
-
region="us-en",
|
240 |
-
backend=backend,
|
241 |
-
safesearch="moderate"
|
242 |
-
))
|
243 |
-
|
244 |
-
for result in results:
|
245 |
-
url = result.get('href', '')
|
246 |
-
|
247 |
-
# Avoid duplicates and filter out low-quality sources
|
248 |
-
if url not in seen_urls and is_quality_source(url):
|
249 |
-
search_results.append({
|
250 |
-
'url': url,
|
251 |
-
'title': result.get('title', 'No title'),
|
252 |
-
'snippet': result.get('body', 'No content')
|
253 |
-
})
|
254 |
-
seen_urls.add(url)
|
255 |
-
|
256 |
-
# Limit to 6 unique web pages
|
257 |
-
if len(search_results) >= 6:
|
258 |
-
break
|
259 |
-
|
260 |
-
except Exception as e:
|
261 |
-
print(f"General web search attempt failed: {e}")
|
262 |
-
|
263 |
-
return search_results
|
264 |
-
|
265 |
-
except Exception as e:
|
266 |
-
print(f"General web search URL retrieval failed: {e}")
|
267 |
-
return []
|
268 |
-
|
269 |
-
def is_quality_source(url: str) -> bool:
|
270 |
-
"""Filter out low-quality or problematic sources."""
|
271 |
-
low_quality_domains = [
|
272 |
-
'pinterest.com', 'instagram.com', 'facebook.com', 'twitter.com',
|
273 |
-
'tiktok.com', 'youtube.com', 'reddit.com'
|
274 |
-
]
|
275 |
-
|
276 |
-
for domain in low_quality_domains:
|
277 |
-
if domain in url.lower():
|
278 |
-
return False
|
279 |
-
|
280 |
-
return True
|
281 |
-
|
282 |
-
def has_sufficient_results(results: list) -> bool:
|
283 |
-
"""Check if search results are sufficient to proceed."""
|
284 |
-
if not results:
|
285 |
-
return False
|
286 |
-
|
287 |
-
# Check for minimum number of results
|
288 |
-
if len(results) < 2:
|
289 |
-
return False
|
290 |
-
|
291 |
-
# Check if results have meaningful content
|
292 |
-
meaningful_results = 0
|
293 |
-
for result in results:
|
294 |
-
snippet = result.get('snippet', '')
|
295 |
-
title = result.get('title', '')
|
296 |
-
|
297 |
-
# Consider result meaningful if it has substantial content
|
298 |
-
if len(snippet) > 50 or len(title) > 10:
|
299 |
-
meaningful_results += 1
|
300 |
-
|
301 |
-
return meaningful_results >= 2
|
302 |
-
|
303 |
-
def fetch_and_process_results(results: list, source_type: str) -> list:
|
304 |
-
"""Fetch and process webpage content from search results."""
|
305 |
-
documents = []
|
306 |
-
|
307 |
-
for result in results[:4]: # Process top 4 results
|
308 |
-
url = result.get('url', '')
|
309 |
-
title = result.get('title', 'No title')
|
310 |
-
|
311 |
-
print(f"📄 Fetching content from: {title}")
|
312 |
-
content = fetch_webpage_content(url)
|
313 |
-
|
314 |
-
if content and len(content.strip()) > 100: # Ensure meaningful content
|
315 |
-
doc = Document(
|
316 |
-
page_content=content,
|
317 |
-
metadata={
|
318 |
-
"source": url,
|
319 |
-
"title": title,
|
320 |
-
"source_type": source_type
|
321 |
-
}
|
322 |
-
)
|
323 |
-
documents.append(doc)
|
324 |
-
|
325 |
-
return documents
|
326 |
-
|
327 |
-
def fetch_webpage_content(url: str) -> str:
|
328 |
-
"""Fetch and extract clean text content from a webpage."""
|
329 |
-
try:
|
330 |
-
headers = {
|
331 |
-
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
332 |
-
}
|
333 |
-
|
334 |
-
response = requests.get(url, headers=headers, timeout=10)
|
335 |
-
response.raise_for_status()
|
336 |
-
|
337 |
-
# Parse HTML and extract text
|
338 |
-
soup = BeautifulSoup(response.content, 'html.parser')
|
339 |
-
|
340 |
-
# Remove script and style elements
|
341 |
-
for script in soup(["script", "style"]):
|
342 |
-
script.decompose()
|
343 |
-
|
344 |
-
# Get text content
|
345 |
-
text = soup.get_text()
|
346 |
-
|
347 |
-
# Clean up text
|
348 |
-
lines = (line.strip() for line in text.splitlines())
|
349 |
-
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
350 |
-
text = ' '.join(chunk for chunk in chunks if chunk)
|
351 |
-
|
352 |
-
return text[:30000]
|
353 |
-
|
354 |
-
except Exception as e:
|
355 |
-
print(f"Failed to fetch content from {url}: {e}")
|
356 |
-
return ""
|
357 |
-
|
358 |
-
def search_documents_with_vector_store(documents: list, query: str, source_type: str = "Web") -> str:
|
359 |
-
"""Create vector store and search for relevant information."""
|
360 |
-
try:
|
361 |
-
# Split documents into chunks
|
362 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
363 |
-
chunk_size=1000,
|
364 |
-
chunk_overlap=200,
|
365 |
-
length_function=len,
|
366 |
-
)
|
367 |
-
|
368 |
-
splits = text_splitter.split_documents(documents)
|
369 |
-
|
370 |
-
if not splits:
|
371 |
-
return "No content to process after splitting."
|
372 |
-
|
373 |
-
# Create embeddings and vector store
|
374 |
-
embeddings = OpenAIEmbeddings()
|
375 |
-
vectorstore = FAISS.from_documents(splits, embeddings)
|
376 |
-
|
377 |
-
# Search for relevant chunks with the original query
|
378 |
-
relevant_docs = vectorstore.similarity_search(query, k=5)
|
379 |
-
|
380 |
-
# Format results with source type indication
|
381 |
-
results = []
|
382 |
-
results.append(f"🔍 Search Results from {source_type}:\n")
|
383 |
-
|
384 |
-
for i, doc in enumerate(relevant_docs, 1):
|
385 |
-
source = doc.metadata.get('source', 'Unknown source')
|
386 |
-
title = doc.metadata.get('title', 'No title')
|
387 |
-
source_type_meta = doc.metadata.get('source_type', source_type)
|
388 |
-
content = doc.page_content[:2000] # Increased content length
|
389 |
-
|
390 |
-
results.append(f"Result {i} ({source_type_meta}) - {title}:\n{content}\nSource: {source}\n")
|
391 |
-
|
392 |
-
return "\n---\n".join(results)
|
393 |
-
|
394 |
-
except Exception as e:
|
395 |
-
return f"Vector search failed: {str(e)}"
|
396 |
-
|
397 |
-
web_search_tool = Tool(
|
398 |
-
name="enhanced_web_retrieval",
|
399 |
-
func=enhanced_web_retrieval_tool_func,
|
400 |
-
description="Enhanced cascading web search with vector retrieval. First searches Wikipedia for reliable factual information, then falls back to general web search if insufficient results are found. Supports multiple search backends (auto, html, lite, bing) and uses semantic search to find relevant information. Ideal for comprehensive research on any topic."
|
401 |
-
)
|
402 |
-
|
403 |
# List of all tools for easy import
|
404 |
agent_tools = [
|
405 |
web_search_tool,
|
406 |
file_download_tool,
|
407 |
image_analysis_tool,
|
|
|
408 |
text_processor_tool
|
409 |
]
|
|
|
4 |
from PIL import Image
|
5 |
import io
|
6 |
import base64
|
7 |
+
from langchain_community.tools import DuckDuckGoSearchRun
|
8 |
from typing import Optional
|
9 |
import json
|
10 |
import PyPDF2
|
11 |
import tempfile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# Initialize web search tool
|
14 |
+
search_tool = DuckDuckGoSearchRun()
|
15 |
+
|
16 |
+
def web_search_tool_func(query: str) -> str:
|
17 |
+
"""Searches the web for information using DuckDuckGo."""
|
18 |
+
try:
|
19 |
+
results = search_tool.run(query)
|
20 |
+
return results
|
21 |
+
except Exception as e:
|
22 |
+
return f"Web search failed: {str(e)}"
|
23 |
+
|
24 |
+
web_search_tool = Tool(
|
25 |
+
name="web_search",
|
26 |
+
func=web_search_tool_func,
|
27 |
+
description="Searches the web for current information. Use this for factual questions, recent events, or when you need to find information not in your training data."
|
28 |
+
)
|
29 |
|
30 |
def file_download_tool_func(task_id: str) -> str:
|
31 |
"""Downloads a file associated with a GAIA task ID."""
|
|
|
108 |
description="Analyzes images to extract information. Use this for questions involving visual content."
|
109 |
)
|
110 |
|
111 |
+
def calculator_tool_func(expression: str) -> str:
|
112 |
+
"""Performs mathematical calculations safely."""
|
113 |
+
try:
|
114 |
+
# Basic safety check - only allow certain characters
|
115 |
+
allowed_chars = set('0123456789+-*/().= ')
|
116 |
+
if not all(c in allowed_chars for c in expression):
|
117 |
+
return f"Invalid characters in expression: {expression}"
|
118 |
+
|
119 |
+
# Use eval safely for basic math
|
120 |
+
result = eval(expression)
|
121 |
+
return f"Calculation result: {expression} = {result}"
|
122 |
+
except Exception as e:
|
123 |
+
return f"Calculation failed for '{expression}': {str(e)}"
|
124 |
+
|
125 |
+
calculator_tool = Tool(
|
126 |
+
name="calculator",
|
127 |
+
func=calculator_tool_func,
|
128 |
+
description="Performs mathematical calculations. Use this for numerical computations and math problems."
|
129 |
+
)
|
130 |
+
|
131 |
def text_processor_tool_func(text: str, operation: str = "summarize") -> str:
|
132 |
"""Processes text for various operations like summarization, extraction, etc."""
|
133 |
try:
|
|
|
170 |
description="Processes text for various operations like summarization, number extraction, date extraction. Specify operation as second parameter."
|
171 |
)
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
# List of all tools for easy import
|
174 |
agent_tools = [
|
175 |
web_search_tool,
|
176 |
file_download_tool,
|
177 |
image_analysis_tool,
|
178 |
+
calculator_tool,
|
179 |
text_processor_tool
|
180 |
]
|
utils.py
CHANGED
@@ -41,48 +41,81 @@ def submit_answers(username: str, agent_code: str, answers: List[Dict[str, str]]
|
|
41 |
|
42 |
def format_gaia_answer(raw_answer: str) -> str:
|
43 |
"""Format the agent's raw answer for GAIA submission (exact match)."""
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
final_answer_pattern = r'FINAL ANSWER:\s*(.+?)(?:\n|$)'
|
48 |
-
match = re.search(final_answer_pattern, raw_answer, re.IGNORECASE | re.DOTALL)
|
49 |
|
50 |
-
|
51 |
-
answer
|
52 |
-
|
53 |
-
# Fallback: try to extract from common patterns
|
54 |
-
fallback_patterns = [
|
55 |
-
r'(?:The\s+)?(?:final\s+)?answer\s+is:?\s*(.+?)(?:\n|$)',
|
56 |
-
r'(?:Answer|Result):\s*(.+?)(?:\n|$)',
|
57 |
-
]
|
58 |
-
|
59 |
-
answer = raw_answer.strip()
|
60 |
-
for pattern in fallback_patterns:
|
61 |
-
match = re.search(pattern, answer, re.IGNORECASE)
|
62 |
-
if match:
|
63 |
-
answer = match.group(1).strip()
|
64 |
-
break
|
65 |
-
|
66 |
-
# Apply GAIA formatting rules
|
67 |
-
answer = answer.strip()
|
68 |
|
69 |
# Remove trailing punctuation that might not be in ground truth
|
70 |
while answer and answer[-1] in '.!?':
|
71 |
answer = answer[:-1].strip()
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
84 |
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
def log_agent_step(step: str, result: str, step_number: int = None):
|
88 |
"""Log agent execution steps for debugging."""
|
|
|
41 |
|
42 |
def format_gaia_answer(raw_answer: str) -> str:
|
43 |
"""Format the agent's raw answer for GAIA submission (exact match)."""
|
44 |
+
# Remove common prefixes that might interfere with exact matching
|
45 |
+
prefixes_to_remove = [
|
46 |
+
"FINAL ANSWER:",
|
47 |
+
"Final Answer:",
|
48 |
+
"Answer:",
|
49 |
+
"The answer is:",
|
50 |
+
"The final answer is:",
|
51 |
+
]
|
52 |
|
53 |
+
answer = raw_answer.strip()
|
|
|
|
|
54 |
|
55 |
+
for prefix in prefixes_to_remove:
|
56 |
+
if answer.startswith(prefix):
|
57 |
+
answer = answer[len(prefix):].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
# Remove trailing punctuation that might not be in ground truth
|
60 |
while answer and answer[-1] in '.!?':
|
61 |
answer = answer[:-1].strip()
|
62 |
|
63 |
+
return answer
|
64 |
+
|
65 |
+
def analyze_question_type(question: str) -> Dict[str, bool]:
|
66 |
+
"""Analyze what capabilities a question might need."""
|
67 |
+
question_lower = question.lower()
|
68 |
|
69 |
+
analysis = {
|
70 |
+
"needs_web_search": any(keyword in question_lower for keyword in [
|
71 |
+
"current", "recent", "latest", "today", "now", "2024", "2023"
|
72 |
+
]),
|
73 |
+
"needs_file_processing": "file" in question_lower or "document" in question_lower,
|
74 |
+
"needs_calculation": any(keyword in question_lower for keyword in [
|
75 |
+
"calculate", "compute", "sum", "total", "average", "percentage", "multiply", "divide"
|
76 |
+
]),
|
77 |
+
"needs_image_analysis": any(keyword in question_lower for keyword in [
|
78 |
+
"image", "picture", "photo", "visual", "shown", "displayed"
|
79 |
+
]),
|
80 |
+
"needs_text_processing": any(keyword in question_lower for keyword in [
|
81 |
+
"extract", "find in", "search for", "list", "count"
|
82 |
+
])
|
83 |
+
}
|
84 |
|
85 |
+
return analysis
|
86 |
+
|
87 |
+
def create_execution_plan(question: str, task_id: str = None) -> List[str]:
|
88 |
+
"""Create a step-by-step execution plan for a GAIA question."""
|
89 |
+
analysis = analyze_question_type(question)
|
90 |
+
plan = []
|
91 |
|
92 |
+
# Always start with understanding the question
|
93 |
+
plan.append("Analyze the question to understand what information is needed")
|
94 |
+
|
95 |
+
# Add file processing if needed
|
96 |
+
if task_id and analysis["needs_file_processing"]:
|
97 |
+
plan.append(f"Download and process any files associated with task {task_id}")
|
98 |
+
|
99 |
+
# Add web search if needed
|
100 |
+
if analysis["needs_web_search"]:
|
101 |
+
plan.append("Search the web for current/recent information")
|
102 |
+
|
103 |
+
# Add image analysis if needed
|
104 |
+
if analysis["needs_image_analysis"]:
|
105 |
+
plan.append("Analyze any images for visual information")
|
106 |
+
|
107 |
+
# Add calculation if needed
|
108 |
+
if analysis["needs_calculation"]:
|
109 |
+
plan.append("Perform necessary calculations")
|
110 |
+
|
111 |
+
# Add text processing if needed
|
112 |
+
if analysis["needs_text_processing"]:
|
113 |
+
plan.append("Process and extract specific information from text")
|
114 |
+
|
115 |
+
# Always end with synthesis
|
116 |
+
plan.append("Synthesize all information to provide the final answer")
|
117 |
+
|
118 |
+
return plan
|
119 |
|
120 |
def log_agent_step(step: str, result: str, step_number: int = None):
|
121 |
"""Log agent execution steps for debugging."""
|