Yongkang ZOU
update toolkit
8e78869
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState, END
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_groq import ChatGroq
from supabase import create_client
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
import json
import pdfplumber
import pandas as pd
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import torch
import matplotlib.pyplot as plt
import cmath
# from code_interpreter import CodeInterpreter
import uuid
import tempfile
import requests
from urllib.parse import urlparse
from typing import Optional
import io
import contextlib
import base64
import subprocess
import sqlite3
import traceback
load_dotenv()
# ------------------- TOOL DEFINITIONS -------------------
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract b from a."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide a by b. Raise error if b is zero."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get remainder of a divided by b."""
return a % b
@tool
def square_root(a: float) -> float | complex:
"""
Get the square root of a number.
Args:
a (float): the number to get the square root of
"""
if a >= 0:
return a**0.5
return cmath.sqrt(a)
@tool
def power(a: float, b: float) -> float:
"""
Get the power of two numbers.
Args:
a (float): the first number
b (float): the second number
"""
return a**b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query (max 2 results)."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n".join([doc.page_content for doc in docs])
@tool
def web_search(query: str) -> str:
"""Search the web using Tavily (max 3 results)."""
results = TavilySearchResults(max_results=3).invoke(query)
texts = [doc.get("content", "") or doc.get("text", "") for doc in results if isinstance(doc, dict)]
return "\n\n".join(texts)
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for academic papers (max 3 results, truncated to 1000 characters each)."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n".join([doc.page_content[:1000] for doc in docs])
@tool
def read_excel_file(path: str) -> str:
"""Read an Excel file and return the first few rows of each sheet as text."""
import pandas as pd
try:
xls = pd.ExcelFile(path)
content = ""
for sheet in xls.sheet_names:
df = xls.parse(sheet)
content += f"Sheet: {sheet}\n"
content += df.head(5).to_string(index=False) + "\n\n"
return content.strip()
except Exception as e:
return f"Error reading Excel file: {str(e)}"
@tool
def extract_text_from_pdf(path: str) -> str:
"""Extract text from a PDF file given its local path."""
try:
text = ""
with pdfplumber.open(path) as pdf:
for page in pdf.pages[:5]: # 限前5页,避免过大
page_text = page.extract_text()
if page_text:
text += page_text + "\n\n"
return text.strip() if text else "No text extracted from PDF."
except Exception as e:
return f"Error reading PDF: {str(e)}"
# 初始化模型(首次加载可能稍慢)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
@tool
def blip_image_caption(image_path: str) -> str:
"""Generate a description for an image using BLIP."""
try:
image = Image.open(image_path).convert("RGB")
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
out = model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"Failed to process image with BLIP: {str(e)}"
# @tool
# def execute_code_multilang(code: str, language: str = "python") -> str:
# """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results.
# Args:
# code (str): The source code to execute.
# language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java".
# Returns:
# A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any).
# """
# supported_languages = ["python", "bash", "sql", "c", "java"]
# language = language.lower()
# interpreter_instance = CodeInterpreter()
# if language not in supported_languages:
# return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}"
# result = interpreter_instance.execute_code(code, language=language)
# response = []
# if result["status"] == "success":
# response.append(f"✅ Code executed successfully in **{language.upper()}**")
# if result.get("stdout"):
# response.append(
# "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```"
# )
# if result.get("stderr"):
# response.append(
# "\n**Standard Error (if any):**\n```\n"
# + result["stderr"].strip()
# + "\n```"
# )
# if result.get("result") is not None:
# response.append(
# "\n**Execution Result:**\n```\n"
# + str(result["result"]).strip()
# + "\n```"
# )
# if result.get("dataframes"):
# for df_info in result["dataframes"]:
# response.append(
# f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**"
# )
# df_preview = pd.DataFrame(df_info["head"])
# response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```")
# if result.get("plots"):
# response.append(
# f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)"
# )
# else:
# response.append(f"❌ Code execution failed in **{language.upper()}**")
# if result.get("stderr"):
# response.append(
# "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```"
# )
# return "\n".join(response)
@tool
def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
"""
Save content to a file and return the path.
Args:
content (str): the content to save to the file
filename (str, optional): the name of the file. If not provided, a random name file will be created.
"""
temp_dir = tempfile.gettempdir()
if filename is None:
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
filepath = temp_file.name
else:
filepath = os.path.join(temp_dir, filename)
with open(filepath, "w") as f:
f.write(content)
return f"File saved to {filepath}. You can read this file to process its contents."
@tool
def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
"""
Download a file from a URL and save it to a temporary location.
Args:
url (str): the URL of the file to download.
filename (str, optional): the name of the file. If not provided, a random name file will be created.
"""
try:
# Parse URL to get filename if not provided
if not filename:
path = urlparse(url).path
filename = os.path.basename(path)
if not filename:
filename = f"downloaded_{uuid.uuid4().hex[:8]}"
# Create temporary file
temp_dir = tempfile.gettempdir()
filepath = os.path.join(temp_dir, filename)
# Download the file
response = requests.get(url, stream=True)
response.raise_for_status()
# Save the file
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return f"File downloaded to {filepath}. You can read this file to process its contents."
except Exception as e:
return f"Error downloading file: {str(e)}"
@tool
def analyze_csv_file(file_path: str, query: str) -> str:
"""
Analyze a CSV file using pandas and answer a question about it.
Args:
file_path (str): the path to the CSV file.
query (str): Question about the data
"""
try:
# Read the CSV file
df = pd.read_csv(file_path)
# Run various analyses based on the query
result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
result += f"Columns: {', '.join(df.columns)}\n\n"
# Add summary statistics
result += "Summary statistics:\n"
result += str(df.describe())
return result
except Exception as e:
return f"Error analyzing CSV file: {str(e)}"
def execute_code_multilang(code: str, language: str = "python") -> str:
"""
Execute code in Python, Bash, SQL, C, or Java and return formatted results.
Args:
code (str): Source code.
language (str): Language of the code. One of: 'python', 'bash', 'sql', 'c', 'java'.
Returns:
str: Human-readable execution result.
"""
language = language.lower()
exec_id = str(uuid.uuid4())
result = {
"stdout": "",
"stderr": "",
"status": "error",
"plots": [],
"dataframes": [],
}
try:
if language == "python":
plt.switch_backend("Agg")
stdout_buffer = io.StringIO()
stderr_buffer = io.StringIO()
globals_dict = {"pd": pd, "plt": plt, "Image": Image}
with contextlib.redirect_stdout(stdout_buffer), contextlib.redirect_stderr(stderr_buffer):
exec(code, globals_dict)
# Save plots
if plt.get_fignums():
for i, fig_num in enumerate(plt.get_fignums()):
fig = plt.figure(fig_num)
img_path = os.path.join(tempfile.gettempdir(), f"{exec_id}_plot_{i}.png")
fig.savefig(img_path)
with open(img_path, "rb") as f:
img_data = base64.b64encode(f.read()).decode()
result["plots"].append(img_data)
# Check for dataframes
for var_name, var_val in globals_dict.items():
if isinstance(var_val, pd.DataFrame):
result["dataframes"].append((var_name, var_val.head().to_string()))
result["stdout"] = stdout_buffer.getvalue()
result["stderr"] = stderr_buffer.getvalue()
result["status"] = "success"
elif language == "bash":
completed = subprocess.run(code, shell=True, capture_output=True, text=True, timeout=30)
result["stdout"] = completed.stdout
result["stderr"] = completed.stderr
result["status"] = "success" if completed.returncode == 0 else "error"
elif language == "sql":
conn = sqlite3.connect(":memory:")
cur = conn.cursor()
cur.execute(code)
if code.strip().lower().startswith("select"):
cols = [desc[0] for desc in cur.description]
rows = cur.fetchall()
df = pd.DataFrame(rows, columns=cols)
result["dataframes"].append(("query_result", df.head().to_string()))
conn.commit()
conn.close()
result["status"] = "success"
result["stdout"] = "SQL executed successfully."
elif language == "c":
with tempfile.TemporaryDirectory() as tmp:
src = os.path.join(tmp, "main.c")
bin_path = os.path.join(tmp, "main")
with open(src, "w") as f:
f.write(code)
comp = subprocess.run(["gcc", src, "-o", bin_path], capture_output=True, text=True)
if comp.returncode != 0:
result["stderr"] = comp.stderr
else:
run = subprocess.run([bin_path], capture_output=True, text=True, timeout=30)
result["stdout"] = run.stdout
result["stderr"] = run.stderr
result["status"] = "success" if run.returncode == 0 else "error"
elif language == "java":
with tempfile.TemporaryDirectory() as tmp:
src = os.path.join(tmp, "Main.java")
with open(src, "w") as f:
f.write(code)
comp = subprocess.run(["javac", src], capture_output=True, text=True)
if comp.returncode != 0:
result["stderr"] = comp.stderr
else:
run = subprocess.run(["java", "-cp", tmp, "Main"], capture_output=True, text=True, timeout=30)
result["stdout"] = run.stdout
result["stderr"] = run.stderr
result["status"] = "success" if run.returncode == 0 else "error"
else:
return f"❌ Unsupported language: {language}."
except Exception as e:
result["stderr"] = traceback.format_exc()
# Format response
summary = []
if result["status"] == "success":
summary.append(f"✅ Code executed successfully in **{language.upper()}**")
if result["stdout"]:
summary.append(f"\n**Output:**\n```\n{result['stdout'].strip()}\n```")
if result["stderr"]:
summary.append(f"\n**Warnings/Errors:**\n```\n{result['stderr'].strip()}\n```")
for name, df in result["dataframes"]:
summary.append(f"\n**DataFrame `{name}` Preview:**\n```\n{df}\n```")
if result["plots"]:
summary.append(f"\n📊 {len(result['plots'])} plot(s) generated (base64-encoded).")
else:
summary.append(f"❌ Execution failed for **{language.upper()}**")
if result["stderr"]:
summary.append(f"\n**Error:**\n```\n{result['stderr'].strip()}\n```")
return "\n".join(summary)
tools = [multiply, add, subtract, divide, modulus,
wiki_search, web_search, arvix_search, read_excel_file, extract_text_from_pdf,
blip_image_caption, save_and_read_file, download_file_from_url, analyze_csv_file,
execute_code_multilang]
# ------------------- SYSTEM PROMPT -------------------
system_prompt_path = "system_prompt.txt"
if os.path.exists(system_prompt_path):
with open(system_prompt_path, "r", encoding="utf-8") as f:
system_prompt = f.read()
else:
system_prompt = (
"You are an intelligent AI agent who can solve math, science, factual, and research-based problems. "
"You can use tools like Wikipedia, Web search, or Arxiv when needed. Always give precise and helpful answers."
)
sys_msg = SystemMessage(content=system_prompt)
# ------------------- GRAPH CONSTRUCTION -------------------
def build_graph(provider: str = "groq"):
if provider == "google":
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
elif provider == "groq":
groq_key = os.getenv("GROQ_API_KEY")
if not groq_key:
raise ValueError("GROQ_API_KEY is not set.")
llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=groq_key)
elif provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0
)
)
elif provider == "openai":
openai_key = os.getenv("OPENAI_API_KEY")
if not openai_key:
raise ValueError("OPENAI_API_KEY is not set.")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=openai_key)
else:
raise ValueError("Invalid provider")
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]}
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vectorstore = SupabaseVectorStore(
client=supabase,
embedding=embedding_model,
table_name="QA_db"
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
# ✅ 替换 similarity_search_by_vector_with_relevance_scores 方法,直接调用 supabase.rpc
original_fn = vectorstore.similarity_search_by_vector_with_relevance_scores
# ✅ 覆盖 vectorstore 的方法
def patched_fn(embedding, k=4, filter=None, **kwargs):
response = supabase.rpc(
"match_documents",
{
"query_embedding": embedding,
"match_count": k
}
).execute()
documents = []
for r in response.data:
metadata = r["metadata"]
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except Exception:
metadata = {}
doc = Document(
page_content=r["content"],
metadata=metadata
)
documents.append((doc, r["similarity"]))
return documents
# ✅ 覆盖 vectorstore 的方法
vectorstore.similarity_search_by_vector_with_relevance_scores = patched_fn
def qa_retriever_node(state: MessagesState):
user_question = state["messages"][-1].content
docs = retriever.invoke(user_question)
if docs:
return {
"messages": state["messages"] + [AIMessage(content=docs[0].page_content)],
"__condition__": "complete"
}
return {"messages": state["messages"], "__condition__": "default"}
builder = StateGraph(MessagesState)
builder.add_node("retriever", qa_retriever_node)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_conditional_edges("retriever", {
"default": lambda x: "assistant",
"complete": lambda x: END,
})
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
# ------------------- LOCAL TEST -------------------
if __name__ == "__main__":
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
graph = build_graph(provider="openai")
messages = graph.invoke({"messages": [HumanMessage(content=question)]})
print("=== AI Agent Response ===")
for m in messages["messages"]:
m.pretty_print()