Essi commited on
Commit
2e0b688
·
1 Parent(s): e04c929

feat: refactor prompt handling and improve routing logic for task classification

Browse files
Files changed (3) hide show
  1. app.py +17 -52
  2. prompts.yaml +44 -1
  3. tools.py +10 -15
app.py CHANGED
@@ -9,7 +9,7 @@ from langchain_core.messages import HumanMessage, SystemMessage
9
  from langchain_openai import ChatOpenAI
10
  from langgraph.graph import END, StateGraph
11
 
12
- from helpers import fetch_task_file, sniff_excel_type
13
  from tools import (
14
  analyze_excel_file,
15
  calculator,
@@ -28,11 +28,6 @@ DEFAULT_API_URL: str = "https://agents-course-unit4-scoring.hf.space"
28
  MODEL_NAME: str = "o4-mini" # "gpt-4.1-mini"
29
  TEMPERATURE: float = 0.1
30
 
31
- _SYSTEM_PROMPT = """You are a precise research assistant. Return ONLY the literal answer - no preamble.
32
- If the question asks for a *first name*, output the first given name only.
33
- If the answer is numeric, output digits only (no commas, units, or words).
34
- """
35
-
36
  # --------------------------------------------------------------------------- #
37
  # QUESTION CLASSIFIER #
38
  # --------------------------------------------------------------------------- #
@@ -40,45 +35,13 @@ If the answer is numeric, output digits only (no commas, units, or words).
40
  _LABELS = Literal[
41
  "math",
42
  "youtube",
43
- "image_generic",
44
- "image_puzzle",
45
  "code",
46
  "excel",
47
  "audio",
48
  "general",
49
  ]
50
 
51
- _CLASSIFY_PROMPT = """You are a *routing* assistant.
52
- Your ONLY job is to print **one** of the allowed labels - nothing else.
53
-
54
- Allowed labels
55
- ==============
56
- {labels}
57
-
58
- Guidelines
59
- ----------
60
- • **math**: the question is a pure arithmetic/numeric expression.
61
- • **youtube**: the question contains a YouTube URL and asks about its content.
62
- • **code**: the task references attached Python code; caller wants its output.
63
- • **excel**: the task references an attached .xlsx/.xls/.csv and asks for a sum, average, etc.
64
- • **audio**: the task references an attached audio file and asks for its transcript or facts in it.
65
- • **image_generic**: the question asks only *what* is in the picture (e.g. “Which animal is shown?”).
66
- • **image_puzzle**: the question asks for a *move, count, coordinate,* or other board-game tactic that needs an exact piece layout (e.g. "What is Black's winning move?").
67
- • **general**: anything else (fallback).
68
-
69
- Example for the two image labels
70
- --------------------------------
71
- 1. "Identify the landmark in this photo." --> **image_generic**
72
- 2. "It's Black to move in the attached chess position; give the winning line." --> **image_puzzle**
73
-
74
- ~~~
75
- User question:
76
- {question}
77
- ~~~
78
-
79
- IMPORTANT: Respond with **one label exactly**, no punctuation, no explanation.
80
- """
81
-
82
 
83
  # --------------------------------------------------------------------------- #
84
  # ------------------------------- AGENT STATE ----------------------------- #
@@ -104,14 +67,12 @@ def classify(state: AgentState) -> AgentState: # noqa: D401
104
  question = state["question"]
105
 
106
  label_values = set(get_args(_LABELS)) # -> ("math", "youtube", ...)
107
- parsed_labels = ", ".join(repr(v) for v in label_values)
108
- resp = (
109
- _llm_router.invoke(
110
- _CLASSIFY_PROMPT.format(question=question, labels=parsed_labels)
111
- )
112
- .content.strip()
113
- .lower()
114
  )
 
115
  state["label"] = resp if resp in label_values else "general"
116
  return state
117
 
@@ -146,7 +107,7 @@ def gather_context(state: AgentState) -> AgentState:
146
  print(f"[DEBUG] octet-stream sniffed as {sniff_excel_type(blob)}")
147
 
148
  print("[DEBUG] Working with a Excel/CSV attachment file")
149
- state["context"] = analyze_excel_file.invoke(
150
  {"xls_bytes": blob, "question": question}
151
  )
152
  state["label"] = "excel"
@@ -162,7 +123,7 @@ def gather_context(state: AgentState) -> AgentState:
162
  # ── Image --------------------------------------------------------
163
  if "image" in ctype:
164
  print("[DEBUG] Working with an image attachment file")
165
- state["context"] = vision_task.invoke(
166
  {"img_bytes": blob, "question": question}
167
  )
168
  state["label"] = "image"
@@ -187,14 +148,18 @@ def gather_context(state: AgentState) -> AgentState:
187
 
188
 
189
  def generate_answer(state: AgentState) -> AgentState:
190
- # Skip LLM for deterministic labels
191
- if state["label"] in {"math", "code", "excel"}:
192
  return state
193
 
194
  prompt = [
195
- SystemMessage(content=_SYSTEM_PROMPT),
196
  HumanMessage(
197
- content=f"Question: {state['question']}\n\nContext:\n{state['context']}\n\nAnswer:"
 
 
 
 
198
  ),
199
  ]
200
  raw = _llm_answer.invoke(prompt).content.strip()
 
9
  from langchain_openai import ChatOpenAI
10
  from langgraph.graph import END, StateGraph
11
 
12
+ from helpers import fetch_task_file, get_prompt, sniff_excel_type
13
  from tools import (
14
  analyze_excel_file,
15
  calculator,
 
28
  MODEL_NAME: str = "o4-mini" # "gpt-4.1-mini"
29
  TEMPERATURE: float = 0.1
30
 
 
 
 
 
 
31
  # --------------------------------------------------------------------------- #
32
  # QUESTION CLASSIFIER #
33
  # --------------------------------------------------------------------------- #
 
35
  _LABELS = Literal[
36
  "math",
37
  "youtube",
38
+ "image",
 
39
  "code",
40
  "excel",
41
  "audio",
42
  "general",
43
  ]
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # --------------------------------------------------------------------------- #
47
  # ------------------------------- AGENT STATE ----------------------------- #
 
67
  question = state["question"]
68
 
69
  label_values = set(get_args(_LABELS)) # -> ("math", "youtube", ...)
70
+ prompt = get_prompt(
71
+ prompt_key="router",
72
+ question=question,
73
+ labels=", ".join(repr(v) for v in label_values),
 
 
 
74
  )
75
+ resp = _llm_router.invoke(prompt).content.strip().lower()
76
  state["label"] = resp if resp in label_values else "general"
77
  return state
78
 
 
107
  print(f"[DEBUG] octet-stream sniffed as {sniff_excel_type(blob)}")
108
 
109
  print("[DEBUG] Working with a Excel/CSV attachment file")
110
+ state["answer"] = analyze_excel_file.invoke(
111
  {"xls_bytes": blob, "question": question}
112
  )
113
  state["label"] = "excel"
 
123
  # ── Image --------------------------------------------------------
124
  if "image" in ctype:
125
  print("[DEBUG] Working with an image attachment file")
126
+ state["answer"] = vision_task.invoke(
127
  {"img_bytes": blob, "question": question}
128
  )
129
  state["label"] = "image"
 
148
 
149
 
150
  def generate_answer(state: AgentState) -> AgentState:
151
+ # Skip LLM for deterministic labels or tasks that already used LLMs
152
+ if state["label"] in {"code", "excel", "image", "math"}:
153
  return state
154
 
155
  prompt = [
156
+ SystemMessage(content=get_prompt("final_llm_system")),
157
  HumanMessage(
158
+ content=get_prompt(
159
+ prompt_key="final_llm_user",
160
+ question=state["question"],
161
+ context=state["context"],
162
+ )
163
  ),
164
  ]
165
  raw = _llm_answer.invoke(prompt).content.strip()
prompts.yaml CHANGED
@@ -1,4 +1,47 @@
1
- excel_analysis_one_liner: |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  You are a **pandas one-liner generator**.
3
 
4
  Context
 
1
+ router: |
2
+ You are a *routing* assistant.
3
+ Your ONLY job is to print **one** of the allowed labels - nothing else.
4
+
5
+ Allowed labels
6
+ ==============
7
+ {labels}
8
+
9
+ Guidelines
10
+ ----------
11
+ • **math**: the question is a pure arithmetic/numeric expression.
12
+ • **youtube**: the question contains a YouTube URL and asks about its content.
13
+ • **code**: the task references attached Python code; caller wants its output.
14
+ • **excel**: the task references an attached .xlsx/.xls/.csv and asks for a sum, average, etc.
15
+ • **audio**: the task references an attached audio file and asks for its transcript or facts in it.
16
+ • **image**: the task could be either generic like "what is in the picture (e.g. Which animal is shown?) or could be a puzzle like asking for a *move, count, coordinate,* or other board-game tactic that needs an exact piece layout (e.g. "What is Black's winning move?").
17
+ • **general**: anything else (fallback).
18
+
19
+ ~~~
20
+ User question:
21
+ {question}
22
+ ~~~
23
+
24
+ IMPORTANT: Respond with **one label exactly**, no punctuation, no explanation.
25
+
26
+ final_llm_system: |
27
+ You are a precise research assistant.
28
+ Return ONLY the literal answer - no preamble.
29
+ If the question asks for a *first name*, output the first given name only.
30
+ If the answer is numeric, output digits only (no commas, units, or words).
31
+
32
+ final_llm_user: |
33
+ Question: {question}
34
+
35
+ Context: {context}
36
+
37
+ Answer:
38
+
39
+ vision_system: |
40
+ You are a terse assistant. Respond with ONLY the answer to the user's question—no explanations, no punctuation except what the answer itself requires.
41
+ If the answer is a chess move, output it in algebraic notation.
42
+ IMPORTANT: Only respond with the final answer with no extra text.
43
+
44
+ excel_system: |
45
  You are a **pandas one-liner generator**.
46
 
47
  Context
tools.py CHANGED
@@ -165,11 +165,6 @@ def vision_task(img_bytes: bytes, question: str) -> str:
165
  Pass the user's question AND the referenced image to a multimodal LLM and
166
  return its first line of text as the answer. No domain assumptions made.
167
  """
168
- sys_prompt = (
169
- "You are a terse assistant. Respond with ONLY the answer to the user's "
170
- "question—no explanations, no punctuation except what the answer itself "
171
- "requires. If the answer is a chess move, output it in algebraic notation."
172
- )
173
  vision_llm = ChatOpenAI(
174
  model="gpt-4o-mini", # set OPENAI_API_KEY in env
175
  temperature=0,
@@ -178,7 +173,7 @@ def vision_task(img_bytes: bytes, question: str) -> str:
178
  try:
179
  b64 = b64encode(img_bytes).decode()
180
  messages = [
181
- SystemMessage(content=sys_prompt),
182
  HumanMessage(
183
  content=[
184
  {"type": "text", "text": question.strip()},
@@ -215,10 +210,10 @@ def run_py(code: str) -> str:
215
 
216
 
217
  @tool
218
- def transcribe_via_whisper(mp3_bytes: bytes) -> str:
219
- """Transcribe MP3 bytes with Whisper (CPU)."""
220
  with NamedTemporaryFile(suffix=".mp3", delete=False) as f:
221
- f.write(mp3_bytes)
222
  path = f.name
223
  try:
224
  import whisper # openai-whisper
@@ -236,7 +231,6 @@ def analyze_excel_file(xls_bytes: bytes, question: str) -> str:
236
  "Analyze Excel or CSV file by passing the data preview to LLM and getting the Python Pandas operation to run"
237
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=64)
238
 
239
- # 1. full dataframe
240
  try:
241
  df = pd.read_excel(BytesIO(xls_bytes))
242
  except Exception:
@@ -245,18 +239,19 @@ def analyze_excel_file(xls_bytes: bytes, question: str) -> str:
245
  for col in df.select_dtypes(include="number").columns:
246
  df[col] = df[col].astype(float)
247
 
248
- # 2. ask the LLM for a single expression
249
  prompt = get_prompt(
250
- prompt_key="excel_analysis_one_liner", preview=df.head(5).to_dict(orient="list")
 
 
251
  )
252
  expr = llm.invoke(prompt).content.strip()
253
 
254
- # 3. run it on the FULL df
255
  try:
256
  result = eval(expr, {"df": df, "pd": pd, "__builtins__": {}})
257
- # ── normalize scalars to string -------------------------------------------
258
  if isinstance(result, np.generic):
259
- # keep existing LLM formatting (e.g. {:.2f}) if it's already a str
260
  result = float(result) # → plain Python float
261
  return f"{result:.2f}" # or str(result) if no decimals needed
262
 
 
165
  Pass the user's question AND the referenced image to a multimodal LLM and
166
  return its first line of text as the answer. No domain assumptions made.
167
  """
 
 
 
 
 
168
  vision_llm = ChatOpenAI(
169
  model="gpt-4o-mini", # set OPENAI_API_KEY in env
170
  temperature=0,
 
173
  try:
174
  b64 = b64encode(img_bytes).decode()
175
  messages = [
176
+ SystemMessage(content=get_prompt(prompt_key="vision_system")),
177
  HumanMessage(
178
  content=[
179
  {"type": "text", "text": question.strip()},
 
210
 
211
 
212
  @tool
213
+ def transcribe_via_whisper(audio_bytes: bytes) -> str:
214
+ """Transcribe audio with Whisper (CPU)."""
215
  with NamedTemporaryFile(suffix=".mp3", delete=False) as f:
216
+ f.write(audio_bytes)
217
  path = f.name
218
  try:
219
  import whisper # openai-whisper
 
231
  "Analyze Excel or CSV file by passing the data preview to LLM and getting the Python Pandas operation to run"
232
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, max_tokens=64)
233
 
 
234
  try:
235
  df = pd.read_excel(BytesIO(xls_bytes))
236
  except Exception:
 
239
  for col in df.select_dtypes(include="number").columns:
240
  df[col] = df[col].astype(float)
241
 
242
+ # Ask the LLM for a single expression
243
  prompt = get_prompt(
244
+ prompt_key="excel_system",
245
+ question=question,
246
+ preview=df.head(5).to_dict(orient="list"),
247
  )
248
  expr = llm.invoke(prompt).content.strip()
249
 
250
+ # Run generated Pandas' one-line expression
251
  try:
252
  result = eval(expr, {"df": df, "pd": pd, "__builtins__": {}})
253
+ # Normalize scalars to string
254
  if isinstance(result, np.generic):
 
255
  result = float(result) # → plain Python float
256
  return f"{result:.2f}" # or str(result) if no decimals needed
257