Essi commited on
Commit
e773514
·
1 Parent(s): 463e805

feat: implement multi-backend search and YouTube transcript retrieval

Browse files
Files changed (1) hide show
  1. app.py +81 -63
app.py CHANGED
@@ -6,6 +6,7 @@ import re
6
  from functools import lru_cache
7
  from io import BytesIO
8
  from typing import TypedDict
 
9
 
10
  import gradio as gr
11
  import pandas as pd
@@ -15,6 +16,8 @@ from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
15
  from langchain_core.messages import HumanMessage, SystemMessage
16
  from langchain_openai import ChatOpenAI
17
  from langgraph.graph import END, StateGraph
 
 
18
 
19
  # --- Constants ---
20
  DEFAULT_API_URL: str = "https://agents-course-unit4-scoring.hf.space"
@@ -82,12 +85,51 @@ def _search_duckduckgo(query: str, k: int = 5) -> list[dict[str, str]]:
82
 
83
 
84
  @tool
85
- def web_search(query: str) -> str:
86
- """DuckDuckGo search. Returns compact JSON (max 5 hits)."""
 
 
 
87
  try:
88
- return json.dumps(_search_duckduckgo(query), ensure_ascii=False)
89
- except Exception as exc:
90
- return f"search_error:{exc}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  # --------------------------------------------------------------------------- #
@@ -100,37 +142,18 @@ def _needs_calc(q: str) -> bool:
100
 
101
 
102
  def _extract_search_terms(question: str) -> str:
103
- """Extract key search terms from question."""
104
- stops = {
105
- "what",
106
- "who",
107
- "where",
108
- "when",
109
- "how",
110
- "why",
111
- "is",
112
- "are",
113
- "was",
114
- "were",
115
- "the",
116
- "and",
117
- "or",
118
- }
119
- tokens = re.findall(r"[A-Za-z0-9]+", question.lower())
120
- key_terms = []
121
-
122
- for tok in tokens:
123
- if (
124
- tok.lower() not in stops or len(tok) > 6
125
- ): # Keep longer words even if they're stop words
126
- key_terms.append(tok)
127
-
128
- # Limit to avoid overly long queries
129
- return " ".join(key_terms[:8])
130
 
131
 
132
  def _summarize_results(results_json: str, max_hits: int = 3) -> str:
133
  """Turn JSON list of hits into a compact text context for the LLM."""
 
 
 
 
134
  try:
135
  hits = json.loads(results_json)[:max_hits]
136
  context_parts = []
@@ -145,25 +168,6 @@ def _summarize_results(results_json: str, max_hits: int = 3) -> str:
145
  return ""
146
 
147
 
148
- def _contains_file_reference(question: str) -> bool:
149
- """Check if question references attached files."""
150
- file_indicators = [
151
- "attached",
152
- "attachment",
153
- "file",
154
- "excel",
155
- "spreadsheet",
156
- "xls",
157
- "csv",
158
- "document",
159
- "image",
160
- "video",
161
- "audio",
162
- "recording",
163
- ]
164
- return any(indicator in question.lower() for indicator in file_indicators)
165
-
166
-
167
  # --------------------------------------------------------------------------- #
168
  # ------------------------------- AGENT STATE ----------------------------- #
169
  # --------------------------------------------------------------------------- #
@@ -191,8 +195,7 @@ class GAIAAgent:
191
  3. For names, provide just the name(s)
192
  4. For yes/no questions, answer "Yes" or "No"
193
  5. Use the provided context carefully to find the exact answer
194
- 6. If you need to make calculations, show your work briefly
195
- 7. Be precise and factual
196
 
197
  Return ONLY the final answer."""
198
 
@@ -210,7 +213,7 @@ class GAIAAgent:
210
  self.llm = None
211
 
212
  # Following is defined for book-keeping purposes
213
- self.tools = [web_search, calculator]
214
 
215
  self.graph = self._build_graph()
216
 
@@ -243,12 +246,12 @@ class GAIAAgent:
243
  return state
244
 
245
  def _route(self, state: AgentState) -> AgentState:
246
- q = state["question"]
247
 
248
  # 1️⃣ Calculator path
249
- if _needs_calc(q):
250
  # 1) strip all whitespace
251
- expr = re.sub(r"\s+", "", q)
252
 
253
  # 2) remove ANY character that isn’t digit, dot, operator, or parenthesis
254
  # (kills “USD”, “kg”, YouTube IDs, etc.)
@@ -263,7 +266,7 @@ class GAIAAgent:
263
  return state
264
 
265
  # 2️⃣ Attachment (Excel file)
266
- if "attached" in q.lower() and "excel" in q.lower():
267
  try:
268
  task_id = state.get("task_id")
269
  file_url = f"{DEFAULT_API_URL}/files/{task_id}"
@@ -278,18 +281,32 @@ class GAIAAgent:
278
  except Exception as e:
279
  state["reasoning_steps"].append(f"xlsx_error:{e}")
280
 
281
- # 3️⃣ Web search path
282
- query = _extract_search_terms(q)
283
- results_json = web_search.invoke({"query": query})
 
 
 
 
 
 
 
 
 
284
 
285
  state["search_results"] = results_json
286
- state["tools_used"].append("web_search")
287
  state["reasoning_steps"].append(f"SEARCH: {query}")
288
  state["answer"] = ""
289
 
290
  return state
291
 
292
  def _process_info(self, state: AgentState) -> AgentState:
 
 
 
 
 
293
  if state["answer"]:
294
  # If calc already produced an answer, just pass through
295
  state["context"] = ""
@@ -321,7 +338,7 @@ class GAIAAgent:
321
  ),
322
  ]
323
  response = self.llm.invoke(prompt)
324
- print(f">>> Raw response from LLM:\n{response}\n\n")
325
  state["answer"] = response.content.strip()
326
  state["reasoning_steps"].append("GENERATE ANSWER (LLM)")
327
  return state
@@ -372,6 +389,7 @@ class GAIAAgent:
372
 
373
  answer = final_state["answer"]
374
  print(f"Agent reasoning: {' ==> '.join(final_state['reasoning_steps'])}")
 
375
  print(f"Tools used: {final_state['tools_used']}")
376
  print(f"Final answer: {answer}")
377
 
 
6
  from functools import lru_cache
7
  from io import BytesIO
8
  from typing import TypedDict
9
+ from urllib import parse
10
 
11
  import gradio as gr
12
  import pandas as pd
 
16
  from langchain_core.messages import HumanMessage, SystemMessage
17
  from langchain_openai import ChatOpenAI
18
  from langgraph.graph import END, StateGraph
19
+ from wikipedia import summary as wiki_summary
20
+ from youtube_transcript_api import YouTubeTranscriptApi
21
 
22
  # --- Constants ---
23
  DEFAULT_API_URL: str = "https://agents-course-unit4-scoring.hf.space"
 
85
 
86
 
87
  @tool
88
+ def web_multi_search(query: str, k: int = 5) -> str:
89
+ """
90
+ Multi-backend search. JSON list of {title,snippet,link}.
91
+ Order: DuckDuckGo → Wikipedia → Google Lite JSON.
92
+ """
93
  try:
94
+ hits = _search_duckduckgo(query, k)
95
+ if hits:
96
+ return json.dumps(hits, ensure_ascii=False)
97
+ except Exception:
98
+ pass
99
+
100
+ # Fallback 2: Wikipedia single-article summary
101
+ try:
102
+ page = wiki_summary(query, sentences=2, auto_suggest=False)
103
+ return json.dumps([{"title": "Wikipedia", "snippet": page, "link": ""}])
104
+ except Exception:
105
+ pass
106
+
107
+ # Fallback 3: simple Google (no key) – tiny quota but better than nothing
108
+ try:
109
+ url = "https://r.jina.ai/http://api.allorigins.win/raw?url=" + parse.quote(
110
+ "https://lite.duckduckgo.com/lite/?q=" + query
111
+ )
112
+ txt = requests.get(url, timeout=10).text[:600]
113
+ return json.dumps(
114
+ [{"title": "Google-lite", "snippet": re.sub(r"<.*?>", "", txt), "link": ""}]
115
+ )
116
+ except Exception as e:
117
+ return f"search_error:{e}"
118
+
119
+
120
+ @tool
121
+ def youtube_transcript(url: str, num_first_chars: int = 10_000) -> str:
122
+ """Returns the YouTube transcript (first `num_first_chars` characters only)."""
123
+ video_id = re.search(r"v=([A-Za-z0-9_\-]{11})", url)
124
+ if not video_id:
125
+ return "yt_error: id"
126
+ try:
127
+ txt = " ".join(
128
+ [x["text"] for x in YouTubeTranscriptApi.get_transcript(video_id.group(1))]
129
+ )
130
+ return txt[:num_first_chars]
131
+ except Exception as e:
132
+ return f"yt_error: {e}"
133
 
134
 
135
  # --------------------------------------------------------------------------- #
 
142
 
143
 
144
  def _extract_search_terms(question: str) -> str:
145
+ key = re.findall(r"[A-Za-z0-9']+", question.lower())
146
+ phrase = " ".join(key)
147
+ # if we lost critical tokens (length diff > 40 %), fallback to full q
148
+ return phrase if len(phrase) > 0.6 * len(question) else question
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  def _summarize_results(results_json: str, max_hits: int = 3) -> str:
152
  """Turn JSON list of hits into a compact text context for the LLM."""
153
+ if not results_json or not results_json.lstrip().startswith("["):
154
+ # Not JSON or empty → return raw text
155
+ return results_json
156
+
157
  try:
158
  hits = json.loads(results_json)[:max_hits]
159
  context_parts = []
 
168
  return ""
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # --------------------------------------------------------------------------- #
172
  # ------------------------------- AGENT STATE ----------------------------- #
173
  # --------------------------------------------------------------------------- #
 
195
  3. For names, provide just the name(s)
196
  4. For yes/no questions, answer "Yes" or "No"
197
  5. Use the provided context carefully to find the exact answer
198
+ 6. Be precise and factual
 
199
 
200
  Return ONLY the final answer."""
201
 
 
213
  self.llm = None
214
 
215
  # Following is defined for book-keeping purposes
216
+ self.tools = [web_multi_search, calculator, youtube_transcript]
217
 
218
  self.graph = self._build_graph()
219
 
 
246
  return state
247
 
248
  def _route(self, state: AgentState) -> AgentState:
249
+ question = state["question"]
250
 
251
  # 1️⃣ Calculator path
252
+ if _needs_calc(question):
253
  # 1) strip all whitespace
254
+ expr = re.sub(r"\s+", "", question)
255
 
256
  # 2) remove ANY character that isn’t digit, dot, operator, or parenthesis
257
  # (kills “USD”, “kg”, YouTube IDs, etc.)
 
266
  return state
267
 
268
  # 2️⃣ Attachment (Excel file)
269
+ if "attached" in question.lower() and "excel" in question.lower():
270
  try:
271
  task_id = state.get("task_id")
272
  file_url = f"{DEFAULT_API_URL}/files/{task_id}"
 
281
  except Exception as e:
282
  state["reasoning_steps"].append(f"xlsx_error:{e}")
283
 
284
+ # 3️⃣ YouTube search path
285
+ youtube_url = re.search(r"https?://www\.youtube\.com/\S+", question)
286
+ if youtube_url:
287
+ transcript = youtube_transcript.invoke({"url": youtube_url.group(0)})
288
+ state["context"] = transcript
289
+ state["tools_used"].append("youtube_transcript")
290
+ state["reasoning_steps"].append("YouTube")
291
+ return state
292
+
293
+ # 4️⃣ Web search path
294
+ query = _extract_search_terms(question)
295
+ results_json = web_multi_search.invoke({"query": query})
296
 
297
  state["search_results"] = results_json
298
+ state["tools_used"].append("web_multi_search")
299
  state["reasoning_steps"].append(f"SEARCH: {query}")
300
  state["answer"] = ""
301
 
302
  return state
303
 
304
  def _process_info(self, state: AgentState) -> AgentState:
305
+ if state["context"]:
306
+ # ✅ If context already populated (e.g. YouTube transcript), keep it.
307
+ state["reasoning_steps"].append("PROCESS(skip)")
308
+ return state
309
+
310
  if state["answer"]:
311
  # If calc already produced an answer, just pass through
312
  state["context"] = ""
 
338
  ),
339
  ]
340
  response = self.llm.invoke(prompt)
341
+ print(f">>> Raw response from LLM:\n{response}\n")
342
  state["answer"] = response.content.strip()
343
  state["reasoning_steps"].append("GENERATE ANSWER (LLM)")
344
  return state
 
389
 
390
  answer = final_state["answer"]
391
  print(f"Agent reasoning: {' ==> '.join(final_state['reasoning_steps'])}")
392
+ print(f"Agent's context {final_state['context']}")
393
  print(f"Tools used: {final_state['tools_used']}")
394
  print(f"Final answer: {answer}")
395