Kai Jennissen commited on
Commit
8102d4b
·
unverified ·
1 Parent(s): 1c0a810

added tools

Browse files
Files changed (5) hide show
  1. agent.py +120 -20
  2. app.py +20 -5
  3. requirements.in +3 -0
  4. requirements.txt +6 -0
  5. tools.py +672 -0
agent.py CHANGED
@@ -3,10 +3,21 @@ from smolagents import (
3
  CodeAgent,
4
  DuckDuckGoSearchTool,
5
  VisitWebpageTool,
6
- InferenceClientModel,
 
 
7
  )
8
  from dotenv import load_dotenv
9
  from tracing import setup_tracing
 
 
 
 
 
 
 
 
 
10
 
11
  load_dotenv()
12
 
@@ -22,6 +33,59 @@ If you are asked for a string, don't use articles, neither abbreviations (e.g. f
22
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
23
  Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def initialize_tracing(enabled=True, provider="langfuse"):
27
  """
@@ -45,39 +109,75 @@ def get_agent():
45
 
46
  # SmolagentsInstrumentor will automatically trace agent operations
47
 
48
- llm_qwen = InferenceClientModel(
49
- model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together"
50
- )
51
- llm_deepseek = InferenceClientModel(
52
- "deepseek-ai/DeepSeek-R1",
53
- provider="together",
54
- max_tokens=8096,
55
- # "Qwen/Qwen3-235B-A22B-FP8",
56
- # provider="together",
57
- # max_tokens=8096,
58
- )
59
 
60
  # Create web agent
61
  web_agent = ToolCallingAgent(
62
- tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
63
- model=llm_qwen,
 
 
 
 
64
  max_steps=3,
65
  name="Web_Agent",
66
  description="A web agent that can search the web and visit webpages.",
67
  verbosity_level=1,
68
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # Create manager agent
71
  manager_agent = CodeAgent(
72
- tools=[],
73
- managed_agents=[web_agent],
74
- model=llm_deepseek,
75
  max_steps=5,
76
  planning_interval=10,
77
  additional_authorized_imports=["pandas", "numpy"],
78
  verbosity_level=1,
79
- description=MANAGER_PROMPT,
80
  )
 
 
81
  return manager_agent
82
 
83
 
@@ -88,11 +188,11 @@ if __name__ == "__main__":
88
 
89
  # Get agent with tracing already configured
90
  agent = get_agent()
91
-
92
  # Run agent - SmolagentsInstrumentor will automatically trace the execution
93
  print("Running agent with tracing enabled...")
94
  result = agent.run(
95
- "What is the latest news about AI? Please search the web and summarize the results."
96
  )
97
  print(f"Result: {result}")
98
  print(
 
3
  CodeAgent,
4
  DuckDuckGoSearchTool,
5
  VisitWebpageTool,
6
+ # InferenceClientModel,
7
+ OpenAIServerModel,
8
+ WikipediaSearchTool,
9
  )
10
  from dotenv import load_dotenv
11
  from tracing import setup_tracing
12
+ from tools import (
13
+ read_image,
14
+ transcribe_audio,
15
+ run_video,
16
+ read_code,
17
+ fetch_task_files,
18
+ )
19
+
20
+ # from tools import go_back, close_popups, search_item_ctrl_f, save_screenshot
21
 
22
  load_dotenv()
23
 
 
33
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
34
  Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """
35
 
36
+ helium_instructions = """
37
+ You can use helium to access websites. Don't bother about the helium driver, it's already managed.
38
+ We've already ran "from helium import *"
39
+ Then you can go to pages!
40
+ Code:
41
+ ```py
42
+ go_to('github.com/trending')
43
+ ```<end_code>
44
+
45
+ You can directly click clickable elements by inputting the text that appears on them.
46
+ Code:
47
+ ```py
48
+ click("Top products")
49
+ ```<end_code>
50
+
51
+ If it's a link:
52
+ Code:
53
+ ```py
54
+ click(Link("Top products"))
55
+ ```<end_code>
56
+
57
+ If you try to interact with an element and it's not found, you'll get a LookupError.
58
+ In general stop your action after each button click to see what happens on your screenshot.
59
+ Never try to login in a page.
60
+
61
+ To scroll up or down, use scroll_down or scroll_up with as an argument the number of pixels to scroll from.
62
+ Code:
63
+ ```py
64
+ scroll_down(num_pixels=1200) # This will scroll one viewport down
65
+ ```<end_code>
66
+
67
+ When you have pop-ups with a cross icon to close, don't try to click the close icon by finding its element or targeting an 'X' element (this most often fails).
68
+ Just use your built-in tool `close_popups` to close them:
69
+ Code:
70
+ ```py
71
+ close_popups()
72
+ ```<end_code>
73
+
74
+ You can use .exists() to check for the existence of an element. For example:
75
+ Code:
76
+ ```py
77
+ if Text('Accept cookies?').exists():
78
+ click('I accept')
79
+ ```<end_code>
80
+ """
81
+
82
+ add_sys_prompt = """\n\nIf a file_url is available or an url is given in question statement, then request and use the content to answer the question. \
83
+ If a code file, such as .py file, is given, do not attempt to execute it but rather open it as a text file and analyze the content. \
84
+ When a tabluar file, such as csv, tsv, xlsx, is given, read it using pandas.
85
+
86
+ Make sure you provide the answer in accordance with the instruction provided in the question. Do not return the result of tool as a final_answer.
87
+ Do Not add any additional information, explanation, unnecessary words or symbols. The answer is likely as simple as one word."""
88
+
89
 
90
  def initialize_tracing(enabled=True, provider="langfuse"):
91
  """
 
109
 
110
  # SmolagentsInstrumentor will automatically trace agent operations
111
 
112
+ # llm_qwen = InferenceClientModel(
113
+ # model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together"
114
+ # )
115
+ # llm_deepseek = InferenceClientModel(
116
+ # "deepseek-ai/DeepSeek-R1",
117
+ # provider="together",
118
+ # max_tokens=8096,
119
+ # # "Qwen/Qwen3-235B-A22B-FP8",
120
+ # # provider="together",
121
+ # # max_tokens=8096,
122
+ # )
123
 
124
  # Create web agent
125
  web_agent = ToolCallingAgent(
126
+ tools=[
127
+ DuckDuckGoSearchTool(),
128
+ VisitWebpageTool(),
129
+ WikipediaSearchTool(),
130
+ ],
131
+ model=OpenAIServerModel(model_id="gpt-4.1", temperature=0.1),
132
  max_steps=3,
133
  name="Web_Agent",
134
  description="A web agent that can search the web and visit webpages.",
135
  verbosity_level=1,
136
  )
137
+ mm_agent = CodeAgent(
138
+ tools=[
139
+ read_image,
140
+ transcribe_audio,
141
+ read_code,
142
+ run_video,
143
+ ],
144
+ model=OpenAIServerModel(model_id="gpt-4.1", temperature=0.1),
145
+ max_steps=3,
146
+ name="Multimedia_Agent",
147
+ description="An agent that can answer questions about all types of images, videos and speech. Needs to be provided with a valid url or an image.",
148
+ verbosity_level=1,
149
+ )
150
 
151
+ # Initialize the model
152
+ # vlm = InferenceClientModel(model_id="Qwen/Qwen2.5-Vision-32B", provider="together")
153
+
154
+ # # Create the agent
155
+ # vision_agent = CodeAgent(
156
+ # tools=[go_back, close_popups, search_item_ctrl_f],
157
+ # model=vlm,
158
+ # additional_authorized_imports=["helium", "selenium"],
159
+ # step_callbacks=[save_screenshot],
160
+ # max_steps=10,
161
+ # planning_interval=10,
162
+ # verbosity_level=1,
163
+ # name="Vision_Agent",
164
+ # description="A vision agent that can interact with webpages and take screenshots.",
165
+ # )
166
+ # vision_agent.prompt_templates["system_prompt"] += helium_instructions
167
+
168
+ # Import helium for the agent
169
  # Create manager agent
170
  manager_agent = CodeAgent(
171
+ tools=[fetch_task_files],
172
+ managed_agents=[web_agent, mm_agent],
173
+ model=OpenAIServerModel(model_id="gpt-4.1", temperature=0.1),
174
  max_steps=5,
175
  planning_interval=10,
176
  additional_authorized_imports=["pandas", "numpy"],
177
  verbosity_level=1,
 
178
  )
179
+
180
+ manager_agent.prompt_templates["system_prompt"] += add_sys_prompt
181
  return manager_agent
182
 
183
 
 
188
 
189
  # Get agent with tracing already configured
190
  agent = get_agent()
191
+ agent.visualize()
192
  # Run agent - SmolagentsInstrumentor will automatically trace the execution
193
  print("Running agent with tracing enabled...")
194
  result = agent.run(
195
+ "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
196
  )
197
  print(f"Result: {result}")
198
  print(
app.py CHANGED
@@ -24,10 +24,24 @@ class BasicAgent:
24
  self.agent = get_agent()
25
  print("BasicAgent initialized.")
26
 
27
- def __call__(self, question: str) -> str:
28
  print(f"Agent received question (first 50 chars): {question[:50]}...")
29
- answer = self.agent.run(question)
30
- print(f"Agent returning fixed answer: {answer}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return answer
32
 
33
 
@@ -93,14 +107,15 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
93
  results_log = []
94
  answers_payload = []
95
  print(f"Running agent on {len(questions_data)} questions...")
96
- for item in questions_data[:1]:
97
  task_id = item.get("task_id")
98
  question_text = item.get("question")
99
  if not task_id or question_text is None:
100
  print(f"Skipping item with missing task_id or question: {item}")
101
  continue
102
  try:
103
- submitted_answer = agent(question_text)
 
104
  answers_payload.append(
105
  {"task_id": task_id, "submitted_answer": submitted_answer}
106
  )
 
24
  self.agent = get_agent()
25
  print("BasicAgent initialized.")
26
 
27
+ def __call__(self, question: str, task_id: str = None) -> str:
28
  print(f"Agent received question (first 50 chars): {question[:50]}...")
29
+
30
+ # If task_id is provided, we'll include context about possible files
31
+ if task_id:
32
+ # Add context about files to the question
33
+ context = f"""Task ID: {task_id}
34
+
35
+ If you need files for this task, you can use the fetch_task_files tool with the task_id.
36
+ Example: fetch_task_files(task_id="{task_id}")
37
+
38
+ Question: {question}"""
39
+
40
+ answer = self.agent.run(context)
41
+ else:
42
+ answer = self.agent.run(question)
43
+
44
+ print(f"Agent returning answer: {answer}")
45
  return answer
46
 
47
 
 
107
  results_log = []
108
  answers_payload = []
109
  print(f"Running agent on {len(questions_data)} questions...")
110
+ for item in questions_data[3:4]:
111
  task_id = item.get("task_id")
112
  question_text = item.get("question")
113
  if not task_id or question_text is None:
114
  print(f"Skipping item with missing task_id or question: {item}")
115
  continue
116
  try:
117
+ # Pass both question text and task_id to the agent
118
+ submitted_answer = agent(question_text, task_id)
119
  answers_payload.append(
120
  {"task_id": task_id, "submitted_answer": submitted_answer}
121
  )
requirements.in CHANGED
@@ -1,5 +1,8 @@
 
1
  duckduckgo_search>=7.0.0,<8.0.0
2
  gradio[oauth]
 
3
  requests
4
  smolagents[gradio,litellm,openai,telemetry,toolkit,torch,transformers,vision]
5
  wikipedia-api
 
 
1
+ av
2
  duckduckgo_search>=7.0.0,<8.0.0
3
  gradio[oauth]
4
+ pytube
5
  requests
6
  smolagents[gradio,litellm,openai,telemetry,toolkit,torch,transformers,vision]
7
  wikipedia-api
8
+ yt-dlp
requirements.txt CHANGED
@@ -47,6 +47,8 @@ authlib==1.5.2
47
  # via
48
  # arize-phoenix
49
  # gradio
 
 
50
  beautifulsoup4==4.13.4
51
  # via markdownify
52
  cachetools==5.5.2
@@ -353,6 +355,8 @@ python-multipart==0.0.20
353
  # via
354
  # arize-phoenix
355
  # gradio
 
 
356
  pytz==2025.2
357
  # via pandas
358
  pyyaml==6.0.2
@@ -526,5 +530,7 @@ wsproto==1.2.0
526
  # via trio-websocket
527
  yarl==1.20.0
528
  # via aiohttp
 
 
529
  zipp==3.21.0
530
  # via importlib-metadata
 
47
  # via
48
  # arize-phoenix
49
  # gradio
50
+ av==14.3.0
51
+ # via -r requirements.in
52
  beautifulsoup4==4.13.4
53
  # via markdownify
54
  cachetools==5.5.2
 
355
  # via
356
  # arize-phoenix
357
  # gradio
358
+ pytube==15.0.0
359
+ # via -r requirements.in
360
  pytz==2025.2
361
  # via pandas
362
  pyyaml==6.0.2
 
530
  # via trio-websocket
531
  yarl==1.20.0
532
  # via aiohttp
533
+ yt-dlp==2025.4.30
534
+ # via -r requirements.in
535
  zipp==3.21.0
536
  # via importlib-metadata
tools.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import io
3
+ import base64
4
+ import openai
5
+ from openai import OpenAI
6
+ from smolagents import tool
7
+ import os
8
+ import pandas as pd
9
+ import functools
10
+ from typing import List, Optional, Dict, Any
11
+ import sys
12
+
13
+ import av
14
+ from yt_dlp import YoutubeDL
15
+
16
+ from PIL import Image
17
+ import wikipediaapi
18
+ import tempfile
19
+
20
+ model_id = "gpt-4.1"
21
+
22
+
23
+ @tool
24
+ def read_image(query: str, img_url: str) -> str:
25
+ """
26
+ Use a visual question answering (VQA) model to generate a response to a query based on an image.
27
+
28
+ Args:
29
+ query (str): A natural language question about the image.
30
+ img_url (str): The URL of the image to analyze.
31
+
32
+ Returns:
33
+ str: A response generated by the VQA model based on the provided image and question.
34
+ """
35
+ client = OpenAI()
36
+ response = client.responses.create(
37
+ model=model_id,
38
+ input=[
39
+ {
40
+ "role": "user",
41
+ "content": [
42
+ {"type": "input_text", "text": query},
43
+ {
44
+ "type": "input_image",
45
+ "image_url": img_url,
46
+ },
47
+ ],
48
+ }
49
+ ],
50
+ )
51
+ return response.output_text
52
+
53
+
54
+ @tool
55
+ def read_code(file_url: str) -> str:
56
+ """
57
+ Read the contents of a code file such as py file instead of executing it. Use this tool to analyze a code snippet.
58
+
59
+ Args:
60
+ file_url (str): The URL of the code file to retrieve.
61
+
62
+ Returns:
63
+ str: The content of the file as a string.
64
+ """
65
+ response = requests.get(file_url)
66
+ response.raise_for_status()
67
+ return response.text
68
+
69
+
70
+ @tool
71
+ def transcribe_audio(file_url: str, file_name: str) -> str:
72
+ """
73
+ Download and transcribe an audio file using transcription model.
74
+
75
+ Args:
76
+ file_url (str): Direct URL to the audio file (e.g., .mp3, .wav).
77
+ file_name (str): Filename including extension, used to determine format.
78
+
79
+ Returns:
80
+ str: The transcribed text from the audio file.
81
+ """
82
+ # Download audio content
83
+ response = requests.get(file_url)
84
+ response.raise_for_status()
85
+
86
+ # Extract extension (fallback to mp3 if missing)
87
+ extension = file_name.split(".")[-1].lower() or "mp3"
88
+
89
+ # Wrap bytes in a file-like object with a valid name
90
+ audio_file = io.BytesIO(response.content)
91
+ audio_file.name = f"audio.{extension}"
92
+
93
+ # Create OpenAI client and transcribe
94
+ client = OpenAI()
95
+ transcription = client.audio.transcriptions.create(
96
+ model="gpt-4o-transcribe", file=audio_file
97
+ )
98
+
99
+ return transcription.text
100
+
101
+
102
+ ### set of functions for youtube video processing
103
+ def _pytube_buffer(url: str) -> Optional[io.BytesIO]:
104
+ try:
105
+ from pytube import YouTube
106
+
107
+ yt = YouTube(url)
108
+ stream = (
109
+ yt.streams.filter(progressive=True, file_extension="mp4")
110
+ .order_by("resolution")
111
+ .desc()
112
+ .first()
113
+ )
114
+ if stream is None: # no progressive stream
115
+ raise RuntimeError("No MP4 with audio found")
116
+ buf = io.BytesIO()
117
+ stream.stream_to_buffer(buf) # PyTube’s built-in helper
118
+ buf.seek(0)
119
+ return buf
120
+ except Exception as e:
121
+ print(f"[youtube_to_buffer] PyTube failed → {e}", file=sys.stderr)
122
+ return None # trigger fallback
123
+
124
+
125
+ def _ytdlp_buffer(url: str) -> io.BytesIO:
126
+ """
127
+ Return a BytesIO containing some MP4 video stream for `url`.
128
+ Works whether YouTube serves a progressive file or separate A/V.
129
+ """
130
+ ydl_opts = {
131
+ "quiet": True,
132
+ "skip_download": True,
133
+ "format": "bestvideo[ext=mp4]/best[ext=mp4]/best",
134
+ }
135
+ with YoutubeDL(ydl_opts) as ydl:
136
+ info = ydl.extract_info(url, download=False)
137
+ if "entries" in info: # playlists
138
+ info = info["entries"][0]
139
+
140
+ if "url" in info:
141
+ video_urls = [info["url"]]
142
+
143
+ elif "requested_formats" in info:
144
+ video_urls = [
145
+ fmt["url"]
146
+ for fmt in info["requested_formats"]
147
+ if fmt.get("vcodec") != "none" # keep only video
148
+ ]
149
+ if not video_urls:
150
+ raise RuntimeError("yt-dlp returned audio-only formats")
151
+
152
+ else:
153
+ raise RuntimeError("yt-dlp could not extract a stream URL")
154
+
155
+ buf = io.BytesIO()
156
+ for direct_url in video_urls:
157
+ with requests.get(direct_url, stream=True) as r:
158
+ r.raise_for_status()
159
+ for chunk in r.iter_content(chunk_size=1 << 16):
160
+ buf.write(chunk)
161
+
162
+ buf.seek(0)
163
+ return buf
164
+
165
+
166
+ @functools.lru_cache(maxsize=8) # tiny cache so repeat calls are fast
167
+ def youtube_to_buffer(url: str) -> io.BytesIO:
168
+ """
169
+
170
+ Return a BytesIO containing a single progressive MP4
171
+ (H.264 + AAC) – the safest thing PyAV can open everywhere.
172
+ """
173
+ ydl_opts = {
174
+ "quiet": True,
175
+ "skip_download": True,
176
+ # progressive (has both audio+video) • mp4 • h264
177
+ "format": (
178
+ "best[ext=mp4][vcodec^=avc1][acodec!=none]"
179
+ "/best[ext=mp4][acodec!=none]" # fallback: any prog-MP4
180
+ ),
181
+ }
182
+
183
+ with YoutubeDL(ydl_opts) as ydl:
184
+ info = ydl.extract_info(url, download=False)
185
+ if "entries" in info: # playlists → first entry
186
+ info = info["entries"][0]
187
+
188
+ direct_url = info.get("url")
189
+ if not direct_url:
190
+ raise RuntimeError("yt-dlp could not find a progressive MP4 track")
191
+
192
+ # Stream it straight into RAM
193
+ buf = io.BytesIO()
194
+ with requests.get(direct_url, stream=True) as r:
195
+ r.raise_for_status()
196
+ for chunk in r.iter_content(chunk_size=1 << 17): # 128 kB
197
+ buf.write(chunk)
198
+
199
+ buf.seek(0)
200
+ return buf
201
+
202
+
203
+ def sample_frames(video_bytes: io.BytesIO, n_frames: int = 6) -> List[Image.Image]:
204
+ """Decode `n_frames` uniformly spaced RGB frames as PIL images."""
205
+ container = av.open(video_bytes, metadata_errors="ignore")
206
+ video = container.streams.video[0]
207
+ total = video.frames or 0
208
+
209
+ # If PyAV couldn't count frames (‐1), fall back to timestamp spacing
210
+ step = max(1, total // n_frames) if total else 30
211
+
212
+ frames: list[Image.Image] = []
213
+ for i, frame in enumerate(container.decode(video=0)):
214
+ if i % step == 0:
215
+ frames.append(frame.to_image())
216
+ if len(frames) >= n_frames:
217
+ break
218
+ container.close()
219
+ return frames
220
+
221
+
222
+ def pil_to_data_url(img: Image.Image, quality: int = 80) -> str:
223
+ buf = io.BytesIO()
224
+ img.save(buf, format="JPEG", quality=quality, optimize=True)
225
+ b64 = base64.b64encode(buf.getvalue()).decode()
226
+ return f"data:image/jpeg;base64,{b64}"
227
+
228
+
229
+ def save_audio_stream_to_temp_wav_file(video_bytes: io.BytesIO) -> Optional[str]:
230
+ """
231
+ Extracts the audio stream from video_bytes, saves it as a temporary WAV file,
232
+ and returns the path to the file.
233
+ Returns None if no audio stream is found or an error occurs.
234
+ """
235
+ try:
236
+ video_bytes.seek(0) # Ensure buffer is at the beginning
237
+ input_container = av.open(video_bytes, metadata_errors="ignore")
238
+
239
+ if not input_container.streams.audio:
240
+ print("No audio streams found in the video.", file=sys.stderr)
241
+ return None
242
+ input_audio_stream = input_container.streams.audio[0]
243
+
244
+ # Create a temporary file with .wav suffix
245
+ # delete=False because we need to pass the path to another process (Whisper)
246
+ # and we will manually delete it later.
247
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
248
+ temp_audio_file_path = tmp_file.name
249
+
250
+ output_container = av.open(temp_audio_file_path, mode="w", format="wav")
251
+
252
+ # For WAV, a common codec is pcm_s16le (16-bit signed PCM).
253
+ # Use the input stream's sample rate.
254
+ # Determine channel layout (e.g., 'stereo', 'mono')
255
+ channel_layout = "stereo" # Default
256
+ if (
257
+ hasattr(input_audio_stream.codec_context, "layout")
258
+ and input_audio_stream.codec_context.layout
259
+ ):
260
+ channel_layout = input_audio_stream.codec_context.layout.name
261
+ elif (
262
+ hasattr(input_audio_stream.codec_context, "channels")
263
+ and input_audio_stream.codec_context.channels == 1
264
+ ):
265
+ channel_layout = "mono"
266
+
267
+ output_audio_stream = output_container.add_stream(
268
+ "pcm_s16le",
269
+ rate=input_audio_stream.codec_context.sample_rate,
270
+ layout=channel_layout,
271
+ )
272
+
273
+ for frame in input_container.decode(input_audio_stream):
274
+ # PyAV decodes audio into AudioFrame objects.
275
+ # These frames need to be encoded by the output stream's codec.
276
+ for packet in output_audio_stream.encode(frame):
277
+ output_container.mux(packet)
278
+
279
+ # Flush any remaining frames from the encoder
280
+ for packet in output_audio_stream.encode():
281
+ output_container.mux(packet)
282
+
283
+ output_container.close()
284
+ input_container.close()
285
+ return temp_audio_file_path
286
+
287
+ except Exception as e:
288
+ print(f"Error extracting audio to temp WAV file: {e}", file=sys.stderr)
289
+ # Clean up if temp file path was assigned and file exists
290
+ if "temp_audio_file_path" in locals() and os.path.exists(temp_audio_file_path):
291
+ os.remove(temp_audio_file_path)
292
+ return None
293
+
294
+
295
+ @tool
296
+ def run_video(query: str, url: str) -> str:
297
+ """
298
+ Get a YouTube video from url and return an answer to a natural-language query using the video.
299
+
300
+ Args:
301
+ query (str): A natural-language question whose answer is expected to be found in the visual content of the video.
302
+ url (str): Fully qualified URL of the YouTube video to analyze.
303
+
304
+ Returns:
305
+ str: A response generated by the VQA model based on the provided video and question.
306
+ """
307
+ n_frames = 4
308
+ buff = youtube_to_buffer(url)
309
+ if buff is None:
310
+ return "Error: Could not download or buffer the video."
311
+
312
+ # 1. Sample visual frames
313
+ frames = sample_frames(buff, n_frames=n_frames)
314
+ buff.seek(0) # Reset buffer pointer for audio extraction
315
+
316
+ # 2. Extract and Transcribe Audio
317
+ transcript = "[Audio could not be processed]"
318
+ audio_file_path = None
319
+ try:
320
+ audio_file_path = save_audio_stream_to_temp_wav_file(buff)
321
+ if audio_file_path:
322
+ with open(audio_file_path, "rb") as audio_data:
323
+ # Make sure you have the OpenAI client initialized, e.g., client = openai.OpenAI()
324
+ transcription_response = openai.audio.transcriptions.create(
325
+ model="gpt-4o-transcribe", file=audio_data
326
+ )
327
+ transcript = transcription_response.text
328
+ else:
329
+ transcript = "[No audio stream found or error during extraction]"
330
+ print(
331
+ "No audio file path returned, skipping transcription.", file=sys.stderr
332
+ )
333
+ except Exception as e:
334
+ print(f"Error during audio transcription: {e}", file=sys.stderr)
335
+ transcript = f"[Error during audio transcription: {e}]"
336
+ finally:
337
+ if audio_file_path and os.path.exists(audio_file_path):
338
+ os.remove(audio_file_path) # Clean up the temporary audio file
339
+
340
+ # 3. Prepare content for the AI model (text query, transcript, and images)
341
+ prompt_text = f"Original Query: {query}\n\nVideo Transcript:\n{transcript}\n\nKey Visual Frames (analyze these along with the transcript to answer the query):"
342
+
343
+ content = [{"type": "text", "text": prompt_text}]
344
+
345
+ for img in frames:
346
+ content.append(
347
+ {
348
+ "type": "image_url",
349
+ "image_url": {"url": pil_to_data_url(img)},
350
+ }
351
+ )
352
+
353
+ # 4. Send to AI model
354
+ try:
355
+ resp = openai.chat.completions.create(
356
+ model=model_id,
357
+ messages=[{"role": "user", "content": content}],
358
+ temperature=0.1,
359
+ )
360
+ result = resp.choices[0].message.content.strip()
361
+ except Exception as e:
362
+ print(f"Error calling OpenAI API: {e}", file=sys.stderr)
363
+ result = f"[Error processing with AI model: {e}]"
364
+
365
+ return result
366
+
367
+
368
+ ## Read video only, ignore audio
369
+ # @tool
370
+ # def run_video(query: str, url: str) -> str:
371
+ # """
372
+ # Get a YouTube video from url and return an answer to a natural-language query using the video.
373
+
374
+ # Args:
375
+ # query (str): A natural-language question whose answer is expected to be found in the visual content of the video.
376
+ # url (str): Fully qualified URL of the YouTube video to analyze.
377
+
378
+ # Returns:
379
+ # str: A response generated by the VQA model based on the provided video and question.
380
+ # """
381
+ # buff = youtube_to_buffer(url)
382
+ # n_frames = 8
383
+ # frames = sample_frames(buff, n_frames=n_frames)
384
+
385
+ # content = [{"type": "text", "text": query}] + [
386
+ # {
387
+ # "type": "image_url",
388
+ # "image_url": {"url": pil_to_data_url(img)},
389
+ # }
390
+ # for img in frames
391
+ # ]
392
+
393
+ # resp = openai.chat.completions.create(
394
+ # model="gpt-4.1-mini",
395
+ # messages=[{"role": "user", "content": content}],
396
+ # temperature=0.1,
397
+ # )
398
+ # return resp.choices[0].message.content.strip()
399
+
400
+
401
+ # Helper functions for processing different file types
402
+ def process_image(response, filename, content_type):
403
+ """Process image files - convert to base64 data URL for vision models"""
404
+ img_data = base64.b64encode(response.content).decode("utf-8")
405
+ data_url = f"data:{content_type};base64,{img_data}"
406
+
407
+ return {
408
+ "file_type": "image",
409
+ "filename": filename,
410
+ "content_type": content_type,
411
+ "data_url": data_url,
412
+ }
413
+
414
+
415
+ def process_audio(response, filename, content_type):
416
+ """Process audio files - either return data URL or save to temp file for processing"""
417
+ audio_data = base64.b64encode(response.content).decode("utf-8")
418
+ data_url = f"data:{content_type};base64,{audio_data}"
419
+
420
+ # For compatibility with audio processing tools, save to temp file
421
+ audio_file = io.BytesIO(response.content)
422
+ extension = os.path.splitext(filename)[1].lower() or ".mp3"
423
+ audio_file.name = f"audio{extension}" # Some libraries need filename
424
+
425
+ return {
426
+ "file_type": "audio",
427
+ "filename": filename,
428
+ "content_type": content_type,
429
+ "data_url": data_url,
430
+ "audio_buffer": audio_file, # Include buffer for processing
431
+ }
432
+
433
+
434
+ def process_video(response, filename, content_type):
435
+ """Process video files - save to buffer and extract frames"""
436
+ video_buffer = io.BytesIO(response.content)
437
+
438
+ # Option to extract frames - similar to what run_video does
439
+ try:
440
+ frames = sample_frames(video_buffer, n_frames=4) # Reuse existing function
441
+ frame_urls = [pil_to_data_url(img) for img in frames]
442
+ frame_extraction_success = True
443
+ except Exception:
444
+ frame_urls = []
445
+ frame_extraction_success = False
446
+
447
+ return {
448
+ "file_type": "video",
449
+ "filename": filename,
450
+ "content_type": content_type,
451
+ "video_buffer": video_buffer,
452
+ "frame_urls": frame_urls,
453
+ "frames_extracted": frame_extraction_success,
454
+ }
455
+
456
+
457
+ def process_tabular(response, filename, content_type):
458
+ """Process spreadsheet files using pandas"""
459
+ excel_buffer = io.BytesIO(response.content)
460
+
461
+ try:
462
+ # Determine format based on extension
463
+ if filename.lower().endswith(".csv"):
464
+ df = pd.read_csv(excel_buffer)
465
+ else: # Excel formats
466
+ df = pd.read_excel(excel_buffer)
467
+
468
+ return {
469
+ "file_type": "tabular",
470
+ "filename": filename,
471
+ "content_type": content_type,
472
+ "data": df.to_dict(orient="records"),
473
+ "columns": df.columns.tolist(),
474
+ "shape": df.shape,
475
+ }
476
+ except Exception as e:
477
+ # Fallback if parsing fails
478
+ return {
479
+ "file_type": "tabular",
480
+ "filename": filename,
481
+ "content_type": content_type,
482
+ "error": f"Failed to parse tabular data: {e}",
483
+ "raw_data": base64.b64encode(response.content).decode("utf-8"),
484
+ }
485
+
486
+
487
+ def process_text(response, filename, content_type):
488
+ """Process text files (code, plain text, etc.)"""
489
+ try:
490
+ text_content = response.text
491
+ return {
492
+ "file_type": "text",
493
+ "filename": filename,
494
+ "content_type": content_type,
495
+ "content": text_content,
496
+ "extension": os.path.splitext(filename)[
497
+ 1
498
+ ], # Useful for syntax highlighting
499
+ }
500
+ except Exception as e:
501
+ return {
502
+ "file_type": "text",
503
+ "filename": filename,
504
+ "content_type": content_type,
505
+ "error": f"Failed to decode text: {e}",
506
+ "raw_data": base64.b64encode(response.content).decode("utf-8"),
507
+ }
508
+
509
+
510
+ def process_json(response, filename, content_type):
511
+ """Process JSON data"""
512
+ try:
513
+ json_data = response.json()
514
+ return {
515
+ "file_type": "json",
516
+ "filename": filename,
517
+ "content_type": content_type,
518
+ "data": json_data,
519
+ }
520
+ except Exception:
521
+ # Try as text if JSON parsing fails
522
+ return process_text(response, filename, content_type)
523
+
524
+
525
+ def process_pdf(response, filename, content_type):
526
+ """Process PDF files - return as binary with metadata"""
527
+ # Simple version - just return binary for now
528
+ # Could be enhanced with PDF text extraction libraries
529
+ pdf_data = base64.b64encode(response.content).decode("utf-8")
530
+
531
+ return {
532
+ "file_type": "pdf",
533
+ "filename": filename,
534
+ "content_type": content_type,
535
+ "data": pdf_data,
536
+ }
537
+
538
+
539
+ def process_binary(response, filename, content_type):
540
+ """Process other binary files (fallback handler)"""
541
+ binary_data = base64.b64encode(response.content).decode("utf-8")
542
+
543
+ return {
544
+ "file_type": "binary",
545
+ "filename": filename,
546
+ "content_type": content_type,
547
+ "data": binary_data,
548
+ }
549
+
550
+
551
+ @tool
552
+ def fetch_task_files(task_id: str) -> Dict[str, Any]:
553
+ """
554
+ Download files associated with a specific task from the API.
555
+
556
+ Args:
557
+ task_id (str): The Task-ID of the task to download files for.
558
+
559
+ Returns:
560
+ dict: A dictionary containing file information and data in appropriate format for the file type
561
+ """
562
+ api_base_url: str = "https://agents-course-unit4-scoring.hf.space"
563
+ files_url = f"{api_base_url}/files/{task_id}"
564
+
565
+ try:
566
+ response = requests.get(files_url, timeout=15)
567
+ response.raise_for_status()
568
+
569
+ # Extract metadata
570
+ content_type = response.headers.get("Content-Type", "").lower()
571
+ filename = response.headers.get("content-disposition", "")
572
+ if "filename=" in filename:
573
+ filename = filename.split("filename=")[-1].strip('"')
574
+ else:
575
+ filename = f"{task_id}.bin" # Default filename
576
+
577
+ print(f"Received file: {filename}, type: {content_type}")
578
+
579
+ # Route to appropriate helper based on content type or file extension
580
+ if "image/" in content_type or any(
581
+ filename.lower().endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".gif"]
582
+ ):
583
+ return process_image(response, filename, content_type)
584
+
585
+ elif "audio/" in content_type or any(
586
+ filename.lower().endswith(ext) for ext in [".mp3", ".wav", ".ogg"]
587
+ ):
588
+ return process_audio(response, filename, content_type)
589
+
590
+ elif "video/" in content_type or any(
591
+ filename.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov"]
592
+ ):
593
+ return process_video(response, filename, content_type)
594
+
595
+ elif (
596
+ "spreadsheet" in content_type
597
+ or "excel" in content_type
598
+ or any(filename.lower().endswith(ext) for ext in [".xlsx", ".xls", ".csv"])
599
+ ):
600
+ return process_tabular(response, filename, content_type)
601
+
602
+ elif (
603
+ "text/" in content_type
604
+ or "code" in content_type
605
+ or any(
606
+ filename.lower().endswith(ext)
607
+ for ext in [".txt", ".py", ".js", ".html", ".md"]
608
+ )
609
+ ):
610
+ return process_text(response, filename, content_type)
611
+
612
+ elif "application/json" in content_type or filename.lower().endswith(".json"):
613
+ return process_json(response, filename, content_type)
614
+
615
+ elif "application/pdf" in content_type or filename.lower().endswith(".pdf"):
616
+ return process_pdf(response, filename, content_type)
617
+
618
+ else:
619
+ # Default fallback for binary files
620
+ return process_binary(response, filename, content_type)
621
+
622
+ except requests.exceptions.RequestException as e:
623
+ print(f"Error fetching files for task {task_id}: {e}")
624
+ return {"error": f"Error fetching files: {e}"}
625
+ except Exception as e:
626
+ print(f"An unexpected error occurred fetching files for task {task_id}: {e}")
627
+ return {"error": f"An unexpected error occurred: {e}"}
628
+
629
+
630
+ @tool
631
+ def search_wikipedia(query: str) -> str:
632
+ """
633
+ get the contents of wikipedia page retrieved by search query.
634
+
635
+ Args:
636
+ query (str): A search term to search within wikipedia. Ideally it should be one word or a group of few words.
637
+
638
+ Returns:
639
+ str: The text content of wikipedia page
640
+ """
641
+ get_wiki = wikipediaapi.Wikipedia(
642
+ language="en",
643
+ user_agent="test_tokki",
644
+ extract_format=wikipediaapi.ExtractFormat.WIKI,
645
+ )
646
+ page_content = get_wiki.page(query)
647
+ text_content = page_content.text
648
+
649
+ cutoff = 25000
650
+ text_content = " ".join(text_content.split(" ")[:cutoff])
651
+ return text_content
652
+
653
+
654
+ if __name__ == "__main__":
655
+ # Simple test for fetch_task_files
656
+ task_ids = [
657
+ "cca530fc-4052-43b2-b130-b30968d8aa44",
658
+ "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
659
+ "7bd855d8-463d-4ed5-93ca-5fe35145f733",
660
+ ]
661
+ for task_id in task_ids:
662
+ print(
663
+ "=" * 20
664
+ + " "
665
+ + f"Testing fetch_task_files with task_id: {task_id}"
666
+ + " "
667
+ + "=" * 20
668
+ )
669
+
670
+ result = fetch_task_files(task_id)
671
+ print(f"File type: {result.get('file_type')}")
672
+ print(f"Filename: {result.get('filename')}")