Essi commited on
Commit
21394c0
·
1 Parent(s): edd4fd8

feat: enhance attachment handling in `gather_context` with improved file type detection and routing

Browse files
Files changed (1) hide show
  1. app.py +83 -44
app.py CHANGED
@@ -9,12 +9,13 @@ from langchain_core.messages import HumanMessage, SystemMessage
9
  from langchain_openai import ChatOpenAI
10
  from langgraph.graph import END, StateGraph
11
 
 
12
  from tools import (
13
  analyze_excel_file,
14
  calculator,
15
- image_describe,
16
  run_py,
17
  transcribe_via_whisper,
 
18
  web_multi_search,
19
  wiki_search,
20
  youtube_transcript,
@@ -36,15 +37,46 @@ If the answer is numeric, output digits only (no commas, units, or words).
36
  # QUESTION CLASSIFIER #
37
  # --------------------------------------------------------------------------- #
38
 
39
- _LABELS = Literal["math", "youtube", "image", "code", "excel", "audio", "general"]
40
-
41
- _CLASSIFY_PROMPT = """You are a router that labels the user question with exactly one of the following categories:
42
- {labels}.
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  User question:
45
  {question}
 
46
 
47
- Label:
48
  """
49
 
50
 
@@ -56,7 +88,6 @@ class AgentState(TypedDict):
56
  label: str
57
  context: str
58
  answer: str
59
- confidence: float
60
  task_id: str | None = None
61
 
62
 
@@ -93,28 +124,49 @@ def gather_context(state: AgentState) -> AgentState:
93
 
94
  # ---- attachment detection ------------------------------------------------
95
  if task_id:
96
- file_url = f"{DEFAULT_API_URL}/files/{task_id}"
97
- head = requests.head(file_url, timeout=10)
98
- ctype = head.headers.get("content-type", "")
99
-
100
- print(f"[DEBUG] attachment type={ctype} | url={file_url}")
101
- if "python" in ctype or file_url.endswith(".py"):
102
- code = requests.get(file_url, timeout=10).text
103
- state["answer"] = run_py.invoke({"code": code})
104
- state["label"] = "code"
105
- return state
106
- if "excel" in ctype or file_url.endswith((".xlsx", ".csv")):
107
- blob = requests.get(file_url, timeout=10).content
108
- state["context"] = analyze_excel_file.invoke(
109
- {"xls_bytes": blob, "question": question}
110
- )
111
- state["label"] = "excel"
112
- return state
113
- if "audio" in ctype or file_url.endswith(".mp3"):
114
- blob = requests.get(file_url, timeout=10).content
115
- state["context"] = transcribe_via_whisper.invoke({"mp3_bytes": blob})
116
- state["label"] = "audio"
117
- return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  if label == "math":
120
  print("[TOOL] calculator")
@@ -125,11 +177,6 @@ def gather_context(state: AgentState) -> AgentState:
125
  if matched_obj:
126
  url = matched_obj[0]
127
  state["context"] = youtube_transcript.invoke({"url": url})
128
- elif label == "image" and matched_obj:
129
- print("[TOOL] image")
130
- if matched_obj:
131
- url = matched_obj[0]
132
- state["context"] = image_describe.invoke({"image_url": url})
133
  else: # general
134
  print("[TOOL] general")
135
  search_json = web_multi_search.invoke({"query": question})
@@ -142,7 +189,6 @@ def gather_context(state: AgentState) -> AgentState:
142
  def generate_answer(state: AgentState) -> AgentState:
143
  # Skip LLM for deterministic labels
144
  if state["label"] in {"math", "code", "excel"}:
145
- state["confidence"] = 0.9
146
  return state
147
 
148
  prompt = [
@@ -153,23 +199,17 @@ def generate_answer(state: AgentState) -> AgentState:
153
  ]
154
  raw = _llm_answer.invoke(prompt).content.strip()
155
  state["answer"] = raw
156
- state["confidence"] = 0.5
157
  return state
158
 
159
 
160
  def validate(state: AgentState) -> AgentState:
161
- """Simple format + confidence gate."""
162
  txt = re.sub(r"^(final answer:?\s*)", "", state["answer"], flags=re.I).strip()
163
 
164
  # If question demands a single token (first name / one word), enforce it
165
  if any(kw in state["question"].lower() for kw in ["first name", "single word"]):
166
  txt = txt.split(" ")[0]
167
 
168
- txt = txt.rstrip(".")
169
- if not txt or len(txt.split()) > 6 or state["confidence"] < 0.2:
170
- txt = "I don’t know"
171
-
172
- state["answer"] = txt
173
  return state
174
 
175
 
@@ -210,7 +250,6 @@ class GAIAAgent:
210
  "label": "general",
211
  "context": "",
212
  "answer": "",
213
- "confidence": 0.0,
214
  "task_id": task_id,
215
  }
216
  final = self.graph.invoke(state)
 
9
  from langchain_openai import ChatOpenAI
10
  from langgraph.graph import END, StateGraph
11
 
12
+ from helpers import fetch_task_file, sniff_excel_type
13
  from tools import (
14
  analyze_excel_file,
15
  calculator,
 
16
  run_py,
17
  transcribe_via_whisper,
18
+ vision_task,
19
  web_multi_search,
20
  wiki_search,
21
  youtube_transcript,
 
37
  # QUESTION CLASSIFIER #
38
  # --------------------------------------------------------------------------- #
39
 
40
+ _LABELS = Literal[
41
+ "math",
42
+ "youtube",
43
+ "image_generic",
44
+ "image_puzzle",
45
+ "code",
46
+ "excel",
47
+ "audio",
48
+ "general",
49
+ ]
50
+
51
+ _CLASSIFY_PROMPT = """You are a *routing* assistant.
52
+ Your ONLY job is to print **one** of the allowed labels - nothing else.
53
+
54
+ Allowed labels
55
+ ==============
56
+ {labels}
57
+
58
+ Guidelines
59
+ ----------
60
+ • **math**: the question is a pure arithmetic/numeric expression.
61
+ • **youtube**: the question contains a YouTube URL and asks about its content.
62
+ • **code**: the task references attached Python code; caller wants its output.
63
+ • **excel**: the task references an attached .xlsx/.xls/.csv and asks for a sum, average, etc.
64
+ • **audio**: the task references an attached audio file and asks for its transcript or facts in it.
65
+ • **image_generic**: the question asks only *what* is in the picture (e.g. “Which animal is shown?”).
66
+ • **image_puzzle**: the question asks for a *move, count, coordinate,* or other board-game tactic that needs an exact piece layout (e.g. "What is Black's winning move?").
67
+ • **general**: anything else (fallback).
68
+
69
+ Example for the two image labels
70
+ --------------------------------
71
+ 1. "Identify the landmark in this photo." --> **image_generic**
72
+ 2. "It's Black to move in the attached chess position; give the winning line." --> **image_puzzle**
73
+
74
+ ~~~
75
  User question:
76
  {question}
77
+ ~~~
78
 
79
+ IMPORTANT: Respond with **one label exactly**, no punctuation, no explanation.
80
  """
81
 
82
 
 
88
  label: str
89
  context: str
90
  answer: str
 
91
  task_id: str | None = None
92
 
93
 
 
124
 
125
  # ---- attachment detection ------------------------------------------------
126
  if task_id:
127
+ blob, ctype = fetch_task_file(api_url=DEFAULT_API_URL, task_id=task_id)
128
+
129
+ if any([blob, ctype]):
130
+ print(f"[DEBUG] attachment type={ctype} ")
131
+ # ── Python code ------------------------------------------------------
132
+ if "python" in ctype:
133
+ print("[DEBUG] Working with a Python attachment file")
134
+ state["answer"] = run_py.invoke({"code": blob.decode("utf-8")})
135
+ state["label"] = "code"
136
+ return state
137
+
138
+ # ── Excel / CSV ------------------------------------------------------
139
+ # 1) Header hints
140
+ header_says_sheet = any(key in ctype for key in ("excel", "sheet", "csv"))
141
+ # 2) Magic-number sniff (works when ctype is application/octet-stream)
142
+ blob_says_sheet = sniff_excel_type(blob) in {"xlsx", "xls", "csv"}
143
+
144
+ if header_says_sheet or blob_says_sheet:
145
+ if blob_says_sheet:
146
+ print(f"[DEBUG] octet-stream sniffed as {sniff_excel_type(blob)}")
147
+
148
+ print("[DEBUG] Working with a Excel/CSV attachment file")
149
+ state["context"] = analyze_excel_file.invoke(
150
+ {"xls_bytes": blob, "question": question}
151
+ )
152
+ state["label"] = "excel"
153
+ return state
154
+
155
+ # ── Audio --------------------------------------------------------
156
+ if "audio" in ctype:
157
+ print("[DEBUG] Working with an audio attachment file")
158
+ state["context"] = transcribe_via_whisper.invoke({"mp3_bytes": blob})
159
+ state["label"] = "audio"
160
+ return state
161
+
162
+ # ── Image --------------------------------------------------------
163
+ if "image" in ctype:
164
+ print("[DEBUG] Working with an image attachment file")
165
+ state["context"] = vision_task.invoke(
166
+ {"img_bytes": blob, "question": question}
167
+ )
168
+ state["label"] = "image"
169
+ return state
170
 
171
  if label == "math":
172
  print("[TOOL] calculator")
 
177
  if matched_obj:
178
  url = matched_obj[0]
179
  state["context"] = youtube_transcript.invoke({"url": url})
 
 
 
 
 
180
  else: # general
181
  print("[TOOL] general")
182
  search_json = web_multi_search.invoke({"query": question})
 
189
  def generate_answer(state: AgentState) -> AgentState:
190
  # Skip LLM for deterministic labels
191
  if state["label"] in {"math", "code", "excel"}:
 
192
  return state
193
 
194
  prompt = [
 
199
  ]
200
  raw = _llm_answer.invoke(prompt).content.strip()
201
  state["answer"] = raw
 
202
  return state
203
 
204
 
205
  def validate(state: AgentState) -> AgentState:
 
206
  txt = re.sub(r"^(final answer:?\s*)", "", state["answer"], flags=re.I).strip()
207
 
208
  # If question demands a single token (first name / one word), enforce it
209
  if any(kw in state["question"].lower() for kw in ["first name", "single word"]):
210
  txt = txt.split(" ")[0]
211
 
212
+ state["answer"] = txt.rstrip(".")
 
 
 
 
213
  return state
214
 
215
 
 
250
  "label": "general",
251
  "context": "",
252
  "answer": "",
 
253
  "task_id": task_id,
254
  }
255
  final = self.graph.invoke(state)