import os from dotenv import load_dotenv from langgraph.graph import START, StateGraph, MessagesState, END from langgraph.prebuilt import tools_condition, ToolNode from langchain_google_genai import ChatGoogleGenerativeAI from langchain_groq import ChatGroq from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.tools import tool from langchain_groq import ChatGroq from supabase import create_client from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import SupabaseVectorStore from langchain_openai import ChatOpenAI from langchain_core.documents import Document import json import pdfplumber import pandas as pd from transformers import BlipProcessor, BlipForConditionalGeneration from PIL import Image import torch import matplotlib.pyplot as plt import cmath # from code_interpreter import CodeInterpreter import uuid import tempfile import requests from urllib.parse import urlparse from typing import Optional import io import contextlib import base64 import subprocess import sqlite3 import traceback load_dotenv() # ------------------- TOOL DEFINITIONS ------------------- @tool def multiply(a: int, b: int) -> int: """Multiply two numbers.""" return a * b @tool def add(a: int, b: int) -> int: """Add two numbers.""" return a + b @tool def subtract(a: int, b: int) -> int: """Subtract b from a.""" return a - b @tool def divide(a: int, b: int) -> float: """Divide a by b. Raise error if b is zero.""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Get remainder of a divided by b.""" return a % b @tool def square_root(a: float) -> float | complex: """ Get the square root of a number. Args: a (float): the number to get the square root of """ if a >= 0: return a**0.5 return cmath.sqrt(a) @tool def power(a: float, b: float) -> float: """ Get the power of two numbers. Args: a (float): the first number b (float): the second number """ return a**b @tool def wiki_search(query: str) -> str: """Search Wikipedia for a query (max 2 results).""" docs = WikipediaLoader(query=query, load_max_docs=2).load() return "\n\n".join([doc.page_content for doc in docs]) @tool def web_search(query: str) -> str: """Search the web using Tavily (max 3 results).""" results = TavilySearchResults(max_results=3).invoke(query) texts = [doc.get("content", "") or doc.get("text", "") for doc in results if isinstance(doc, dict)] return "\n\n".join(texts) @tool def arvix_search(query: str) -> str: """Search Arxiv for academic papers (max 3 results, truncated to 1000 characters each).""" docs = ArxivLoader(query=query, load_max_docs=3).load() return "\n\n".join([doc.page_content[:1000] for doc in docs]) @tool def read_excel_file(path: str) -> str: """Read an Excel file and return the first few rows of each sheet as text.""" import pandas as pd try: xls = pd.ExcelFile(path) content = "" for sheet in xls.sheet_names: df = xls.parse(sheet) content += f"Sheet: {sheet}\n" content += df.head(5).to_string(index=False) + "\n\n" return content.strip() except Exception as e: return f"Error reading Excel file: {str(e)}" @tool def extract_text_from_pdf(path: str) -> str: """Extract text from a PDF file given its local path.""" try: text = "" with pdfplumber.open(path) as pdf: for page in pdf.pages[:5]: # 限前5页,避免过大 page_text = page.extract_text() if page_text: text += page_text + "\n\n" return text.strip() if text else "No text extracted from PDF." except Exception as e: return f"Error reading PDF: {str(e)}" # 初始化模型(首次加载可能稍慢) processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") @tool def blip_image_caption(image_path: str) -> str: """Generate a description for an image using BLIP.""" try: image = Image.open(image_path).convert("RGB") inputs = processor(image, return_tensors="pt") with torch.no_grad(): out = model.generate(**inputs) caption = processor.decode(out[0], skip_special_tokens=True) return caption except Exception as e: return f"Failed to process image with BLIP: {str(e)}" # @tool # def execute_code_multilang(code: str, language: str = "python") -> str: # """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results. # Args: # code (str): The source code to execute. # language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java". # Returns: # A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any). # """ # supported_languages = ["python", "bash", "sql", "c", "java"] # language = language.lower() # interpreter_instance = CodeInterpreter() # if language not in supported_languages: # return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}" # result = interpreter_instance.execute_code(code, language=language) # response = [] # if result["status"] == "success": # response.append(f"✅ Code executed successfully in **{language.upper()}**") # if result.get("stdout"): # response.append( # "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```" # ) # if result.get("stderr"): # response.append( # "\n**Standard Error (if any):**\n```\n" # + result["stderr"].strip() # + "\n```" # ) # if result.get("result") is not None: # response.append( # "\n**Execution Result:**\n```\n" # + str(result["result"]).strip() # + "\n```" # ) # if result.get("dataframes"): # for df_info in result["dataframes"]: # response.append( # f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**" # ) # df_preview = pd.DataFrame(df_info["head"]) # response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```") # if result.get("plots"): # response.append( # f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)" # ) # else: # response.append(f"❌ Code execution failed in **{language.upper()}**") # if result.get("stderr"): # response.append( # "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```" # ) # return "\n".join(response) @tool def save_and_read_file(content: str, filename: Optional[str] = None) -> str: """ Save content to a file and return the path. Args: content (str): the content to save to the file filename (str, optional): the name of the file. If not provided, a random name file will be created. """ temp_dir = tempfile.gettempdir() if filename is None: temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) filepath = temp_file.name else: filepath = os.path.join(temp_dir, filename) with open(filepath, "w") as f: f.write(content) return f"File saved to {filepath}. You can read this file to process its contents." @tool def download_file_from_url(url: str, filename: Optional[str] = None) -> str: """ Download a file from a URL and save it to a temporary location. Args: url (str): the URL of the file to download. filename (str, optional): the name of the file. If not provided, a random name file will be created. """ try: # Parse URL to get filename if not provided if not filename: path = urlparse(url).path filename = os.path.basename(path) if not filename: filename = f"downloaded_{uuid.uuid4().hex[:8]}" # Create temporary file temp_dir = tempfile.gettempdir() filepath = os.path.join(temp_dir, filename) # Download the file response = requests.get(url, stream=True) response.raise_for_status() # Save the file with open(filepath, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return f"File downloaded to {filepath}. You can read this file to process its contents." except Exception as e: return f"Error downloading file: {str(e)}" @tool def analyze_csv_file(file_path: str, query: str) -> str: """ Analyze a CSV file using pandas and answer a question about it. Args: file_path (str): the path to the CSV file. query (str): Question about the data """ try: # Read the CSV file df = pd.read_csv(file_path) # Run various analyses based on the query result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n" result += f"Columns: {', '.join(df.columns)}\n\n" # Add summary statistics result += "Summary statistics:\n" result += str(df.describe()) return result except Exception as e: return f"Error analyzing CSV file: {str(e)}" def execute_code_multilang(code: str, language: str = "python") -> str: """ Execute code in Python, Bash, SQL, C, or Java and return formatted results. Args: code (str): Source code. language (str): Language of the code. One of: 'python', 'bash', 'sql', 'c', 'java'. Returns: str: Human-readable execution result. """ language = language.lower() exec_id = str(uuid.uuid4()) result = { "stdout": "", "stderr": "", "status": "error", "plots": [], "dataframes": [], } try: if language == "python": plt.switch_backend("Agg") stdout_buffer = io.StringIO() stderr_buffer = io.StringIO() globals_dict = {"pd": pd, "plt": plt, "Image": Image} with contextlib.redirect_stdout(stdout_buffer), contextlib.redirect_stderr(stderr_buffer): exec(code, globals_dict) # Save plots if plt.get_fignums(): for i, fig_num in enumerate(plt.get_fignums()): fig = plt.figure(fig_num) img_path = os.path.join(tempfile.gettempdir(), f"{exec_id}_plot_{i}.png") fig.savefig(img_path) with open(img_path, "rb") as f: img_data = base64.b64encode(f.read()).decode() result["plots"].append(img_data) # Check for dataframes for var_name, var_val in globals_dict.items(): if isinstance(var_val, pd.DataFrame): result["dataframes"].append((var_name, var_val.head().to_string())) result["stdout"] = stdout_buffer.getvalue() result["stderr"] = stderr_buffer.getvalue() result["status"] = "success" elif language == "bash": completed = subprocess.run(code, shell=True, capture_output=True, text=True, timeout=30) result["stdout"] = completed.stdout result["stderr"] = completed.stderr result["status"] = "success" if completed.returncode == 0 else "error" elif language == "sql": conn = sqlite3.connect(":memory:") cur = conn.cursor() cur.execute(code) if code.strip().lower().startswith("select"): cols = [desc[0] for desc in cur.description] rows = cur.fetchall() df = pd.DataFrame(rows, columns=cols) result["dataframes"].append(("query_result", df.head().to_string())) conn.commit() conn.close() result["status"] = "success" result["stdout"] = "SQL executed successfully." elif language == "c": with tempfile.TemporaryDirectory() as tmp: src = os.path.join(tmp, "main.c") bin_path = os.path.join(tmp, "main") with open(src, "w") as f: f.write(code) comp = subprocess.run(["gcc", src, "-o", bin_path], capture_output=True, text=True) if comp.returncode != 0: result["stderr"] = comp.stderr else: run = subprocess.run([bin_path], capture_output=True, text=True, timeout=30) result["stdout"] = run.stdout result["stderr"] = run.stderr result["status"] = "success" if run.returncode == 0 else "error" elif language == "java": with tempfile.TemporaryDirectory() as tmp: src = os.path.join(tmp, "Main.java") with open(src, "w") as f: f.write(code) comp = subprocess.run(["javac", src], capture_output=True, text=True) if comp.returncode != 0: result["stderr"] = comp.stderr else: run = subprocess.run(["java", "-cp", tmp, "Main"], capture_output=True, text=True, timeout=30) result["stdout"] = run.stdout result["stderr"] = run.stderr result["status"] = "success" if run.returncode == 0 else "error" else: return f"❌ Unsupported language: {language}." except Exception as e: result["stderr"] = traceback.format_exc() # Format response summary = [] if result["status"] == "success": summary.append(f"✅ Code executed successfully in **{language.upper()}**") if result["stdout"]: summary.append(f"\n**Output:**\n```\n{result['stdout'].strip()}\n```") if result["stderr"]: summary.append(f"\n**Warnings/Errors:**\n```\n{result['stderr'].strip()}\n```") for name, df in result["dataframes"]: summary.append(f"\n**DataFrame `{name}` Preview:**\n```\n{df}\n```") if result["plots"]: summary.append(f"\n📊 {len(result['plots'])} plot(s) generated (base64-encoded).") else: summary.append(f"❌ Execution failed for **{language.upper()}**") if result["stderr"]: summary.append(f"\n**Error:**\n```\n{result['stderr'].strip()}\n```") return "\n".join(summary) tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search, read_excel_file, extract_text_from_pdf, blip_image_caption, save_and_read_file, download_file_from_url, analyze_csv_file, execute_code_multilang] # ------------------- SYSTEM PROMPT ------------------- system_prompt_path = "system_prompt.txt" if os.path.exists(system_prompt_path): with open(system_prompt_path, "r", encoding="utf-8") as f: system_prompt = f.read() else: system_prompt = ( "You are an intelligent AI agent who can solve math, science, factual, and research-based problems. " "You can use tools like Wikipedia, Web search, or Arxiv when needed. Always give precise and helpful answers." ) sys_msg = SystemMessage(content=system_prompt) # ------------------- GRAPH CONSTRUCTION ------------------- def build_graph(provider: str = "groq"): if provider == "google": llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) elif provider == "groq": groq_key = os.getenv("GROQ_API_KEY") if not groq_key: raise ValueError("GROQ_API_KEY is not set.") llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=groq_key) elif provider == "huggingface": llm = ChatHuggingFace( llm=HuggingFaceEndpoint( url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", temperature=0 ) ) elif provider == "openai": openai_key = os.getenv("OPENAI_API_KEY") if not openai_key: raise ValueError("OPENAI_API_KEY is not set.") llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=openai_key) else: raise ValueError("Invalid provider") llm_with_tools = llm.bind_tools(tools) def assistant(state: MessagesState): return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]} SUPABASE_URL = os.getenv("SUPABASE_URL") SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY") supabase = create_client(SUPABASE_URL, SUPABASE_KEY) embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") vectorstore = SupabaseVectorStore( client=supabase, embedding=embedding_model, table_name="QA_db" ) retriever = vectorstore.as_retriever(search_kwargs={"k": 1}) # ✅ 替换 similarity_search_by_vector_with_relevance_scores 方法,直接调用 supabase.rpc original_fn = vectorstore.similarity_search_by_vector_with_relevance_scores # ✅ 覆盖 vectorstore 的方法 def patched_fn(embedding, k=4, filter=None, **kwargs): response = supabase.rpc( "match_documents", { "query_embedding": embedding, "match_count": k } ).execute() documents = [] for r in response.data: metadata = r["metadata"] if isinstance(metadata, str): try: metadata = json.loads(metadata) except Exception: metadata = {} doc = Document( page_content=r["content"], metadata=metadata ) documents.append((doc, r["similarity"])) return documents # ✅ 覆盖 vectorstore 的方法 vectorstore.similarity_search_by_vector_with_relevance_scores = patched_fn def qa_retriever_node(state: MessagesState): user_question = state["messages"][-1].content docs = retriever.invoke(user_question) if docs: return { "messages": state["messages"] + [AIMessage(content=docs[0].page_content)], "__condition__": "complete" } return {"messages": state["messages"], "__condition__": "default"} builder = StateGraph(MessagesState) builder.add_node("retriever", qa_retriever_node) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "retriever") builder.add_conditional_edges("retriever", { "default": lambda x: "assistant", "complete": lambda x: END, }) builder.add_conditional_edges("assistant", tools_condition) builder.add_edge("tools", "assistant") return builder.compile() # ------------------- LOCAL TEST ------------------- if __name__ == "__main__": question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" graph = build_graph(provider="openai") messages = graph.invoke({"messages": [HumanMessage(content=question)]}) print("=== AI Agent Response ===") for m in messages["messages"]: m.pretty_print()