from typing import TypedDict, Annotated import os from dotenv import load_dotenv from langgraph.graph.message import add_messages # Load environment variables from .env file load_dotenv() from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage from langgraph.prebuilt import ToolNode from langgraph.graph import START, StateGraph from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import tools_condition from langchain_openai import ChatOpenAI from tools import agent_tools from utils import format_gaia_answer, log_agent_step # Initialize OpenAI LLM with GPT-4o (most capable model) chat = ChatOpenAI( model="gpt-4o", temperature=0.1, max_tokens=1024, api_key=os.environ.get("OPENAI_API_KEY") ) chat_with_tools = chat.bind_tools(agent_tools) # System prompt for GAIA evaluation (exact format required by HF) SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. You have access to tools that can help you: - Search the web for current information - Download and process files associated with task IDs - Analyze images - Perform calculations - Process text IMPORTANT: You must provide a specific answer in the FINAL ANSWER format. Do not say you cannot find information or provide general approaches. Use web search to find the information you need, but limit yourself to 2-3 search attempts maximum. If you cannot find perfect information, make your best determination based on what you found and provide a concrete FINAL ANSWER. Always end with a specific FINAL ANSWER, never with explanations about not finding information.""" # Generate the AgentState class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] task_id: str def assistant(state: AgentState): """Main assistant function that processes messages and calls tools.""" messages = state["messages"] # Add system prompt if not already present if not any(isinstance(msg, SystemMessage) for msg in messages): messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages # Get the response from the LLM response = chat_with_tools.invoke(messages) return { "messages": [response], } def create_smart_agent(): """Create and return the smart agent graph.""" # Build the graph builder = StateGraph(AgentState) # Define nodes builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(agent_tools)) # Define edges builder.add_edge(START, "assistant") builder.add_conditional_edges( "assistant", tools_condition, ) builder.add_edge("tools", "assistant") agent = builder.compile() return agent class SmartAgent: """High-level intelligent agent class that wraps the LangGraph agent.""" def __init__(self): self.agent = create_smart_agent() print("šŸ¤– Smart Agent initialized with OpenAI GPT-4o and tools") def __call__(self, question: str, task_id: str = None) -> tuple: """Process a question and return the formatted answer and reasoning trace.""" try: print(f"\nšŸŽÆ Processing question: {question[:100]}...") enhanced_question = question if task_id: enhanced_question = f"Task ID: {task_id}\n\nQuestion: {question}" config = { "recursion_limit": 15 } initial_state = { "messages": [HumanMessage(content=enhanced_question)], "task_id": task_id or "" } result = self.agent.invoke(initial_state, config=config) if result and 'messages' in result and result['messages']: final_message = result['messages'][-1] raw_answer = final_message.content reasoning_trace = [] for msg in result['messages']: if hasattr(msg, 'content') and msg.content: reasoning_trace.append(msg.content) reasoning_text = "\n---\n".join(reasoning_trace) else: raw_answer = "No response generated" reasoning_text = "No reasoning trace available" # Format the answer for submission formatted_answer = format_gaia_answer(raw_answer) print(f"āœ… Raw answer: {raw_answer}") print(f"šŸŽÆ Formatted answer: {formatted_answer}") # Validate the formatted answer if not formatted_answer or formatted_answer.strip() == "": print("āš ļø WARNING: Empty formatted answer!") formatted_answer = "ERROR: No valid answer extracted" return formatted_answer, reasoning_text except Exception as e: error_msg = f"Error processing question: {str(e)}" print(f"āŒ {error_msg}") return error_msg, f"Error occurred: {str(e)}" smart_agent = SmartAgent()