Spaces:
Starting
Starting
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.memory import MemorySaver | |
from state import JARVISState | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool | |
from langfuse.callback import LangfuseCallbackHandler | |
import json | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Debug: Verify environment variables | |
print(f"OPENAI_API_KEY loaded in graph.py: {'set' if os.getenv('OPENAI_API_KEY') else 'not set'}") | |
print(f"LANGFUSE_PUBLIC_KEY loaded in graph.py: {'set' if os.getenv('LANGFUSE_PUBLIC_KEY') else 'not set'}") | |
# Initialize LLM and Langfuse | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
raise ValueError("OPENAI_API_KEY environment variable not set") | |
llm = ChatOpenAI(model="gpt-4o", api_key=api_key) | |
langfuse = LangfuseCallbackHandler( | |
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), | |
secret_key=os.getenv("LANGFUSE_SECRET_KEY"), | |
host=os.getenv("LANGFUSE_HOST") | |
) | |
memory = MemorySaver() | |
# Question Parser Node | |
async def parse_question(state: JARVISState) -> JARVISState: | |
question = state["question"] | |
prompt = f"""Analyze this GAIA question: {question} | |
Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever). | |
Return a JSON list of tool names.""" | |
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]}) | |
tools_needed = json.loads(response.content) | |
return {"messages": state["messages"] + [response], "tools_needed": tools_needed} | |
# Web Search Agent Node | |
async def web_search_agent(state: JARVISState) -> JARVISState: | |
results = [] | |
if "web_search" in state["tools_needed"]: | |
result = await search_tool.arun(state["question"]) | |
results.append(result) | |
if "multi_hop_search" in state["tools_needed"]: | |
result = await multi_hop_search_tool.aparse(state["question"], steps=3) | |
results.append(result) | |
return {"web_results": results} | |
# File Parser Agent Node | |
async def file_parser_agent(state: JARVISState) -> JARVISState: | |
if "file_parser" in state["tools_needed"]: | |
result = await file_parser_tool.aparse(state["task_id"]) | |
return {"file_results": result} | |
return {"file_results": ""} | |
# Image Parser Agent Node | |
async def image_parser_agent(state: JARVISState) -> JARVISState: | |
if "image_parser" in state["tools_needed"]: | |
task = "match" if "fruits" in state["question"].lower() else "describe" | |
match_query = "fruits" if task == "match" else "" | |
result = await image_parser_tool.aparse( | |
f"temp_{state['task_id']}.jpg", task=task, match_query=match_query | |
) | |
return {"image_results": result} | |
return {"image_results": ""} | |
# Calculator Agent Node | |
async def calculator_agent(state: JARVISState) -> JARVISState: | |
if "calculator" in state["tools_needed"]: | |
prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}" | |
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]}) | |
expression = response.content | |
result = await calculator_tool.aparse(expression) | |
return {"calculation_results": result} | |
return {"calculation_results": ""} | |
# Document Retriever Agent Node | |
async def document_retriever_agent(state: JARVISState) -> JARVISState: | |
if "document_retriever" in state["tools_needed"]: | |
file_type = "txt" if "menu" in state["question"].lower() else "csv" | |
if "report" in state["question"].lower() or "document" in state["question"].lower(): | |
file_type = "pdf" | |
result = await document_retriever_tool.aparse( | |
state["task_id"], state["question"], file_type=file_type | |
) | |
return {"document_results": result} | |
return {"document_results": ""} | |
# Reasoning Agent Node | |
async def reasoning_agent(state: JARVISState) -> JARVISState: | |
prompt = f"""Question: {state['question']} | |
Web Results: {state['web_results']} | |
File Results: {state['file_results']} | |
Image Results: {state['image_results']} | |
Calculation Results: {state['calculation_results']} | |
Document Results: {state['document_results']} | |
Synthesize an exact-match answer for the GAIA benchmark. | |
Output only the answer (e.g., '90', 'White;5876').""" | |
response = await llm.ainvoke( | |
[ | |
SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."), | |
HumanMessage(content=prompt) | |
], | |
config={"callbacks": [langfuse]} | |
) | |
return {"answer": response.content, "messages": state["messages"] + [response]} | |
# Conditional Edge Router | |
def router(state: JARVISState) -> str: | |
if state["tools_needed"]: | |
return "tools" | |
return "reasoning" | |
# Build Graph | |
workflow = StateGraph(JARVISState) | |
workflow.add_node("parse", parse_question) | |
workflow.add_node("web_search", web_search_agent) | |
workflow.add_node("file_parser", file_parser_agent) | |
workflow.add_node("image_parser", image_parser_agent) | |
workflow.add_node("calculator", calculator_agent) | |
workflow.add_node("document_retriever", document_retriever_agent) | |
workflow.add_node("reasoning", reasoning_agent) | |
workflow.set_entry_point("parse") | |
workflow.add_conditional_edges( | |
"parse", | |
router, | |
{ | |
"tools": ["web_search", "file_parser", "image_parser", "calculator", "document_retriever"], | |
"reasoning": "reasoning" | |
} | |
) | |
workflow.add_edge("web_search", "reasoning") | |
workflow.add_edge("file_parser", "reasoning") | |
workflow.add_edge("image_parser", "reasoning") | |
workflow.add_edge("calculator", "reasoning") | |
workflow.add_edge("document_retriever", "reasoning") | |
workflow.add_edge("reasoning", END) | |
graph = workflow.compile(checkpointer=memory) |