onisj's picture
Rewrite app.py and search.py with multi-hop LLM refinement
c6951f4
raw
history blame
4.27 kB
import os
from serpapi import GoogleSearch
from langchain.tools import Tool
import asyncio
from typing import List, Dict, Any
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, HumanMessage
def search_tool(query: str) -> List[str]:
"""
Perform a web search using SERPAPI with retries.
Args:
query: Search query string.
Returns:
List of search result snippets.
Raises:
Exception: If search fails after retries.
"""
params = {
"q": query,
"api_key": os.getenv("SERPAPI_API_KEY"),
"num": 5,
}
for attempt in range(3):
try:
search = GoogleSearch(params, timeout=30)
results = search.get_dict()
organic_results = results.get("organic_results", [])
return [r.get("snippet", "") for r in organic_results]
except Exception as e:
print(f"INFO - SERPAPI retry {attempt + 1}/3 due to: {e}")
asyncio.sleep(2)
raise Exception("SERPAPI failed after retries")
async def multi_hop_search_tool(query: str, steps: int = 3, llm_client: Any = None, llm_type: str = None) -> List[Dict[str, str]]:
"""
Perform iterative web searches for complex queries, refining the query using an LLM.
Args:
query: Initial search query.
steps: Number of search iterations.
llm_client: LLM client for query refinement.
llm_type: Type of LLM client ("together", "hf_api", or "hf_local").
Returns:
List of dictionaries containing search result content.
"""
results = []
current_query = query
for step in range(steps):
try:
# Perform search
search_results = search_tool(current_query)
results.extend([{"content": str(r)} for r in search_results])
# Refine query using LLM if available
if llm_client and step < steps - 1:
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="""Refine the following query to dig deeper into the topic, focusing on missing details or related aspects. Return ONLY the refined query as plain text, no explanations."""),
HumanMessage(content=f"Original query: {current_query}\nPrevious results: {json.dumps(search_results[:2], indent=2)}")
])
messages = [
{"role": "system", "content": prompt[0].content},
{"role": "user", "content": prompt[1].content}
]
try:
if llm_type == "hf_local":
model, tokenizer = llm_client
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("mps")
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7)
refined_query = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
else:
response = llm_client.chat.completions.create(
model=llm_client.model if llm_type == "together" else "meta-llama/Llama-3.2-1B-Instruct",
messages=messages,
max_tokens=100,
temperature=0.7
)
refined_query = response.choices[0].message.content.strip()
current_query = refined_query if refined_query else f"more details on {current_query}"
except Exception as e:
print(f"INFO - Query refinement failed at step {step + 1}: {e}")
current_query = f"more details on {current_query}"
await asyncio.sleep(1) # Rate limit
except Exception as e:
print(f"INFO - Multi-hop search step {step + 1} failed: {e}")
break
return results
multi_hop_search_tool = Tool.from_function(
func=multi_hop_search_tool,
name="multi_hop_search_tool",
description="Performs iterative web searches for complex queries, refining the query with an LLM."
)