File size: 5,934 Bytes
1bbca12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from state import JARVISState
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool
from langfuse.callback import LangfuseCallbackHandler
import json
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()
# Debug: Verify environment variables
print(f"OPENAI_API_KEY loaded in graph.py: {'set' if os.getenv('OPENAI_API_KEY') else 'not set'}")
print(f"LANGFUSE_PUBLIC_KEY loaded in graph.py: {'set' if os.getenv('LANGFUSE_PUBLIC_KEY') else 'not set'}")

# Initialize LLM and Langfuse
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("OPENAI_API_KEY environment variable not set")
llm = ChatOpenAI(model="gpt-4o", api_key=api_key)
langfuse = LangfuseCallbackHandler(
    public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
    secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
    host=os.getenv("LANGFUSE_HOST")
)
memory = MemorySaver()

# Question Parser Node
async def parse_question(state: JARVISState) -> JARVISState:
    question = state["question"]
    prompt = f"""Analyze this GAIA question: {question}
    Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever).
    Return a JSON list of tool names."""
    response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
    tools_needed = json.loads(response.content)
    return {"messages": state["messages"] + [response], "tools_needed": tools_needed}

# Web Search Agent Node
async def web_search_agent(state: JARVISState) -> JARVISState:
    results = []
    if "web_search" in state["tools_needed"]:
        result = await search_tool.arun(state["question"])
        results.append(result)
    if "multi_hop_search" in state["tools_needed"]:
        result = await multi_hop_search_tool.aparse(state["question"], steps=3)
        results.append(result)
    return {"web_results": results}

# File Parser Agent Node
async def file_parser_agent(state: JARVISState) -> JARVISState:
    if "file_parser" in state["tools_needed"]:
        result = await file_parser_tool.aparse(state["task_id"])
        return {"file_results": result}
    return {"file_results": ""}

# Image Parser Agent Node
async def image_parser_agent(state: JARVISState) -> JARVISState:
    if "image_parser" in state["tools_needed"]:
        task = "match" if "fruits" in state["question"].lower() else "describe"
        match_query = "fruits" if task == "match" else ""
        result = await image_parser_tool.aparse(
            f"temp_{state['task_id']}.jpg", task=task, match_query=match_query
        )
        return {"image_results": result}
    return {"image_results": ""}

# Calculator Agent Node
async def calculator_agent(state: JARVISState) -> JARVISState:
    if "calculator" in state["tools_needed"]:
        prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}"
        response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
        expression = response.content
        result = await calculator_tool.aparse(expression)
        return {"calculation_results": result}
    return {"calculation_results": ""}

# Document Retriever Agent Node
async def document_retriever_agent(state: JARVISState) -> JARVISState:
    if "document_retriever" in state["tools_needed"]:
        file_type = "txt" if "menu" in state["question"].lower() else "csv"
        if "report" in state["question"].lower() or "document" in state["question"].lower():
            file_type = "pdf"
        result = await document_retriever_tool.aparse(
            state["task_id"], state["question"], file_type=file_type
        )
        return {"document_results": result}
    return {"document_results": ""}

# Reasoning Agent Node
async def reasoning_agent(state: JARVISState) -> JARVISState:
    prompt = f"""Question: {state['question']}
    Web Results: {state['web_results']}
    File Results: {state['file_results']}
    Image Results: {state['image_results']}
    Calculation Results: {state['calculation_results']}
    Document Results: {state['document_results']}

    Synthesize an exact-match answer for the GAIA benchmark.
    Output only the answer (e.g., '90', 'White;5876')."""
    response = await llm.ainvoke(
        [
            SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."),
            HumanMessage(content=prompt)
        ],
        config={"callbacks": [langfuse]}
    )
    return {"answer": response.content, "messages": state["messages"] + [response]}

# Conditional Edge Router
def router(state: JARVISState) -> str:
    if state["tools_needed"]:
        return "tools"
    return "reasoning"

# Build Graph
workflow = StateGraph(JARVISState)
workflow.add_node("parse", parse_question)
workflow.add_node("web_search", web_search_agent)
workflow.add_node("file_parser", file_parser_agent)
workflow.add_node("image_parser", image_parser_agent)
workflow.add_node("calculator", calculator_agent)
workflow.add_node("document_retriever", document_retriever_agent)
workflow.add_node("reasoning", reasoning_agent)

workflow.set_entry_point("parse")
workflow.add_conditional_edges(
    "parse",
    router,
    {
        "tools": ["web_search", "file_parser", "image_parser", "calculator", "document_retriever"],
        "reasoning": "reasoning"
    }
)
workflow.add_edge("web_search", "reasoning")
workflow.add_edge("file_parser", "reasoning")
workflow.add_edge("image_parser", "reasoning")
workflow.add_edge("calculator", "reasoning")
workflow.add_edge("document_retriever", "reasoning")
workflow.add_edge("reasoning", END)

graph = workflow.compile(checkpointer=memory)