GAIA-Agent / helpers.py
Essi
refactor: rename fetch_task_file to fetch_task_attachment for clarity in attachment handling
2d7820e
import csv
from io import BytesIO
from pathlib import Path
from sys import stderr
from traceback import print_exception
from zipfile import BadZipFile, ZipFile
import requests
from yaml import safe_load
CURRENT_DIR = Path(__file__).parent
_PROMPTS = safe_load(CURRENT_DIR.joinpath("prompts.yaml").read_text())
def fetch_task_attachment(api_url: str, task_id: str) -> tuple[bytes, str]:
"""
Returns (file_bytes, content_type) or (b'', '') if no attachment found.
Follows any redirect the endpoint issues.
"""
url = f"{api_url}/files/{task_id}"
try:
r = requests.get(url, timeout=15, allow_redirects=True)
except requests.RequestException as e:
print(f"[DEBUG] GET {url} failed → {e}")
return b"", ""
if r.status_code != 200:
print(f"[DEBUG] GET {url}{r.status_code}")
return b"", ""
return r.content, r.headers.get("content-type", "").lower()
def sniff_excel_type(blob: bytes) -> str:
"""
Return one of 'xlsx', 'xls', 'csv', or '' (unknown) given raw bytes.
"""
# 1️⃣ XLSX / XLSM / ODS (ZIP container)
if blob[:4] == b"PK\x03\x04":
try:
with ZipFile(BytesIO(blob)) as zf:
names = set(zf.namelist())
if {"xl/workbook.xml", "[Content_Types].xml"} & names:
return "xlsx"
except BadZipFile:
pass # fall through
# 2️⃣ Legacy XLS (OLE Compound File)
if blob[:8] == b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1":
return "xls"
# 3️⃣ Text-like -> CSV/TSV
try:
sample = blob[:1024].decode("utf-8", "ignore")
first_line = sample.splitlines()[0]
if any(sep in first_line for sep in (",", ";", "\t")):
# Confirm via csv.Sniffer to avoid random text
csv.Sniffer().sniff(sample)
return "csv"
except (UnicodeDecodeError, csv.Error):
pass
return ""
def get_prompt(prompt_key: str, **kwargs: str) -> str:
"""Get a prompt by key and fill in placeholders via `.format(**kwargs)`"""
return _PROMPTS[prompt_key].format(**kwargs)
def print_debug_trace(err: Exception, label: str = "") -> None:
"""
Print the full stack trace of `err` to STDERR so it shows up in HF logs.
"""
banner = f"[TRACE {label}]" if label else "[TRACE]"
print(banner, file=stderr)
print_exception(type(err), err, err.__traceback__, file=stderr)
print("-" * 60, file=stderr)