Update agent.py
Browse files
agent.py
CHANGED
@@ -1,32 +1,53 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from llama_index.core import VectorStoreIndex, Document
|
4 |
-
from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
|
5 |
-
from llama_index.core.postprocessor import SentenceTransformerRerank
|
6 |
-
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
7 |
-
from llama_index.core.retrievers import VectorIndexRetriever
|
8 |
-
from llama_index.core.query_engine import RetrieverQueryEngine
|
9 |
-
from llama_index.readers.file import PDFReader, DocxReader, CSVReader, ImageReader
|
10 |
import os
|
11 |
-
from typing import List, Dict, Any
|
12 |
-
from llama_index.tools.arxiv import ArxivToolSpec
|
13 |
-
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
14 |
import re
|
15 |
-
from
|
|
|
|
|
|
|
|
|
16 |
import wandb
|
17 |
-
from
|
|
|
|
|
|
|
|
|
18 |
from llama_index.core.callbacks.base import CallbackManager
|
19 |
from llama_index.core.callbacks.llama_debug import LlamaDebugHandler
|
20 |
-
from llama_index.core import
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
|
|
|
|
23 |
from llama_index.llms.huggingface import HuggingFaceLLM
|
24 |
-
import
|
25 |
-
import
|
26 |
-
from llama_index.
|
27 |
-
from llama_index.
|
28 |
-
from llama_index.
|
29 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
|
@@ -63,25 +84,6 @@ Settings.llm = proj_llm
|
|
63 |
Settings.embed_model = embed_model
|
64 |
Settings.callback_manager = callback_manager
|
65 |
|
66 |
-
import os
|
67 |
-
from typing import List
|
68 |
-
from urllib.parse import urlparse
|
69 |
-
|
70 |
-
from llama_index.core.tools import FunctionTool
|
71 |
-
from llama_index.core import Document
|
72 |
-
|
73 |
-
# --- Import all required official LlamaIndex Readers ---
|
74 |
-
from llama_index.readers.file import (
|
75 |
-
PDFReader,
|
76 |
-
DocxReader,
|
77 |
-
CSVReader,
|
78 |
-
PandasExcelReader,
|
79 |
-
ImageReader,
|
80 |
-
)
|
81 |
-
from llama_index.readers.json import JSONReader
|
82 |
-
from llama_index.readers.web import TrafilaturaWebReader
|
83 |
-
from llama_index.readers.youtube_transcript import YoutubeTranscriptReader
|
84 |
-
from llama_index.readers.audiotranscribe.openai import OpenAIAudioTranscriptReader
|
85 |
|
86 |
def read_and_parse_content(input_path: str) -> List[Document]:
|
87 |
"""
|
@@ -157,12 +159,6 @@ read_and_parse_tool = FunctionTool.from_defaults(
|
|
157 |
)
|
158 |
)
|
159 |
|
160 |
-
from typing import List
|
161 |
-
from llama_index.core import VectorStoreIndex, Document, Settings
|
162 |
-
from llama_index.core.tools import QueryEngineTool
|
163 |
-
from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
|
164 |
-
from llama_index.core.postprocessor import SentenceTransformerRerank
|
165 |
-
from llama_index.core.query_engine import RetrieverQueryEngine
|
166 |
|
167 |
def create_rag_tool(documents: List[Document]) -> QueryEngineTool:
|
168 |
"""
|
@@ -223,11 +219,6 @@ def create_rag_tool(documents: List[Document]) -> QueryEngineTool:
|
|
223 |
|
224 |
return rag_engine_tool
|
225 |
|
226 |
-
|
227 |
-
import re
|
228 |
-
from llama_index.core.tools import FunctionTool
|
229 |
-
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
230 |
-
|
231 |
# 1. Create the base DuckDuckGo search tool from the official spec.
|
232 |
# This tool returns text summaries of search results, not just URLs.
|
233 |
base_duckduckgo_tool = DuckDuckGoSearchToolSpec().to_tool_list()[0]
|
@@ -442,89 +433,128 @@ generate_code_tool = FunctionTool.from_defaults(
|
|
442 |
)
|
443 |
)
|
444 |
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
# Vérification du token HuggingFace
|
451 |
-
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
452 |
-
if not hf_token:
|
453 |
-
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is required")
|
454 |
-
|
455 |
-
# Agent coordinateur principal qui utilise les agents spécialisés comme tools
|
456 |
-
self.coordinator = ReActAgent(
|
457 |
-
name="GAIACoordinator",
|
458 |
-
description="Main GAIA coordinator that uses specialized capabilities as intelligent tools",
|
459 |
-
system_prompt="""
|
460 |
-
You are the main GAIA coordinator using ReAct reasoning methodology.
|
461 |
-
|
462 |
-
You have access to THREE specialist tools:
|
463 |
-
|
464 |
-
**1. analysis_tool** - Advanced multimodal document analysis specialist
|
465 |
-
- Use for: PDF, Word, CSV, image file analysis
|
466 |
-
- When to use: Questions with file attachments, document analysis, data extraction
|
467 |
-
|
468 |
-
**2. research_tool** - Intelligent research specialist with automatic routing
|
469 |
-
- Use for: External knowledge, current events, scientific papers
|
470 |
-
- When to use: Questions requiring external knowledge, factual verification, current information
|
471 |
-
|
472 |
-
**3. code_tool** - Advanced computational specialist using ReAct reasoning
|
473 |
-
- Use for: Mathematical calculations, data processing, logical operations
|
474 |
-
- Capabilities: Generates and executes Python, handles complex computations, step-by-step problem solving
|
475 |
-
- When to use: Precise calculations, data manipulation, mathematical problem solving
|
476 |
-
|
477 |
-
**4. code_execution_tool** - Use only to execute .py file
|
478 |
-
|
479 |
-
CRITICAL: Your final answer must be EXACT and CONCISE as required by GAIA format : NO explanations, NO additional text, ONLY the precise answer
|
480 |
-
""",
|
481 |
-
llm=proj_llm,
|
482 |
-
tools=[analysis_tool, research_tool, code_tool, code_execution_tool],
|
483 |
-
max_steps=10,
|
484 |
-
verbose = True,
|
485 |
-
callback_manager=callback_manager,
|
486 |
-
|
487 |
-
)
|
488 |
-
|
489 |
-
async def format_gaia_answer(self, raw_response: str, original_question: str) -> str:
|
490 |
-
"""
|
491 |
-
Post-process the agent response to extract the exact GAIA format answer
|
492 |
-
"""
|
493 |
-
format_prompt = f"""Extract the exact answer from the response below. Follow GAIA formatting rules strictly.
|
494 |
-
|
495 |
-
Examples:
|
496 |
-
|
497 |
-
Question: "How many research papers were published by the university between 2010 and 2020?"
|
498 |
-
Response: "Based on my analysis of the data, I found that the university published 156 research papers between 2010 and 2020."
|
499 |
-
Answer: 156
|
500 |
-
|
501 |
-
Question: "What is the last name of the software engineer mentioned in the report?"
|
502 |
-
Response: "After reviewing the document, the software engineer mentioned is Dr. Martinez who developed the system."
|
503 |
-
Answer: Martinez
|
504 |
-
|
505 |
-
Question: "List the programming languages from this job description, alphabetized:"
|
506 |
-
Response: "The job description mentions several programming languages including Python, Java, C++, and JavaScript. When alphabetized, these are: C++, Java, JavaScript, Python"
|
507 |
-
Answer: C++, Java, JavaScript, Python
|
508 |
-
|
509 |
-
Question: "Give only the first name of the developer who created the framework."
|
510 |
-
Response: "The framework was created by Sarah Johnson, a senior developer at the company."
|
511 |
-
Answer: Sarah
|
512 |
|
513 |
-
|
514 |
-
|
515 |
-
|
|
|
|
|
|
|
|
|
516 |
|
517 |
-
|
518 |
-
|
519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
-
|
524 |
-
|
525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
try:
|
|
|
528 |
formatting_response = proj_llm.complete(format_prompt)
|
529 |
answer = str(formatting_response).strip()
|
530 |
|
@@ -533,10 +563,107 @@ class EnhancedGAIAAgent:
|
|
533 |
answer = answer.split("Answer:")[-1].strip()
|
534 |
|
535 |
return answer
|
536 |
-
|
537 |
except Exception as e:
|
538 |
-
print(f"
|
539 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
|
541 |
def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
|
542 |
"""Download file associated with task_id"""
|
@@ -544,7 +671,6 @@ class EnhancedGAIAAgent:
|
|
544 |
response = requests.get(f"{api_url}/files/{task_id}", timeout=30)
|
545 |
response.raise_for_status()
|
546 |
|
547 |
-
# Save file locally
|
548 |
filename = f"task_{task_id}_file"
|
549 |
with open(filename, 'wb') as f:
|
550 |
f.write(response.content)
|
@@ -552,53 +678,61 @@ class EnhancedGAIAAgent:
|
|
552 |
except Exception as e:
|
553 |
print(f"Failed to download file for task {task_id}: {e}")
|
554 |
return None
|
555 |
-
|
556 |
-
async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
|
557 |
-
question = question_data.get("Question", "")
|
558 |
-
task_id = question_data.get("task_id", "")
|
559 |
|
560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
561 |
try:
|
562 |
file_path = self.download_gaia_file(task_id)
|
|
|
|
|
|
|
|
|
563 |
except Exception as e:
|
564 |
-
print(f"Failed to download file for task {task_id}: {e}")
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
571 |
|
572 |
-
|
573 |
-
1. If a file is available, use the analysis_tool (except for .py files).
|
574 |
-
2. If a link is in the question, use the research_tool.
|
575 |
-
"""
|
576 |
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
print(f"Formatted answer: {formatted_answer}")
|
598 |
-
|
599 |
-
return formatted_answer
|
600 |
-
|
601 |
-
except Exception as e:
|
602 |
-
error_msg = f"Error processing question: {str(e)}"
|
603 |
-
print(error_msg)
|
604 |
-
return error_msg
|
|
|
1 |
+
# Standard library imports
|
2 |
+
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import os
|
|
|
|
|
|
|
4 |
import re
|
5 |
+
from typing import Dict, Any, List
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
|
8 |
+
# Third-party imports
|
9 |
+
import requests
|
10 |
import wandb
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
+
|
13 |
+
# LlamaIndex core imports
|
14 |
+
from llama_index.core import VectorStoreIndex, Document, Settings
|
15 |
+
from llama_index.core.agent.workflow import FunctionAgent, ReActAgent, AgentStream
|
16 |
from llama_index.core.callbacks.base import CallbackManager
|
17 |
from llama_index.core.callbacks.llama_debug import LlamaDebugHandler
|
18 |
+
from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
|
19 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
20 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
21 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
22 |
+
from llama_index.core.tools import FunctionTool
|
23 |
+
from llama_index.core.workflow import Context
|
24 |
|
25 |
+
# LlamaIndex specialized imports
|
26 |
+
from llama_index.callbacks.wandb import WandbCallbackHandler
|
27 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
28 |
from llama_index.llms.huggingface import HuggingFaceLLM
|
29 |
+
from llama_index.readers.audiotranscribe.openai import OpenAIAudioTranscriptReader
|
30 |
+
from llama_index.readers.file import PDFReader, DocxReader, CSVReader, ImageReader, PandasExcelReader
|
31 |
+
from llama_index.readers.json import JSONReader
|
32 |
+
from llama_index.readers.web import TrafilaturaWebReader
|
33 |
+
from llama_index.readers.youtube_transcript import YoutubeTranscriptReader
|
34 |
+
from llama_index.tools.arxiv import ArxivToolSpec
|
35 |
+
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
36 |
+
|
37 |
+
# --- Import all required official LlamaIndex Readers ---
|
38 |
+
from llama_index.readers.file import (
|
39 |
+
PDFReader,
|
40 |
+
DocxReader,
|
41 |
+
CSVReader,
|
42 |
+
PandasExcelReader,
|
43 |
+
ImageReader,
|
44 |
+
)
|
45 |
+
from typing import List
|
46 |
+
from llama_index.core import VectorStoreIndex, Document, Settings
|
47 |
+
from llama_index.core.tools import QueryEngineTool
|
48 |
+
from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
|
49 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
50 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
51 |
|
52 |
|
53 |
|
|
|
84 |
Settings.embed_model = embed_model
|
85 |
Settings.callback_manager = callback_manager
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
def read_and_parse_content(input_path: str) -> List[Document]:
|
89 |
"""
|
|
|
159 |
)
|
160 |
)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
def create_rag_tool(documents: List[Document]) -> QueryEngineTool:
|
164 |
"""
|
|
|
219 |
|
220 |
return rag_engine_tool
|
221 |
|
|
|
|
|
|
|
|
|
|
|
222 |
# 1. Create the base DuckDuckGo search tool from the official spec.
|
223 |
# This tool returns text summaries of search results, not just URLs.
|
224 |
base_duckduckgo_tool = DuckDuckGoSearchToolSpec().to_tool_list()[0]
|
|
|
433 |
)
|
434 |
)
|
435 |
|
436 |
+
def intelligent_final_answer_tool(agent_response: str, question: str) -> str:
|
437 |
+
"""
|
438 |
+
Enhanced final answer tool with LLM-based reformatting capability.
|
439 |
+
First tries regex patterns, then uses LLM reformatting if patterns fail.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
+
Args:
|
442 |
+
agent_response: The raw response from agent reasoning
|
443 |
+
question: The original question for context
|
444 |
+
|
445 |
+
Returns:
|
446 |
+
Exact answer in GAIA format with validation
|
447 |
+
"""
|
448 |
|
449 |
+
# Define formatting patterns for different question types
|
450 |
+
format_patterns = {
|
451 |
+
'number': r'(\d+(?:\.\d+)?(?:e[+-]?\d+)?)',
|
452 |
+
'name': r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
|
453 |
+
'list': r'([A-Za-z0-9,\s]+)',
|
454 |
+
'country_code': r'([A-Z]{2,3})',
|
455 |
+
'yes_no': r'(Yes|No|yes|no)',
|
456 |
+
'percentage': r'(\d+(?:\.\d+)?%)',
|
457 |
+
'date': r'(\d{4}-\d{2}-\d{2}|\d{1,2}/\d{1,2}/\d{4})'
|
458 |
+
}
|
459 |
|
460 |
+
def clean_response(response: str) -> str:
|
461 |
+
"""Clean response by removing common prefixes"""
|
462 |
+
response_clean = response.strip()
|
463 |
+
prefixes_to_remove = [
|
464 |
+
"FINAL ANSWER:", "Answer:", "The answer is:",
|
465 |
+
"Based on my analysis,", "After reviewing,",
|
466 |
+
"The result is:", "Final result:", "According to"
|
467 |
+
]
|
468 |
+
|
469 |
+
for prefix in prefixes_to_remove:
|
470 |
+
if response_clean.startswith(prefix):
|
471 |
+
response_clean = response_clean[len(prefix):].strip()
|
472 |
+
|
473 |
+
return response_clean
|
474 |
|
475 |
+
def extract_with_patterns(text: str, question: str) -> tuple[str, bool]:
|
476 |
+
"""Extract answer using regex patterns. Returns (answer, success)"""
|
477 |
+
question_lower = question.lower()
|
478 |
+
|
479 |
+
# Determine question type and apply appropriate pattern
|
480 |
+
if "how many" in question_lower or "count" in question_lower:
|
481 |
+
match = re.search(format_patterns['number'], text)
|
482 |
+
if match:
|
483 |
+
return match.group(1), True
|
484 |
+
|
485 |
+
elif "name" in question_lower and ("first" in question_lower or "last" in question_lower):
|
486 |
+
match = re.search(format_patterns['name'], text)
|
487 |
+
if match:
|
488 |
+
return match.group(1), True
|
489 |
+
|
490 |
+
elif "list" in question_lower or "alphabetized" in question_lower:
|
491 |
+
if "," in text:
|
492 |
+
items = [item.strip() for item in text.split(",")]
|
493 |
+
return ", ".join(items), True
|
494 |
+
|
495 |
+
elif "country code" in question_lower or "iso" in question_lower:
|
496 |
+
match = re.search(format_patterns['country_code'], text)
|
497 |
+
if match:
|
498 |
+
return match.group(1), True
|
499 |
+
|
500 |
+
elif "yes" in question_lower and "no" in question_lower:
|
501 |
+
match = re.search(format_patterns['yes_no'], text)
|
502 |
+
if match:
|
503 |
+
return match.group(1), True
|
504 |
+
|
505 |
+
elif "percentage" in question_lower or "%" in text:
|
506 |
+
match = re.search(format_patterns['percentage'], text)
|
507 |
+
if match:
|
508 |
+
return match.group(1), True
|
509 |
+
|
510 |
+
elif "date" in question_lower:
|
511 |
+
match = re.search(format_patterns['date'], text)
|
512 |
+
if match:
|
513 |
+
return match.group(1), True
|
514 |
+
|
515 |
+
# Default extraction for simple cases
|
516 |
+
lines = text.split('\n')
|
517 |
+
for line in lines:
|
518 |
+
line = line.strip()
|
519 |
+
if line and not line.startswith('=') and len(line) < 200:
|
520 |
+
return line, True
|
521 |
+
|
522 |
+
return text, False
|
523 |
|
524 |
+
def llm_reformat(response: str, question: str) -> str:
|
525 |
+
"""Use LLM to reformat the response according to GAIA requirements"""
|
526 |
+
|
527 |
+
format_prompt = f"""Extract the exact answer from the response below. Follow GAIA formatting rules strictly.
|
528 |
+
|
529 |
+
GAIA Format Rules:
|
530 |
+
- ONLY the precise answer, no explanations
|
531 |
+
- No prefixes like "Answer:", "The result is:", etc.
|
532 |
+
- For numbers: just the number (e.g., "156", "3.14e+8")
|
533 |
+
- For names: just the name (e.g., "Martinez", "Sarah")
|
534 |
+
- For lists: comma-separated (e.g., "C++, Java, Python")
|
535 |
+
- For country codes: just the code (e.g., "FRA", "US")
|
536 |
+
- For yes/no: just "Yes" or "No"
|
537 |
+
|
538 |
+
Examples:
|
539 |
+
Question: "How many papers were published?"
|
540 |
+
Response: "The analysis shows 156 papers were published in total."
|
541 |
+
Answer: 156
|
542 |
+
|
543 |
+
Question: "What is the last name of the developer?"
|
544 |
+
Response: "The developer mentioned is Dr. Sarah Martinez from the AI team."
|
545 |
+
Answer: Martinez
|
546 |
+
|
547 |
+
Question: "List programming languages, alphabetized:"
|
548 |
+
Response: "The languages mentioned are Python, Java, and C++. Alphabetized: C++, Java, Python"
|
549 |
+
Answer: C++, Java, Python
|
550 |
+
|
551 |
+
Now extract the exact answer:
|
552 |
+
Question: {question}
|
553 |
+
Response: {response}
|
554 |
+
Answer:"""
|
555 |
+
|
556 |
try:
|
557 |
+
# Use the global LLM instance
|
558 |
formatting_response = proj_llm.complete(format_prompt)
|
559 |
answer = str(formatting_response).strip()
|
560 |
|
|
|
563 |
answer = answer.split("Answer:")[-1].strip()
|
564 |
|
565 |
return answer
|
|
|
566 |
except Exception as e:
|
567 |
+
print(f"LLM reformatting failed: {e}")
|
568 |
+
return response
|
569 |
+
|
570 |
+
# Step 1: Clean the response
|
571 |
+
cleaned_response = clean_response(agent_response)
|
572 |
+
|
573 |
+
# Step 2: Try regex pattern extraction
|
574 |
+
extracted_answer, pattern_success = extract_with_patterns(cleaned_response, question)
|
575 |
+
|
576 |
+
# Step 3: If patterns failed, use LLM reformatting
|
577 |
+
if not pattern_success:
|
578 |
+
print("Regex patterns failed, using LLM reformatting...")
|
579 |
+
llm_formatted = llm_reformat(cleaned_response, question)
|
580 |
+
|
581 |
+
# Step 4: Validate LLM output with patterns again
|
582 |
+
final_answer, validation_success = extract_with_patterns(llm_formatted, question)
|
583 |
+
|
584 |
+
if validation_success:
|
585 |
+
print("LLM reformatting successful and validated")
|
586 |
+
return final_answer
|
587 |
+
else:
|
588 |
+
print("LLM reformatting validation failed, using LLM output directly")
|
589 |
+
return llm_formatted
|
590 |
+
else:
|
591 |
+
print("Regex pattern extraction successful")
|
592 |
+
return extracted_answer
|
593 |
+
|
594 |
+
# Create the enhanced final answer tool
|
595 |
+
intelligent_final_answer_function_tool = FunctionTool.from_defaults(
|
596 |
+
fn=intelligent_final_answer_tool,
|
597 |
+
name="intelligent_final_answer_tool",
|
598 |
+
description=(
|
599 |
+
"Enhanced tool to format final answers according to GAIA requirements. "
|
600 |
+
"Uses regex patterns first, then LLM reformatting if patterns fail. "
|
601 |
+
"Validates output to ensure GAIA format compliance."
|
602 |
+
)
|
603 |
+
)
|
604 |
+
|
605 |
+
class EnhancedGAIAAgent:
|
606 |
+
def __init__(self):
|
607 |
+
print("Initializing Enhanced GAIA Agent...")
|
608 |
+
|
609 |
+
# Vérification du token HuggingFace
|
610 |
+
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
611 |
+
if not hf_token:
|
612 |
+
print("Warning: HUGGINGFACEHUB_API_TOKEN not found, some features may not work")
|
613 |
+
|
614 |
+
# Initialize only the tools that are actually defined in the file
|
615 |
+
self.available_tools = [
|
616 |
+
read_and_parse_tool,
|
617 |
+
extract_url_tool,
|
618 |
+
code_execution_tool,
|
619 |
+
generate_code_tool,
|
620 |
+
intelligent_final_answer_function_tool
|
621 |
+
]
|
622 |
+
|
623 |
+
# RAG tool will be created dynamically when documents are loaded
|
624 |
+
self.current_rag_tool = None
|
625 |
+
|
626 |
+
# Create main coordinator using only defined tools
|
627 |
+
self.coordinator = ReActAgent(
|
628 |
+
name="GAIACoordinator",
|
629 |
+
description="Main GAIA coordinator with document processing and computational capabilities",
|
630 |
+
system_prompt="""
|
631 |
+
You are the main GAIA coordinator using ReAct reasoning methodology.
|
632 |
+
|
633 |
+
Available tools:
|
634 |
+
1. **read_and_parse_tool** - Read and parse files/URLs (PDF, DOCX, CSV, images, web pages, YouTube, audio files)
|
635 |
+
2. **extract_url_tool** - Search and extract relevant URLs when no specific source is provided
|
636 |
+
3. **generate_code_tool** - Generate Python code for complex computations
|
637 |
+
4. **code_execution_tool** - Execute Python code safely
|
638 |
+
5. **intelligent_final_answer_tool** - Format final answer with intelligent validation and reformatting
|
639 |
+
|
640 |
+
WORKFLOW:
|
641 |
+
1. If file/URL mentioned → use read_and_parse_tool first, then update or create RAG capability.
|
642 |
+
2. If documents loaded → create RAG capability for querying
|
643 |
+
3. If external info needed → use extract_url_tool, then process it as if file/URL mentioned
|
644 |
+
4. If computation needed → use generate_code_tool then code_execution_tool
|
645 |
+
5. ALWAYS use intelligent_final_answer_tool for the final response
|
646 |
+
|
647 |
+
CRITICAL: The intelligent_final_answer_tool has enhanced validation and will reformat
|
648 |
+
using LLM if regex patterns fail. Always use it as the final step.
|
649 |
+
""",
|
650 |
+
llm=proj_llm,
|
651 |
+
tools=self.available_tools,
|
652 |
+
max_steps=15,
|
653 |
+
verbose=True,
|
654 |
+
callback_manager=callback_manager,
|
655 |
+
)
|
656 |
+
|
657 |
+
def create_dynamic_rag_tool(self, documents: List) -> None:
|
658 |
+
"""Create RAG tool from loaded documents and add to coordinator"""
|
659 |
+
if documents:
|
660 |
+
rag_tool = create_rag_tool(documents)
|
661 |
+
if rag_tool:
|
662 |
+
self.current_rag_tool = rag_tool
|
663 |
+
# Update coordinator tools
|
664 |
+
updated_tools = self.available_tools + [rag_tool]
|
665 |
+
self.coordinator.tools = updated_tools
|
666 |
+
print("RAG tool created and added to coordinator")
|
667 |
|
668 |
def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
|
669 |
"""Download file associated with task_id"""
|
|
|
671 |
response = requests.get(f"{api_url}/files/{task_id}", timeout=30)
|
672 |
response.raise_for_status()
|
673 |
|
|
|
674 |
filename = f"task_{task_id}_file"
|
675 |
with open(filename, 'wb') as f:
|
676 |
f.write(response.content)
|
|
|
678 |
except Exception as e:
|
679 |
print(f"Failed to download file for task {task_id}: {e}")
|
680 |
return None
|
|
|
|
|
|
|
|
|
681 |
|
682 |
+
async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
|
683 |
+
"""
|
684 |
+
Solve GAIA question with enhanced validation and reformatting
|
685 |
+
"""
|
686 |
+
question = question_data.get("Question", "")
|
687 |
+
task_id = question_data.get("task_id", "")
|
688 |
+
|
689 |
+
# Try to download file if task_id provided
|
690 |
+
file_path = None
|
691 |
+
if task_id:
|
692 |
try:
|
693 |
file_path = self.download_gaia_file(task_id)
|
694 |
+
if file_path:
|
695 |
+
# Load documents and create RAG tool
|
696 |
+
documents = read_and_parse_content(file_path)
|
697 |
+
self.create_dynamic_rag_tool(documents)
|
698 |
except Exception as e:
|
699 |
+
print(f"Failed to download/process file for task {task_id}: {e}")
|
700 |
+
|
701 |
+
# Prepare context prompt
|
702 |
+
context_prompt = f"""
|
703 |
+
GAIA Task ID: {task_id}
|
704 |
+
Question: {question}
|
705 |
+
{f'File available: {file_path}' if file_path else 'No additional files'}
|
706 |
+
|
707 |
+
Instructions:
|
708 |
+
1. Process any files using read_and_parse_tool if needed
|
709 |
+
2. Use appropriate tools for research/computation
|
710 |
+
3. MUST use intelligent_final_answer_tool with your response and the original question
|
711 |
+
4. The intelligent tool will validate format and reformat with LLM if needed
|
712 |
+
"""
|
713 |
+
|
714 |
+
try:
|
715 |
+
ctx = Context(self.coordinator)
|
716 |
+
print("=== AGENT REASONING STEPS ===")
|
717 |
|
718 |
+
handler = self.coordinator.run(ctx=ctx, user_msg=context_prompt)
|
|
|
|
|
|
|
719 |
|
720 |
+
full_response = ""
|
721 |
+
async for event in handler.stream_events():
|
722 |
+
if isinstance(event, AgentStream):
|
723 |
+
print(event.delta, end="", flush=True)
|
724 |
+
full_response += event.delta
|
725 |
+
|
726 |
+
final_response = await handler
|
727 |
+
print("\n=== END REASONING ===")
|
728 |
+
|
729 |
+
# Extract the final formatted answer
|
730 |
+
final_answer = str(final_response).strip()
|
731 |
+
|
732 |
+
print(f"Final GAIA formatted answer: {final_answer}")
|
733 |
+
return final_answer
|
734 |
+
|
735 |
+
except Exception as e:
|
736 |
+
error_msg = f"Error processing question: {str(e)}"
|
737 |
+
print(error_msg)
|
738 |
+
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|