Essi commited on
Commit
e879014
·
1 Parent(s): 5b89395

feat: enhance file handling capabilities with support for code execution, Excel analysis, and audio transcription

Browse files
Files changed (3) hide show
  1. app.py +38 -8
  2. requirements.txt +1 -0
  3. tools.py +64 -15
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import re
3
- from typing import Literal, TypedDict
4
 
5
  import gradio as gr
6
  import pandas as pd
@@ -10,8 +10,11 @@ from langchain_openai import ChatOpenAI
10
  from langgraph.graph import END, StateGraph
11
 
12
  from tools import (
 
13
  calculator,
14
  image_describe,
 
 
15
  web_multi_search,
16
  wiki_search,
17
  youtube_transcript,
@@ -33,7 +36,7 @@ If the answer is numeric, output digits only (no commas, units, or words).
33
  # QUESTION CLASSIFIER #
34
  # --------------------------------------------------------------------------- #
35
 
36
- _LABELS = {"math", "youtube", "image", "general"}
37
 
38
  _CLASSIFY_PROMPT = """You are a router that labels the user question with exactly one of the following categories:
39
  {labels}.
@@ -50,7 +53,7 @@ Label:
50
  # --------------------------------------------------------------------------- #
51
  class AgentState(TypedDict):
52
  question: str
53
- label: Literal["math", "youtube", "image", "general"]
54
  context: str
55
  answer: str
56
  confidence: float
@@ -68,9 +71,12 @@ _llm_answer = ChatOpenAI(model=MODEL_NAME)
68
  def classify(state: AgentState) -> AgentState: # noqa: D401
69
  """Label the task so we know which toolchain to invoke."""
70
  question = state["question"]
 
 
 
71
  resp = (
72
  _llm_router.invoke(
73
- _CLASSIFY_PROMPT.format(question=question, labels=", ".join(_LABELS))
74
  )
75
  .content.strip()
76
  .lower()
@@ -80,11 +86,36 @@ def classify(state: AgentState) -> AgentState: # noqa: D401
80
 
81
 
82
  def gather_context(state: AgentState) -> AgentState:
83
- question, label = state["question"], state["label"]
84
 
85
  matched_pattern = r"https?://\S+"
86
  matched_obj = re.search(matched_pattern, question)
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  if label == "math":
89
  print("[TOOL] calculator")
90
  expr = re.sub(r"\s+", "", question)
@@ -109,9 +140,8 @@ def gather_context(state: AgentState) -> AgentState:
109
 
110
 
111
  def generate_answer(state: AgentState) -> AgentState:
112
- # Deterministic calculator path
113
- if state["label"] == "math":
114
- state["answer"] = state["context"].strip()
115
  state["confidence"] = 0.9
116
  return state
117
 
 
1
  import os
2
  import re
3
+ from typing import Literal, TypedDict, get_args
4
 
5
  import gradio as gr
6
  import pandas as pd
 
10
  from langgraph.graph import END, StateGraph
11
 
12
  from tools import (
13
+ analyze_excel_file,
14
  calculator,
15
  image_describe,
16
+ run_py,
17
+ transcribe_via_whisper,
18
  web_multi_search,
19
  wiki_search,
20
  youtube_transcript,
 
36
  # QUESTION CLASSIFIER #
37
  # --------------------------------------------------------------------------- #
38
 
39
+ _LABELS = Literal["math", "youtube", "image", "code", "excel", "audio", "general"]
40
 
41
  _CLASSIFY_PROMPT = """You are a router that labels the user question with exactly one of the following categories:
42
  {labels}.
 
53
  # --------------------------------------------------------------------------- #
54
  class AgentState(TypedDict):
55
  question: str
56
+ label: _LABELS
57
  context: str
58
  answer: str
59
  confidence: float
 
71
  def classify(state: AgentState) -> AgentState: # noqa: D401
72
  """Label the task so we know which toolchain to invoke."""
73
  question = state["question"]
74
+
75
+ values = get_args(_LABELS) # -> ("math", "youtube", ...)
76
+ parsed_labels = ", ".join(repr(v) for v in values)
77
  resp = (
78
  _llm_router.invoke(
79
+ _CLASSIFY_PROMPT.format(question=question, labels=parsed_labels)
80
  )
81
  .content.strip()
82
  .lower()
 
86
 
87
 
88
  def gather_context(state: AgentState) -> AgentState:
89
+ question, label, task_id = state["question"], state["label"], state["task_id"]
90
 
91
  matched_pattern = r"https?://\S+"
92
  matched_obj = re.search(matched_pattern, question)
93
 
94
+ # ---- attachment detection ------------------------------------------------
95
+ if task_id:
96
+ file_url = f"{DEFAULT_API_URL}/files/{task_id}"
97
+ head = requests.head(file_url, timeout=10)
98
+ ctype = head.headers.get("content-type", "")
99
+
100
+ print(f"[DEBUG] attachment type={ctype} | url={file_url}")
101
+ if "python" in ctype or file_url.endswith(".py"):
102
+ code = requests.get(file_url, timeout=10).text
103
+ state["answer"] = run_py.invoke({"code": code})
104
+ state["label"] = "code"
105
+ return state
106
+ if "excel" in ctype or file_url.endswith((".xlsx", ".csv")):
107
+ blob = requests.get(file_url, timeout=10).content
108
+ state["context"] = analyze_excel_file.invoke(
109
+ {"xls_bytes": blob, "question": question}
110
+ )
111
+ state["label"] = "excel"
112
+ return state
113
+ if "audio" in ctype or file_url.endswith(".mp3"):
114
+ blob = requests.get(file_url, timeout=10).content
115
+ state["context"] = transcribe_via_whisper.invoke({"mp3_bytes": blob})
116
+ state["label"] = "audio"
117
+ return state
118
+
119
  if label == "math":
120
  print("[TOOL] calculator")
121
  expr = re.sub(r"\s+", "", question)
 
140
 
141
 
142
  def generate_answer(state: AgentState) -> AgentState:
143
+ # Skip LLM for deterministic labels
144
+ if state["label"] in {"math", "code", "excel"}:
 
145
  state["confidence"] = 0.9
146
  return state
147
 
requirements.txt CHANGED
@@ -20,6 +20,7 @@ wikipedia==1.4.0 # WikipediaLoader
20
  youtube-transcript-api==1.0.3 # YouTube transcripts
21
  openpyxl==3.1.5 # Excel parsing when GAIA attaches .xlsx
22
  Pillow>=10.2.0 # image handling for transformers
 
23
 
24
  # ── Lightweight vision model
25
  transformers>=4.41.2
 
20
  youtube-transcript-api==1.0.3 # YouTube transcripts
21
  openpyxl==3.1.5 # Excel parsing when GAIA attaches .xlsx
22
  Pillow>=10.2.0 # image handling for transformers
23
+ openai-whisper=20240930
24
 
25
  # ── Lightweight vision model
26
  transformers>=4.41.2
tools.py CHANGED
@@ -2,8 +2,10 @@ import ast
2
  import json
3
  import operator
4
  import re
 
5
  from functools import lru_cache
6
  from io import BytesIO
 
7
 
8
  import requests
9
  from langchain_community.document_loaders import WikipediaLoader
@@ -46,7 +48,7 @@ def calculator(expression: str) -> str:
46
  tree = ast.parse(expression, mode="eval")
47
  value = _safe_eval(tree.body)
48
  return str(value)
49
- except Exception as exc: # pragma: no cover – we surface errors to the agent
50
  return f"calc_error:{exc}"
51
 
52
 
@@ -62,9 +64,9 @@ def _ddg_search(query: str, k: int = 6) -> list[dict[str, str]]:
62
  hits = wrapper.results(query)
63
  return [
64
  {
65
- "title": hit.get("title", "")[:120],
66
- "snippet": hit.get("snippet", "")[:300],
67
- "link": hit.get("link", "")[:200],
68
  }
69
  for hit in hits[:k]
70
  ]
@@ -87,9 +89,9 @@ def web_multi_search(query: str, k: int = 6) -> str:
87
  )
88
  formatted = [
89
  {
90
- "title": d.metadata.get("title", "")[:120],
91
- "snippet": d.page_content[:300],
92
- "link": d.metadata.get("source", "")[:200],
93
  }
94
  for d in tavily_hits
95
  ]
@@ -156,16 +158,61 @@ def image_describe(image_url: str, top_k: int = 3) -> str:
156
 
157
 
158
  @tool
159
- def csv_sum(url: str, column: str) -> str:
160
- """Download a CSV and return the sum of the specified numeric column."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  try:
162
- import pandas as pd # local import to avoid mandatory pandas if unused
163
 
164
- df = pd.read_csv(url)
165
- total = df[column].sum()
166
- return str(total)
 
167
  except Exception as exc:
168
- return f"csv_error:{exc}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  __all__ = [
@@ -174,5 +221,7 @@ __all__ = [
174
  "wiki_search",
175
  "youtube_transcript",
176
  "image_describe",
177
- "csv_sum",
 
 
178
  ]
 
2
  import json
3
  import operator
4
  import re
5
+ import subprocess
6
  from functools import lru_cache
7
  from io import BytesIO
8
+ from tempfile import NamedTemporaryFile
9
 
10
  import requests
11
  from langchain_community.document_loaders import WikipediaLoader
 
48
  tree = ast.parse(expression, mode="eval")
49
  value = _safe_eval(tree.body)
50
  return str(value)
51
+ except Exception as exc:
52
  return f"calc_error:{exc}"
53
 
54
 
 
64
  hits = wrapper.results(query)
65
  return [
66
  {
67
+ "title": hit.get("title", "")[:500],
68
+ "snippet": hit.get("snippet", "")[:750],
69
+ "link": hit.get("link", "")[:300],
70
  }
71
  for hit in hits[:k]
72
  ]
 
89
  )
90
  formatted = [
91
  {
92
+ "title": d.metadata.get("title", "")[:500],
93
+ "snippet": d.page_content[:750],
94
+ "link": d.metadata.get("source", "")[:300],
95
  }
96
  for d in tavily_hits
97
  ]
 
158
 
159
 
160
  @tool
161
+ def run_py(code: str) -> str:
162
+ """Execute Python code in a sandboxed subprocess and return last stdout line."""
163
+ try:
164
+ with NamedTemporaryFile(delete=False, suffix=".py", mode="w") as f:
165
+ f.write(code)
166
+ path = f.name
167
+ proc = subprocess.run(
168
+ ["python", path], capture_output=True, text=True, timeout=4
169
+ )
170
+ out = proc.stdout.strip().splitlines()
171
+ return out[-1] if out else ""
172
+ except Exception as exc:
173
+ return f"py_error:{exc}"
174
+
175
+
176
+ @tool
177
+ def transcribe_via_whisper(mp3_bytes: bytes) -> str:
178
+ """Transcribe MP3 bytes with Whisper (CPU)."""
179
+ with NamedTemporaryFile(suffix=".mp3", delete=False) as f:
180
+ f.write(mp3_bytes)
181
+ path = f.name
182
  try:
183
+ import whisper # openai-whisper
184
 
185
+ model = whisper.load_model("base")
186
+ output = model.transcribe(path)["text"].strip()
187
+ print(f"[DEBUG] Whisper transcript (first 200 chars): {output[:200]}")
188
+ return output
189
  except Exception as exc:
190
+ return f"asr_error:{exc}"
191
+
192
+
193
+ @tool
194
+ def analyze_excel_file(xls_bytes: bytes, question: str) -> str:
195
+ """Generic Excel/CSV aggregation handler."""
196
+ import pandas as pd
197
+
198
+ # Try both Excel and CSV loaders
199
+ try:
200
+ df = pd.read_excel(BytesIO(xls_bytes))
201
+ except Exception:
202
+ df = pd.read_csv(BytesIO(xls_bytes))
203
+
204
+ numeric = df.select_dtypes("number")
205
+ if numeric.empty:
206
+ return "No numeric data"
207
+
208
+ q = question.lower()
209
+ if any(term in q for term in ["total", "sum", "aggregate"]):
210
+ return f"{numeric.sum().sum():.2f}"
211
+ if any(term in q for term in ["average", "mean"]):
212
+ return f"{numeric.mean().mean():.2f}"
213
+
214
+ # Fallback: return first 10 rows as csv for LLM to reason on
215
+ return df.head(10).to_csv(index=False)
216
 
217
 
218
  __all__ = [
 
221
  "wiki_search",
222
  "youtube_transcript",
223
  "image_describe",
224
+ "run_py",
225
+ "transcribe_via_whisper",
226
+ "analyze_excel_file",
227
  ]