JoachimVC's picture
Upload GAIA agent implementation files for assessment
c922f8b
"""
arXiv search tool for the GAIA agent.
This module provides a tool for searching academic papers on arXiv.org.
It uses the arxiv Python package to access the arXiv library and retrieve
paper information.
The tool handles API responses and errors, and formats results in a
consistent way for use by the GAIA agent.
"""
import logging
import time
import traceback
from typing import Dict, Any, List, Optional, Union, Tuple
from datetime import datetime
try:
import arxiv
except ImportError:
arxiv = None
from src.gaia.agent.config import get_tool_config
logger = logging.getLogger("gaia_agent.tools.arxiv")
class ArxivSearchTool:
"""Tool for searching academic papers on arXiv."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize the arXiv search tool.
Args:
config: Optional configuration dictionary
"""
self.config = config or get_tool_config().get("arxiv", {})
self.max_results = self.config.get("max_results", 3)
if arxiv is None:
logger.warning("arXiv package not installed. Install with: pip install arxiv")
def search(self, query: str, max_results: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Search arXiv for papers matching the query.
Args:
query: The search query
max_results: Maximum number of results to return (overrides config)
Returns:
List of paper information dictionaries
Raises:
Exception: If an error occurs during the search
"""
if arxiv is None:
raise ImportError("arXiv package not installed. Install with: pip install arxiv")
max_results = max_results or self.max_results
try:
client = arxiv.Client()
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance
)
results = list(client.results(search))
formatted_results = []
for paper in results:
published = paper.published
if published:
published_str = published.strftime("%Y-%m-%d")
else:
published_str = "Unknown"
authors = [author.name for author in paper.authors]
authors_str = ", ".join(authors)
formatted_result = {
"title": paper.title,
"authors": authors_str,
"summary": paper.summary,
"published": published_str,
"url": paper.entry_id,
"pdf_url": paper.pdf_url,
"categories": paper.categories,
"comment": paper.comment,
"journal_ref": paper.journal_ref,
"doi": paper.doi
}
formatted_results.append(formatted_result)
return formatted_results
except Exception as e:
logger.error(f"Error searching arXiv: {str(e)}")
logger.error(traceback.format_exc())
raise Exception(f"arXiv search failed: {str(e)}")
def get_paper_by_id(self, paper_id: str) -> Dict[str, Any]:
"""
Get a specific paper by its arXiv ID.
Args:
paper_id: The arXiv ID (e.g., "2103.00020")
Returns:
Dictionary containing paper information
Raises:
Exception: If an error occurs during the retrieval
"""
if arxiv is None:
raise ImportError("arXiv package not installed. Install with: pip install arxiv")
try:
client = arxiv.Client()
search = arxiv.Search(id_list=[paper_id])
results = list(client.results(search))
if not results:
raise ValueError(f"No paper found with ID: {paper_id}")
paper = results[0]
published = paper.published
if published:
published_str = published.strftime("%Y-%m-%d")
else:
published_str = "Unknown"
authors = [author.name for author in paper.authors]
authors_str = ", ".join(authors)
result = {
"title": paper.title,
"authors": authors_str,
"summary": paper.summary,
"published": published_str,
"url": paper.entry_id,
"pdf_url": paper.pdf_url,
"categories": paper.categories,
"comment": paper.comment,
"journal_ref": paper.journal_ref,
"doi": paper.doi
}
return result
except Exception as e:
logger.error(f"Error getting arXiv paper {paper_id}: {str(e)}")
logger.error(traceback.format_exc())
raise Exception(f"arXiv paper retrieval failed: {str(e)}")
def search_by_category(self, category: str, max_results: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Search for recent papers in a specific category.
Args:
category: The arXiv category (e.g., "cs.AI")
max_results: Maximum number of results to return (overrides config)
Returns:
List of paper information dictionaries
Raises:
Exception: If an error occurs during the search
"""
query = f"cat:{category}"
return self.search(query, max_results)
def create_arxiv_search() -> ArxivSearchTool:
"""Create an arXiv search tool instance."""
return ArxivSearchTool()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
arxiv_tool = create_arxiv_search()
try:
results = arxiv_tool.search("large language models")
for i, paper in enumerate(results):
print(f"{i+1}. {paper.get('title')}")
print(f" URL: {paper.get('url')}")
print(f" Published: {paper.get('published')}")
print()
except Exception as e:
print(f"Error searching arXiv: {str(e)}")