|
from typing import TypedDict, Annotated |
|
import os |
|
from langgraph.graph.message import add_messages |
|
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage |
|
from langgraph.prebuilt import ToolNode |
|
from langgraph.graph import START, StateGraph |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langgraph.prebuilt import tools_condition |
|
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
|
from tools import agent_tools |
|
from utils import format_gaia_answer, analyze_question_type, create_execution_plan, log_agent_step |
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
repo_id="Qwen/Qwen2.5-Coder-32B-Instruct", |
|
huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"), |
|
temperature=0.1, |
|
max_new_tokens=1024, |
|
) |
|
|
|
chat = ChatHuggingFace(llm=llm, verbose=True) |
|
chat_with_tools = chat.bind_tools(agent_tools) |
|
|
|
|
|
SYSTEM_PROMPT = """You are a highly capable AI assistant designed to answer questions accurately and helpfully. |
|
|
|
Your approach should include: |
|
- Multi-step reasoning and planning for complex questions |
|
- Intelligent tool usage when needed for web search, file processing, calculations, and analysis |
|
- Precise, factual answers based on reliable information |
|
- Breaking down complex questions into manageable steps |
|
|
|
IMPORTANT GUIDELINES: |
|
1. Think step-by-step and use available tools when they can help provide better answers |
|
2. For current information: Search the web for up-to-date facts |
|
3. For files: Process associated files when task_id is provided |
|
4. For visual content: Analyze images carefully when present |
|
5. For calculations: Use computational tools for accuracy |
|
6. Provide concise, direct answers without unnecessary prefixes |
|
7. Focus on accuracy and helpfulness |
|
8. Be factual and avoid speculation |
|
|
|
Your goal is to be as helpful and accurate as possible while using the right tools for each task.""" |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[list[AnyMessage], add_messages] |
|
task_id: str |
|
question_analysis: dict |
|
|
|
def assistant(state: AgentState): |
|
"""Main assistant function that processes messages and calls tools.""" |
|
messages = state["messages"] |
|
|
|
|
|
if not any(isinstance(msg, SystemMessage) for msg in messages): |
|
messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages |
|
|
|
|
|
response = chat_with_tools.invoke(messages) |
|
|
|
return { |
|
"messages": [response], |
|
} |
|
|
|
def create_smart_agent(): |
|
"""Create and return the smart agent graph.""" |
|
|
|
builder = StateGraph(AgentState) |
|
|
|
|
|
builder.add_node("assistant", assistant) |
|
builder.add_node("tools", ToolNode(agent_tools)) |
|
|
|
|
|
builder.add_edge(START, "assistant") |
|
builder.add_conditional_edges( |
|
"assistant", |
|
tools_condition, |
|
) |
|
builder.add_edge("tools", "assistant") |
|
|
|
|
|
memory = MemorySaver() |
|
agent = builder.compile(checkpointer=memory) |
|
|
|
return agent |
|
|
|
class SmartAgent: |
|
"""High-level intelligent agent class that wraps the LangGraph agent.""" |
|
|
|
def __init__(self): |
|
self.agent = create_smart_agent() |
|
print("π€ Smart Agent initialized with LangGraph and tools") |
|
|
|
def __call__(self, question: str, task_id: str = None) -> str: |
|
"""Process a question and return the formatted answer.""" |
|
try: |
|
print(f"\nπ― Processing question: {question[:100]}...") |
|
|
|
|
|
analysis = analyze_question_type(question) |
|
print(f"π Question analysis: {analysis}") |
|
|
|
|
|
plan = create_execution_plan(question, task_id) |
|
print(f"π Execution plan: {plan}") |
|
|
|
|
|
enhanced_question = question |
|
if task_id: |
|
enhanced_question = f"Task ID: {task_id}\n\nQuestion: {question}\n\nNote: If this question involves files, use the file_download tool with task_id '{task_id}' to access associated files." |
|
|
|
|
|
thread_id = f"task-{task_id}" if task_id else "general" |
|
config = {"configurable": {"thread_id": thread_id}} |
|
|
|
initial_state = { |
|
"messages": [HumanMessage(content=enhanced_question)], |
|
"task_id": task_id or "", |
|
"question_analysis": analysis |
|
} |
|
|
|
result = self.agent.invoke(initial_state, config=config) |
|
|
|
|
|
if result and 'messages' in result and result['messages']: |
|
final_message = result['messages'][-1] |
|
raw_answer = final_message.content |
|
else: |
|
raw_answer = "No response generated" |
|
|
|
|
|
formatted_answer = format_gaia_answer(raw_answer) |
|
|
|
print(f"β
Raw answer: {raw_answer}") |
|
print(f"π― Formatted answer: {formatted_answer}") |
|
|
|
return formatted_answer |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing question: {str(e)}" |
|
print(f"β {error_msg}") |
|
return error_msg |
|
|
|
smart_agent = SmartAgent() |
|
|