|
""" |
|
arXiv search tool for the GAIA agent. |
|
|
|
This module provides a tool for searching academic papers on arXiv.org. |
|
It uses the arxiv Python package to access the arXiv library and retrieve |
|
paper information. |
|
|
|
The tool handles API responses and errors, and formats results in a |
|
consistent way for use by the GAIA agent. |
|
""" |
|
|
|
import logging |
|
import time |
|
import traceback |
|
from typing import Dict, Any, List, Optional, Union, Tuple |
|
from datetime import datetime |
|
|
|
try: |
|
import arxiv |
|
except ImportError: |
|
arxiv = None |
|
|
|
from src.gaia.agent.config import get_tool_config |
|
|
|
logger = logging.getLogger("gaia_agent.tools.arxiv") |
|
|
|
class ArxivSearchTool: |
|
"""Tool for searching academic papers on arXiv.""" |
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None): |
|
""" |
|
Initialize the arXiv search tool. |
|
|
|
Args: |
|
config: Optional configuration dictionary |
|
""" |
|
self.config = config or get_tool_config().get("arxiv", {}) |
|
self.max_results = self.config.get("max_results", 3) |
|
|
|
if arxiv is None: |
|
logger.warning("arXiv package not installed. Install with: pip install arxiv") |
|
|
|
def search(self, query: str, max_results: Optional[int] = None) -> List[Dict[str, Any]]: |
|
""" |
|
Search arXiv for papers matching the query. |
|
|
|
Args: |
|
query: The search query |
|
max_results: Maximum number of results to return (overrides config) |
|
|
|
Returns: |
|
List of paper information dictionaries |
|
|
|
Raises: |
|
Exception: If an error occurs during the search |
|
""" |
|
|
|
if arxiv is None: |
|
raise ImportError("arXiv package not installed. Install with: pip install arxiv") |
|
|
|
max_results = max_results or self.max_results |
|
|
|
try: |
|
client = arxiv.Client() |
|
|
|
search = arxiv.Search( |
|
query=query, |
|
max_results=max_results, |
|
sort_by=arxiv.SortCriterion.Relevance |
|
) |
|
|
|
results = list(client.results(search)) |
|
|
|
|
|
formatted_results = [] |
|
for paper in results: |
|
published = paper.published |
|
if published: |
|
published_str = published.strftime("%Y-%m-%d") |
|
else: |
|
published_str = "Unknown" |
|
|
|
authors = [author.name for author in paper.authors] |
|
authors_str = ", ".join(authors) |
|
|
|
formatted_result = { |
|
"title": paper.title, |
|
"authors": authors_str, |
|
"summary": paper.summary, |
|
"published": published_str, |
|
"url": paper.entry_id, |
|
"pdf_url": paper.pdf_url, |
|
"categories": paper.categories, |
|
"comment": paper.comment, |
|
"journal_ref": paper.journal_ref, |
|
"doi": paper.doi |
|
} |
|
|
|
formatted_results.append(formatted_result) |
|
|
|
return formatted_results |
|
|
|
except Exception as e: |
|
logger.error(f"Error searching arXiv: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
raise Exception(f"arXiv search failed: {str(e)}") |
|
|
|
def get_paper_by_id(self, paper_id: str) -> Dict[str, Any]: |
|
""" |
|
Get a specific paper by its arXiv ID. |
|
|
|
Args: |
|
paper_id: The arXiv ID (e.g., "2103.00020") |
|
|
|
Returns: |
|
Dictionary containing paper information |
|
|
|
Raises: |
|
Exception: If an error occurs during the retrieval |
|
""" |
|
|
|
if arxiv is None: |
|
raise ImportError("arXiv package not installed. Install with: pip install arxiv") |
|
|
|
try: |
|
client = arxiv.Client() |
|
|
|
search = arxiv.Search(id_list=[paper_id]) |
|
|
|
results = list(client.results(search)) |
|
|
|
if not results: |
|
raise ValueError(f"No paper found with ID: {paper_id}") |
|
|
|
paper = results[0] |
|
|
|
published = paper.published |
|
if published: |
|
published_str = published.strftime("%Y-%m-%d") |
|
else: |
|
published_str = "Unknown" |
|
|
|
authors = [author.name for author in paper.authors] |
|
authors_str = ", ".join(authors) |
|
|
|
result = { |
|
"title": paper.title, |
|
"authors": authors_str, |
|
"summary": paper.summary, |
|
"published": published_str, |
|
"url": paper.entry_id, |
|
"pdf_url": paper.pdf_url, |
|
"categories": paper.categories, |
|
"comment": paper.comment, |
|
"journal_ref": paper.journal_ref, |
|
"doi": paper.doi |
|
} |
|
|
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Error getting arXiv paper {paper_id}: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
raise Exception(f"arXiv paper retrieval failed: {str(e)}") |
|
|
|
def search_by_category(self, category: str, max_results: Optional[int] = None) -> List[Dict[str, Any]]: |
|
""" |
|
Search for recent papers in a specific category. |
|
|
|
Args: |
|
category: The arXiv category (e.g., "cs.AI") |
|
max_results: Maximum number of results to return (overrides config) |
|
|
|
Returns: |
|
List of paper information dictionaries |
|
|
|
Raises: |
|
Exception: If an error occurs during the search |
|
""" |
|
|
|
query = f"cat:{category}" |
|
|
|
return self.search(query, max_results) |
|
|
|
def create_arxiv_search() -> ArxivSearchTool: |
|
"""Create an arXiv search tool instance.""" |
|
return ArxivSearchTool() |
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
arxiv_tool = create_arxiv_search() |
|
|
|
try: |
|
results = arxiv_tool.search("large language models") |
|
for i, paper in enumerate(results): |
|
print(f"{i+1}. {paper.get('title')}") |
|
print(f" URL: {paper.get('url')}") |
|
print(f" Published: {paper.get('published')}") |
|
print() |
|
except Exception as e: |
|
print(f"Error searching arXiv: {str(e)}") |
|
|