jarvis_gaia_agent / graph.py
onisj's picture
Add .gitignore and clean tracked files
1bbca12
raw
history blame
5.93 kB
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from state import JARVISState
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool
from langfuse.callback import LangfuseCallbackHandler
import json
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Debug: Verify environment variables
print(f"OPENAI_API_KEY loaded in graph.py: {'set' if os.getenv('OPENAI_API_KEY') else 'not set'}")
print(f"LANGFUSE_PUBLIC_KEY loaded in graph.py: {'set' if os.getenv('LANGFUSE_PUBLIC_KEY') else 'not set'}")
# Initialize LLM and Langfuse
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set")
llm = ChatOpenAI(model="gpt-4o", api_key=api_key)
langfuse = LangfuseCallbackHandler(
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
host=os.getenv("LANGFUSE_HOST")
)
memory = MemorySaver()
# Question Parser Node
async def parse_question(state: JARVISState) -> JARVISState:
question = state["question"]
prompt = f"""Analyze this GAIA question: {question}
Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever).
Return a JSON list of tool names."""
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
tools_needed = json.loads(response.content)
return {"messages": state["messages"] + [response], "tools_needed": tools_needed}
# Web Search Agent Node
async def web_search_agent(state: JARVISState) -> JARVISState:
results = []
if "web_search" in state["tools_needed"]:
result = await search_tool.arun(state["question"])
results.append(result)
if "multi_hop_search" in state["tools_needed"]:
result = await multi_hop_search_tool.aparse(state["question"], steps=3)
results.append(result)
return {"web_results": results}
# File Parser Agent Node
async def file_parser_agent(state: JARVISState) -> JARVISState:
if "file_parser" in state["tools_needed"]:
result = await file_parser_tool.aparse(state["task_id"])
return {"file_results": result}
return {"file_results": ""}
# Image Parser Agent Node
async def image_parser_agent(state: JARVISState) -> JARVISState:
if "image_parser" in state["tools_needed"]:
task = "match" if "fruits" in state["question"].lower() else "describe"
match_query = "fruits" if task == "match" else ""
result = await image_parser_tool.aparse(
f"temp_{state['task_id']}.jpg", task=task, match_query=match_query
)
return {"image_results": result}
return {"image_results": ""}
# Calculator Agent Node
async def calculator_agent(state: JARVISState) -> JARVISState:
if "calculator" in state["tools_needed"]:
prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}"
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
expression = response.content
result = await calculator_tool.aparse(expression)
return {"calculation_results": result}
return {"calculation_results": ""}
# Document Retriever Agent Node
async def document_retriever_agent(state: JARVISState) -> JARVISState:
if "document_retriever" in state["tools_needed"]:
file_type = "txt" if "menu" in state["question"].lower() else "csv"
if "report" in state["question"].lower() or "document" in state["question"].lower():
file_type = "pdf"
result = await document_retriever_tool.aparse(
state["task_id"], state["question"], file_type=file_type
)
return {"document_results": result}
return {"document_results": ""}
# Reasoning Agent Node
async def reasoning_agent(state: JARVISState) -> JARVISState:
prompt = f"""Question: {state['question']}
Web Results: {state['web_results']}
File Results: {state['file_results']}
Image Results: {state['image_results']}
Calculation Results: {state['calculation_results']}
Document Results: {state['document_results']}
Synthesize an exact-match answer for the GAIA benchmark.
Output only the answer (e.g., '90', 'White;5876')."""
response = await llm.ainvoke(
[
SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."),
HumanMessage(content=prompt)
],
config={"callbacks": [langfuse]}
)
return {"answer": response.content, "messages": state["messages"] + [response]}
# Conditional Edge Router
def router(state: JARVISState) -> str:
if state["tools_needed"]:
return "tools"
return "reasoning"
# Build Graph
workflow = StateGraph(JARVISState)
workflow.add_node("parse", parse_question)
workflow.add_node("web_search", web_search_agent)
workflow.add_node("file_parser", file_parser_agent)
workflow.add_node("image_parser", image_parser_agent)
workflow.add_node("calculator", calculator_agent)
workflow.add_node("document_retriever", document_retriever_agent)
workflow.add_node("reasoning", reasoning_agent)
workflow.set_entry_point("parse")
workflow.add_conditional_edges(
"parse",
router,
{
"tools": ["web_search", "file_parser", "image_parser", "calculator", "document_retriever"],
"reasoning": "reasoning"
}
)
workflow.add_edge("web_search", "reasoning")
workflow.add_edge("file_parser", "reasoning")
workflow.add_edge("image_parser", "reasoning")
workflow.add_edge("calculator", "reasoning")
workflow.add_edge("document_retriever", "reasoning")
workflow.add_edge("reasoning", END)
graph = workflow.compile(checkpointer=memory)