Spaces:
Sleeping
Sleeping
Essi
commited on
Commit
·
21394c0
1
Parent(s):
edd4fd8
feat: enhance attachment handling in `gather_context` with improved file type detection and routing
Browse files
app.py
CHANGED
@@ -9,12 +9,13 @@ from langchain_core.messages import HumanMessage, SystemMessage
|
|
9 |
from langchain_openai import ChatOpenAI
|
10 |
from langgraph.graph import END, StateGraph
|
11 |
|
|
|
12 |
from tools import (
|
13 |
analyze_excel_file,
|
14 |
calculator,
|
15 |
-
image_describe,
|
16 |
run_py,
|
17 |
transcribe_via_whisper,
|
|
|
18 |
web_multi_search,
|
19 |
wiki_search,
|
20 |
youtube_transcript,
|
@@ -36,15 +37,46 @@ If the answer is numeric, output digits only (no commas, units, or words).
|
|
36 |
# QUESTION CLASSIFIER #
|
37 |
# --------------------------------------------------------------------------- #
|
38 |
|
39 |
-
_LABELS = Literal[
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
User question:
|
45 |
{question}
|
|
|
46 |
|
47 |
-
|
48 |
"""
|
49 |
|
50 |
|
@@ -56,7 +88,6 @@ class AgentState(TypedDict):
|
|
56 |
label: str
|
57 |
context: str
|
58 |
answer: str
|
59 |
-
confidence: float
|
60 |
task_id: str | None = None
|
61 |
|
62 |
|
@@ -93,28 +124,49 @@ def gather_context(state: AgentState) -> AgentState:
|
|
93 |
|
94 |
# ---- attachment detection ------------------------------------------------
|
95 |
if task_id:
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
)
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
if label == "math":
|
120 |
print("[TOOL] calculator")
|
@@ -125,11 +177,6 @@ def gather_context(state: AgentState) -> AgentState:
|
|
125 |
if matched_obj:
|
126 |
url = matched_obj[0]
|
127 |
state["context"] = youtube_transcript.invoke({"url": url})
|
128 |
-
elif label == "image" and matched_obj:
|
129 |
-
print("[TOOL] image")
|
130 |
-
if matched_obj:
|
131 |
-
url = matched_obj[0]
|
132 |
-
state["context"] = image_describe.invoke({"image_url": url})
|
133 |
else: # general
|
134 |
print("[TOOL] general")
|
135 |
search_json = web_multi_search.invoke({"query": question})
|
@@ -142,7 +189,6 @@ def gather_context(state: AgentState) -> AgentState:
|
|
142 |
def generate_answer(state: AgentState) -> AgentState:
|
143 |
# Skip LLM for deterministic labels
|
144 |
if state["label"] in {"math", "code", "excel"}:
|
145 |
-
state["confidence"] = 0.9
|
146 |
return state
|
147 |
|
148 |
prompt = [
|
@@ -153,23 +199,17 @@ def generate_answer(state: AgentState) -> AgentState:
|
|
153 |
]
|
154 |
raw = _llm_answer.invoke(prompt).content.strip()
|
155 |
state["answer"] = raw
|
156 |
-
state["confidence"] = 0.5
|
157 |
return state
|
158 |
|
159 |
|
160 |
def validate(state: AgentState) -> AgentState:
|
161 |
-
"""Simple format + confidence gate."""
|
162 |
txt = re.sub(r"^(final answer:?\s*)", "", state["answer"], flags=re.I).strip()
|
163 |
|
164 |
# If question demands a single token (first name / one word), enforce it
|
165 |
if any(kw in state["question"].lower() for kw in ["first name", "single word"]):
|
166 |
txt = txt.split(" ")[0]
|
167 |
|
168 |
-
|
169 |
-
if not txt or len(txt.split()) > 6 or state["confidence"] < 0.2:
|
170 |
-
txt = "I don’t know"
|
171 |
-
|
172 |
-
state["answer"] = txt
|
173 |
return state
|
174 |
|
175 |
|
@@ -210,7 +250,6 @@ class GAIAAgent:
|
|
210 |
"label": "general",
|
211 |
"context": "",
|
212 |
"answer": "",
|
213 |
-
"confidence": 0.0,
|
214 |
"task_id": task_id,
|
215 |
}
|
216 |
final = self.graph.invoke(state)
|
|
|
9 |
from langchain_openai import ChatOpenAI
|
10 |
from langgraph.graph import END, StateGraph
|
11 |
|
12 |
+
from helpers import fetch_task_file, sniff_excel_type
|
13 |
from tools import (
|
14 |
analyze_excel_file,
|
15 |
calculator,
|
|
|
16 |
run_py,
|
17 |
transcribe_via_whisper,
|
18 |
+
vision_task,
|
19 |
web_multi_search,
|
20 |
wiki_search,
|
21 |
youtube_transcript,
|
|
|
37 |
# QUESTION CLASSIFIER #
|
38 |
# --------------------------------------------------------------------------- #
|
39 |
|
40 |
+
_LABELS = Literal[
|
41 |
+
"math",
|
42 |
+
"youtube",
|
43 |
+
"image_generic",
|
44 |
+
"image_puzzle",
|
45 |
+
"code",
|
46 |
+
"excel",
|
47 |
+
"audio",
|
48 |
+
"general",
|
49 |
+
]
|
50 |
+
|
51 |
+
_CLASSIFY_PROMPT = """You are a *routing* assistant.
|
52 |
+
Your ONLY job is to print **one** of the allowed labels - nothing else.
|
53 |
+
|
54 |
+
Allowed labels
|
55 |
+
==============
|
56 |
+
{labels}
|
57 |
+
|
58 |
+
Guidelines
|
59 |
+
----------
|
60 |
+
• **math**: the question is a pure arithmetic/numeric expression.
|
61 |
+
• **youtube**: the question contains a YouTube URL and asks about its content.
|
62 |
+
• **code**: the task references attached Python code; caller wants its output.
|
63 |
+
• **excel**: the task references an attached .xlsx/.xls/.csv and asks for a sum, average, etc.
|
64 |
+
• **audio**: the task references an attached audio file and asks for its transcript or facts in it.
|
65 |
+
• **image_generic**: the question asks only *what* is in the picture (e.g. “Which animal is shown?”).
|
66 |
+
• **image_puzzle**: the question asks for a *move, count, coordinate,* or other board-game tactic that needs an exact piece layout (e.g. "What is Black's winning move?").
|
67 |
+
• **general**: anything else (fallback).
|
68 |
+
|
69 |
+
Example for the two image labels
|
70 |
+
--------------------------------
|
71 |
+
1. "Identify the landmark in this photo." --> **image_generic**
|
72 |
+
2. "It's Black to move in the attached chess position; give the winning line." --> **image_puzzle**
|
73 |
+
|
74 |
+
~~~
|
75 |
User question:
|
76 |
{question}
|
77 |
+
~~~
|
78 |
|
79 |
+
IMPORTANT: Respond with **one label exactly**, no punctuation, no explanation.
|
80 |
"""
|
81 |
|
82 |
|
|
|
88 |
label: str
|
89 |
context: str
|
90 |
answer: str
|
|
|
91 |
task_id: str | None = None
|
92 |
|
93 |
|
|
|
124 |
|
125 |
# ---- attachment detection ------------------------------------------------
|
126 |
if task_id:
|
127 |
+
blob, ctype = fetch_task_file(api_url=DEFAULT_API_URL, task_id=task_id)
|
128 |
+
|
129 |
+
if any([blob, ctype]):
|
130 |
+
print(f"[DEBUG] attachment type={ctype} ")
|
131 |
+
# ── Python code ------------------------------------------------------
|
132 |
+
if "python" in ctype:
|
133 |
+
print("[DEBUG] Working with a Python attachment file")
|
134 |
+
state["answer"] = run_py.invoke({"code": blob.decode("utf-8")})
|
135 |
+
state["label"] = "code"
|
136 |
+
return state
|
137 |
+
|
138 |
+
# ── Excel / CSV ------------------------------------------------------
|
139 |
+
# 1) Header hints
|
140 |
+
header_says_sheet = any(key in ctype for key in ("excel", "sheet", "csv"))
|
141 |
+
# 2) Magic-number sniff (works when ctype is application/octet-stream)
|
142 |
+
blob_says_sheet = sniff_excel_type(blob) in {"xlsx", "xls", "csv"}
|
143 |
+
|
144 |
+
if header_says_sheet or blob_says_sheet:
|
145 |
+
if blob_says_sheet:
|
146 |
+
print(f"[DEBUG] octet-stream sniffed as {sniff_excel_type(blob)}")
|
147 |
+
|
148 |
+
print("[DEBUG] Working with a Excel/CSV attachment file")
|
149 |
+
state["context"] = analyze_excel_file.invoke(
|
150 |
+
{"xls_bytes": blob, "question": question}
|
151 |
+
)
|
152 |
+
state["label"] = "excel"
|
153 |
+
return state
|
154 |
+
|
155 |
+
# ── Audio --------------------------------------------------------
|
156 |
+
if "audio" in ctype:
|
157 |
+
print("[DEBUG] Working with an audio attachment file")
|
158 |
+
state["context"] = transcribe_via_whisper.invoke({"mp3_bytes": blob})
|
159 |
+
state["label"] = "audio"
|
160 |
+
return state
|
161 |
+
|
162 |
+
# ── Image --------------------------------------------------------
|
163 |
+
if "image" in ctype:
|
164 |
+
print("[DEBUG] Working with an image attachment file")
|
165 |
+
state["context"] = vision_task.invoke(
|
166 |
+
{"img_bytes": blob, "question": question}
|
167 |
+
)
|
168 |
+
state["label"] = "image"
|
169 |
+
return state
|
170 |
|
171 |
if label == "math":
|
172 |
print("[TOOL] calculator")
|
|
|
177 |
if matched_obj:
|
178 |
url = matched_obj[0]
|
179 |
state["context"] = youtube_transcript.invoke({"url": url})
|
|
|
|
|
|
|
|
|
|
|
180 |
else: # general
|
181 |
print("[TOOL] general")
|
182 |
search_json = web_multi_search.invoke({"query": question})
|
|
|
189 |
def generate_answer(state: AgentState) -> AgentState:
|
190 |
# Skip LLM for deterministic labels
|
191 |
if state["label"] in {"math", "code", "excel"}:
|
|
|
192 |
return state
|
193 |
|
194 |
prompt = [
|
|
|
199 |
]
|
200 |
raw = _llm_answer.invoke(prompt).content.strip()
|
201 |
state["answer"] = raw
|
|
|
202 |
return state
|
203 |
|
204 |
|
205 |
def validate(state: AgentState) -> AgentState:
|
|
|
206 |
txt = re.sub(r"^(final answer:?\s*)", "", state["answer"], flags=re.I).strip()
|
207 |
|
208 |
# If question demands a single token (first name / one word), enforce it
|
209 |
if any(kw in state["question"].lower() for kw in ["first name", "single word"]):
|
210 |
txt = txt.split(" ")[0]
|
211 |
|
212 |
+
state["answer"] = txt.rstrip(".")
|
|
|
|
|
|
|
|
|
213 |
return state
|
214 |
|
215 |
|
|
|
250 |
"label": "general",
|
251 |
"context": "",
|
252 |
"answer": "",
|
|
|
253 |
"task_id": task_id,
|
254 |
}
|
255 |
final = self.graph.invoke(state)
|