Spaces:
Starting
Starting
import os | |
import json | |
import logging | |
import asyncio | |
import aiohttp | |
import nest_asyncio | |
import requests | |
import pandas as pd | |
from typing import Dict, Any, List | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langgraph.graph import StateGraph, END | |
from sentence_transformers import SentenceTransformer | |
import gradio as gr | |
from dotenv import load_dotenv | |
from huggingface_hub import InferenceClient | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import together | |
from state import JARVISState | |
from tools import ( | |
search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, | |
calculator_tool, document_retriever_tool, duckduckgo_search_tool, | |
weather_info_tool, hub_stats_tool, guest_info_retriever_tool | |
) | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Apply nest_asyncio | |
nest_asyncio.apply() | |
# Load environment variables | |
load_dotenv() | |
SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent") | |
GAIA_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
GAIA_FILE_URL = f"{GAIA_API_URL}/files/" | |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
# Verify environment variables | |
if not SPACE_ID: | |
raise ValueError("SPACE_ID not set") | |
if not HF_API_TOKEN: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set") | |
if not TOGETHER_API_KEY: | |
raise ValueError("TOGETHER_API_KEY not set") | |
logger.info(f"SPACE_ID: {SPACE_ID}") | |
# Model configuration | |
TOGETHER_MODELS = [ | |
"meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", | |
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free", | |
] | |
HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct" | |
# Initialize LLM clients | |
def initialize_llm(): | |
# Try Together AI models | |
for model in TOGETHER_MODELS: | |
try: | |
together.api_key = TOGETHER_API_KEY | |
client = together.Together() | |
# Test the model | |
response = client.chat.completions.create( | |
model=model, | |
messages=[{"role": "user", "content": "Test"}], | |
max_tokens=10 | |
) | |
logger.info(f"Initialized Together AI model: {model}") | |
return client, "together" | |
except Exception as e: | |
logger.warning(f"Failed to initialize Together AI model {model}: {e}") | |
# Fallback to Hugging Face Inference API | |
try: | |
client = InferenceClient( | |
model=HF_MODEL, | |
token=HF_API_TOKEN, | |
timeout=30 | |
) | |
logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}") | |
return client, "hf_api" | |
except Exception as e: | |
logger.warning(f"Failed to initialize HF Inference API: {e}") | |
# Fallback to local Hugging Face model | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto") | |
logger.info(f"Initialized local Hugging Face model: {HF_MODEL}") | |
return (model, tokenizer), "hf_local" | |
except Exception as e: | |
logger.error(f"Failed to initialize local HF model: {e}") | |
raise Exception("No LLM could be initialized") | |
llm_client, llm_type = initialize_llm() | |
# Initialize embedder | |
try: | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
logger.info("Sentence transformer initialized") | |
except Exception as e: | |
logger.error(f"Failed to initialize embedder: {e}") | |
embedder = None | |
# Download file with local fallback | |
async def download_file(task_id: str, ext: str) -> str | None: | |
try: | |
url = f"{GAIA_FILE_URL}{task_id}.{ext}" | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url, timeout=10) as resp: | |
logger.info(f"GAIA API test for task {task_id} with .{ext}: HTTP {resp.status}") | |
if resp.status == 200: | |
os.makedirs("temp", exist_ok=True) | |
file_path = f"temp/{task_id}.{ext}" | |
with open(file_path, "wb") as f: | |
f.write(await resp.read()) | |
return file_path | |
except Exception as e: | |
logger.warning(f"File download failed for {task_id}.{ext}: {e}") | |
local_path = f"temp/{task_id}.{ext}" | |
if os.path.exists(local_path): | |
logger.info(f"Using local file: {local_path}") | |
return local_path | |
return None | |
# Parse question to select tools | |
async def parse_question(state: JARVISState) -> JARVISState: | |
try: | |
question = state["question"] | |
task_id = state["task_id"] | |
tools_needed = ["search_tool"] | |
if llm_client: | |
prompt = ChatPromptTemplate.from_messages([ | |
SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool']. | |
Return JSON list, e.g., ["search_tool", "file_parser_tool"]. | |
Rules: | |
- Always include "search_tool" unless purely computational. | |
- Use "multi_hop_search_tool" for complex queries (over 20 words or requiring multiple steps). | |
- Use "file_parser_tool" for data, tables, or Excel. | |
- Use "image_parser_tool" for images/videos. | |
- Use "calculator_tool" for math calculations. | |
- Use "document_retriever_tool" for documents/PDFs. | |
- Use "duckduckgo_search_tool" for additional search capability. | |
- Use "weather_info_tool" for weather-related queries. | |
- Use "hub_stats_tool" for Hugging Face Hub queries. | |
- Use "guest_info_retriever_tool" for guest-related queries. | |
- Output ONLY valid JSON."""), | |
HumanMessage(content=f"Query: {question}") | |
]) | |
try: | |
if llm_type == "hf_local": | |
model, tokenizer = llm_client | |
inputs = tokenizer.apply_chat_template( | |
[{"role": "system", "content": prompt[0].content}, {"role": "user", "content": prompt[1].content}], | |
return_tensors="pt" | |
).to(model.device) | |
outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
tools_needed = json.loads(response.strip()) | |
elif llm_type == "together": | |
response = llm_client.chat.completions.create( | |
model=llm_client.model, | |
messages=[ | |
{"role": "system", "content": prompt[0].content}, | |
{"role": "user", "content": prompt[1].content} | |
], | |
max_tokens=512, | |
temperature=0.7 | |
) | |
tools_needed = json.loads(response.choices[0].message.content.strip()) | |
else: # hf_api | |
response = llm_client.chat.completions.create( | |
model=HF_MODEL, | |
messages=[ | |
{"role": "system", "content": prompt[0].content}, | |
{"role": "user", "content": prompt[1].content} | |
], | |
max_tokens=512, | |
temperature=0.7 | |
) | |
tools_needed = json.loads(response.choices[0].message.content.strip()) | |
valid_tools = { | |
"search_tool", "multi_hop_search_tool", "file_parser_tool", "image_parser_tool", | |
"calculator_tool", "document_retriever_tool", "duckduckgo_search_tool", | |
"weather_info_tool", "hub_stats_tool", "guest_info_retriever_tool" | |
} | |
tools_needed = [tool for tool in tools_needed if tool in valid_tools] | |
except Exception as e: | |
logger.warning(f"Task {task_id} tool selection failed: {e}") | |
state["error"] = f"Tool selection failed: {str(e)}" | |
# Keyword-based fallback | |
question_lower = question.lower() | |
if any(word in question_lower for word in ["image", "video", "picture"]): | |
tools_needed.append("image_parser_tool") | |
if any(word in question_lower for word in ["data", "table", "excel", ".txt", ".csv", ".xlsx"]): | |
tools_needed.append("file_parser_tool") | |
if any(word in question_lower for word in ["calculate", "math", "sum", "average", "total"]): | |
tools_needed.append("calculator_tool") | |
if any(word in question_lower for word in ["document", "pdf", "report", "menu"]): | |
tools_needed.append("document_retriever_tool") | |
if any(word in question_lower for word in ["weather", "temperature"]): | |
tools_needed.append("weather_info_tool") | |
if any(word in question_lower for word in ["model", "huggingface", "dataset"]): | |
tools_needed.append("hub_stats_tool") | |
if any(word in question_lower for word in ["guest", "name", "relation", "person"]): | |
tools_needed.append("guest_info_retriever_tool") | |
if len(question.split()) > 20 or "multiple" in question_lower: | |
tools_needed.append("multi_hop_search_tool") | |
if any(word in question_lower for word in ["search", "wikipedia", "online"]): | |
tools_needed.append("duckduckgo_search_tool") | |
# Check file availability | |
for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]: | |
file_path = await download_file(task_id, ext) | |
if file_path: | |
if ext in ["txt", "csv", "xlsx"] and "file_parser_tool" not in tools_needed: | |
tools_needed.append("file_parser_tool") | |
if ext == "jpg" and "image_parser_tool" not in tools_needed: | |
tools_needed.append("image_parser_tool") | |
if ext == "pdf" and "document_retriever_tool" not in tools_needed: | |
tools_needed.append("document_retriever_tool") | |
state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path} | |
break | |
state["tools_needed"] = list(set(tools_needed)) | |
logger.info(f"Task {task_id}: Selected tools: {tools_needed}") | |
return state | |
except Exception as e: | |
logger.error(f"Error parsing task {task_id}: {e}") | |
state["error"] = f"Parse question failed: {str(e)}" | |
state["tools_needed"] = ["search_tool"] | |
return state | |
# Tool dispatcher | |
async def tool_dispatcher(state: JARVISState) -> JARVISState: | |
try: | |
updated_state = state.copy() | |
file_type = "jpg" if "image" in state["question"].lower() else "txt" | |
if any(word in state["question"].lower() for word in ["menu", "report"]): | |
file_type = "pdf" | |
elif "data" in state["question"].lower(): | |
file_type = "xlsx" | |
for tool in updated_state["tools_needed"]: | |
try: | |
if tool == "search_tool": | |
result = search_tool(updated_state["question"]) | |
updated_state["web_results"].extend([str(r) for r in result]) | |
elif tool == "multi_hop_search_tool": | |
result = await multi_hop_search_tool.ainvoke({"query": updated_state["question"], "steps": 3, "llm_client": llm_client, "llm_type": llm_type}) | |
updated_state["multi_hop_results"].extend([r["content"] for r in result]) | |
await asyncio.sleep(2) | |
elif tool == "file_parser_tool": | |
for ext in ["txt", "csv", "xlsx"]: | |
file_path = await download_file(updated_state["task_id"], ext) | |
if file_path: | |
result = file_parser_tool(file_path) | |
updated_state["file_results"] = str(result) | |
break | |
elif tool == "image_parser_tool": | |
file_path = await download_file(updated_state["task_id"], "jpg") | |
if file_path: | |
result = image_parser_tool(file_path) | |
updated_state["image_results"] = str(result) | |
elif tool == "calculator_tool": | |
result = calculator_tool(updated_state["question"]) | |
updated_state["calculation_results"] = str(result) | |
elif tool == "document_retriever_tool": | |
file_path = await download_file(updated_state["task_id"], "pdf") | |
if file_path: | |
result = document_retriever_tool({"task_id": updated_state["task_id"], "query": updated_state["question"], "file_type": "pdf"}) | |
updated_state["document_results"] = str(result) | |
elif tool == "duckduckgo_search_tool": | |
result = duckduckgo_search_tool(updated_state["question"]) | |
updated_state["web_results"].append(str(result)) | |
elif tool == "weather_info_tool": | |
location = updated_state["question"].split("weather in ")[1].split()[0] if "weather in" in updated_state["question"].lower() else "Unknown" | |
result = weather_info_tool({"location": location}) | |
updated_state["web_results"].append(str(result)) | |
elif tool == "hub_stats_tool": | |
author = updated_state["question"].split("by ")[1].split()[0] if "by" in updated_state["question"].lower() else "Unknown" | |
result = hub_stats_tool({"author": author}) | |
updated_state["web_results"].append(str(result)) | |
elif tool == "guest_info_retriever_tool": | |
query = updated_state["question"].split("about ")[1] if "about" in updated_state["question"].lower() else updated_state["question"] | |
result = guest_info_retriever_tool({"query": query}) | |
updated_state["web_results"].append(str(result)) | |
updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_executed": True} | |
except Exception as e: | |
logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {str(e)}") | |
updated_state["error"] = f"Tool {tool} failed: {str(e)}" | |
updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_error": str(e)} | |
logger.info(f"Task {updated_state['task_id']}: Tool results: {updated_state}") | |
return updated_state | |
except Exception as e: | |
logger.error(f"Tool dispatch failed for task {state['task_id']}: {e}") | |
updated_state["error"] = f"Tool dispatch failed: {str(e)}" | |
return updated_state | |
# Reasoning | |
async def reasoning(state: JARVISState) -> Dict[str, Any]: | |
try: | |
prompt = ChatPromptTemplate.from_messages([ | |
SystemMessage(content="""Provide ONLY the exact answer (e.g., '90', 'HUE'). For USD, use two decimal places (e.g., '1234.00'). For lists, use comma-separated values (e.g., 'Smith, Lee'). For IOC codes, use three-letter codes (e.g., 'ARG'). No explanations or conversational text."""), | |
HumanMessage(content="""Task: {task_id} | |
Question: {question} | |
Web results: {web_results} | |
Multi-hop results: {multi_hop_results} | |
File results: {file_results} | |
Image results: {image_results} | |
Calculation results: {calculation_results} | |
Document results: {document_results}""") | |
]) | |
messages = [ | |
{"role": "system", "content": prompt[0].content}, | |
{"role": "user", "content": prompt[1].content.format( | |
task_id=state["task_id"], | |
question=state["question"], | |
web_results="\n".join(state["web_results"]), | |
multi_hop_results="\n".join(state["multi_hop_results"]), | |
file_results=state["file_results"], | |
image_results=state["image_results"], | |
calculation_results=state["calculation_results"], | |
document_results=state["document_results"] | |
)} | |
] | |
for attempt in range(3): | |
try: | |
if llm_type == "hf_local": | |
model, tokenizer = llm_client | |
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device) | |
outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
elif llm_type == "together": | |
response = llm_client.chat.completions.create( | |
model=llm_client.model, | |
messages=messages, | |
max_tokens=512, | |
temperature=0.7 | |
) | |
answer = response.choices[0].message.content.strip() | |
else: # hf_api | |
response = llm_client.chat.completions.create( | |
model=HF_MODEL, | |
messages=messages, | |
max_tokens=512, | |
temperature=0.7 | |
) | |
answer = response.choices[0].message.content.strip() | |
# Format answer | |
if "USD" in state["question"].lower(): | |
try: | |
answer = f"{float(answer):.2f}" | |
except ValueError: | |
pass | |
if "before and after" in state["question"].lower(): | |
answer = answer.replace(" and ", ", ") | |
if "IOC code" in state["question"].lower(): | |
answer = answer.upper()[:3] | |
logger.info(f"Task {state['task_id']}: Answer: {answer}") | |
return {"answer": answer} | |
except Exception as e: | |
logger.warning(f"LLM retry {attempt + 1}/3 for task {state['task_id']}: {e}") | |
await asyncio.sleep(2) | |
state["error"] = "LLM failed after retries" | |
return {"answer": "Error: LLM failed after retries"} | |
except Exception as e: | |
logger.error(f"Reasoning failed for task {state['task_id']}: {e}") | |
state["error"] = f"Reasoning failed: {str(e)}" | |
return {"answer": f"Error: {str(e)}"} | |
# Router | |
def router(state: JARVISState) -> str: | |
if state["tools_needed"]: | |
return "tool_dispatcher" | |
return "reasoning" | |
# Define StateGraph | |
workflow = StateGraph(JARVISState) | |
workflow.add_node("parse", parse_question) | |
workflow.add_node("tool_dispatcher", tool_dispatcher) | |
workflow.add_node("reasoning", reasoning) | |
workflow.set_entry_point("parse") | |
workflow.add_conditional_edges( | |
"parse", | |
router, | |
{ | |
"tool_dispatcher": "tool_dispatcher", | |
"reasoning": "reasoning" | |
} | |
) | |
workflow.add_edge("tool_dispatcher", "reasoning") | |
workflow.add_edge("reasoning", END) | |
graph = workflow.compile() | |
# Agent class | |
class JARVISAgent: | |
def __init__(self): | |
self.state = JARVISState( | |
task_id="", | |
question="", | |
tools_needed=[], | |
web_results=[], | |
file_results="", | |
image_results="", | |
calculation_results="", | |
document_results="", | |
multi_hop_results=[], | |
messages=[], | |
answer="", | |
results_table=[], | |
status_output="", | |
error=None, | |
metadata={} | |
) | |
logger.info("JARVISAgent initialized.") | |
async def process_question(self, task_id: str, question: str) -> str: | |
state = JARVISState( | |
task_id=task_id, | |
question=question, | |
tools_needed=["search_tool"], | |
web_results=[], | |
file_results="", | |
image_results="", | |
calculation_results="", | |
document_results="", | |
multi_hop_results=[], | |
messages=[HumanMessage(content=question)], | |
answer="", | |
results_table=[], | |
status_output="", | |
error=None, | |
metadata={} | |
) | |
try: | |
result = await graph.ainvoke(state) | |
answer = result["answer"] or "Unknown" | |
logger.info(f"Task {task_id}: Final answer: {answer}") | |
self.state.results_table.append({"Task ID": task_id, "Question": question, "Answer": answer}) | |
self.state.metadata = self.state.get("metadata", {}) | {"last_task": task_id, "answer": answer} | |
return answer | |
except Exception as e: | |
logger.error(f"Error processing task {task_id}: {e}") | |
self.state.results_table.append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"}) | |
self.state.error = f"Task {task_id} failed: {str(e)}" | |
return f"Error: {str(e)}" | |
finally: | |
for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]: | |
file_path = f"temp/{task_id}.{ext}" | |
if os.path.exists(file_path): | |
try: | |
os.remove(file_path) | |
logger.info(f"Removed temp file: {file_path}") | |
except Exception as e: | |
logger.error(f"Error removing file {file_path}: {e}") | |
async def process_all_questions(self, profile: gr.OAuthProfile | None): | |
if not profile: | |
logger.error("User not logged in.") | |
self.state.status_output = "Please Login to Hugging Face." | |
return pd.DataFrame(self.state.results_table), self.state.status_output | |
username = f"{profile.username}" | |
logger.info(f"User logged in: {username}") | |
questions_url = f"{GAIA_API_URL}/questions" | |
submit_url = f"{GAIA_API_URL}/submit" | |
agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main" | |
try: | |
response = requests.get(questions_url, timeout=15) | |
response.raise_for_status() | |
questions = response.json() | |
logger.info(f"Fetched {len(questions)} questions.") | |
except Exception as e: | |
logger.error(f"Error fetching questions: {e}") | |
self.state.status_output = f"Error fetching questions: {e}" | |
self.state.error = f"Fetch questions failed: {str(e)}" | |
return pd.DataFrame(self.state.results_table), self.state.status_output | |
answers_payload = [] | |
for item in questions: | |
task_id = item.get("task_id") | |
question = item.get("question") | |
if not task_id or not question: | |
logger.warning(f"Skipping invalid item: {item}") | |
continue | |
answer = await self.process_question(task_id, question) | |
answers_payload.append({"task_id": task_id, "submitted_answer": answer}) | |
if not answers_payload: | |
logger.error("No answers generated.") | |
self.state.status_output = "No answers to submit." | |
self.state.error = "No answers generated" | |
return pd.DataFrame(self.state.results_table), self.state.status_output | |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
try: | |
response = requests.post(submit_url, json=submission_data, timeout=120) | |
response.raise_for_status() | |
result_data = response.json() | |
self.state.status_output = ( | |
f"Submission Successful!\n" | |
f"User: {result_data.get('username')}\n" | |
f"Overall Score: {result_data.get('score', 'N/A')}% " | |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
f"Message: {result_data.get('message', 'No message received.')}" | |
) | |
self.state.metadata = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')} | |
except Exception as e: | |
logger.error(f"Submission failed: {e}") | |
self.state.status_output = f"Submission Failed: {e}" | |
self.state.error = f"Submission failed: {str(e)}" | |
return pd.DataFrame(self.state.results_table), self.state.status_output | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Evolved JARVIS GAIA Agent") | |
gr.Markdown( | |
""" | |
**Instructions:** | |
1. Log in to Hugging Face using the button below. | |
2. Click 'Run Evaluation & Submit All Answers' to process GAIA questions and submit. | |
--- | |
**Disclaimers:** | |
Uses Hugging Face Inference, Together AI, SERPAPI, and OpenWeatherMap for GAIA benchmark. | |
""" | |
) | |
with gr.Row(): | |
gr.LoginButton(value="Login to Hugging Face") | |
# Removed gr.LogoutButton due to deprecation | |
run_button = gr.Button("Run Evaluation & Submit All Answers") | |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) | |
results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"]) | |
agent = JARVISAgent() | |
run_button.click( | |
fn=agent.process_all_questions, | |
outputs=[results_table, status_output] | |
) | |
if __name__ == "__main__": | |
logger.info("\n" + "-"*30 + " App Starting " + "-"*30) | |
logger.info(f"SPACE_ID: {SPACE_ID}") | |
logger.info("Launching Gradio Interface...") | |
demo.launch(debug=True, share=False) |