onisj's picture
docs(readme): readme updated
f95c630
raw
history blame
25.9 kB
import os
import json
import logging
import asyncio
import aiohttp
import nest_asyncio
import requests
import pandas as pd
from typing import Dict, Any, List
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, END
from sentence_transformers import SentenceTransformer
import gradio as gr
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import together
from state import JARVISState
from tools import (
search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool,
calculator_tool, document_retriever_tool, duckduckgo_search_tool,
weather_info_tool, hub_stats_tool, guest_info_retriever_tool
)
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Apply nest_asyncio
nest_asyncio.apply()
# Load environment variables
load_dotenv()
SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
GAIA_API_URL = "https://agents-course-unit4-scoring.hf.space"
GAIA_FILE_URL = f"{GAIA_API_URL}/files/"
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Verify environment variables
if not SPACE_ID:
raise ValueError("SPACE_ID not set")
if not HF_API_TOKEN:
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
if not TOGETHER_API_KEY:
raise ValueError("TOGETHER_API_KEY not set")
logger.info(f"SPACE_ID: {SPACE_ID}")
# Model configuration
TOGETHER_MODELS = [
"meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
]
HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
# Initialize LLM clients
def initialize_llm():
# Try Together AI models
for model in TOGETHER_MODELS:
try:
together.api_key = TOGETHER_API_KEY
client = together.Together()
# Test the model
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Test"}],
max_tokens=10
)
logger.info(f"Initialized Together AI model: {model}")
return client, "together"
except Exception as e:
logger.warning(f"Failed to initialize Together AI model {model}: {e}")
# Fallback to Hugging Face Inference API
try:
client = InferenceClient(
model=HF_MODEL,
token=HF_API_TOKEN,
timeout=30
)
logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
return client, "hf_api"
except Exception as e:
logger.warning(f"Failed to initialize HF Inference API: {e}")
# Fallback to local Hugging Face model
try:
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
return (model, tokenizer), "hf_local"
except Exception as e:
logger.error(f"Failed to initialize local HF model: {e}")
raise Exception("No LLM could be initialized")
llm_client, llm_type = initialize_llm()
# Initialize embedder
try:
embedder = SentenceTransformer("all-MiniLM-L6-v2")
logger.info("Sentence transformer initialized")
except Exception as e:
logger.error(f"Failed to initialize embedder: {e}")
embedder = None
# Download file with local fallback
async def download_file(task_id: str, ext: str) -> str | None:
try:
url = f"{GAIA_FILE_URL}{task_id}.{ext}"
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=10) as resp:
logger.info(f"GAIA API test for task {task_id} with .{ext}: HTTP {resp.status}")
if resp.status == 200:
os.makedirs("temp", exist_ok=True)
file_path = f"temp/{task_id}.{ext}"
with open(file_path, "wb") as f:
f.write(await resp.read())
return file_path
except Exception as e:
logger.warning(f"File download failed for {task_id}.{ext}: {e}")
local_path = f"temp/{task_id}.{ext}"
if os.path.exists(local_path):
logger.info(f"Using local file: {local_path}")
return local_path
return None
# Parse question to select tools
async def parse_question(state: JARVISState) -> JARVISState:
try:
question = state["question"]
task_id = state["task_id"]
tools_needed = ["search_tool"]
if llm_client:
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
Return JSON list, e.g., ["search_tool", "file_parser_tool"].
Rules:
- Always include "search_tool" unless purely computational.
- Use "multi_hop_search_tool" for complex queries (over 20 words or requiring multiple steps).
- Use "file_parser_tool" for data, tables, or Excel.
- Use "image_parser_tool" for images/videos.
- Use "calculator_tool" for math calculations.
- Use "document_retriever_tool" for documents/PDFs.
- Use "duckduckgo_search_tool" for additional search capability.
- Use "weather_info_tool" for weather-related queries.
- Use "hub_stats_tool" for Hugging Face Hub queries.
- Use "guest_info_retriever_tool" for guest-related queries.
- Output ONLY valid JSON."""),
HumanMessage(content=f"Query: {question}")
])
try:
if llm_type == "hf_local":
model, tokenizer = llm_client
inputs = tokenizer.apply_chat_template(
[{"role": "system", "content": prompt[0].content}, {"role": "user", "content": prompt[1].content}],
return_tensors="pt"
).to(model.device)
outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
tools_needed = json.loads(response.strip())
elif llm_type == "together":
response = llm_client.chat.completions.create(
model=llm_client.model,
messages=[
{"role": "system", "content": prompt[0].content},
{"role": "user", "content": prompt[1].content}
],
max_tokens=512,
temperature=0.7
)
tools_needed = json.loads(response.choices[0].message.content.strip())
else: # hf_api
response = llm_client.chat.completions.create(
model=HF_MODEL,
messages=[
{"role": "system", "content": prompt[0].content},
{"role": "user", "content": prompt[1].content}
],
max_tokens=512,
temperature=0.7
)
tools_needed = json.loads(response.choices[0].message.content.strip())
valid_tools = {
"search_tool", "multi_hop_search_tool", "file_parser_tool", "image_parser_tool",
"calculator_tool", "document_retriever_tool", "duckduckgo_search_tool",
"weather_info_tool", "hub_stats_tool", "guest_info_retriever_tool"
}
tools_needed = [tool for tool in tools_needed if tool in valid_tools]
except Exception as e:
logger.warning(f"Task {task_id} tool selection failed: {e}")
state["error"] = f"Tool selection failed: {str(e)}"
# Keyword-based fallback
question_lower = question.lower()
if any(word in question_lower for word in ["image", "video", "picture"]):
tools_needed.append("image_parser_tool")
if any(word in question_lower for word in ["data", "table", "excel", ".txt", ".csv", ".xlsx"]):
tools_needed.append("file_parser_tool")
if any(word in question_lower for word in ["calculate", "math", "sum", "average", "total"]):
tools_needed.append("calculator_tool")
if any(word in question_lower for word in ["document", "pdf", "report", "menu"]):
tools_needed.append("document_retriever_tool")
if any(word in question_lower for word in ["weather", "temperature"]):
tools_needed.append("weather_info_tool")
if any(word in question_lower for word in ["model", "huggingface", "dataset"]):
tools_needed.append("hub_stats_tool")
if any(word in question_lower for word in ["guest", "name", "relation", "person"]):
tools_needed.append("guest_info_retriever_tool")
if len(question.split()) > 20 or "multiple" in question_lower:
tools_needed.append("multi_hop_search_tool")
if any(word in question_lower for word in ["search", "wikipedia", "online"]):
tools_needed.append("duckduckgo_search_tool")
# Check file availability
for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
file_path = await download_file(task_id, ext)
if file_path:
if ext in ["txt", "csv", "xlsx"] and "file_parser_tool" not in tools_needed:
tools_needed.append("file_parser_tool")
if ext == "jpg" and "image_parser_tool" not in tools_needed:
tools_needed.append("image_parser_tool")
if ext == "pdf" and "document_retriever_tool" not in tools_needed:
tools_needed.append("document_retriever_tool")
state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path}
break
state["tools_needed"] = list(set(tools_needed))
logger.info(f"Task {task_id}: Selected tools: {tools_needed}")
return state
except Exception as e:
logger.error(f"Error parsing task {task_id}: {e}")
state["error"] = f"Parse question failed: {str(e)}"
state["tools_needed"] = ["search_tool"]
return state
# Tool dispatcher
async def tool_dispatcher(state: JARVISState) -> JARVISState:
try:
updated_state = state.copy()
file_type = "jpg" if "image" in state["question"].lower() else "txt"
if any(word in state["question"].lower() for word in ["menu", "report"]):
file_type = "pdf"
elif "data" in state["question"].lower():
file_type = "xlsx"
for tool in updated_state["tools_needed"]:
try:
if tool == "search_tool":
result = search_tool(updated_state["question"])
updated_state["web_results"].extend([str(r) for r in result])
elif tool == "multi_hop_search_tool":
result = await multi_hop_search_tool.ainvoke({"query": updated_state["question"], "steps": 3, "llm_client": llm_client, "llm_type": llm_type})
updated_state["multi_hop_results"].extend([r["content"] for r in result])
await asyncio.sleep(2)
elif tool == "file_parser_tool":
for ext in ["txt", "csv", "xlsx"]:
file_path = await download_file(updated_state["task_id"], ext)
if file_path:
result = file_parser_tool(file_path)
updated_state["file_results"] = str(result)
break
elif tool == "image_parser_tool":
file_path = await download_file(updated_state["task_id"], "jpg")
if file_path:
result = image_parser_tool(file_path)
updated_state["image_results"] = str(result)
elif tool == "calculator_tool":
result = calculator_tool(updated_state["question"])
updated_state["calculation_results"] = str(result)
elif tool == "document_retriever_tool":
file_path = await download_file(updated_state["task_id"], "pdf")
if file_path:
result = document_retriever_tool({"task_id": updated_state["task_id"], "query": updated_state["question"], "file_type": "pdf"})
updated_state["document_results"] = str(result)
elif tool == "duckduckgo_search_tool":
result = duckduckgo_search_tool(updated_state["question"])
updated_state["web_results"].append(str(result))
elif tool == "weather_info_tool":
location = updated_state["question"].split("weather in ")[1].split()[0] if "weather in" in updated_state["question"].lower() else "Unknown"
result = weather_info_tool({"location": location})
updated_state["web_results"].append(str(result))
elif tool == "hub_stats_tool":
author = updated_state["question"].split("by ")[1].split()[0] if "by" in updated_state["question"].lower() else "Unknown"
result = hub_stats_tool({"author": author})
updated_state["web_results"].append(str(result))
elif tool == "guest_info_retriever_tool":
query = updated_state["question"].split("about ")[1] if "about" in updated_state["question"].lower() else updated_state["question"]
result = guest_info_retriever_tool({"query": query})
updated_state["web_results"].append(str(result))
updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_executed": True}
except Exception as e:
logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {str(e)}")
updated_state["error"] = f"Tool {tool} failed: {str(e)}"
updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_error": str(e)}
logger.info(f"Task {updated_state['task_id']}: Tool results: {updated_state}")
return updated_state
except Exception as e:
logger.error(f"Tool dispatch failed for task {state['task_id']}: {e}")
updated_state["error"] = f"Tool dispatch failed: {str(e)}"
return updated_state
# Reasoning
async def reasoning(state: JARVISState) -> Dict[str, Any]:
try:
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="""Provide ONLY the exact answer (e.g., '90', 'HUE'). For USD, use two decimal places (e.g., '1234.00'). For lists, use comma-separated values (e.g., 'Smith, Lee'). For IOC codes, use three-letter codes (e.g., 'ARG'). No explanations or conversational text."""),
HumanMessage(content="""Task: {task_id}
Question: {question}
Web results: {web_results}
Multi-hop results: {multi_hop_results}
File results: {file_results}
Image results: {image_results}
Calculation results: {calculation_results}
Document results: {document_results}""")
])
messages = [
{"role": "system", "content": prompt[0].content},
{"role": "user", "content": prompt[1].content.format(
task_id=state["task_id"],
question=state["question"],
web_results="\n".join(state["web_results"]),
multi_hop_results="\n".join(state["multi_hop_results"]),
file_results=state["file_results"],
image_results=state["image_results"],
calculation_results=state["calculation_results"],
document_results=state["document_results"]
)}
]
for attempt in range(3):
try:
if llm_type == "hf_local":
model, tokenizer = llm_client
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
elif llm_type == "together":
response = llm_client.chat.completions.create(
model=llm_client.model,
messages=messages,
max_tokens=512,
temperature=0.7
)
answer = response.choices[0].message.content.strip()
else: # hf_api
response = llm_client.chat.completions.create(
model=HF_MODEL,
messages=messages,
max_tokens=512,
temperature=0.7
)
answer = response.choices[0].message.content.strip()
# Format answer
if "USD" in state["question"].lower():
try:
answer = f"{float(answer):.2f}"
except ValueError:
pass
if "before and after" in state["question"].lower():
answer = answer.replace(" and ", ", ")
if "IOC code" in state["question"].lower():
answer = answer.upper()[:3]
logger.info(f"Task {state['task_id']}: Answer: {answer}")
return {"answer": answer}
except Exception as e:
logger.warning(f"LLM retry {attempt + 1}/3 for task {state['task_id']}: {e}")
await asyncio.sleep(2)
state["error"] = "LLM failed after retries"
return {"answer": "Error: LLM failed after retries"}
except Exception as e:
logger.error(f"Reasoning failed for task {state['task_id']}: {e}")
state["error"] = f"Reasoning failed: {str(e)}"
return {"answer": f"Error: {str(e)}"}
# Router
def router(state: JARVISState) -> str:
if state["tools_needed"]:
return "tool_dispatcher"
return "reasoning"
# Define StateGraph
workflow = StateGraph(JARVISState)
workflow.add_node("parse", parse_question)
workflow.add_node("tool_dispatcher", tool_dispatcher)
workflow.add_node("reasoning", reasoning)
workflow.set_entry_point("parse")
workflow.add_conditional_edges(
"parse",
router,
{
"tool_dispatcher": "tool_dispatcher",
"reasoning": "reasoning"
}
)
workflow.add_edge("tool_dispatcher", "reasoning")
workflow.add_edge("reasoning", END)
graph = workflow.compile()
# Agent class
class JARVISAgent:
def __init__(self):
self.state = JARVISState(
task_id="",
question="",
tools_needed=[],
web_results=[],
file_results="",
image_results="",
calculation_results="",
document_results="",
multi_hop_results=[],
messages=[],
answer="",
results_table=[],
status_output="",
error=None,
metadata={}
)
logger.info("JARVISAgent initialized.")
async def process_question(self, task_id: str, question: str) -> str:
state = JARVISState(
task_id=task_id,
question=question,
tools_needed=["search_tool"],
web_results=[],
file_results="",
image_results="",
calculation_results="",
document_results="",
multi_hop_results=[],
messages=[HumanMessage(content=question)],
answer="",
results_table=[],
status_output="",
error=None,
metadata={}
)
try:
result = await graph.ainvoke(state)
answer = result["answer"] or "Unknown"
logger.info(f"Task {task_id}: Final answer: {answer}")
self.state.results_table.append({"Task ID": task_id, "Question": question, "Answer": answer})
self.state.metadata = self.state.get("metadata", {}) | {"last_task": task_id, "answer": answer}
return answer
except Exception as e:
logger.error(f"Error processing task {task_id}: {e}")
self.state.results_table.append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"})
self.state.error = f"Task {task_id} failed: {str(e)}"
return f"Error: {str(e)}"
finally:
for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
file_path = f"temp/{task_id}.{ext}"
if os.path.exists(file_path):
try:
os.remove(file_path)
logger.info(f"Removed temp file: {file_path}")
except Exception as e:
logger.error(f"Error removing file {file_path}: {e}")
async def process_all_questions(self, profile: gr.OAuthProfile | None):
if not profile:
logger.error("User not logged in.")
self.state.status_output = "Please Login to Hugging Face."
return pd.DataFrame(self.state.results_table), self.state.status_output
username = f"{profile.username}"
logger.info(f"User logged in: {username}")
questions_url = f"{GAIA_API_URL}/questions"
submit_url = f"{GAIA_API_URL}/submit"
agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
try:
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions = response.json()
logger.info(f"Fetched {len(questions)} questions.")
except Exception as e:
logger.error(f"Error fetching questions: {e}")
self.state.status_output = f"Error fetching questions: {e}"
self.state.error = f"Fetch questions failed: {str(e)}"
return pd.DataFrame(self.state.results_table), self.state.status_output
answers_payload = []
for item in questions:
task_id = item.get("task_id")
question = item.get("question")
if not task_id or not question:
logger.warning(f"Skipping invalid item: {item}")
continue
answer = await self.process_question(task_id, question)
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
if not answers_payload:
logger.error("No answers generated.")
self.state.status_output = "No answers to submit."
self.state.error = "No answers generated"
return pd.DataFrame(self.state.results_table), self.state.status_output
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
try:
response = requests.post(submit_url, json=submission_data, timeout=120)
response.raise_for_status()
result_data = response.json()
self.state.status_output = (
f"Submission Successful!\n"
f"User: {result_data.get('username')}\n"
f"Overall Score: {result_data.get('score', 'N/A')}% "
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
f"Message: {result_data.get('message', 'No message received.')}"
)
self.state.metadata = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')}
except Exception as e:
logger.error(f"Submission failed: {e}")
self.state.status_output = f"Submission Failed: {e}"
self.state.error = f"Submission failed: {str(e)}"
return pd.DataFrame(self.state.results_table), self.state.status_output
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Evolved JARVIS GAIA Agent")
gr.Markdown(
"""
**Instructions:**
1. Log in to Hugging Face using the button below.
2. Click 'Run Evaluation & Submit All Answers' to process GAIA questions and submit.
---
**Disclaimers:**
Uses Hugging Face Inference, Together AI, SERPAPI, and OpenWeatherMap for GAIA benchmark.
"""
)
with gr.Row():
gr.LoginButton(value="Login to Hugging Face")
# Removed gr.LogoutButton due to deprecation
run_button = gr.Button("Run Evaluation & Submit All Answers")
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
agent = JARVISAgent()
run_button.click(
fn=agent.process_all_questions,
outputs=[results_table, status_output]
)
if __name__ == "__main__":
logger.info("\n" + "-"*30 + " App Starting " + "-"*30)
logger.info(f"SPACE_ID: {SPACE_ID}")
logger.info("Launching Gradio Interface...")
demo.launch(debug=True, share=False)