Update agent.py
Browse files
agent.py
CHANGED
@@ -64,15 +64,41 @@ logging.basicConfig(level=logging.INFO)
|
|
64 |
logging.getLogger("llama_index.core.agent").setLevel(logging.DEBUG)
|
65 |
logging.getLogger("llama_index.llms").setLevel(logging.DEBUG)
|
66 |
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
proj_llm = HuggingFaceLLM(
|
69 |
model_name=model_id,
|
70 |
tokenizer_name=model_id,
|
71 |
-
device_map="auto",
|
72 |
-
model_kwargs={
|
|
|
|
|
|
|
73 |
generate_kwargs={"temperature": 0.1, "top_p": 0.3} # More focused
|
74 |
)
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
embed_model = HuggingFaceEmbedding("BAAI/bge-small-en-v1.5")
|
77 |
|
78 |
wandb.init(project="gaia-llamaindex-agents") # Choisis ton nom de projet
|
@@ -285,6 +311,19 @@ def search_and_extract_top_url(query: str) -> str:
|
|
285 |
else:
|
286 |
return "No URL could be extracted from the search results."
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
# 3. Create the final, customized FunctionTool for the agent.
|
289 |
# This is the tool you will actually give to your agent.
|
290 |
extract_url_tool = FunctionTool.from_defaults(
|
@@ -295,6 +334,16 @@ extract_url_tool = FunctionTool.from_defaults(
|
|
295 |
)
|
296 |
)
|
297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
safe_globals = {
|
299 |
"__builtins__": {
|
300 |
"len": len, "str": str, "int": int, "float": float,
|
@@ -303,65 +352,65 @@ safe_globals = {
|
|
303 |
"range": range, "zip": zip, "map": map, "filter": filter,
|
304 |
"any": any, "all": all, "type": type, "isinstance": isinstance,
|
305 |
"print": print, "open": open, "bool": bool, "set": set, "tuple": tuple
|
306 |
-
}
|
307 |
-
|
308 |
-
"math": __import__("math"),
|
309 |
-
"datetime": __import__("datetime"),
|
310 |
-
"re": __import__("re"),
|
311 |
-
"os": __import__("os"),
|
312 |
-
"sys": __import__("sys"),
|
313 |
-
"json": __import__("json"),
|
314 |
-
"csv": __import__("csv"),
|
315 |
-
"random": __import__("random"),
|
316 |
-
"itertools": __import__("itertools"),
|
317 |
-
"collections": __import__("collections"),
|
318 |
-
"functools": __import__("functools"),
|
319 |
-
|
320 |
-
# Data Science and Numerical Computing
|
321 |
-
"numpy": __import__("numpy"),
|
322 |
-
"np": __import__("numpy"),
|
323 |
-
"pandas": __import__("pandas"),
|
324 |
-
"pd": __import__("pandas"),
|
325 |
-
"scipy": __import__("scipy"),
|
326 |
-
|
327 |
-
# Visualization
|
328 |
-
"matplotlib": __import__("matplotlib"),
|
329 |
-
"plt": __import__("matplotlib.pyplot"),
|
330 |
-
"seaborn": __import__("seaborn"),
|
331 |
-
"sns": __import__("seaborn"),
|
332 |
-
"plotly": __import__("plotly"),
|
333 |
-
|
334 |
-
# Machine Learning
|
335 |
-
"sklearn": __import__("sklearn"),
|
336 |
-
"xgboost": __import__("xgboost"),
|
337 |
-
"lightgbm": __import__("lightgbm"),
|
338 |
-
|
339 |
-
# Statistics
|
340 |
-
"statistics": __import__("statistics"),
|
341 |
-
"statsmodels": __import__("statsmodels"),
|
342 |
-
|
343 |
-
# Image Processing
|
344 |
-
"PIL": __import__("PIL"),
|
345 |
-
"cv2": __import__("cv2"),
|
346 |
-
"skimage": __import__("skimage"),
|
347 |
-
|
348 |
-
# Time Series
|
349 |
-
"pytz": __import__("pytz"),
|
350 |
-
|
351 |
-
# Utilities
|
352 |
-
"tqdm": __import__("tqdm"),
|
353 |
-
"pickle": __import__("pickle"),
|
354 |
-
"gzip": __import__("gzip"),
|
355 |
-
"base64": __import__("base64"),
|
356 |
-
"hashlib": __import__("hashlib"),
|
357 |
-
|
358 |
-
# Scientific Computing
|
359 |
-
"sympy": __import__("sympy"),
|
360 |
|
361 |
-
|
362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
}
|
364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
def execute_python_code(code: str) -> str:
|
366 |
try:
|
367 |
exec_locals = {}
|
@@ -376,83 +425,20 @@ def execute_python_code(code: str) -> str:
|
|
376 |
return f"Code execution failed: {str(e)}"
|
377 |
|
378 |
code_execution_tool = FunctionTool.from_defaults(
|
379 |
-
fn=execute_python_code,
|
380 |
-
name="Python Code Execution",
|
381 |
-
description="
|
382 |
)
|
383 |
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
tokenizer_name="Qwen/Qwen2.5-Coder-3B",
|
394 |
-
device_map="auto",
|
395 |
-
model_kwargs={"torch_dtype": "auto"},
|
396 |
-
# Set generation parameters for precise, non-creative code output
|
397 |
-
generate_kwargs={"temperature": 0.0, "do_sample": False}
|
398 |
-
)
|
399 |
-
|
400 |
-
def generate_python_code(query: str) -> str:
|
401 |
-
"""
|
402 |
-
Generates executable Python code based on a natural language query.
|
403 |
-
|
404 |
-
Args:
|
405 |
-
query: A detailed description of the desired functionality for the Python code.
|
406 |
-
|
407 |
-
Returns:
|
408 |
-
A string containing only the generated Python code, ready for execution.
|
409 |
-
"""
|
410 |
-
if not code_llm:
|
411 |
-
return "Error: Code generation model is not available."
|
412 |
-
|
413 |
-
# --- 2. Create a precise prompt for the code model ---
|
414 |
-
# This prompt explicitly asks for only code, no explanations.
|
415 |
-
prompt = f"""
|
416 |
-
Your task is to generate ONLY the Python code for the following request.
|
417 |
-
Do not include any explanations, introductory text, or markdown formatting like '```python'.
|
418 |
-
The output must be a single, clean block of Python code.
|
419 |
-
|
420 |
-
IMPORTANT LIMITATIONS:
|
421 |
-
Your code will be executed in a restricted environment with limited functions and modules.
|
422 |
-
{str(safe_globals)}
|
423 |
-
Only use the functions and modules listed above. Do not use imports or other built-in functions.
|
424 |
-
|
425 |
-
Request: "{query}"
|
426 |
-
|
427 |
-
Python Code:
|
428 |
-
"""
|
429 |
-
|
430 |
-
# --- 3. Generate the response and post-process it ---
|
431 |
-
response = code_llm.complete(prompt)
|
432 |
-
raw_code = str(response)
|
433 |
-
|
434 |
-
# --- 4. Clean the output to ensure it's pure code ---
|
435 |
-
# Models often wrap code in markdown fences, this removes them.
|
436 |
-
code_match = re.search(r"```(?:python)?\n(.*)```", raw_code, re.DOTALL)
|
437 |
-
if code_match:
|
438 |
-
# Extract the code from within the markdown block
|
439 |
-
return code_match.group(1).strip()
|
440 |
-
else:
|
441 |
-
# If no markdown, assume the model followed instructions and return the text directly
|
442 |
-
return raw_code.strip()
|
443 |
-
|
444 |
-
|
445 |
-
# --- 5. Create the LlamaIndex Tool from the function ---
|
446 |
-
generate_code_tool = FunctionTool.from_defaults(
|
447 |
-
fn=generate_python_code,
|
448 |
-
name="generate_python_code_tool",
|
449 |
-
description=(
|
450 |
-
"Use this tool to generate executable Python code ONLY for mathematical calculations and problem solving. "
|
451 |
-
"This tool is specifically designed for numerical computations, statistical analysis, algebraic operations, "
|
452 |
-
"mathematical modeling, and scientific calculations."
|
453 |
-
"DO NOT use this tool for document processing, text manipulation, or data parsing - use appropriate specialized tools instead."
|
454 |
-
"The tool returns a string containing only the Python code for mathematical operations."
|
455 |
-
)
|
456 |
)
|
457 |
|
458 |
def clean_response(response: str) -> str:
|
@@ -540,15 +526,6 @@ def final_answer_tool(agent_response: str, question: str) -> str:
|
|
540 |
|
541 |
return formatted_answer
|
542 |
|
543 |
-
# Create the simplified final answer tool
|
544 |
-
final_answer_function_tool = FunctionTool.from_defaults(
|
545 |
-
fn=final_answer_tool,
|
546 |
-
name="final_answer_tool",
|
547 |
-
description=(
|
548 |
-
"Use this tool to format the final answer according to GAIA requirements. "
|
549 |
-
"Input the agent's response and the original question to get properly formatted output."
|
550 |
-
)
|
551 |
-
)
|
552 |
|
553 |
class EnhancedGAIAAgent:
|
554 |
def __init__(self):
|
@@ -559,28 +536,9 @@ class EnhancedGAIAAgent:
|
|
559 |
if not hf_token:
|
560 |
print("Warning: HUGGINGFACEHUB_API_TOKEN not found, some features may not work")
|
561 |
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
read_and_parse_tool,
|
566 |
-
information_retrieval_tool,
|
567 |
-
code_execution_tool,
|
568 |
-
generate_code_tool,
|
569 |
-
]
|
570 |
-
|
571 |
-
# Create main coordinator using only defined tools
|
572 |
-
self.coordinator = ReActAgent(
|
573 |
-
name="GAIACoordinator",
|
574 |
-
system_prompt="""
|
575 |
-
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
576 |
-
""",
|
577 |
-
llm=proj_llm,
|
578 |
-
tools=self.available_tools,
|
579 |
-
max_steps=15,
|
580 |
-
verbose=True,
|
581 |
-
callback_manager=callback_manager,
|
582 |
-
)
|
583 |
-
|
584 |
|
585 |
def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
|
586 |
"""Download file associated with task_id"""
|
@@ -618,8 +576,7 @@ You are a general AI assistant. I will ask you a question. Report your thoughts,
|
|
618 |
GAIA Task ID: {task_id}
|
619 |
Question: {question}
|
620 |
{f'File available: {file_path}' if file_path else 'No additional files'}
|
621 |
-
"""
|
622 |
-
|
623 |
try:
|
624 |
ctx = Context(self.coordinator)
|
625 |
print("=== AGENT REASONING STEPS ===")
|
|
|
64 |
logging.getLogger("llama_index.core.agent").setLevel(logging.DEBUG)
|
65 |
logging.getLogger("llama_index.llms").setLevel(logging.DEBUG)
|
66 |
|
67 |
+
def get_max_memory_config(max_memory_per_gpu):
|
68 |
+
"""Generate max_memory config for available GPUs"""
|
69 |
+
if torch.cuda.is_available():
|
70 |
+
num_gpus = torch.cuda.device_count()
|
71 |
+
max_memory = {}
|
72 |
+
for i in range(num_gpus):
|
73 |
+
max_memory[i] = max_memory_per_gpu
|
74 |
+
return max_memory
|
75 |
+
return None
|
76 |
+
|
77 |
+
model_id = "google/gemma-3-12b-it"
|
78 |
proj_llm = HuggingFaceLLM(
|
79 |
model_name=model_id,
|
80 |
tokenizer_name=model_id,
|
81 |
+
device_map="auto",
|
82 |
+
model_kwargs={
|
83 |
+
"torch_dtype": "auto",
|
84 |
+
"max_memory": get_max_memory_config("10GB")
|
85 |
+
},
|
86 |
generate_kwargs={"temperature": 0.1, "top_p": 0.3} # More focused
|
87 |
)
|
88 |
|
89 |
+
code_llm = HuggingFaceLLM(
|
90 |
+
model_name="Qwen/Qwen2.5-Coder-3B",
|
91 |
+
tokenizer_name="Qwen/Qwen2.5-Coder-3B",
|
92 |
+
device_map="auto",
|
93 |
+
model_kwargs={
|
94 |
+
"torch_dtype": "auto",
|
95 |
+
"max_memory": get_max_memory_config("3GB")
|
96 |
+
},
|
97 |
+
# Set generation parameters for precise, non-creative code output
|
98 |
+
generate_kwargs={"temperature": 0.0, "do_sample": False}
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
embed_model = HuggingFaceEmbedding("BAAI/bge-small-en-v1.5")
|
103 |
|
104 |
wandb.init(project="gaia-llamaindex-agents") # Choisis ton nom de projet
|
|
|
311 |
else:
|
312 |
return "No URL could be extracted from the search results."
|
313 |
|
314 |
+
|
315 |
+
# Create external_knowledge agent - ReAct agent with extract_url_tool and information_retrieval tool
|
316 |
+
external_knowledge_agent = ReActAgent(
|
317 |
+
name="external_knowledge_agent",
|
318 |
+
description="Retrieves information from external sources and documents",
|
319 |
+
system_prompt="You are an information retrieval specialist. You find and extract relevant information from external sources, URLs, and documents to answer queries.""",
|
320 |
+
tools=[extract_url_tool, information_retrieval_tool],
|
321 |
+
llm=proj_llm,
|
322 |
+
max_steps=6,
|
323 |
+
verbose=True,
|
324 |
+
callback_manager=callback_manager,
|
325 |
+
)
|
326 |
+
|
327 |
# 3. Create the final, customized FunctionTool for the agent.
|
328 |
# This is the tool you will actually give to your agent.
|
329 |
extract_url_tool = FunctionTool.from_defaults(
|
|
|
334 |
)
|
335 |
)
|
336 |
|
337 |
+
import importlib.util
|
338 |
+
import sys
|
339 |
+
|
340 |
+
def safe_import(module_name):
|
341 |
+
"""Safely import a module, return None if not available"""
|
342 |
+
try:
|
343 |
+
return __import__(module_name)
|
344 |
+
except ImportError:
|
345 |
+
return None
|
346 |
+
|
347 |
safe_globals = {
|
348 |
"__builtins__": {
|
349 |
"len": len, "str": str, "int": int, "float": float,
|
|
|
352 |
"range": range, "zip": zip, "map": map, "filter": filter,
|
353 |
"any": any, "all": all, "type": type, "isinstance": isinstance,
|
354 |
"print": print, "open": open, "bool": bool, "set": set, "tuple": tuple
|
355 |
+
}
|
356 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
|
358 |
+
# Core modules (always available)
|
359 |
+
core_modules = [
|
360 |
+
"math", "datetime", "re", "os", "sys", "json", "csv", "random",
|
361 |
+
"itertools", "collections", "functools", "operator", "copy",
|
362 |
+
"decimal", "fractions", "uuid", "typing", "statistics", "pathlib",
|
363 |
+
"glob", "shutil", "tempfile", "pickle", "gzip", "zipfile", "tarfile",
|
364 |
+
"base64", "hashlib", "secrets", "hmac", "textwrap", "string",
|
365 |
+
"difflib", "socket", "ipaddress", "logging", "warnings", "traceback",
|
366 |
+
"pprint", "threading", "queue", "sqlite3", "urllib", "html", "xml",
|
367 |
+
"configparser"
|
368 |
+
]
|
369 |
+
|
370 |
+
for module in core_modules:
|
371 |
+
imported = safe_import(module)
|
372 |
+
if imported:
|
373 |
+
safe_globals[module] = imported
|
374 |
+
|
375 |
+
# Data science modules (may not be available)
|
376 |
+
optional_modules = {
|
377 |
+
"numpy": "numpy",
|
378 |
+
"np": "numpy",
|
379 |
+
"pandas": "pandas",
|
380 |
+
"pd": "pandas",
|
381 |
+
"scipy": "scipy",
|
382 |
+
"matplotlib": "matplotlib",
|
383 |
+
"plt": "matplotlib.pyplot",
|
384 |
+
"seaborn": "seaborn",
|
385 |
+
"sns": "seaborn",
|
386 |
+
"plotly": "plotly",
|
387 |
+
"sklearn": "sklearn",
|
388 |
+
"statsmodels": "statsmodels",
|
389 |
+
"PIL": "PIL",
|
390 |
+
"skimage": "skimage",
|
391 |
+
"pytz": "pytz",
|
392 |
+
"requests": "requests",
|
393 |
+
"bs4": "bs4",
|
394 |
+
"sympy": "sympy",
|
395 |
+
"tqdm": "tqdm",
|
396 |
+
"yaml": "yaml",
|
397 |
+
"toml": "toml"
|
398 |
}
|
399 |
|
400 |
+
for alias, module_name in optional_modules.items():
|
401 |
+
imported = safe_import(module_name)
|
402 |
+
if imported:
|
403 |
+
safe_globals[alias] = imported
|
404 |
+
|
405 |
+
# Special cases
|
406 |
+
if safe_globals.get("bs4"):
|
407 |
+
safe_globals["BeautifulSoup"] = safe_globals["bs4"].BeautifulSoup
|
408 |
+
|
409 |
+
if safe_globals.get("PIL"):
|
410 |
+
image_module = safe_import("PIL.Image")
|
411 |
+
if image_module:
|
412 |
+
safe_globals["Image"] = image_module
|
413 |
+
|
414 |
def execute_python_code(code: str) -> str:
|
415 |
try:
|
416 |
exec_locals = {}
|
|
|
425 |
return f"Code execution failed: {str(e)}"
|
426 |
|
427 |
code_execution_tool = FunctionTool.from_defaults(
|
428 |
+
fn=execute_python_code,
|
429 |
+
name="Python Code Execution",
|
430 |
+
description="Executes Python code safely for calculations and data processing"
|
431 |
)
|
432 |
|
433 |
+
code_agent = ReActAgent(
|
434 |
+
name="code_agent",
|
435 |
+
description="Handles Python code for calculations and data processing",
|
436 |
+
system_prompt="You are a Python programming specialist. You work with Python code to perform calculations, data analysis, and mathematical operations.",
|
437 |
+
tools=[code_execution_tool],
|
438 |
+
llm=code_llm,
|
439 |
+
max_steps=6,
|
440 |
+
verbose=True,
|
441 |
+
callback_manager=callback_manager,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
)
|
443 |
|
444 |
def clean_response(response: str) -> str:
|
|
|
526 |
|
527 |
return formatted_answer
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
|
530 |
class EnhancedGAIAAgent:
|
531 |
def __init__(self):
|
|
|
536 |
if not hf_token:
|
537 |
print("Warning: HUGGINGFACEHUB_API_TOKEN not found, some features may not work")
|
538 |
|
539 |
+
self.coordinator = AgentWorkflow(
|
540 |
+
agents=[external_knowledge_agent, code_agent],
|
541 |
+
root_agent="external_knowledge_agent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
|
543 |
def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
|
544 |
"""Download file associated with task_id"""
|
|
|
576 |
GAIA Task ID: {task_id}
|
577 |
Question: {question}
|
578 |
{f'File available: {file_path}' if file_path else 'No additional files'}
|
579 |
+
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""",
|
|
|
580 |
try:
|
581 |
ctx = Context(self.coordinator)
|
582 |
print("=== AGENT REASONING STEPS ===")
|