File size: 6,630 Bytes
c922f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
"""
arXiv search tool for the GAIA agent.
This module provides a tool for searching academic papers on arXiv.org.
It uses the arxiv Python package to access the arXiv library and retrieve
paper information.
The tool handles API responses and errors, and formats results in a
consistent way for use by the GAIA agent.
"""
import logging
import time
import traceback
from typing import Dict, Any, List, Optional, Union, Tuple
from datetime import datetime
try:
import arxiv
except ImportError:
arxiv = None
from src.gaia.agent.config import get_tool_config
logger = logging.getLogger("gaia_agent.tools.arxiv")
class ArxivSearchTool:
"""Tool for searching academic papers on arXiv."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize the arXiv search tool.
Args:
config: Optional configuration dictionary
"""
self.config = config or get_tool_config().get("arxiv", {})
self.max_results = self.config.get("max_results", 3)
if arxiv is None:
logger.warning("arXiv package not installed. Install with: pip install arxiv")
def search(self, query: str, max_results: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Search arXiv for papers matching the query.
Args:
query: The search query
max_results: Maximum number of results to return (overrides config)
Returns:
List of paper information dictionaries
Raises:
Exception: If an error occurs during the search
"""
if arxiv is None:
raise ImportError("arXiv package not installed. Install with: pip install arxiv")
max_results = max_results or self.max_results
try:
client = arxiv.Client()
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance
)
results = list(client.results(search))
formatted_results = []
for paper in results:
published = paper.published
if published:
published_str = published.strftime("%Y-%m-%d")
else:
published_str = "Unknown"
authors = [author.name for author in paper.authors]
authors_str = ", ".join(authors)
formatted_result = {
"title": paper.title,
"authors": authors_str,
"summary": paper.summary,
"published": published_str,
"url": paper.entry_id,
"pdf_url": paper.pdf_url,
"categories": paper.categories,
"comment": paper.comment,
"journal_ref": paper.journal_ref,
"doi": paper.doi
}
formatted_results.append(formatted_result)
return formatted_results
except Exception as e:
logger.error(f"Error searching arXiv: {str(e)}")
logger.error(traceback.format_exc())
raise Exception(f"arXiv search failed: {str(e)}")
def get_paper_by_id(self, paper_id: str) -> Dict[str, Any]:
"""
Get a specific paper by its arXiv ID.
Args:
paper_id: The arXiv ID (e.g., "2103.00020")
Returns:
Dictionary containing paper information
Raises:
Exception: If an error occurs during the retrieval
"""
if arxiv is None:
raise ImportError("arXiv package not installed. Install with: pip install arxiv")
try:
client = arxiv.Client()
search = arxiv.Search(id_list=[paper_id])
results = list(client.results(search))
if not results:
raise ValueError(f"No paper found with ID: {paper_id}")
paper = results[0]
published = paper.published
if published:
published_str = published.strftime("%Y-%m-%d")
else:
published_str = "Unknown"
authors = [author.name for author in paper.authors]
authors_str = ", ".join(authors)
result = {
"title": paper.title,
"authors": authors_str,
"summary": paper.summary,
"published": published_str,
"url": paper.entry_id,
"pdf_url": paper.pdf_url,
"categories": paper.categories,
"comment": paper.comment,
"journal_ref": paper.journal_ref,
"doi": paper.doi
}
return result
except Exception as e:
logger.error(f"Error getting arXiv paper {paper_id}: {str(e)}")
logger.error(traceback.format_exc())
raise Exception(f"arXiv paper retrieval failed: {str(e)}")
def search_by_category(self, category: str, max_results: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Search for recent papers in a specific category.
Args:
category: The arXiv category (e.g., "cs.AI")
max_results: Maximum number of results to return (overrides config)
Returns:
List of paper information dictionaries
Raises:
Exception: If an error occurs during the search
"""
query = f"cat:{category}"
return self.search(query, max_results)
def create_arxiv_search() -> ArxivSearchTool:
"""Create an arXiv search tool instance."""
return ArxivSearchTool()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
arxiv_tool = create_arxiv_search()
try:
results = arxiv_tool.search("large language models")
for i, paper in enumerate(results):
print(f"{i+1}. {paper.get('title')}")
print(f" URL: {paper.get('url')}")
print(f" Published: {paper.get('published')}")
print()
except Exception as e:
print(f"Error searching arXiv: {str(e)}")
|