Spaces:
Starting
Starting
File size: 5,934 Bytes
1bbca12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from state import JARVISState
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool
from langfuse.callback import LangfuseCallbackHandler
import json
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Debug: Verify environment variables
print(f"OPENAI_API_KEY loaded in graph.py: {'set' if os.getenv('OPENAI_API_KEY') else 'not set'}")
print(f"LANGFUSE_PUBLIC_KEY loaded in graph.py: {'set' if os.getenv('LANGFUSE_PUBLIC_KEY') else 'not set'}")
# Initialize LLM and Langfuse
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set")
llm = ChatOpenAI(model="gpt-4o", api_key=api_key)
langfuse = LangfuseCallbackHandler(
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
host=os.getenv("LANGFUSE_HOST")
)
memory = MemorySaver()
# Question Parser Node
async def parse_question(state: JARVISState) -> JARVISState:
question = state["question"]
prompt = f"""Analyze this GAIA question: {question}
Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever).
Return a JSON list of tool names."""
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
tools_needed = json.loads(response.content)
return {"messages": state["messages"] + [response], "tools_needed": tools_needed}
# Web Search Agent Node
async def web_search_agent(state: JARVISState) -> JARVISState:
results = []
if "web_search" in state["tools_needed"]:
result = await search_tool.arun(state["question"])
results.append(result)
if "multi_hop_search" in state["tools_needed"]:
result = await multi_hop_search_tool.aparse(state["question"], steps=3)
results.append(result)
return {"web_results": results}
# File Parser Agent Node
async def file_parser_agent(state: JARVISState) -> JARVISState:
if "file_parser" in state["tools_needed"]:
result = await file_parser_tool.aparse(state["task_id"])
return {"file_results": result}
return {"file_results": ""}
# Image Parser Agent Node
async def image_parser_agent(state: JARVISState) -> JARVISState:
if "image_parser" in state["tools_needed"]:
task = "match" if "fruits" in state["question"].lower() else "describe"
match_query = "fruits" if task == "match" else ""
result = await image_parser_tool.aparse(
f"temp_{state['task_id']}.jpg", task=task, match_query=match_query
)
return {"image_results": result}
return {"image_results": ""}
# Calculator Agent Node
async def calculator_agent(state: JARVISState) -> JARVISState:
if "calculator" in state["tools_needed"]:
prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}"
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
expression = response.content
result = await calculator_tool.aparse(expression)
return {"calculation_results": result}
return {"calculation_results": ""}
# Document Retriever Agent Node
async def document_retriever_agent(state: JARVISState) -> JARVISState:
if "document_retriever" in state["tools_needed"]:
file_type = "txt" if "menu" in state["question"].lower() else "csv"
if "report" in state["question"].lower() or "document" in state["question"].lower():
file_type = "pdf"
result = await document_retriever_tool.aparse(
state["task_id"], state["question"], file_type=file_type
)
return {"document_results": result}
return {"document_results": ""}
# Reasoning Agent Node
async def reasoning_agent(state: JARVISState) -> JARVISState:
prompt = f"""Question: {state['question']}
Web Results: {state['web_results']}
File Results: {state['file_results']}
Image Results: {state['image_results']}
Calculation Results: {state['calculation_results']}
Document Results: {state['document_results']}
Synthesize an exact-match answer for the GAIA benchmark.
Output only the answer (e.g., '90', 'White;5876')."""
response = await llm.ainvoke(
[
SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."),
HumanMessage(content=prompt)
],
config={"callbacks": [langfuse]}
)
return {"answer": response.content, "messages": state["messages"] + [response]}
# Conditional Edge Router
def router(state: JARVISState) -> str:
if state["tools_needed"]:
return "tools"
return "reasoning"
# Build Graph
workflow = StateGraph(JARVISState)
workflow.add_node("parse", parse_question)
workflow.add_node("web_search", web_search_agent)
workflow.add_node("file_parser", file_parser_agent)
workflow.add_node("image_parser", image_parser_agent)
workflow.add_node("calculator", calculator_agent)
workflow.add_node("document_retriever", document_retriever_agent)
workflow.add_node("reasoning", reasoning_agent)
workflow.set_entry_point("parse")
workflow.add_conditional_edges(
"parse",
router,
{
"tools": ["web_search", "file_parser", "image_parser", "calculator", "document_retriever"],
"reasoning": "reasoning"
}
)
workflow.add_edge("web_search", "reasoning")
workflow.add_edge("file_parser", "reasoning")
workflow.add_edge("image_parser", "reasoning")
workflow.add_edge("calculator", "reasoning")
workflow.add_edge("document_retriever", "reasoning")
workflow.add_edge("reasoning", END)
graph = workflow.compile(checkpointer=memory) |