Spaces:
Starting
Starting
from langchain_core.tools import tool | |
from langchain_huggingface import HuggingFacePipeline | |
from sentence_transformers import SentenceTransformer | |
import logging | |
from typing import List, Dict, Any | |
import requests | |
import os | |
logger = logging.getLogger(__name__) | |
# Initialize embedding model (free, open-source) | |
try: | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
except Exception as e: | |
logger.error(f"Failed to initialize embedding model: {e}") | |
embedder = None | |
# Global LLM instance | |
search_llm = None | |
def initialize_search_tools(llm: HuggingFacePipeline) -> None: | |
"""Initialize search tools with the provided LLM""" | |
global search_llm | |
search_llm = llm | |
logger.info("Search tools initialized with HuggingFace LLM") | |
async def search_tool(query: str) -> List[Dict[str, Any]]: | |
"""Perform a web search using the query""" | |
try: | |
if not search_llm: | |
logger.warning("Search LLM not initialized") | |
return [{"content": "Search unavailable", "url": ""}] | |
# Refine query using LLM | |
prompt = f"Refine this search query for better results: {query}" | |
response = await search_llm.ainvoke(prompt) | |
refined_query = response.content.strip() | |
# Check for SerpAPI key (free tier available) | |
serpapi_key = os.getenv("SERPAPI_API_KEY") | |
if serpapi_key: | |
try: | |
params = {"q": refined_query, "api_key": serpapi_key} | |
response = requests.get("https://serpapi.com/search", params=params) | |
response.raise_for_status() | |
results = response.json().get("organic_results", []) | |
return [{"content": r.get("snippet", ""), "url": r.get("link", "")} for r in results] | |
except Exception as e: | |
logger.warning(f"SerpAPI failed: {e}, falling back to mock search") | |
# Mock search if no API key or API fails | |
if embedder: | |
query_embedding = embedder.encode(refined_query) | |
results = [ | |
{"content": f"Mock result for {refined_query}", "url": "https://example.com"}, | |
{"content": f"Another mock result for {refined_query}", "url": "https://example.org"} | |
] | |
else: | |
results = [{"content": "Embedding model unavailable", "url": ""}] | |
logger.info(f"Search results for query '{refined_query}': {len(results)} items") | |
return results | |
except Exception as e: | |
logger.error(f"Error in search_tool: {e}") | |
return [{"content": f"Search failed: {str(e)}", "url": ""}] | |
async def multi_hop_search_tool(query: str, steps: int = 3) -> List[Dict[str, Any]]: | |
"""Perform a multi-hop search by iteratively refining the query""" | |
try: | |
if not search_llm: | |
logger.warning("Search LLM not initialized") | |
return [{"content": "Multi-hop search unavailable", "url": ""}] | |
results = [] | |
current_query = query | |
for step in range(steps): | |
prompt = f"Based on the query '{current_query}', generate a follow-up question to deepen the search." | |
response = await search_llm.ainvoke(prompt) | |
next_query = response.content.strip() | |
step_results = await search_tool.invoke({"query": next_query}) | |
results.extend(step_results) | |
current_query = next_query | |
logger.info(f"Multi-hop step {step + 1}: {next_query}") | |
return results | |
except Exception as e: | |
logger.error(f"Error in multi_hop_search_tool: {e}") | |
return [{"content": f"Multi-hop search failed: {str(e)}", "url": ""}] |