onisj commited on
Commit
c6951f4
·
1 Parent(s): 4701375

Rewrite app.py and search.py with multi-hop LLM refinement

Browse files
Files changed (6) hide show
  1. app.py +310 -251
  2. requirements.txt +6 -1
  3. result.txt +0 -0
  4. state.py +29 -4
  5. test.py +7 -0
  6. tools/search.py +99 -41
app.py CHANGED
@@ -14,6 +14,7 @@ from sentence_transformers import SentenceTransformer
14
  import gradio as gr
15
  from dotenv import load_dotenv
16
  from huggingface_hub import InferenceClient
 
17
  from state import JARVISState
18
  from tools import (
19
  search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool,
@@ -33,27 +34,68 @@ load_dotenv()
33
  SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
34
  GAIA_API_URL = "https://agents-course-unit4-scoring.hf.space"
35
  GAIA_FILE_URL = f"{GAIA_API_URL}/files/"
36
- HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
37
 
38
  # Verify environment variables
39
  if not SPACE_ID:
40
  raise ValueError("SPACE_ID not set")
41
- if not HF_TOKEN:
42
  raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
 
 
43
  logger.info(f"SPACE_ID: {SPACE_ID}")
44
 
45
- # Initialize models
46
- try:
47
- llm = InferenceClient(
48
- model="meta-llama/Meta-Llama-3-8B-Instruct",
49
- token=HF_TOKEN,
50
- timeout=30
51
- )
52
- logger.info("Hugging Face Inference LLM initialized")
53
- except Exception as e:
54
- logger.error(f"Failed to initialize LLM: {e}")
55
- llm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
57
  try:
58
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
59
  logger.info("Sentence transformer initialized")
@@ -61,40 +103,41 @@ except Exception as e:
61
  logger.error(f"Failed to initialize embedder: {e}")
62
  embedder = None
63
 
64
- # --- Helper Functions ---
65
- async def test_gaia_api(task_id: str, file_type: str = "txt") -> tuple[bool, str | None]:
66
- """Test if a file exists for the task ID."""
67
  try:
68
- for ext in [file_type, "txt", "csv", "xlsx", "jpg", "pdf"]:
69
- async with aiohttp.ClientSession() as session:
70
- async with session.get(f"{GAIA_FILE_URL}{task_id}.{ext}", timeout=5) as resp:
71
- logger.info(f"GAIA API test for task {task_id} with .{ext}: HTTP {resp.status}")
72
- if resp.status == 200:
73
- file_path = f"temp_{task_id}.{ext}"
74
- with open(file_path, "wb") as f:
75
- f.write(await resp.read())
76
- return True, ext
77
- logger.info(f"No file found for task {task_id}")
78
- return False, None
79
  except Exception as e:
80
- logger.warning(f"GAIA API test failed: {str(e)}")
81
- return False, None
82
-
83
- # --- Node Functions ---
84
- async def parse_question(state: Dict[str, Any]) -> Dict[str, Any]:
85
- """Parse the question to select appropriate tools."""
 
 
 
86
  try:
87
  question = state["question"]
88
  task_id = state["task_id"]
89
  tools_needed = ["search_tool"]
90
 
91
- if llm:
92
  prompt = ChatPromptTemplate.from_messages([
93
  SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
94
  Return JSON list, e.g., ["search_tool", "file_parser_tool"].
95
  Rules:
96
  - Always include "search_tool" unless purely computational.
97
- - Use "multi_hop_search_tool" for complex queries (over 20 words).
98
  - Use "file_parser_tool" for data, tables, or Excel.
99
  - Use "image_parser_tool" for images/videos.
100
  - Use "calculator_tool" for math calculations.
@@ -107,15 +150,27 @@ async def parse_question(state: Dict[str, Any]) -> Dict[str, Any]:
107
  HumanMessage(content=f"Query: {question}")
108
  ])
109
  try:
110
- response = llm.chat_completion(
111
- messages=[
112
- {"role": "system", "content": prompt[0].content},
113
- {"role": "user", "content": prompt[1].content}
114
- ],
115
- max_tokens=512,
116
- temperature=0.7
117
- )
118
- tools_needed = json.loads(response["choices"][0]["message"]["content"].strip())
 
 
 
 
 
 
 
 
 
 
 
 
119
  valid_tools = {
120
  "search_tool", "multi_hop_search_tool", "file_parser_tool", "image_parser_tool",
121
  "calculator_tool", "document_retriever_tool", "duckduckgo_search_tool",
@@ -123,165 +178,192 @@ async def parse_question(state: Dict[str, Any]) -> Dict[str, Any]:
123
  }
124
  tools_needed = [tool for tool in tools_needed if tool in valid_tools]
125
  except Exception as e:
126
- logger.warning(f"Task {task_id} failed: JSON parse error: {e}")
127
- tools_needed = ["search_tool"]
128
 
129
  # Keyword-based fallback
130
  question_lower = question.lower()
131
- if any(word in question_lower for word in ["image", "video"]):
132
  tools_needed.append("image_parser_tool")
133
- if any(word in question_lower for word in ["data", "table", "excel"]):
134
  tools_needed.append("file_parser_tool")
135
- if any(word in question_lower for word in ["calculate", "math"]):
136
  tools_needed.append("calculator_tool")
137
- if any(word in question_lower for word in ["document", "pdf"]):
138
  tools_needed.append("document_retriever_tool")
139
- if any(word in question_lower for word in ["weather"]):
140
  tools_needed.append("weather_info_tool")
141
- if any(word in question_lower for word in ["model", "huggingface"]):
142
  tools_needed.append("hub_stats_tool")
143
- if any(word in question_lower for word in ["guest", "name", "relation"]):
144
  tools_needed.append("guest_info_retriever_tool")
145
- if len(question.split()) > 20:
146
  tools_needed.append("multi_hop_search_tool")
147
-
148
- file_available, file_ext = await test_gaia_api(task_id)
149
- if file_available:
150
- if "file_parser_tool" not in tools_needed and any(word in question_lower for word in ["data", "table", "excel"]):
151
- tools_needed.append("file_parser_tool")
152
- if "image_parser_tool" not in tools_needed and "image" in question_lower:
153
- tools_needed.append("image_parser_tool")
154
- if "document_retriever_tool" not in tools_needed and file_ext == "pdf":
155
- tools_needed.append("document_retriever_tool")
156
- else:
157
- tools_needed = [tool for tool in tools_needed if tool not in ["file_parser_tool", "image_parser_tool", "document_retriever_tool"]]
158
-
159
- state["tools_needed"] = list(set(tools_needed)) # Remove duplicates
 
 
 
 
160
  logger.info(f"Task {task_id}: Selected tools: {tools_needed}")
161
  return state
162
  except Exception as e:
163
  logger.error(f"Error parsing task {task_id}: {e}")
 
164
  state["tools_needed"] = ["search_tool"]
165
  return state
166
 
 
167
  async def tool_dispatcher(state: JARVISState) -> JARVISState:
168
- """Dispatch selected tools to process the state."""
169
  try:
170
  updated_state = state.copy()
171
  file_type = "jpg" if "image" in state["question"].lower() else "txt"
172
- if "menu" in state["question"].lower() or "report" in state["question"].lower():
173
  file_type = "pdf"
174
  elif "data" in state["question"].lower():
175
  file_type = "xlsx"
176
 
177
- can_download, file_ext = await test_gaia_api(updated_state["task_id"], file_type)
178
-
179
  for tool in updated_state["tools_needed"]:
180
  try:
181
  if tool == "search_tool":
182
- result = await search_tool.ainvoke({"query": updated_state["question"]})
183
- updated_state["web_results"].extend([r["content"] for r in result])
184
  elif tool == "multi_hop_search_tool":
185
- result = await multi_hop_search_tool.ainvoke({"query": updated_state["question"], "steps": 3})
186
- updated_state["web_results"].extend([r["content"] for r in result])
187
- await asyncio.sleep(2) # Rate limit
188
- elif tool == "file_parser_tool" and can_download:
189
- result = await file_parser_tool.ainvoke({"task_id": updated_state["task_id"], "file_type": file_ext})
190
- updated_state["file_results"] = str(result)
191
- elif tool == "image_parser_tool" and can_download:
192
- result = await image_parser_tool.ainvoke({
193
- "file_path": f"temp_{updated_state['task_id']}.{file_ext}",
194
- "task": "describe"
195
- })
196
- updated_state["image_results"] = str(result)
 
 
 
197
  elif tool == "calculator_tool":
198
- result = await calculator_tool.ainvoke({"expression": updated_state.get("question", "")})
199
  updated_state["calculation_results"] = str(result)
200
- elif tool == "document_retriever_tool" and can_download:
201
- result = await document_retriever_tool.ainvoke({
202
- "task_id": updated_state["task_id"],
203
- "query": updated_state["question"],
204
- "file_type": file_ext
205
- })
206
- updated_state["document_results"] = str(result)
207
  elif tool == "duckduckgo_search_tool":
208
- result = await duckduckgo_search_tool.run(updated_state["question"])
209
  updated_state["web_results"].append(str(result))
210
  elif tool == "weather_info_tool":
211
  location = updated_state["question"].split("weather in ")[1].split()[0] if "weather in" in updated_state["question"].lower() else "Unknown"
212
- result = await weather_info_tool.ainvoke({"location": location})
213
  updated_state["web_results"].append(str(result))
214
  elif tool == "hub_stats_tool":
215
  author = updated_state["question"].split("by ")[1].split()[0] if "by" in updated_state["question"].lower() else "Unknown"
216
- result = await hub_stats_tool.ainvoke({"author": author})
217
  updated_state["web_results"].append(str(result))
218
  elif tool == "guest_info_retriever_tool":
219
  query = updated_state["question"].split("about ")[1] if "about" in updated_state["question"].lower() else updated_state["question"]
220
- result = await guest_info_retriever_tool.ainvoke({"query": query})
221
  updated_state["web_results"].append(str(result))
 
222
  except Exception as e:
223
  logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {str(e)}")
224
- updated_state[f"{tool}_results"] = f"Error: {str(e)}"
 
225
 
226
  logger.info(f"Task {updated_state['task_id']}: Tool results: {updated_state}")
227
  return updated_state
228
  except Exception as e:
229
  logger.error(f"Tool dispatch failed for task {state['task_id']}: {e}")
230
- return state
 
231
 
 
232
  async def reasoning(state: JARVISState) -> Dict[str, Any]:
233
- """Generate exact-match answer with specific formatting."""
234
  try:
235
- if not llm:
236
- return {"answer": "LLM unavailable"}
237
  prompt = ChatPromptTemplate.from_messages([
238
  SystemMessage(content="""Provide ONLY the exact answer (e.g., '90', 'HUE'). For USD, use two decimal places (e.g., '1234.00'). For lists, use comma-separated values (e.g., 'Smith, Lee'). For IOC codes, use three-letter codes (e.g., 'ARG'). No explanations or conversational text."""),
239
- HumanMessage(content="""Question: {question}
 
240
  Web results: {web_results}
 
241
  File results: {file_results}
242
  Image results: {image_results}
243
  Calculation results: {calculation_results}
244
  Document results: {document_results}""")
245
  ])
246
- response = llm.chat_completion(
247
- messages=[
248
- {"role": "system", "content": prompt[0].content},
249
- {"role": "user", "content": prompt[1].content.format(
250
- question=state["question"],
251
- web_results="\n".join(state["web_results"]),
252
- file_results=state["file_results"],
253
- image_results=state["image_results"],
254
- calculation_results=state["calculation_results"],
255
- document_results=state["document_results"]
256
- )}
257
- ],
258
- max_tokens=512,
259
- temperature=0.7
260
- )
261
- answer = response["choices"][0]["message"]["content"].strip()
262
- # Clean answer for specific formats
263
- if "USD" in state["question"].lower():
264
  try:
265
- answer = f"{float(answer):.2f}"
266
- except ValueError:
267
- pass
268
- if "before and after" in state["question"].lower():
269
- answer = answer.replace(" and ", ", ")
270
- elif "IOC code" in state["question"].lower():
271
- answer = answer.upper()[:3]
272
- logger.info(f"Task {state['task_id']}: Answer: {answer}")
273
- return {"answer": answer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  except Exception as e:
275
  logger.error(f"Reasoning failed for task {state['task_id']}: {e}")
 
276
  return {"answer": f"Error: {str(e)}"}
277
 
 
278
  def router(state: JARVISState) -> str:
279
- """Route based on tools needed."""
280
  if state["tools_needed"]:
281
  return "tool_dispatcher"
282
  return "reasoning"
283
 
284
- # --- Define StateGraph ---
285
  workflow = StateGraph(JARVISState)
286
  workflow.add_node("parse", parse_question)
287
  workflow.add_node("tool_dispatcher", tool_dispatcher)
@@ -299,33 +381,29 @@ workflow.add_edge("tool_dispatcher", "reasoning")
299
  workflow.add_edge("reasoning", END)
300
  graph = workflow.compile()
301
 
302
- # --- Basic Agent ---
303
- class BasicAgent:
304
  def __init__(self):
305
- logger.info("BasicAgent initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  async def process_question(self, task_id: str, question: str) -> str:
308
- """Process a single question with file handling."""
309
- file_type = "jpg" if "image" in question.lower() else "txt"
310
- if "menu" in question.lower() or "report" in question.lower():
311
- file_type = "pdf"
312
- elif "data" in question.lower():
313
- file_type = "xlsx"
314
-
315
- file_path = f"temp_{task_id}.{file_type}"
316
- file_available, file_ext = await test_gaia_api(task_id, file_type)
317
- if file_available:
318
- try:
319
- async with aiohttp.ClientSession() as session:
320
- async with session.get(f"{GAIA_FILE_URL}{task_id}.{file_ext}") as resp:
321
- if resp.status == 200:
322
- with open(file_path, "wb") as f:
323
- f.write(await resp.read())
324
- else:
325
- logger.warning(f"Failed to fetch file for {task_id}: HTTP {resp.status}")
326
- except Exception as e:
327
- logger.error(f"Error downloading file for task {task_id}: {str(e)}")
328
-
329
  state = JARVISState(
330
  task_id=task_id,
331
  question=question,
@@ -335,116 +413,98 @@ class BasicAgent:
335
  image_results="",
336
  calculation_results="",
337
  document_results="",
 
338
  messages=[HumanMessage(content=question)],
339
- answer=""
 
 
 
 
340
  )
341
  try:
342
  result = await graph.ainvoke(state)
343
  answer = result["answer"] or "Unknown"
344
- logger.info(f"Task {task_id}: Final answer generated: {answer}")
 
 
345
  return answer
346
  except Exception as e:
347
  logger.error(f"Error processing task {task_id}: {e}")
 
 
348
  return f"Error: {str(e)}"
349
  finally:
350
  for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
351
- file_path = f"temp_{task_id}.{ext}"
352
  if os.path.exists(file_path):
353
  try:
354
  os.remove(file_path)
 
355
  except Exception as e:
356
  logger.error(f"Error removing file {file_path}: {e}")
357
 
358
- async def async_call(self, question: str, task_id: str) -> str:
359
- return await self.process_question(question, task_id)
360
-
361
- def __call__(self, question: str, task_id: str = None) -> str:
362
- logger.info(f"Processing question: {question[:50]}...")
363
- if task_id is None:
364
- task_id = "unknown_task_id"
365
- try:
366
- loop = asyncio.get_event_loop()
367
- except RuntimeError:
368
- loop = asyncio.new_event_loop()
369
- asyncio.set_event_loop(loop)
370
- return loop.run_until_complete(self.async_call(question, task_id))
371
-
372
- # --- Evaluation and Submission ---
373
- def run_and_submit_all(profile: gr.OAuthProfile | None):
374
- """Run evaluation and submit answers to GAIA API."""
375
- if not profile:
376
- logger.error("User not logged in.")
377
- return "Please Login to Hugging Face.", None
378
- username = f"{profile.username}"
379
- logger.info(f"User logged in: {username}")
380
-
381
- questions_url = f"{GAIA_API_URL}/questions"
382
- submit_url = f"{GAIA_API_URL}/submit"
383
- agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
384
 
385
- try:
386
- agent = BasicAgent()
387
- except Exception as e:
388
- logger.error(f"Agent initialization failed: {e}")
389
- return f"Error initializing agent: {e}", None
390
 
391
- logger.info(f"Fetching questions from: {questions_url}")
392
- try:
393
- response = requests.get(questions_url, timeout=15)
394
- response.raise_for_status()
395
- questions_data = response.json()
396
- if not questions_data:
397
- logger.error("Empty questions list.")
398
- return "No questions fetched.", None
399
- logger.info(f"Fetched {len(questions_data)} questions.")
400
- except Exception as e:
401
- logger.error(f"Error fetching questions: {e}")
402
- return f"Error fetching questions: {e}", None
403
-
404
- results_log = []
405
- answers_payload = []
406
- logger.info(f"Processing {len(questions_data)} questions...")
407
- for item in questions_data:
408
- task_id = item.get("task_id")
409
- question_text = item.get("question")
410
- if not task_id or question_text is None:
411
- logger.warning(f"Skipping invalid item: {item}")
412
- continue
413
  try:
414
- submitted_answer = agent(question_text, task_id)
415
- answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
416
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  except Exception as e:
418
- logger.error(f"Error for task {task_id}: {e}")
419
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
 
420
 
421
- if not answers_payload:
422
- logger.error("No answers generated.")
423
- return "No answers to submit.", pd.DataFrame(results_log)
424
 
425
- submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
426
- logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}")
427
- try:
428
- response = requests.post(submit_url, json=submission_data, timeout=120)
429
- response.raise_for_status()
430
- result_data = response.json()
431
- final_status = (
432
- f"Submission Successful!\n"
433
- f"User: {result_data.get('username')}\n"
434
- f"Overall Score: {result_data.get('score', 'N/A')}% "
435
- f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
436
- f"Message: {result_data.get('message', 'No message received.')}"
437
- )
438
- results_df = pd.DataFrame(results_log)
439
- return final_status, results_df
440
- except Exception as e:
441
- logger.error(f"Submission failed: {e}")
442
- results_df = pd.DataFrame(results_log)
443
- return f"Submission Failed: {e}", results_df
444
-
445
- # --- Gradio Interface ---
446
  with gr.Blocks() as demo:
447
- gr.Markdown("# Evolved JARVIS Agent Evaluation")
448
  gr.Markdown(
449
  """
450
  **Instructions:**
@@ -454,23 +514,22 @@ with gr.Blocks() as demo:
454
 
455
  ---
456
  **Disclaimers:**
457
- Uses Hugging Face Inference, SERPAPI, and OpenWeatherMap for GAIA benchmark.
458
  """
459
  )
460
-
461
- gr.LoginButton()
462
-
463
  run_button = gr.Button("Run Evaluation & Submit All Answers")
464
-
465
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
466
- results_table = gr.DataFrame(label="Questions and Answers", wrap=True)
467
 
 
468
  run_button.click(
469
- fn=run_and_submit_all,
470
- outputs=[status_output, results_table]
471
  )
472
 
473
- # --- Main ---
474
  if __name__ == "__main__":
475
  logger.info("\n" + "-"*30 + " App Starting " + "-"*30)
476
  logger.info(f"SPACE_ID: {SPACE_ID}")
 
14
  import gradio as gr
15
  from dotenv import load_dotenv
16
  from huggingface_hub import InferenceClient
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM
18
  from state import JARVISState
19
  from tools import (
20
  search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool,
 
34
  SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
35
  GAIA_API_URL = "https://agents-course-unit4-scoring.hf.space"
36
  GAIA_FILE_URL = f"{GAIA_API_URL}/files/"
37
+ TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
38
+ HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
39
 
40
  # Verify environment variables
41
  if not SPACE_ID:
42
  raise ValueError("SPACE_ID not set")
43
+ if not HF_API_TOKEN:
44
  raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
45
+ if not TOGETHER_API_KEY:
46
+ raise ValueError("TOGETHER_API_KEY not set")
47
  logger.info(f"SPACE_ID: {SPACE_ID}")
48
 
49
+ # Model configuration
50
+ TOGETHER_MODELS = [
51
+ "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
52
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
53
+ ]
54
+ HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
55
+
56
+ # Initialize LLM clients
57
+ def initialize_llm():
58
+ for model in TOGETHER_MODELS:
59
+ try:
60
+ client = InferenceClient(
61
+ model=model,
62
+ api_key=TOGETHER_API_KEY,
63
+ base_url="https://api.together.ai/v1",
64
+ timeout=30
65
+ )
66
+ client.chat.completions.create(
67
+ model=model,
68
+ messages=[{"role": "user", "content": "Test"}],
69
+ max_tokens=10,
70
+ )
71
+ logger.info(f"Initialized Together AI model: {model}")
72
+ return client, "together"
73
+ except Exception as e:
74
+ logger.warning(f"Failed to initialize {model}: {e}")
75
+
76
+ try:
77
+ client = InferenceClient(
78
+ model=HF_MODEL,
79
+ token=HF_API_TOKEN,
80
+ timeout=30
81
+ )
82
+ logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
83
+ return client, "hf_api"
84
+ except Exception as e:
85
+ logger.warning(f"Failed to initialize HF Inference API: {e}")
86
+
87
+ try:
88
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
89
+ model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="mps")
90
+ logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
91
+ return (model, tokenizer), "hf_local"
92
+ except Exception as e:
93
+ logger.error(f"Failed to initialize local HF model: {e}")
94
+ raise Exception("No LLM could be initialized")
95
+
96
+ llm_client, llm_type = initialize_llm()
97
 
98
+ # Initialize embedder
99
  try:
100
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
101
  logger.info("Sentence transformer initialized")
 
103
  logger.error(f"Failed to initialize embedder: {e}")
104
  embedder = None
105
 
106
+ # Download file with local fallback
107
+ async def download_file(task_id: str, ext: str) -> str | None:
 
108
  try:
109
+ url = f"{GAIA_FILE_URL}{task_id}.{ext}"
110
+ async with aiohttp.ClientSession() as session:
111
+ async with session.get(url, timeout=10) as resp:
112
+ logger.info(f"GAIA API test for task {task_id} with .{ext}: HTTP {resp.status}")
113
+ if resp.status == 200:
114
+ os.makedirs("temp", exist_ok=True)
115
+ file_path = f"temp/{task_id}.{ext}"
116
+ with open(file_path, "wb") as f:
117
+ f.write(await resp.read())
118
+ return file_path
 
119
  except Exception as e:
120
+ logger.warning(f"File download failed for {task_id}.{ext}: {e}")
121
+ local_path = f"temp/{task_id}.{ext}"
122
+ if os.path.exists(local_path):
123
+ logger.info(f"Using local file: {local_path}")
124
+ return local_path
125
+ return None
126
+
127
+ # Parse question to select tools
128
+ async def parse_question(state: JARVISState) -> JARVISState:
129
  try:
130
  question = state["question"]
131
  task_id = state["task_id"]
132
  tools_needed = ["search_tool"]
133
 
134
+ if llm_client:
135
  prompt = ChatPromptTemplate.from_messages([
136
  SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
137
  Return JSON list, e.g., ["search_tool", "file_parser_tool"].
138
  Rules:
139
  - Always include "search_tool" unless purely computational.
140
+ - Use "multi_hop_search_tool" for complex queries (over 20 words or requiring multiple steps).
141
  - Use "file_parser_tool" for data, tables, or Excel.
142
  - Use "image_parser_tool" for images/videos.
143
  - Use "calculator_tool" for math calculations.
 
150
  HumanMessage(content=f"Query: {question}")
151
  ])
152
  try:
153
+ if llm_type == "hf_local":
154
+ model, tokenizer = llm_client
155
+ inputs = tokenizer.apply_chat_template(
156
+ [{"role": "system", "content": prompt[0].content}, {"role": "user", "content": prompt[1].content}],
157
+ return_tensors="pt"
158
+ ).to("mps")
159
+ outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
160
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
161
+ tools_needed = json.loads(response.strip())
162
+ else:
163
+ response = llm_client.chat.completions.create(
164
+ model=llm_client.model if llm_type == "together" else HF_MODEL,
165
+ messages=[
166
+ {"role": "system", "content": prompt[0].content},
167
+ {"role": "user", "content": prompt[1].content}
168
+ ],
169
+ max_tokens=512,
170
+ temperature=0.7
171
+ )
172
+ tools_needed = json.loads(response.choices[0].message.content.strip())
173
+
174
  valid_tools = {
175
  "search_tool", "multi_hop_search_tool", "file_parser_tool", "image_parser_tool",
176
  "calculator_tool", "document_retriever_tool", "duckduckgo_search_tool",
 
178
  }
179
  tools_needed = [tool for tool in tools_needed if tool in valid_tools]
180
  except Exception as e:
181
+ logger.warning(f"Task {task_id} tool selection failed: {e}")
182
+ state["error"] = f"Tool selection failed: {str(e)}"
183
 
184
  # Keyword-based fallback
185
  question_lower = question.lower()
186
+ if any(word in question_lower for word in ["image", "video", "picture"]):
187
  tools_needed.append("image_parser_tool")
188
+ if any(word in question_lower for word in ["data", "table", "excel", ".txt", ".csv", ".xlsx"]):
189
  tools_needed.append("file_parser_tool")
190
+ if any(word in question_lower for word in ["calculate", "math", "sum", "average", "total"]):
191
  tools_needed.append("calculator_tool")
192
+ if any(word in question_lower for word in ["document", "pdf", "report", "menu"]):
193
  tools_needed.append("document_retriever_tool")
194
+ if any(word in question_lower for word in ["weather", "temperature"]):
195
  tools_needed.append("weather_info_tool")
196
+ if any(word in question_lower for word in ["model", "huggingface", "dataset"]):
197
  tools_needed.append("hub_stats_tool")
198
+ if any(word in question_lower for word in ["guest", "name", "relation", "person"]):
199
  tools_needed.append("guest_info_retriever_tool")
200
+ if len(question.split()) > 20 or "multiple" in question_lower:
201
  tools_needed.append("multi_hop_search_tool")
202
+ if any(word in question_lower for word in ["search", "wikipedia", "online"]):
203
+ tools_needed.append("duckduckgo_search_tool")
204
+
205
+ # Check file availability
206
+ for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
207
+ file_path = await download_file(task_id, ext)
208
+ if file_path:
209
+ if ext in ["txt", "csv", "xlsx"] and "file_parser_tool" not in tools_needed:
210
+ tools_needed.append("file_parser_tool")
211
+ if ext == "jpg" and "image_parser_tool" not in tools_needed:
212
+ tools_needed.append("image_parser_tool")
213
+ if ext == "pdf" and "document_retriever_tool" not in tools_needed:
214
+ tools_needed.append("document_retriever_tool")
215
+ state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path}
216
+ break
217
+
218
+ state["tools_needed"] = list(set(tools_needed))
219
  logger.info(f"Task {task_id}: Selected tools: {tools_needed}")
220
  return state
221
  except Exception as e:
222
  logger.error(f"Error parsing task {task_id}: {e}")
223
+ state["error"] = f"Parse question failed: {str(e)}"
224
  state["tools_needed"] = ["search_tool"]
225
  return state
226
 
227
+ # Tool dispatcher
228
  async def tool_dispatcher(state: JARVISState) -> JARVISState:
 
229
  try:
230
  updated_state = state.copy()
231
  file_type = "jpg" if "image" in state["question"].lower() else "txt"
232
+ if any(word in state["question"].lower() for word in ["menu", "report"]):
233
  file_type = "pdf"
234
  elif "data" in state["question"].lower():
235
  file_type = "xlsx"
236
 
 
 
237
  for tool in updated_state["tools_needed"]:
238
  try:
239
  if tool == "search_tool":
240
+ result = search_tool(updated_state["question"])
241
+ updated_state["web_results"].extend([str(r) for r in result])
242
  elif tool == "multi_hop_search_tool":
243
+ result = await multi_hop_search_tool.ainvoke({"query": updated_state["question"], "steps": 3, "llm_client": llm_client, "llm_type": llm_type})
244
+ updated_state["multi_hop_results"].extend([r["content"] for r in result])
245
+ await asyncio.sleep(2)
246
+ elif tool == "file_parser_tool":
247
+ for ext in ["txt", "csv", "xlsx"]:
248
+ file_path = await download_file(updated_state["task_id"], ext)
249
+ if file_path:
250
+ result = file_parser_tool(file_path)
251
+ updated_state["file_results"] = str(result)
252
+ break
253
+ elif tool == "image_parser_tool":
254
+ file_path = await download_file(updated_state["task_id"], "jpg")
255
+ if file_path:
256
+ result = image_parser_tool(file_path)
257
+ updated_state["image_results"] = str(result)
258
  elif tool == "calculator_tool":
259
+ result = calculator_tool(updated_state["question"])
260
  updated_state["calculation_results"] = str(result)
261
+ elif tool == "document_retriever_tool":
262
+ file_path = await download_file(updated_state["task_id"], "pdf")
263
+ if file_path:
264
+ result = document_retriever_tool({"task_id": updated_state["task_id"], "query": updated_state["question"], "file_type": "pdf"})
265
+ updated_state["document_results"] = str(result)
 
 
266
  elif tool == "duckduckgo_search_tool":
267
+ result = duckduckgo_search_tool(updated_state["question"])
268
  updated_state["web_results"].append(str(result))
269
  elif tool == "weather_info_tool":
270
  location = updated_state["question"].split("weather in ")[1].split()[0] if "weather in" in updated_state["question"].lower() else "Unknown"
271
+ result = weather_info_tool({"location": location})
272
  updated_state["web_results"].append(str(result))
273
  elif tool == "hub_stats_tool":
274
  author = updated_state["question"].split("by ")[1].split()[0] if "by" in updated_state["question"].lower() else "Unknown"
275
+ result = hub_stats_tool({"author": author})
276
  updated_state["web_results"].append(str(result))
277
  elif tool == "guest_info_retriever_tool":
278
  query = updated_state["question"].split("about ")[1] if "about" in updated_state["question"].lower() else updated_state["question"]
279
+ result = guest_info_retriever_tool({"query": query})
280
  updated_state["web_results"].append(str(result))
281
+ updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_executed": True}
282
  except Exception as e:
283
  logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {str(e)}")
284
+ updated_state["error"] = f"Tool {tool} failed: {str(e)}"
285
+ updated_state["metadata"] = updated_state.get("metadata", {}) | {f"{tool}_error": str(e)}
286
 
287
  logger.info(f"Task {updated_state['task_id']}: Tool results: {updated_state}")
288
  return updated_state
289
  except Exception as e:
290
  logger.error(f"Tool dispatch failed for task {state['task_id']}: {e}")
291
+ updated_state["error"] = f"Tool dispatch failed: {str(e)}"
292
+ return updated_state
293
 
294
+ # Reasoning
295
  async def reasoning(state: JARVISState) -> Dict[str, Any]:
 
296
  try:
 
 
297
  prompt = ChatPromptTemplate.from_messages([
298
  SystemMessage(content="""Provide ONLY the exact answer (e.g., '90', 'HUE'). For USD, use two decimal places (e.g., '1234.00'). For lists, use comma-separated values (e.g., 'Smith, Lee'). For IOC codes, use three-letter codes (e.g., 'ARG'). No explanations or conversational text."""),
299
+ HumanMessage(content="""Task: {task_id}
300
+ Question: {question}
301
  Web results: {web_results}
302
+ Multi-hop results: {multi_hop_results}
303
  File results: {file_results}
304
  Image results: {image_results}
305
  Calculation results: {calculation_results}
306
  Document results: {document_results}""")
307
  ])
308
+ messages = [
309
+ {"role": "system", "content": prompt[0].content},
310
+ {"role": "user", "content": prompt[1].content.format(
311
+ task_id=state["task_id"],
312
+ question=state["question"],
313
+ web_results="\n".join(state["web_results"]),
314
+ multi_hop_results="\n".join(state["multi_hop_results"]),
315
+ file_results=state["file_results"],
316
+ image_results=state["image_results"],
317
+ calculation_results=state["calculation_results"],
318
+ document_results=state["document_results"]
319
+ )}
320
+ ]
321
+ for attempt in range(3):
 
 
 
 
322
  try:
323
+ if llm_type == "hf_local":
324
+ model, tokenizer = llm_client
325
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("mps")
326
+ outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
327
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
328
+ else:
329
+ response = llm_client.chat.completions.create(
330
+ model=llm_client.model if llm_type == "together" else HF_MODEL,
331
+ messages=messages,
332
+ max_tokens=512,
333
+ temperature=0.7
334
+ )
335
+ answer = response.choices[0].message.content.strip()
336
+
337
+ # Format answer
338
+ if "USD" in state["question"].lower():
339
+ try:
340
+ answer = f"{float(answer):.2f}"
341
+ except ValueError:
342
+ pass
343
+ if "before and after" in state["question"].lower():
344
+ answer = answer.replace(" and ", ", ")
345
+ if "IOC code" in state["question"].lower():
346
+ answer = answer.upper()[:3]
347
+
348
+ logger.info(f"Task {state['task_id']}: Answer: {answer}")
349
+ return {"answer": answer}
350
+ except Exception as e:
351
+ logger.warning(f"LLM retry {attempt + 1}/3 for task {state['task_id']}: {e}")
352
+ await asyncio.sleep(2)
353
+ state["error"] = "LLM failed after retries"
354
+ return {"answer": "Error: LLM failed after retries"}
355
  except Exception as e:
356
  logger.error(f"Reasoning failed for task {state['task_id']}: {e}")
357
+ state["error"] = f"Reasoning failed: {str(e)}"
358
  return {"answer": f"Error: {str(e)}"}
359
 
360
+ # Router
361
  def router(state: JARVISState) -> str:
 
362
  if state["tools_needed"]:
363
  return "tool_dispatcher"
364
  return "reasoning"
365
 
366
+ # Define StateGraph
367
  workflow = StateGraph(JARVISState)
368
  workflow.add_node("parse", parse_question)
369
  workflow.add_node("tool_dispatcher", tool_dispatcher)
 
381
  workflow.add_edge("reasoning", END)
382
  graph = workflow.compile()
383
 
384
+ # Agent class
385
+ class JARVISAgent:
386
  def __init__(self):
387
+ self.state = JARVISState(
388
+ task_id="",
389
+ question="",
390
+ tools_needed=[],
391
+ web_results=[],
392
+ file_results="",
393
+ image_results="",
394
+ calculation_results="",
395
+ document_results="",
396
+ multi_hop_results=[],
397
+ messages=[],
398
+ answer="",
399
+ results_table=[],
400
+ status_output="",
401
+ error=None,
402
+ metadata={}
403
+ )
404
+ logger.info("JARVISAgent initialized.")
405
 
406
  async def process_question(self, task_id: str, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  state = JARVISState(
408
  task_id=task_id,
409
  question=question,
 
413
  image_results="",
414
  calculation_results="",
415
  document_results="",
416
+ multi_hop_results=[],
417
  messages=[HumanMessage(content=question)],
418
+ answer="",
419
+ results_table=[],
420
+ status_output="",
421
+ error=None,
422
+ metadata={}
423
  )
424
  try:
425
  result = await graph.ainvoke(state)
426
  answer = result["answer"] or "Unknown"
427
+ logger.info(f"Task {task_id}: Final answer: {answer}")
428
+ self.state.results_table.append({"Task ID": task_id, "Question": question, "Answer": answer})
429
+ self.state.metadata = self.state.get("metadata", {}) | {"last_task": task_id, "answer": answer}
430
  return answer
431
  except Exception as e:
432
  logger.error(f"Error processing task {task_id}: {e}")
433
+ self.state.results_table.append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"})
434
+ self.state.error = f"Task {task_id} failed: {str(e)}"
435
  return f"Error: {str(e)}"
436
  finally:
437
  for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
438
+ file_path = f"temp/{task_id}.{ext}"
439
  if os.path.exists(file_path):
440
  try:
441
  os.remove(file_path)
442
+ logger.info(f"Removed temp file: {file_path}")
443
  except Exception as e:
444
  logger.error(f"Error removing file {file_path}: {e}")
445
 
446
+ async def process_all_questions(self, profile: gr.OAuthProfile | None):
447
+ if not profile:
448
+ logger.error("User not logged in.")
449
+ self.state.status_output = "Please Login to Hugging Face."
450
+ return pd.DataFrame(self.state.results_table), self.state.status_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
+ username = f"{profile.username}"
453
+ logger.info(f"User logged in: {username}")
454
+ questions_url = f"{GAIA_API_URL}/questions"
455
+ submit_url = f"{GAIA_API_URL}/submit"
456
+ agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  try:
459
+ response = requests.get(questions_url, timeout=15)
460
+ response.raise_for_status()
461
+ questions = response.json()
462
+ logger.info(f"Fetched {len(questions)} questions.")
463
+ except Exception as e:
464
+ logger.error(f"Error fetching questions: {e}")
465
+ self.state.status_output = f"Error fetching questions: {e}"
466
+ self.state.error = f"Fetch questions failed: {str(e)}"
467
+ return pd.DataFrame(self.state.results_table), self.state.status_output
468
+
469
+ answers_payload = []
470
+ for item in questions:
471
+ task_id = item.get("task_id")
472
+ question = item.get("question")
473
+ if not task_id or not question:
474
+ logger.warning(f"Skipping invalid item: {item}")
475
+ continue
476
+ answer = await self.process_question(task_id, question)
477
+ answers_payload.append({"task_id": task_id, "submitted_answer": answer})
478
+
479
+ if not answers_payload:
480
+ logger.error("No answers generated.")
481
+ self.state.status_output = "No answers to submit."
482
+ self.state.error = "No answers generated"
483
+ return pd.DataFrame(self.state.results_table), self.state.status_output
484
+
485
+ submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
486
+ try:
487
+ response = requests.post(submit_url, json=submission_data, timeout=120)
488
+ response.raise_for_status()
489
+ result_data = response.json()
490
+ self.state.status_output = (
491
+ f"Submission Successful!\n"
492
+ f"User: {result_data.get('username')}\n"
493
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
494
+ f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
495
+ f"Message: {result_data.get('message', 'No message received.')}"
496
+ )
497
+ self.state.metadata = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')}
498
  except Exception as e:
499
+ logger.error(f"Submission failed: {e}")
500
+ self.state.status_output = f"Submission Failed: {e}"
501
+ self.state.error = f"Submission failed: {str(e)}"
502
 
503
+ return pd.DataFrame(self.state.results_table), self.state.status_output
 
 
504
 
505
+ # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  with gr.Blocks() as demo:
507
+ gr.Markdown("# Evolved JARVIS GAIA Agent")
508
  gr.Markdown(
509
  """
510
  **Instructions:**
 
514
 
515
  ---
516
  **Disclaimers:**
517
+ Uses Hugging Face Inference, Together AI, SERPAPI, and OpenWeatherMap for GAIA benchmark.
518
  """
519
  )
520
+ with gr.Row():
521
+ gr.LoginButton()
522
+ gr.LogoutButton()
523
  run_button = gr.Button("Run Evaluation & Submit All Answers")
 
524
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
525
+ results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
526
 
527
+ agent = JARVISAgent()
528
  run_button.click(
529
+ fn=agent.process_all_questions,
530
+ outputs=[results_table, status_output]
531
  )
532
 
 
533
  if __name__ == "__main__":
534
  logger.info("\n" + "-"*30 + " App Starting " + "-"*30)
535
  logger.info(f"SPACE_ID: {SPACE_ID}")
requirements.txt CHANGED
@@ -4,6 +4,7 @@ pandas
4
  PyPDF2
5
  easyocr
6
  langchain
 
7
  langchain-community
8
  langgraph
9
  sentence-transformers
@@ -15,4 +16,8 @@ sympy
15
  openpyxl
16
  smolagents
17
  datasets
18
- asyncio
 
 
 
 
 
4
  PyPDF2
5
  easyocr
6
  langchain
7
+ langchain-core
8
  langchain-community
9
  langgraph
10
  sentence-transformers
 
16
  openpyxl
17
  smolagents
18
  datasets
19
+ transformers
20
+ asyncio
21
+ serpapi
22
+ duckduckgo-search
23
+ torch
result.txt ADDED
The diff for this file is too large to render. See raw diff
 
state.py CHANGED
@@ -1,7 +1,27 @@
1
- from typing import TypedDict, List
2
- from langchain_core.messages import AnyMessage
3
 
4
  class JARVISState(TypedDict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  task_id: str
6
  question: str
7
  tools_needed: List[str]
@@ -10,5 +30,10 @@ class JARVISState(TypedDict):
10
  image_results: str
11
  calculation_results: str
12
  document_results: str
13
- messages: List[AnyMessage]
14
- answer: str
 
 
 
 
 
 
1
+ from typing import TypedDict, List, Dict, Optional, Any
2
+ from langchain_core.messages import BaseMessage
3
 
4
  class JARVISState(TypedDict):
5
+ """
6
+ State dictionary for the JARVIS GAIA Agent, used with LangGraph to manage task processing.
7
+
8
+ Attributes:
9
+ task_id: Unique identifier for the GAIA task.
10
+ question: The question text to be answered.
11
+ tools_needed: List of tool names to be used for the task.
12
+ web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo).
13
+ file_results: Parsed content from text, CSV, or Excel files.
14
+ image_results: OCR or description results from image files.
15
+ calculation_results: Results from mathematical calculations.
16
+ document_results: Extracted content from PDF documents.
17
+ multi_hop_results: Results from iterative multi-hop searches.
18
+ messages: List of messages for LLM context (e.g., user prompts, system instructions).
19
+ answer: Final answer for the task, formatted for GAIA submission.
20
+ results_table: List of task results for Gradio display (Task ID, Question, Answer).
21
+ status_output: Status message for Gradio output (e.g., submission result).
22
+ error: Optional error message if task processing fails.
23
+ metadata: Optional metadata (e.g., timestamps, tool execution status).
24
+ """
25
  task_id: str
26
  question: str
27
  tools_needed: List[str]
 
30
  image_results: str
31
  calculation_results: str
32
  document_results: str
33
+ multi_hop_results: List[str]
34
+ messages: List[BaseMessage]
35
+ answer: str
36
+ results_table: List[Dict[str, str]]
37
+ status_output: str
38
+ error: Optional[str]
39
+ metadata: Optional[Dict[str, Any]]
test.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+
5
+ headers = {"Authorization": f"Bearer {os.getenv('TOGETHER_API_KEY')}"}
6
+ response = requests.get("https://api.together.ai/models", headers=headers)
7
+ print(response.json())
tools/search.py CHANGED
@@ -1,46 +1,104 @@
1
- from langchain_core.tools import tool
2
- import logging
3
- import requests
4
  import os
 
 
 
5
  from typing import List, Dict, Any
6
- from dotenv import load_dotenv
 
7
 
8
- logger = logging.getLogger(__name__)
9
- load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- @tool
12
- async def search_tool(query: str) -> List[Dict[str, Any]]:
13
- """Perform a web search using SERPAPI."""
14
- try:
15
- serpapi_key = os.getenv("SERPAPI_API_KEY")
16
- if not serpapi_key:
17
- logger.error("SERPAPI_API_KEY not set")
18
- return [{"content": "Search unavailable: API key missing", "url": ""}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- params = {"q": query, "api_key": serpapi_key}
21
- response = requests.get("https://serpapi.com/search", params=params, timeout=10)
22
- response.raise_for_status()
23
- results = response.json().get("organic_results", [])
24
- logger.info(f"Search results for query '{query}': {len(results)} items")
25
- search_results = [{"content": r.get("snippet", ""), "url": r.get("link", "")} for r in results]
26
- return search_results or [{"content": "No search results", "url": ""}]
27
- except Exception as e:
28
- logger.error(f"Error in search_tool: {e}")
29
- return [{"content": f"Search failed: {str(e)}", "url": ""}]
30
-
31
- @tool
32
- async def multi_hop_search_tool(query: str, steps: int = 3) -> List[Dict[str, Any]]:
33
- """Perform a multi-hop search."""
34
- try:
35
- results = []
36
- current_query = query
37
- for step in range(steps):
38
- step_results = await search_tool.invoke({"query": current_query})
39
- results.extend(step_results)
40
- current_query = f"{current_query} more details"
41
- logger.info(f"Multi-hop step {step + 1}: {current_query}")
42
- await asyncio.sleep(2) # Avoid rate limits
43
- return results or [{"content": "No multi-hop results", "url": ""}]
44
- except Exception as e:
45
- logger.error(f"Error in multi_hop_search_tool: {e}")
46
- return [{"content": f"Multi-hop search failed: {str(e)}", "url": ""}]
 
 
 
 
1
  import os
2
+ from serpapi import GoogleSearch
3
+ from langchain.tools import Tool
4
+ import asyncio
5
  from typing import List, Dict, Any
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_core.messages import SystemMessage, HumanMessage
8
 
9
+ def search_tool(query: str) -> List[str]:
10
+ """
11
+ Perform a web search using SERPAPI with retries.
12
+
13
+ Args:
14
+ query: Search query string.
15
+
16
+ Returns:
17
+ List of search result snippets.
18
+
19
+ Raises:
20
+ Exception: If search fails after retries.
21
+ """
22
+ params = {
23
+ "q": query,
24
+ "api_key": os.getenv("SERPAPI_API_KEY"),
25
+ "num": 5,
26
+ }
27
+
28
+ for attempt in range(3):
29
+ try:
30
+ search = GoogleSearch(params, timeout=30)
31
+ results = search.get_dict()
32
+ organic_results = results.get("organic_results", [])
33
+ return [r.get("snippet", "") for r in organic_results]
34
+ except Exception as e:
35
+ print(f"INFO - SERPAPI retry {attempt + 1}/3 due to: {e}")
36
+ asyncio.sleep(2)
37
+
38
+ raise Exception("SERPAPI failed after retries")
39
 
40
+ async def multi_hop_search_tool(query: str, steps: int = 3, llm_client: Any = None, llm_type: str = None) -> List[Dict[str, str]]:
41
+ """
42
+ Perform iterative web searches for complex queries, refining the query using an LLM.
43
+
44
+ Args:
45
+ query: Initial search query.
46
+ steps: Number of search iterations.
47
+ llm_client: LLM client for query refinement.
48
+ llm_type: Type of LLM client ("together", "hf_api", or "hf_local").
49
+
50
+ Returns:
51
+ List of dictionaries containing search result content.
52
+ """
53
+ results = []
54
+ current_query = query
55
+
56
+ for step in range(steps):
57
+ try:
58
+ # Perform search
59
+ search_results = search_tool(current_query)
60
+ results.extend([{"content": str(r)} for r in search_results])
61
+
62
+ # Refine query using LLM if available
63
+ if llm_client and step < steps - 1:
64
+ prompt = ChatPromptTemplate.from_messages([
65
+ SystemMessage(content="""Refine the following query to dig deeper into the topic, focusing on missing details or related aspects. Return ONLY the refined query as plain text, no explanations."""),
66
+ HumanMessage(content=f"Original query: {current_query}\nPrevious results: {json.dumps(search_results[:2], indent=2)}")
67
+ ])
68
+ messages = [
69
+ {"role": "system", "content": prompt[0].content},
70
+ {"role": "user", "content": prompt[1].content}
71
+ ]
72
+
73
+ try:
74
+ if llm_type == "hf_local":
75
+ model, tokenizer = llm_client
76
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("mps")
77
+ outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7)
78
+ refined_query = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
79
+ else:
80
+ response = llm_client.chat.completions.create(
81
+ model=llm_client.model if llm_type == "together" else "meta-llama/Llama-3.2-1B-Instruct",
82
+ messages=messages,
83
+ max_tokens=100,
84
+ temperature=0.7
85
+ )
86
+ refined_query = response.choices[0].message.content.strip()
87
+
88
+ current_query = refined_query if refined_query else f"more details on {current_query}"
89
+ except Exception as e:
90
+ print(f"INFO - Query refinement failed at step {step + 1}: {e}")
91
+ current_query = f"more details on {current_query}"
92
+
93
+ await asyncio.sleep(1) # Rate limit
94
+ except Exception as e:
95
+ print(f"INFO - Multi-hop search step {step + 1} failed: {e}")
96
+ break
97
+
98
+ return results
99
 
100
+ multi_hop_search_tool = Tool.from_function(
101
+ func=multi_hop_search_tool,
102
+ name="multi_hop_search_tool",
103
+ description="Performs iterative web searches for complex queries, refining the query with an LLM."
104
+ )