Spaces:
Sleeping
Sleeping
Essi
commited on
Commit
·
e879014
1
Parent(s):
5b89395
feat: enhance file handling capabilities with support for code execution, Excel analysis, and audio transcription
Browse files- app.py +38 -8
- requirements.txt +1 -0
- tools.py +64 -15
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
-
from typing import Literal, TypedDict
|
4 |
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
@@ -10,8 +10,11 @@ from langchain_openai import ChatOpenAI
|
|
10 |
from langgraph.graph import END, StateGraph
|
11 |
|
12 |
from tools import (
|
|
|
13 |
calculator,
|
14 |
image_describe,
|
|
|
|
|
15 |
web_multi_search,
|
16 |
wiki_search,
|
17 |
youtube_transcript,
|
@@ -33,7 +36,7 @@ If the answer is numeric, output digits only (no commas, units, or words).
|
|
33 |
# QUESTION CLASSIFIER #
|
34 |
# --------------------------------------------------------------------------- #
|
35 |
|
36 |
-
_LABELS =
|
37 |
|
38 |
_CLASSIFY_PROMPT = """You are a router that labels the user question with exactly one of the following categories:
|
39 |
{labels}.
|
@@ -50,7 +53,7 @@ Label:
|
|
50 |
# --------------------------------------------------------------------------- #
|
51 |
class AgentState(TypedDict):
|
52 |
question: str
|
53 |
-
label:
|
54 |
context: str
|
55 |
answer: str
|
56 |
confidence: float
|
@@ -68,9 +71,12 @@ _llm_answer = ChatOpenAI(model=MODEL_NAME)
|
|
68 |
def classify(state: AgentState) -> AgentState: # noqa: D401
|
69 |
"""Label the task so we know which toolchain to invoke."""
|
70 |
question = state["question"]
|
|
|
|
|
|
|
71 |
resp = (
|
72 |
_llm_router.invoke(
|
73 |
-
_CLASSIFY_PROMPT.format(question=question, labels=
|
74 |
)
|
75 |
.content.strip()
|
76 |
.lower()
|
@@ -80,11 +86,36 @@ def classify(state: AgentState) -> AgentState: # noqa: D401
|
|
80 |
|
81 |
|
82 |
def gather_context(state: AgentState) -> AgentState:
|
83 |
-
question, label = state["question"], state["label"]
|
84 |
|
85 |
matched_pattern = r"https?://\S+"
|
86 |
matched_obj = re.search(matched_pattern, question)
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
if label == "math":
|
89 |
print("[TOOL] calculator")
|
90 |
expr = re.sub(r"\s+", "", question)
|
@@ -109,9 +140,8 @@ def gather_context(state: AgentState) -> AgentState:
|
|
109 |
|
110 |
|
111 |
def generate_answer(state: AgentState) -> AgentState:
|
112 |
-
#
|
113 |
-
if state["label"]
|
114 |
-
state["answer"] = state["context"].strip()
|
115 |
state["confidence"] = 0.9
|
116 |
return state
|
117 |
|
|
|
1 |
import os
|
2 |
import re
|
3 |
+
from typing import Literal, TypedDict, get_args
|
4 |
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
|
|
10 |
from langgraph.graph import END, StateGraph
|
11 |
|
12 |
from tools import (
|
13 |
+
analyze_excel_file,
|
14 |
calculator,
|
15 |
image_describe,
|
16 |
+
run_py,
|
17 |
+
transcribe_via_whisper,
|
18 |
web_multi_search,
|
19 |
wiki_search,
|
20 |
youtube_transcript,
|
|
|
36 |
# QUESTION CLASSIFIER #
|
37 |
# --------------------------------------------------------------------------- #
|
38 |
|
39 |
+
_LABELS = Literal["math", "youtube", "image", "code", "excel", "audio", "general"]
|
40 |
|
41 |
_CLASSIFY_PROMPT = """You are a router that labels the user question with exactly one of the following categories:
|
42 |
{labels}.
|
|
|
53 |
# --------------------------------------------------------------------------- #
|
54 |
class AgentState(TypedDict):
|
55 |
question: str
|
56 |
+
label: _LABELS
|
57 |
context: str
|
58 |
answer: str
|
59 |
confidence: float
|
|
|
71 |
def classify(state: AgentState) -> AgentState: # noqa: D401
|
72 |
"""Label the task so we know which toolchain to invoke."""
|
73 |
question = state["question"]
|
74 |
+
|
75 |
+
values = get_args(_LABELS) # -> ("math", "youtube", ...)
|
76 |
+
parsed_labels = ", ".join(repr(v) for v in values)
|
77 |
resp = (
|
78 |
_llm_router.invoke(
|
79 |
+
_CLASSIFY_PROMPT.format(question=question, labels=parsed_labels)
|
80 |
)
|
81 |
.content.strip()
|
82 |
.lower()
|
|
|
86 |
|
87 |
|
88 |
def gather_context(state: AgentState) -> AgentState:
|
89 |
+
question, label, task_id = state["question"], state["label"], state["task_id"]
|
90 |
|
91 |
matched_pattern = r"https?://\S+"
|
92 |
matched_obj = re.search(matched_pattern, question)
|
93 |
|
94 |
+
# ---- attachment detection ------------------------------------------------
|
95 |
+
if task_id:
|
96 |
+
file_url = f"{DEFAULT_API_URL}/files/{task_id}"
|
97 |
+
head = requests.head(file_url, timeout=10)
|
98 |
+
ctype = head.headers.get("content-type", "")
|
99 |
+
|
100 |
+
print(f"[DEBUG] attachment type={ctype} | url={file_url}")
|
101 |
+
if "python" in ctype or file_url.endswith(".py"):
|
102 |
+
code = requests.get(file_url, timeout=10).text
|
103 |
+
state["answer"] = run_py.invoke({"code": code})
|
104 |
+
state["label"] = "code"
|
105 |
+
return state
|
106 |
+
if "excel" in ctype or file_url.endswith((".xlsx", ".csv")):
|
107 |
+
blob = requests.get(file_url, timeout=10).content
|
108 |
+
state["context"] = analyze_excel_file.invoke(
|
109 |
+
{"xls_bytes": blob, "question": question}
|
110 |
+
)
|
111 |
+
state["label"] = "excel"
|
112 |
+
return state
|
113 |
+
if "audio" in ctype or file_url.endswith(".mp3"):
|
114 |
+
blob = requests.get(file_url, timeout=10).content
|
115 |
+
state["context"] = transcribe_via_whisper.invoke({"mp3_bytes": blob})
|
116 |
+
state["label"] = "audio"
|
117 |
+
return state
|
118 |
+
|
119 |
if label == "math":
|
120 |
print("[TOOL] calculator")
|
121 |
expr = re.sub(r"\s+", "", question)
|
|
|
140 |
|
141 |
|
142 |
def generate_answer(state: AgentState) -> AgentState:
|
143 |
+
# Skip LLM for deterministic labels
|
144 |
+
if state["label"] in {"math", "code", "excel"}:
|
|
|
145 |
state["confidence"] = 0.9
|
146 |
return state
|
147 |
|
requirements.txt
CHANGED
@@ -20,6 +20,7 @@ wikipedia==1.4.0 # WikipediaLoader
|
|
20 |
youtube-transcript-api==1.0.3 # YouTube transcripts
|
21 |
openpyxl==3.1.5 # Excel parsing when GAIA attaches .xlsx
|
22 |
Pillow>=10.2.0 # image handling for transformers
|
|
|
23 |
|
24 |
# ── Lightweight vision model
|
25 |
transformers>=4.41.2
|
|
|
20 |
youtube-transcript-api==1.0.3 # YouTube transcripts
|
21 |
openpyxl==3.1.5 # Excel parsing when GAIA attaches .xlsx
|
22 |
Pillow>=10.2.0 # image handling for transformers
|
23 |
+
openai-whisper=20240930
|
24 |
|
25 |
# ── Lightweight vision model
|
26 |
transformers>=4.41.2
|
tools.py
CHANGED
@@ -2,8 +2,10 @@ import ast
|
|
2 |
import json
|
3 |
import operator
|
4 |
import re
|
|
|
5 |
from functools import lru_cache
|
6 |
from io import BytesIO
|
|
|
7 |
|
8 |
import requests
|
9 |
from langchain_community.document_loaders import WikipediaLoader
|
@@ -46,7 +48,7 @@ def calculator(expression: str) -> str:
|
|
46 |
tree = ast.parse(expression, mode="eval")
|
47 |
value = _safe_eval(tree.body)
|
48 |
return str(value)
|
49 |
-
except Exception as exc:
|
50 |
return f"calc_error:{exc}"
|
51 |
|
52 |
|
@@ -62,9 +64,9 @@ def _ddg_search(query: str, k: int = 6) -> list[dict[str, str]]:
|
|
62 |
hits = wrapper.results(query)
|
63 |
return [
|
64 |
{
|
65 |
-
"title": hit.get("title", "")[:
|
66 |
-
"snippet": hit.get("snippet", "")[:
|
67 |
-
"link": hit.get("link", "")[:
|
68 |
}
|
69 |
for hit in hits[:k]
|
70 |
]
|
@@ -87,9 +89,9 @@ def web_multi_search(query: str, k: int = 6) -> str:
|
|
87 |
)
|
88 |
formatted = [
|
89 |
{
|
90 |
-
"title": d.metadata.get("title", "")[:
|
91 |
-
"snippet": d.page_content[:
|
92 |
-
"link": d.metadata.get("source", "")[:
|
93 |
}
|
94 |
for d in tavily_hits
|
95 |
]
|
@@ -156,16 +158,61 @@ def image_describe(image_url: str, top_k: int = 3) -> str:
|
|
156 |
|
157 |
|
158 |
@tool
|
159 |
-
def
|
160 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
try:
|
162 |
-
import
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
167 |
except Exception as exc:
|
168 |
-
return f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
|
171 |
__all__ = [
|
@@ -174,5 +221,7 @@ __all__ = [
|
|
174 |
"wiki_search",
|
175 |
"youtube_transcript",
|
176 |
"image_describe",
|
177 |
-
"
|
|
|
|
|
178 |
]
|
|
|
2 |
import json
|
3 |
import operator
|
4 |
import re
|
5 |
+
import subprocess
|
6 |
from functools import lru_cache
|
7 |
from io import BytesIO
|
8 |
+
from tempfile import NamedTemporaryFile
|
9 |
|
10 |
import requests
|
11 |
from langchain_community.document_loaders import WikipediaLoader
|
|
|
48 |
tree = ast.parse(expression, mode="eval")
|
49 |
value = _safe_eval(tree.body)
|
50 |
return str(value)
|
51 |
+
except Exception as exc:
|
52 |
return f"calc_error:{exc}"
|
53 |
|
54 |
|
|
|
64 |
hits = wrapper.results(query)
|
65 |
return [
|
66 |
{
|
67 |
+
"title": hit.get("title", "")[:500],
|
68 |
+
"snippet": hit.get("snippet", "")[:750],
|
69 |
+
"link": hit.get("link", "")[:300],
|
70 |
}
|
71 |
for hit in hits[:k]
|
72 |
]
|
|
|
89 |
)
|
90 |
formatted = [
|
91 |
{
|
92 |
+
"title": d.metadata.get("title", "")[:500],
|
93 |
+
"snippet": d.page_content[:750],
|
94 |
+
"link": d.metadata.get("source", "")[:300],
|
95 |
}
|
96 |
for d in tavily_hits
|
97 |
]
|
|
|
158 |
|
159 |
|
160 |
@tool
|
161 |
+
def run_py(code: str) -> str:
|
162 |
+
"""Execute Python code in a sandboxed subprocess and return last stdout line."""
|
163 |
+
try:
|
164 |
+
with NamedTemporaryFile(delete=False, suffix=".py", mode="w") as f:
|
165 |
+
f.write(code)
|
166 |
+
path = f.name
|
167 |
+
proc = subprocess.run(
|
168 |
+
["python", path], capture_output=True, text=True, timeout=4
|
169 |
+
)
|
170 |
+
out = proc.stdout.strip().splitlines()
|
171 |
+
return out[-1] if out else ""
|
172 |
+
except Exception as exc:
|
173 |
+
return f"py_error:{exc}"
|
174 |
+
|
175 |
+
|
176 |
+
@tool
|
177 |
+
def transcribe_via_whisper(mp3_bytes: bytes) -> str:
|
178 |
+
"""Transcribe MP3 bytes with Whisper (CPU)."""
|
179 |
+
with NamedTemporaryFile(suffix=".mp3", delete=False) as f:
|
180 |
+
f.write(mp3_bytes)
|
181 |
+
path = f.name
|
182 |
try:
|
183 |
+
import whisper # openai-whisper
|
184 |
|
185 |
+
model = whisper.load_model("base")
|
186 |
+
output = model.transcribe(path)["text"].strip()
|
187 |
+
print(f"[DEBUG] Whisper transcript (first 200 chars): {output[:200]}")
|
188 |
+
return output
|
189 |
except Exception as exc:
|
190 |
+
return f"asr_error:{exc}"
|
191 |
+
|
192 |
+
|
193 |
+
@tool
|
194 |
+
def analyze_excel_file(xls_bytes: bytes, question: str) -> str:
|
195 |
+
"""Generic Excel/CSV aggregation handler."""
|
196 |
+
import pandas as pd
|
197 |
+
|
198 |
+
# Try both Excel and CSV loaders
|
199 |
+
try:
|
200 |
+
df = pd.read_excel(BytesIO(xls_bytes))
|
201 |
+
except Exception:
|
202 |
+
df = pd.read_csv(BytesIO(xls_bytes))
|
203 |
+
|
204 |
+
numeric = df.select_dtypes("number")
|
205 |
+
if numeric.empty:
|
206 |
+
return "No numeric data"
|
207 |
+
|
208 |
+
q = question.lower()
|
209 |
+
if any(term in q for term in ["total", "sum", "aggregate"]):
|
210 |
+
return f"{numeric.sum().sum():.2f}"
|
211 |
+
if any(term in q for term in ["average", "mean"]):
|
212 |
+
return f"{numeric.mean().mean():.2f}"
|
213 |
+
|
214 |
+
# Fallback: return first 10 rows as csv for LLM to reason on
|
215 |
+
return df.head(10).to_csv(index=False)
|
216 |
|
217 |
|
218 |
__all__ = [
|
|
|
221 |
"wiki_search",
|
222 |
"youtube_transcript",
|
223 |
"image_describe",
|
224 |
+
"run_py",
|
225 |
+
"transcribe_via_whisper",
|
226 |
+
"analyze_excel_file",
|
227 |
]
|