Yongkang ZOU
commited on
Commit
·
8e78869
1
Parent(s):
a16bb01
update toolkit
Browse files
agent.py
CHANGED
@@ -29,7 +29,12 @@ import tempfile
|
|
29 |
import requests
|
30 |
from urllib.parse import urlparse
|
31 |
from typing import Optional
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
load_dotenv()
|
35 |
|
@@ -290,9 +295,137 @@ def analyze_csv_file(file_path: str, query: str) -> str:
|
|
290 |
|
291 |
except Exception as e:
|
292 |
return f"Error analyzing CSV file: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
tools = [multiply, add, subtract, divide, modulus,
|
294 |
wiki_search, web_search, arvix_search, read_excel_file, extract_text_from_pdf,
|
295 |
-
blip_image_caption, save_and_read_file, download_file_from_url, analyze_csv_file
|
|
|
296 |
|
297 |
# ------------------- SYSTEM PROMPT -------------------
|
298 |
system_prompt_path = "system_prompt.txt"
|
|
|
29 |
import requests
|
30 |
from urllib.parse import urlparse
|
31 |
from typing import Optional
|
32 |
+
import io
|
33 |
+
import contextlib
|
34 |
+
import base64
|
35 |
+
import subprocess
|
36 |
+
import sqlite3
|
37 |
+
import traceback
|
38 |
|
39 |
load_dotenv()
|
40 |
|
|
|
295 |
|
296 |
except Exception as e:
|
297 |
return f"Error analyzing CSV file: {str(e)}"
|
298 |
+
|
299 |
+
|
300 |
+
def execute_code_multilang(code: str, language: str = "python") -> str:
|
301 |
+
"""
|
302 |
+
Execute code in Python, Bash, SQL, C, or Java and return formatted results.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
code (str): Source code.
|
306 |
+
language (str): Language of the code. One of: 'python', 'bash', 'sql', 'c', 'java'.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
str: Human-readable execution result.
|
310 |
+
"""
|
311 |
+
language = language.lower()
|
312 |
+
exec_id = str(uuid.uuid4())
|
313 |
+
result = {
|
314 |
+
"stdout": "",
|
315 |
+
"stderr": "",
|
316 |
+
"status": "error",
|
317 |
+
"plots": [],
|
318 |
+
"dataframes": [],
|
319 |
+
}
|
320 |
+
|
321 |
+
try:
|
322 |
+
if language == "python":
|
323 |
+
plt.switch_backend("Agg")
|
324 |
+
stdout_buffer = io.StringIO()
|
325 |
+
stderr_buffer = io.StringIO()
|
326 |
+
globals_dict = {"pd": pd, "plt": plt, "Image": Image}
|
327 |
+
|
328 |
+
with contextlib.redirect_stdout(stdout_buffer), contextlib.redirect_stderr(stderr_buffer):
|
329 |
+
exec(code, globals_dict)
|
330 |
+
|
331 |
+
# Save plots
|
332 |
+
if plt.get_fignums():
|
333 |
+
for i, fig_num in enumerate(plt.get_fignums()):
|
334 |
+
fig = plt.figure(fig_num)
|
335 |
+
img_path = os.path.join(tempfile.gettempdir(), f"{exec_id}_plot_{i}.png")
|
336 |
+
fig.savefig(img_path)
|
337 |
+
with open(img_path, "rb") as f:
|
338 |
+
img_data = base64.b64encode(f.read()).decode()
|
339 |
+
result["plots"].append(img_data)
|
340 |
+
|
341 |
+
# Check for dataframes
|
342 |
+
for var_name, var_val in globals_dict.items():
|
343 |
+
if isinstance(var_val, pd.DataFrame):
|
344 |
+
result["dataframes"].append((var_name, var_val.head().to_string()))
|
345 |
+
|
346 |
+
result["stdout"] = stdout_buffer.getvalue()
|
347 |
+
result["stderr"] = stderr_buffer.getvalue()
|
348 |
+
result["status"] = "success"
|
349 |
+
|
350 |
+
elif language == "bash":
|
351 |
+
completed = subprocess.run(code, shell=True, capture_output=True, text=True, timeout=30)
|
352 |
+
result["stdout"] = completed.stdout
|
353 |
+
result["stderr"] = completed.stderr
|
354 |
+
result["status"] = "success" if completed.returncode == 0 else "error"
|
355 |
+
|
356 |
+
elif language == "sql":
|
357 |
+
conn = sqlite3.connect(":memory:")
|
358 |
+
cur = conn.cursor()
|
359 |
+
cur.execute(code)
|
360 |
+
if code.strip().lower().startswith("select"):
|
361 |
+
cols = [desc[0] for desc in cur.description]
|
362 |
+
rows = cur.fetchall()
|
363 |
+
df = pd.DataFrame(rows, columns=cols)
|
364 |
+
result["dataframes"].append(("query_result", df.head().to_string()))
|
365 |
+
conn.commit()
|
366 |
+
conn.close()
|
367 |
+
result["status"] = "success"
|
368 |
+
result["stdout"] = "SQL executed successfully."
|
369 |
+
|
370 |
+
elif language == "c":
|
371 |
+
with tempfile.TemporaryDirectory() as tmp:
|
372 |
+
src = os.path.join(tmp, "main.c")
|
373 |
+
bin_path = os.path.join(tmp, "main")
|
374 |
+
with open(src, "w") as f:
|
375 |
+
f.write(code)
|
376 |
+
comp = subprocess.run(["gcc", src, "-o", bin_path], capture_output=True, text=True)
|
377 |
+
if comp.returncode != 0:
|
378 |
+
result["stderr"] = comp.stderr
|
379 |
+
else:
|
380 |
+
run = subprocess.run([bin_path], capture_output=True, text=True, timeout=30)
|
381 |
+
result["stdout"] = run.stdout
|
382 |
+
result["stderr"] = run.stderr
|
383 |
+
result["status"] = "success" if run.returncode == 0 else "error"
|
384 |
+
|
385 |
+
elif language == "java":
|
386 |
+
with tempfile.TemporaryDirectory() as tmp:
|
387 |
+
src = os.path.join(tmp, "Main.java")
|
388 |
+
with open(src, "w") as f:
|
389 |
+
f.write(code)
|
390 |
+
comp = subprocess.run(["javac", src], capture_output=True, text=True)
|
391 |
+
if comp.returncode != 0:
|
392 |
+
result["stderr"] = comp.stderr
|
393 |
+
else:
|
394 |
+
run = subprocess.run(["java", "-cp", tmp, "Main"], capture_output=True, text=True, timeout=30)
|
395 |
+
result["stdout"] = run.stdout
|
396 |
+
result["stderr"] = run.stderr
|
397 |
+
result["status"] = "success" if run.returncode == 0 else "error"
|
398 |
+
|
399 |
+
else:
|
400 |
+
return f"❌ Unsupported language: {language}."
|
401 |
+
|
402 |
+
except Exception as e:
|
403 |
+
result["stderr"] = traceback.format_exc()
|
404 |
+
|
405 |
+
# Format response
|
406 |
+
summary = []
|
407 |
+
if result["status"] == "success":
|
408 |
+
summary.append(f"✅ Code executed successfully in **{language.upper()}**")
|
409 |
+
if result["stdout"]:
|
410 |
+
summary.append(f"\n**Output:**\n```\n{result['stdout'].strip()}\n```")
|
411 |
+
if result["stderr"]:
|
412 |
+
summary.append(f"\n**Warnings/Errors:**\n```\n{result['stderr'].strip()}\n```")
|
413 |
+
for name, df in result["dataframes"]:
|
414 |
+
summary.append(f"\n**DataFrame `{name}` Preview:**\n```\n{df}\n```")
|
415 |
+
if result["plots"]:
|
416 |
+
summary.append(f"\n📊 {len(result['plots'])} plot(s) generated (base64-encoded).")
|
417 |
+
else:
|
418 |
+
summary.append(f"❌ Execution failed for **{language.upper()}**")
|
419 |
+
if result["stderr"]:
|
420 |
+
summary.append(f"\n**Error:**\n```\n{result['stderr'].strip()}\n```")
|
421 |
+
|
422 |
+
return "\n".join(summary)
|
423 |
+
|
424 |
+
|
425 |
tools = [multiply, add, subtract, divide, modulus,
|
426 |
wiki_search, web_search, arvix_search, read_excel_file, extract_text_from_pdf,
|
427 |
+
blip_image_caption, save_and_read_file, download_file_from_url, analyze_csv_file,
|
428 |
+
execute_code_multilang]
|
429 |
|
430 |
# ------------------- SYSTEM PROMPT -------------------
|
431 |
system_prompt_path = "system_prompt.txt"
|