Spaces:
Running
Running
""" | |
Enhanced tools for GAIA Agent including Wikipedia search, file processing, and web browsing. | |
""" | |
import os | |
import json | |
import requests | |
import wikipedia | |
from typing import Optional, Dict, Any, List | |
from langchain.tools import Tool | |
from langchain_community.tools import DuckDuckGoSearchRun | |
import pandas as pd | |
from PIL import Image | |
import PyPDF2 | |
from bs4 import BeautifulSoup | |
from io import BytesIO | |
class WikipediaSearchTool: | |
"""Tool for searching Wikipedia with better error handling and content extraction.""" | |
def __init__(self): | |
wikipedia.set_lang("en") | |
def search_wikipedia(self, query: str, max_results: int = 3) -> str: | |
""" | |
Search Wikipedia for information and return a summary. | |
Args: | |
query: Search query string | |
max_results: Maximum number of results to return | |
Returns: | |
Formatted string with search results | |
""" | |
try: | |
# Search for pages | |
search_results = wikipedia.search(query, results=max_results) | |
if not search_results: | |
return f"No Wikipedia articles found for query: '{query}'" | |
results = [] | |
for title in search_results[:max_results]: | |
try: | |
# Get page summary | |
page = wikipedia.page(title) | |
summary = wikipedia.summary(title, sentences=3) | |
results.append({ | |
"title": page.title, | |
"url": page.url, | |
"summary": summary | |
}) | |
except wikipedia.exceptions.DisambiguationError as e: | |
# Handle disambiguation by taking the first option | |
try: | |
page = wikipedia.page(e.options[0]) | |
summary = wikipedia.summary(e.options[0], sentences=3) | |
results.append({ | |
"title": page.title, | |
"url": page.url, | |
"summary": summary | |
}) | |
except: | |
continue | |
except wikipedia.exceptions.PageError: | |
continue | |
except Exception as e: | |
continue | |
if not results: | |
return f"Could not retrieve information for query: '{query}'" | |
# Format results | |
formatted_results = f"Wikipedia search results for '{query}':\n\n" | |
for i, result in enumerate(results, 1): | |
formatted_results += f"{i}. **{result['title']}**\n" | |
formatted_results += f" URL: {result['url']}\n" | |
formatted_results += f" Summary: {result['summary']}\n\n" | |
return formatted_results | |
except Exception as e: | |
return f"Error searching Wikipedia: {str(e)}" | |
class FileProcessorTool: | |
"""Tool for processing various file formats (PDF, Excel, images, etc.).""" | |
def process_file(self, file_path: str) -> str: | |
""" | |
Process different file types and extract content/information. | |
Args: | |
file_path: Path to the file to process | |
Returns: | |
Extracted content or file information | |
""" | |
try: | |
if not os.path.exists(file_path): | |
return f"File not found: {file_path}" | |
file_extension = os.path.splitext(file_path)[1].lower() | |
if file_extension == '.pdf': | |
return self._process_pdf(file_path) | |
elif file_extension in ['.xlsx', '.xls', '.csv']: | |
return self._process_spreadsheet(file_path) | |
elif file_extension in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']: | |
return self._process_image(file_path) | |
elif file_extension == '.txt': | |
return self._process_text(file_path) | |
else: | |
return f"Unsupported file type: {file_extension}" | |
except Exception as e: | |
return f"Error processing file {file_path}: {str(e)}" | |
def _process_pdf(self, file_path: str) -> str: | |
"""Extract text from PDF file.""" | |
try: | |
with open(file_path, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
text_content = "" | |
for page_num in range(len(pdf_reader.pages)): | |
page = pdf_reader.pages[page_num] | |
text_content += page.extract_text() + "\n" | |
return f"PDF Content from {file_path}:\n{text_content[:2000]}..." if len(text_content) > 2000 else f"PDF Content from {file_path}:\n{text_content}" | |
except Exception as e: | |
return f"Error reading PDF: {str(e)}" | |
def _process_spreadsheet(self, file_path: str) -> str: | |
"""Process Excel/CSV files and extract data information.""" | |
try: | |
if file_path.endswith('.csv'): | |
df = pd.read_csv(file_path) | |
else: | |
df = pd.read_excel(file_path) | |
info = f"Spreadsheet Analysis for {file_path}:\n" | |
info += f"Shape: {df.shape[0]} rows, {df.shape[1]} columns\n" | |
info += f"Columns: {', '.join(df.columns.tolist())}\n\n" | |
# Show first few rows | |
info += "First 5 rows:\n" | |
info += df.head().to_string() + "\n\n" | |
# Basic statistics for numeric columns | |
numeric_cols = df.select_dtypes(include=['number']).columns | |
if len(numeric_cols) > 0: | |
info += "Numeric column statistics:\n" | |
info += df[numeric_cols].describe().to_string() + "\n\n" | |
# Calculate totals if there are numeric columns | |
if len(numeric_cols) > 0: | |
info += "Column totals:\n" | |
for col in numeric_cols: | |
total = df[col].sum() | |
info += f"{col}: {total}\n" | |
return info | |
except Exception as e: | |
return f"Error reading spreadsheet: {str(e)}" | |
def _process_image(self, file_path: str) -> str: | |
"""Process image files and return basic information.""" | |
try: | |
with Image.open(file_path) as img: | |
info = f"Image Analysis for {file_path}:\n" | |
info += f"Size: {img.size[0]} x {img.size[1]} pixels\n" | |
info += f"Mode: {img.mode}\n" | |
info += f"Format: {img.format}\n" | |
# Note: For GAIA tasks, you might need OCR or more advanced image analysis | |
info += "\nNote: For text extraction from images, OCR tools would be needed.\n" | |
return info | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
def _process_text(self, file_path: str) -> str: | |
"""Process text files.""" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
content = file.read() | |
return f"Text file content from {file_path}:\n{content[:2000]}..." if len(content) > 2000 else f"Text file content from {file_path}:\n{content}" | |
except Exception as e: | |
return f"Error reading text file: {str(e)}" | |
class EnhancedWebSearchTool: | |
"""Enhanced web search tool with better result processing.""" | |
def __init__(self): | |
self.search_tool = DuckDuckGoSearchRun() | |
def search_web(self, query: str, max_results: int = 5) -> str: | |
""" | |
Perform web search with enhanced result processing. | |
Args: | |
query: Search query string | |
max_results: Maximum number of results to consider | |
Returns: | |
Formatted search results | |
""" | |
try: | |
results = self.search_tool.invoke(query) | |
# Process and format results better | |
formatted_results = f"Web search results for '{query}':\n\n" | |
formatted_results += results | |
return formatted_results | |
except Exception as e: | |
return f"Error performing web search: {str(e)}" | |
class CalculationTool: | |
"""Tool for performing calculations and data analysis.""" | |
def calculate(self, expression: str) -> str: | |
""" | |
Safely evaluate mathematical expressions. | |
Args: | |
expression: Mathematical expression to evaluate | |
Returns: | |
Result of the calculation | |
""" | |
try: | |
# Only allow safe mathematical operations | |
allowed_chars = set('0123456789+-*/().% ') | |
if not all(c in allowed_chars for c in expression.replace(' ', '')): | |
return f"Invalid characters in expression: {expression}" | |
result = eval(expression) | |
return f"Calculation result: {expression} = {result}" | |
except Exception as e: | |
return f"Error in calculation: {str(e)}" | |
def create_gaia_tools() -> List[Tool]: | |
"""Create all tools needed for GAIA benchmark tasks.""" | |
# Initialize tool classes | |
wiki_tool = WikipediaSearchTool() | |
file_tool = FileProcessorTool() | |
web_tool = EnhancedWebSearchTool() | |
calc_tool = CalculationTool() | |
# Create LangChain Tool objects | |
tools = [ | |
Tool( | |
name="wikipedia_search", | |
func=wiki_tool.search_wikipedia, | |
description="Search Wikipedia for information about any topic. Use this for factual information, historical data, scientific concepts, etc. Input should be a clear search query." | |
), | |
Tool( | |
name="file_processor", | |
func=file_tool.process_file, | |
description="Process and analyze files (PDF, Excel, images, text). Input should be the file path. Returns file content, data analysis, or file information." | |
), | |
Tool( | |
name="web_search", | |
func=web_tool.search_web, | |
description="Search the web for current information, news, or specific websites. Use this when you need up-to-date information not available in Wikipedia." | |
), | |
Tool( | |
name="calculator", | |
func=calc_tool.calculate, | |
description="Perform mathematical calculations. Input should be a mathematical expression using +, -, *, /, (), %. Example: '(100 + 50) * 0.15'" | |
) | |
] | |
return tools | |
if __name__ == "__main__": | |
# Test the tools | |
tools = create_gaia_tools() | |
# Test Wikipedia search | |
wiki_result = tools[0].func("Artificial Intelligence") | |
print("Wikipedia test:") | |
print(wiki_result[:500] + "...\n") | |
# Test web search | |
web_result = tools[2].func("latest AI news 2024") | |
print("Web search test:") | |
print(web_result[:500] + "...\n") | |
# Test calculator | |
calc_result = tools[3].func("(100 + 50) * 2") | |
print("Calculator test:") | |
print(calc_result) |