File size: 2,657 Bytes
3758a5b
bc2a697
9608ecf
bc2a697
3758a5b
9608ecf
3758a5b
 
 
 
 
9608ecf
bc2a697
9608ecf
bc2a697
9608ecf
bc2a697
3758a5b
bc2a697
3758a5b
9608ecf
3758a5b
bc2a697
 
 
 
3758a5b
bc2a697
3758a5b
bc2a697
3758a5b
bc2a697
3758a5b
 
bc2a697
3758a5b
bc2a697
3758a5b
 
bc2a697
3758a5b
 
bc2a697
3758a5b
 
bc2a697
 
9608ecf
bc2a697
 
3758a5b
bc2a697
 
3758a5b
bc2a697
 
 
 
 
 
 
3758a5b
bc2a697
 
 
 
 
3758a5b
bc2a697
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from typing import Dict, List
from smolagents.tools import Tool
from langchain_community.tools import TavilySearchResults
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class WebSearchTool(Tool):
    name = "web_search"
    description = "Search the web for a query using Tavily API and return the top 3 most relevant results."
    inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}}
    output_type = "object"

    def __init__(self, max_results: int = 3, **kwargs):
        """
        Initialize the web search tool with Tavily API
        """
        super().__init__()
        
        # Verify Tavily API key is set
        if not os.getenv('TAVILY_API_KEY'):
            logger.error("TAVILY_API_KEY environment variable not set")
            raise ValueError("TAVILY_API_KEY environment variable must be set")
        
        # Initialize Tavily search
        try:
            self.search_tool = TavilySearchResults(max_results=max_results)
        except Exception as e:
            logger.error(f"Failed to initialize Tavily search: {str(e)}")
            raise

    def forward(self, query: str) -> Dict[str, str]:
        """
        Search Tavily for a query and return maximum 3 results.
        
        Args:
            query (str): The search query to perform.
            
        Returns:
            Dict[str, str]: Dictionary containing formatted search results
        """
        try:
            if not query:
                raise ValueError("Search query cannot be empty")

            # Perform search using Tavily
            search_results = self.search_tool.invoke({"query": query})
            
            if not search_results:
                return {"web_results": "No search results found"}
            
            if isinstance(search_results, List):
                # Format the results
                formatted_results = "\n\n---\n\n".join(
                    [
                        f'<Document source="{result.get("url", "")}" page=""/>\n{result.get("content", "")}\n</Document>'
                        for result in search_results
                    ]
                )
            else:
                logger.warning(f"Unexpected search results format: {type(search_results)}")
                return {"web_results": "Unexpected search results format"}
            
            return {"web_results": formatted_results}
            
        except Exception as e:
            error_msg = f"Search failed: {str(e)}"
            logger.error(error_msg)
            return {"web_results": error_msg}