onisj's picture
Use free tools only, remove OpenAI dependency
488dc3e
raw
history blame
19.7 kB
import os
import gradio as gr
import requests
import aiohttp
import asyncio
import json
import nest_asyncio
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
from langchain_core.messages import SystemMessage, HumanMessage
from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool
from tools.search import initialize_search_tools
from state import JARVISState
import pandas as pd
from dotenv import load_dotenv
import logging
from langfuse.callback import CallbackHandler
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Apply nest_asyncio
nest_asyncio.apply()
# Load environment variables
load_dotenv()
# Verify environment variables
required_env_vars = ["SPACE_ID", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"]
for var in required_env_vars:
if not os.getenv(var):
raise ValueError(f"Environment variable {var} is not set")
logger.info(f"Environment variables loaded: SPACE_ID={os.getenv('SPACE_ID')[:10]}..., LANGFUSE_HOST={os.getenv('LANGFUSE_HOST', 'https://cloud.langfuse.com')}")
# Initialize Hugging Face model
try:
hf_pipeline = pipeline(
"text-generation",
model="mistralai/Mixtral-7B-Instruct-v0.1",
device_map="auto",
max_new_tokens=512,
do_sample=True,
temperature=0.7
)
llm = HuggingFacePipeline(pipeline=hf_pipeline)
logger.info("HuggingFace model initialized: mistralai/Mixtral-7B-Instruct-v0.1")
except Exception as e:
logger.error(f"Failed to initialize HuggingFace model: {e}")
llm = None
# Initialize search tools with LLM
try:
initialize_search_tools(llm)
logger.info("Search tools initialized")
except Exception as e:
logger.error(f"Failed to initialize search tools: {e}")
# Initialize Langfuse
try:
langfuse = CallbackHandler(
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
)
logger.info("Langfuse initialized successfully")
except Exception as e:
logger.warning(f"Failed to initialize Langfuse: {e}")
langfuse = None
# Initialize MemorySaver
memory = MemorySaver()
use_checkpointing = True
# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space/api"
GAIA_FILE_URL = "https://api.gaia-benchmark.com/files/"
# --- Helper Functions ---
def log_state(task_id: str, state: JARVISState):
"""Log intermediate state to state_log.json"""
try:
log_entry = {
"task_id": task_id,
"question": state["question"],
"tools_needed": state["tools_needed"],
"web_results": state["web_results"],
"file_results": state["file_results"],
"image_results": state["image_results"],
"calculation_results": state["calculation_results"],
"document_results": state["document_results"],
"answer": state["answer"]
}
with open("state_log.json", "a") as f:
json.dump(log_entry, f, indent=2)
f.write("\n")
except Exception as e:
logger.error(f"Error logging state for task {task_id}: {e}")
async def test_gaia_api(task_id: str) -> bool:
"""Test connectivity to GAIA file API"""
try:
async with aiohttp.ClientSession() as session:
async with session.head(f"{GAIA_FILE_URL}{task_id}", timeout=5) as resp:
return resp.status in [200, 403, 404]
except Exception as e:
logger.warning(f"GAIA API test failed: {e}")
return False
# --- Node Functions ---
async def parse_question(state: JARVISState) -> JARVISState:
try:
question = state["question"]
prompt = f"""Analyze this GAIA question: {question}
Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever).
Return a JSON list of tool names."""
if llm:
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []})
try:
tools_needed = json.loads(response.content)
except json.JSONDecodeError as je:
logger.warning(f"Invalid JSON in LLM response for task {state['task_id']}: {je}")
tools_needed = ["web_search"]
else:
logger.warning("No LLM available, using default tools")
tools_needed = ["web_search"]
state["tools_needed"] = tools_needed
log_state(state["task_id"], state)
return state
except Exception as e:
logger.error(f"Error parsing question for task {state['task_id']}: {e}")
state["tools_needed"] = []
log_state(state["task_id"], state)
return state
async def tool_dispatcher(state: JARVISState) -> JARVISState:
try:
tools_needed = state["tools_needed"]
updated_state = state.copy()
can_download_files = await test_gaia_api(updated_state["task_id"])
for tool in tools_needed:
try:
if tool == "web_search" or tool == "multi_hop_search":
result = await web_search_agent(updated_state)
updated_state["web_results"].extend(result["web_results"])
elif tool == "file_parser" and can_download_files:
result = await file_parser_agent(updated_state)
updated_state["file_results"] = result["file_results"]
elif tool == "image_parser" and can_download_files:
result = await image_parser_agent(updated_state)
updated_state["image_results"] = result["image_results"]
elif tool == "calculator":
result = await calculator_agent(updated_state)
updated_state["calculation_results"] = result["calculation_results"]
elif tool == "document_retriever" and can_download_files:
result = await document_retriever_agent(updated_state)
updated_state["document_results"] = result["document_results"]
except Exception as e:
logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {e}")
log_state(updated_state["task_id"], updated_state)
return updated_state
except Exception as e:
logger.error(f"Error in tool dispatcher for task {state['task_id']}: {e}")
log_state(state["task_id"], state)
return state
async def web_search_agent(state: JARVISState) -> JARVISState:
try:
results = []
if "web_search" in state["tools_needed"]:
result = await search_tool.invoke({"query": state["question"]})
results.append(result)
if "multi_hop_search" in state["tools_needed"]:
result = await multi_hop_search_tool.invoke({"query": state["question"], "steps": 3})
results.append(result)
return {"web_results": results}
except Exception as e:
logger.error(f"Error in web search for task {state['task_id']}: {e}")
return {"web_results": []}
async def file_parser_agent(state: JARVISState) -> JARVISState:
try:
if "file_parser" in state["tools_needed"]:
file_type = "csv" if "data" in state["question"].lower() else "txt"
result = await file_parser_tool.aparse(state["task_id"], file_type=file_type)
return {"file_results": result}
return {"file_results": ""}
except Exception as e:
logger.error(f"Error in file parser for task {state['task_id']}: {e}")
return {"file_results": "File parsing failed"}
async def image_parser_agent(state: JARVISState) -> JARVISState:
try:
if "image_parser" in state["tools_needed"]:
task = "match" if "fruits" in state["question"].lower() else "describe"
match_query = "fruits" if task == "match" else ""
file_path = f"temp_{state['task_id']}.jpg"
if not os.path.exists(file_path):
logger.warning(f"Image file not found for task {state['task_id']}")
return {"image_results": "Image file not found"}
result = await image_parser_tool.aparse(
file_path, task=task, match_query=match_query
)
return {"image_results": result}
return {"image_results": ""}
except Exception as e:
logger.error(f"Error in image parser for task {state['task_id']}: {e}")
return {"image_results": "Image parsing failed"}
async def calculator_agent(state: JARVISState) -> JARVISState:
try:
if "calculator" in state["tools_needed"]:
prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}"
if llm:
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []})
expression = response.content
else:
expression = "0"
result = await calculator_tool.aparse(expression)
return {"calculation_results": result}
return {"calculation_results": ""}
except Exception as e:
logger.error(f"Error in calculator for task {state['task_id']}: {e}")
return {"calculation_results": "Calculation failed"}
async def document_retriever_agent(state: JARVISState) -> JARVISState:
try:
if "document_retriever" in state["tools_needed"]:
file_type = "txt" if "menu" in state["question"].lower() else "csv"
if "report" in state["question"].lower() or "document" in state["question"].lower():
file_type = "pdf"
result = await document_retriever_tool.aparse(
state["task_id"], state["question"], file_type=file_type
)
return {"document_results": result}
return {"document_results": ""}
except Exception as e:
logger.error(f"Error in document retriever for task {state['task_id']}: {e}")
return {"document_results": "Document retrieval failed"}
async def reasoning_agent(state: JARVISState) -> JARVISState:
try:
prompt = f"""Question: {state['question']}
Web Results: {state['web_results']}
File Results: {state['file_results']}
Image Results: {state['image_results']}
Calculation Results: {state['calculation_results']}
Document Results: {state['document_results']}
Synthesize an exact-match answer for the GAIA benchmark.
Output only the answer (e.g., '90', 'White;5876')."""
if llm:
response = await llm.ainvoke(
[
SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."),
HumanMessage(content=prompt)
],
config={"callbacks": [langfuse] if langfuse else []}
)
answer = response.content.strip()
else:
answer = "Unknown"
state["answer"] = answer
log_state(state["task_id"], state)
return state
except Exception as e:
logger.error(f"Error in reasoning for task {state['task_id']}: {e}")
state["answer"] = "Error in reasoning"
log_state(state["task_id"], state)
return state
def router(state: JARVISState) -> str:
if state["tools_needed"]:
return "tool_dispatcher"
return "reasoning"
# --- Define StateGraph ---
workflow = StateGraph(JARVISState)
workflow.add_node("parse", parse_question)
workflow.add_node("tool_dispatcher", tool_dispatcher)
workflow.add_node("reasoning", reasoning_agent)
workflow.set_entry_point("parse")
workflow.add_conditional_edges(
"parse",
router,
{
"tool_dispatcher": "tool_dispatcher",
"reasoning": "reasoning"
}
)
workflow.add_edge("tool_dispatcher", "reasoning")
workflow.add_edge("reasoning", END)
# Compile graph
graph = workflow.compile(checkpointer=memory if use_checkpointing else None)
# --- Basic Agent Definition ---
class BasicAgent:
def __init__(self):
logger.info("BasicAgent initialized.")
async def process_question(self, task_id: str, question: str) -> str:
file_type = "jpg" if "image" in question.lower() else "txt"
if "menu" in question.lower() or "report" in question.lower() or "document" in question.lower():
file_type = "pdf"
elif "data" in question.lower():
file_type = "csv"
file_path = f"temp_{task_id}.{file_type}"
if await test_gaia_api(task_id):
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{GAIA_FILE_URL}{task_id}") as resp:
if resp.status == 200:
with open(file_path, "wb") as f:
f.write(await resp.read())
else:
logger.warning(f"Failed to download file for task {task_id}: HTTP {resp.status}")
except Exception as e:
logger.error(f"Error downloading file for task {task_id}: {e}")
state = JARVISState(
task_id=task_id,
question=question,
tools_needed=[],
web_results=[],
file_results="",
image_results="",
calculation_results="",
document_results="",
messages=[],
answer=""
)
try:
config = {"configurable": {"thread_id": task_id}} if use_checkpointing else {}
result = await graph.ainvoke(state, config=config)
return result["answer"] or "No answer generated"
except Exception as e:
logger.error(f"Error processing task {task_id}: {e}")
return f"Error: {str(e)}"
finally:
if os.path.exists(file_path):
try:
os.remove(file_path)
except Exception as e:
logger.error(f"Error removing file {file_path}: {e}")
async def async_call(self, question: str, task_id: str) -> str:
return await self.process_question(task_id, question)
def __call__(self, question: str, task_id: str = None) -> str:
logger.info(f"Agent received question (first 50 chars): {question[:50]}...")
if task_id is None:
logger.warning("task_id not provided, using placeholder")
task_id = "placeholder_task_id"
try:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.async_call(question, task_id))
finally:
pass
# --- Main Function ---
def run_and_submit_all(profile: gr.OAuthProfile | None):
space_id = os.getenv("SPACE_ID")
if not profile:
logger.error("User not logged in.")
return "Please Login to Hugging Face with the button.", None
username = f"{profile.username}"
logger.info(f"User logged in: {username}")
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
try:
agent = BasicAgent()
except Exception as e:
logger.error(f"Error instantiating agent: {e}")
return f"Error initializing agent: {e}", None
logger.info(f"Fetching questions from: {questions_url}")
try:
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
logger.error("Fetched questions list is empty.")
return "Fetched questions list is empty or invalid format.", None
logger.info(f"Fetched {len(questions_data)} questions.")
except Exception as e:
logger.error(f"Error fetching questions: {e}")
return f"Error fetching questions: {e}", None
results_log = []
answers_payload = []
logger.info(f"Running agent on {len(questions_data)} questions...")
for item in questions_data:
task_id = item.get("task_id")
question_text = item.get("question")
if not task_id or question_text is None:
logger.warning(f"Skipping item with missing task_id or question: {item}")
continue
try:
submitted_answer = agent(question_text, task_id)
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
except Exception as e:
logger.error(f"Error running agent on task {task_id}: {e}")
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
if not answers_payload:
logger.error("Agent did not produce any answers to submit.")
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}")
try:
response = requests.post(submit_url, json=submission_data, timeout=120)
response.raise_for_status()
result_data = response.json()
logger.info(f"Server response: {result_data}")
final_status = (
f"Submission Successful!\n"
f"User: {result_data.get('username')}\n"
f"Overall Score: {result_data.get('score', 'N/A')}% "
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
f"Message: {result_data.get('message', 'No message received.')}"
)
results_df = pd.DataFrame(results_log)
return final_status, results_df
except Exception as e:
logger.error(f"Submission failed: {e}")
results_df = pd.DataFrame(results_log)
return f"Submission Failed: {e}", results_df
# --- Build Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# JARVIS Agent Evaluation Runner")
gr.Markdown(
"""
**Instructions:**
1. Log in to your Hugging Face account using the button below.
2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run the JARVIS agent, and submit answers.
---
**Disclaimers:**
The agent uses a local Hugging Face model (Mixtral-7B) and async tools for the GAIA benchmark.
"""
)
gr.LoginButton()
run_button = gr.Button("Run Evaluation & Submit All Answers")
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
run_button.click(
fn=run_and_submit_all,
outputs=[status_output, results_table]
)
if __name__ == "__main__":
logger.info("\n" + "-"*30 + " App Starting " + "-"*30)
space_id = os.getenv("SPACE_ID")
logger.info(f"SPACE_ID: {space_id}")
logger.info("Launching Gradio Interface...")
demo.launch(debug=True, share=False)