GAIA-Agent / tools.py
Essi
perf: enhance search functionality and update prompt guidelines for clarity
2f99f18
raw
history blame
10.3 kB
import ast
import json
import operator
import re
import subprocess
from base64 import b64encode
from functools import lru_cache
from io import BytesIO
from tempfile import NamedTemporaryFile
import numpy as np
import pandas as pd
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from youtube_transcript_api import YouTubeTranscriptApi
from helpers import get_prompt, print_debug_trace
# --------------------------------------------------------------------------- #
# ARITHMETIC (SAFE CALCULATOR) #
# --------------------------------------------------------------------------- #
_ALLOWED_AST_OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.USub: operator.neg,
}
def _safe_eval(node: ast.AST) -> float | int | complex:
"""Recursively evaluate a *restricted* AST expression tree."""
if isinstance(node, ast.Constant):
return node.n
if isinstance(node, ast.UnaryOp) and type(node.op) in _ALLOWED_AST_OPS:
return _ALLOWED_AST_OPS[type(node.op)](_safe_eval(node.operand))
if isinstance(node, ast.BinOp) and type(node.op) in _ALLOWED_AST_OPS:
return _ALLOWED_AST_OPS[type(node.op)](
_safe_eval(node.left), _safe_eval(node.right)
)
raise ValueError("Unsafe or unsupported expression")
@tool
def calculator(expression: str) -> str:
"""Safely evaluate basic arithmetic expressions (no variables, functions)."""
try:
tree = ast.parse(expression, mode="eval")
value = _safe_eval(tree.body)
return str(value)
except Exception as exc:
print_debug_trace(exc, "Calculator")
return f"calc_error:{exc}"
# --------------------------------------------------------------------------- #
# WEB & WIKI SEARCH #
# --------------------------------------------------------------------------- #
@lru_cache(maxsize=256)
def _ddg_search(query: str, k: int = 6) -> list[dict[str, str]]:
"""Cached DuckDuckGo JSON search."""
wrapper = DuckDuckGoSearchAPIWrapper(max_results=k)
hits = wrapper.results(query)
return [
{
"title": hit.get("title", "")[:500],
"snippet": hit.get("snippet", "")[:750],
"link": hit.get("link", "")[:300],
}
for hit in hits[:k]
]
@tool
def web_multi_search(query: str, k: int = 6) -> str:
"""Run DuckDuckGo → Tavily fallback search. Returns JSON list[dict]."""
try:
hits = _ddg_search(query, k)
if hits:
return json.dumps(hits, ensure_ascii=False)
except Exception: # fall through to Tavily
pass
try:
tavily_results = TavilySearchResults(
max_results=5,
# include_answer=True,
# search_depth="advanced",
)
search_result = tavily_results.invoke({"query": query})
print(
f"[TOOL] TAVILY search is triggered with following response: {search_result}"
)
formatted = [
{
"title": d.get("title", "")[:500],
"snippet": d.get("content", "")[:750],
"link": d.get("url", "")[:300],
}
for d in search_result
]
return json.dumps(formatted, ensure_ascii=False)
except Exception as exc:
print_debug_trace(exc, "Multi Search")
return f"search_error:{exc}"
@tool
def wiki_search(query: str, max_pages: int = 2) -> str:
"""Lightweight wrapper on WikipediaLoader; returns concatenated page texts."""
print(f"[TOOL] wiki_search called with query: {query}")
docs = WikipediaLoader(query=query, load_max_docs=max_pages).load()
joined = "\n\n---\n\n".join(d.page_content for d in docs)
return joined[:8_000] # simple guardrail – stay within context window
# --------------------------------------------------------------------------- #
# YOUTUBE TRANSCRIPT #
# --------------------------------------------------------------------------- #
@tool
def youtube_transcript(url: str, chars: int = 10_000) -> str:
"""Fetch full YouTube transcript (first *chars* characters)."""
video_id_match = re.search(r"[?&]v=([A-Za-z0-9_\-]{11})", url)
if not video_id_match:
return "yt_error:id_not_found"
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id_match.group(1))
text = " ".join(piece["text"] for piece in transcript)
return text[:chars]
except Exception as exc:
print_debug_trace(exc, "YouTube")
return f"yt_error:{exc}"
# --------------------------------------------------------------------------- #
# IMAGE DESCRIPTION #
# --------------------------------------------------------------------------- #
# Instantiate a lightweight CLIP‑based zero‑shot image classifier (runs on CPU)
### The model 'openai/clip-vit-base-patch32' is a vision transformer (ViT) model trained as part of OpenAI’s CLIP project.
### It performs zero-shot image classification by mapping images and labels into the same embedding space.
# _image_pipe = pipeline(
# "image-classification", model="openai/clip-vit-base-patch32", device="cpu"
# )
# @tool
# def image_describe(img_bytes: bytes, top_k: int = 3) -> str:
# """Return the top-k CLIP labels for an image supplied as raw bytes.
# typical result for a random cat photo can be:
# [
# {'label': 'tabby, tabby cat', 'score': 0.41},
# {'label': 'tiger cat', 'score': 0.24},
# {'label': 'Egyptian cat', 'score': 0.22}
# ]
# """
# try:
# labels = _image_pipe(BytesIO(img_bytes))[:top_k]
# return ", ".join(f"{d['label']} (score={d['score']:.2f})" for d in labels)
# except Exception as exc:
# return f"img_error:{exc}"
@tool
def vision_task(img_bytes: bytes, question: str) -> str:
"""
Pass the user's question AND the referenced image to a multimodal LLM and
return its first line of text as the answer. No domain assumptions made.
"""
vision_llm = ChatOpenAI(
model="gpt-4o-mini", # set OPENAI_API_KEY in env
temperature=0,
max_tokens=64,
)
try:
b64 = b64encode(img_bytes).decode()
messages = [
SystemMessage(content=get_prompt(prompt_key="vision_system")),
HumanMessage(
content=[
{"type": "text", "text": question.strip()},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{b64}"},
},
]
),
]
reply = vision_llm.invoke(messages).content.strip()
return reply
except Exception as exc:
print_debug_trace(exc, "vision")
return f"img_error:{exc}"
# --------------------------------------------------------------------------- #
# FILE UTILS #
# --------------------------------------------------------------------------- #
@tool
def run_py(code: str) -> str:
"""Execute Python code in a sandboxed subprocess and return last stdout line."""
try:
with NamedTemporaryFile(delete=False, suffix=".py", mode="w") as f:
f.write(code)
path = f.name
proc = subprocess.run(
["python", path], capture_output=True, text=True, timeout=45
)
out = proc.stdout.strip().splitlines()
return out[-1] if out else ""
except Exception as exc:
print_debug_trace(exc, "run_py")
return f"py_error:{exc}"
@tool
def transcribe_via_whisper(audio_bytes: bytes) -> str:
"""Transcribe audio with Whisper (CPU)."""
with NamedTemporaryFile(suffix=".mp3", delete=False) as f:
f.write(audio_bytes)
path = f.name
try:
import whisper # openai-whisper
model = whisper.load_model("base")
output = model.transcribe(path)["text"].strip()
print(f"[DEBUG] Whisper transcript (first 200 chars): {output[:200]}")
return output
except Exception as exc:
print_debug_trace(exc, "Whisper")
return f"asr_error:{exc}"
@tool
def analyze_excel_file(xls_bytes: bytes, question: str) -> str:
"Analyze Excel or CSV file by passing the data preview to LLM and getting the Python Pandas operation to run"
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=64)
try:
df = pd.read_excel(BytesIO(xls_bytes))
except Exception:
df = pd.read_csv(BytesIO(xls_bytes))
for col in df.select_dtypes(include="number").columns:
df[col] = df[col].astype(float)
# Ask the LLM for a single expression
prompt = get_prompt(
prompt_key="excel_system",
question=question,
preview=df.head(5).to_dict(orient="list"),
)
expr = llm.invoke(prompt).content.strip()
# Run generated Pandas' one-line expression
try:
result = eval(expr, {"df": df, "pd": pd, "__builtins__": {}})
# Normalize scalars to string
if isinstance(result, np.generic):
result = float(result) # → plain Python float
return f"{result:.2f}" # or str(result) if no decimals needed
# DataFrame / Series → single-line string
return (
result.to_string(index=False)
if hasattr(result, "to_string")
else str(result)
)
except Exception as e:
print_debug_trace(e, "Excel")
return f"eval_error:{e}"
__all__ = [
"calculator",
"web_multi_search",
"wiki_search",
"youtube_transcript",
"vision_task",
"run_py",
"transcribe_via_whisper",
"analyze_excel_file",
]