|
import os |
|
from dotenv import load_dotenv |
|
from langgraph.graph import START, StateGraph, MessagesState, END |
|
from langgraph.prebuilt import tools_condition, ToolNode |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_groq import ChatGroq |
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader |
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
from langchain_core.tools import tool |
|
from langchain_groq import ChatGroq |
|
from supabase import create_client |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import SupabaseVectorStore |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.documents import Document |
|
import json |
|
import pdfplumber |
|
import pandas as pd |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
from PIL import Image |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import cmath |
|
|
|
import uuid |
|
import tempfile |
|
import requests |
|
from urllib.parse import urlparse |
|
from typing import Optional |
|
import io |
|
import contextlib |
|
import base64 |
|
import subprocess |
|
import sqlite3 |
|
import traceback |
|
|
|
load_dotenv() |
|
|
|
|
|
@tool |
|
def multiply(a: int, b: int) -> int: |
|
"""Multiply two numbers.""" |
|
return a * b |
|
|
|
@tool |
|
def add(a: int, b: int) -> int: |
|
"""Add two numbers.""" |
|
return a + b |
|
|
|
@tool |
|
def subtract(a: int, b: int) -> int: |
|
"""Subtract b from a.""" |
|
return a - b |
|
|
|
@tool |
|
def divide(a: int, b: int) -> float: |
|
"""Divide a by b. Raise error if b is zero.""" |
|
if b == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def modulus(a: int, b: int) -> int: |
|
"""Get remainder of a divided by b.""" |
|
return a % b |
|
|
|
@tool |
|
def square_root(a: float) -> float | complex: |
|
""" |
|
Get the square root of a number. |
|
Args: |
|
a (float): the number to get the square root of |
|
""" |
|
if a >= 0: |
|
return a**0.5 |
|
return cmath.sqrt(a) |
|
|
|
@tool |
|
def power(a: float, b: float) -> float: |
|
""" |
|
Get the power of two numbers. |
|
Args: |
|
a (float): the first number |
|
b (float): the second number |
|
""" |
|
return a**b |
|
|
|
|
|
@tool |
|
def wiki_search(query: str) -> str: |
|
"""Search Wikipedia for a query (max 2 results).""" |
|
docs = WikipediaLoader(query=query, load_max_docs=2).load() |
|
return "\n\n".join([doc.page_content for doc in docs]) |
|
|
|
@tool |
|
def web_search(query: str) -> str: |
|
"""Search the web using Tavily (max 3 results).""" |
|
results = TavilySearchResults(max_results=3).invoke(query) |
|
texts = [doc.get("content", "") or doc.get("text", "") for doc in results if isinstance(doc, dict)] |
|
return "\n\n".join(texts) |
|
|
|
@tool |
|
def arvix_search(query: str) -> str: |
|
"""Search Arxiv for academic papers (max 3 results, truncated to 1000 characters each).""" |
|
docs = ArxivLoader(query=query, load_max_docs=3).load() |
|
return "\n\n".join([doc.page_content[:1000] for doc in docs]) |
|
|
|
@tool |
|
def read_excel_file(path: str) -> str: |
|
"""Read an Excel file and return the first few rows of each sheet as text.""" |
|
import pandas as pd |
|
try: |
|
xls = pd.ExcelFile(path) |
|
content = "" |
|
for sheet in xls.sheet_names: |
|
df = xls.parse(sheet) |
|
content += f"Sheet: {sheet}\n" |
|
content += df.head(5).to_string(index=False) + "\n\n" |
|
return content.strip() |
|
except Exception as e: |
|
return f"Error reading Excel file: {str(e)}" |
|
|
|
@tool |
|
def extract_text_from_pdf(path: str) -> str: |
|
"""Extract text from a PDF file given its local path.""" |
|
try: |
|
text = "" |
|
with pdfplumber.open(path) as pdf: |
|
for page in pdf.pages[:5]: |
|
page_text = page.extract_text() |
|
if page_text: |
|
text += page_text + "\n\n" |
|
return text.strip() if text else "No text extracted from PDF." |
|
except Exception as e: |
|
return f"Error reading PDF: {str(e)}" |
|
|
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
|
@tool |
|
def blip_image_caption(image_path: str) -> str: |
|
"""Generate a description for an image using BLIP.""" |
|
try: |
|
image = Image.open(image_path).convert("RGB") |
|
inputs = processor(image, return_tensors="pt") |
|
with torch.no_grad(): |
|
out = model.generate(**inputs) |
|
caption = processor.decode(out[0], skip_special_tokens=True) |
|
return caption |
|
except Exception as e: |
|
return f"Failed to process image with BLIP: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@tool |
|
def save_and_read_file(content: str, filename: Optional[str] = None) -> str: |
|
""" |
|
Save content to a file and return the path. |
|
Args: |
|
content (str): the content to save to the file |
|
filename (str, optional): the name of the file. If not provided, a random name file will be created. |
|
""" |
|
temp_dir = tempfile.gettempdir() |
|
if filename is None: |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) |
|
filepath = temp_file.name |
|
else: |
|
filepath = os.path.join(temp_dir, filename) |
|
|
|
with open(filepath, "w") as f: |
|
f.write(content) |
|
|
|
return f"File saved to {filepath}. You can read this file to process its contents." |
|
|
|
|
|
@tool |
|
def download_file_from_url(url: str, filename: Optional[str] = None) -> str: |
|
""" |
|
Download a file from a URL and save it to a temporary location. |
|
Args: |
|
url (str): the URL of the file to download. |
|
filename (str, optional): the name of the file. If not provided, a random name file will be created. |
|
""" |
|
try: |
|
|
|
if not filename: |
|
path = urlparse(url).path |
|
filename = os.path.basename(path) |
|
if not filename: |
|
filename = f"downloaded_{uuid.uuid4().hex[:8]}" |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
filepath = os.path.join(temp_dir, filename) |
|
|
|
|
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
|
|
|
|
with open(filepath, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
|
|
return f"File downloaded to {filepath}. You can read this file to process its contents." |
|
except Exception as e: |
|
return f"Error downloading file: {str(e)}" |
|
|
|
@tool |
|
def analyze_csv_file(file_path: str, query: str) -> str: |
|
""" |
|
Analyze a CSV file using pandas and answer a question about it. |
|
Args: |
|
file_path (str): the path to the CSV file. |
|
query (str): Question about the data |
|
""" |
|
try: |
|
|
|
df = pd.read_csv(file_path) |
|
|
|
|
|
result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n" |
|
result += f"Columns: {', '.join(df.columns)}\n\n" |
|
|
|
|
|
result += "Summary statistics:\n" |
|
result += str(df.describe()) |
|
|
|
return result |
|
|
|
except Exception as e: |
|
return f"Error analyzing CSV file: {str(e)}" |
|
|
|
|
|
def execute_code_multilang(code: str, language: str = "python") -> str: |
|
""" |
|
Execute code in Python, Bash, SQL, C, or Java and return formatted results. |
|
|
|
Args: |
|
code (str): Source code. |
|
language (str): Language of the code. One of: 'python', 'bash', 'sql', 'c', 'java'. |
|
|
|
Returns: |
|
str: Human-readable execution result. |
|
""" |
|
language = language.lower() |
|
exec_id = str(uuid.uuid4()) |
|
result = { |
|
"stdout": "", |
|
"stderr": "", |
|
"status": "error", |
|
"plots": [], |
|
"dataframes": [], |
|
} |
|
|
|
try: |
|
if language == "python": |
|
plt.switch_backend("Agg") |
|
stdout_buffer = io.StringIO() |
|
stderr_buffer = io.StringIO() |
|
globals_dict = {"pd": pd, "plt": plt, "Image": Image} |
|
|
|
with contextlib.redirect_stdout(stdout_buffer), contextlib.redirect_stderr(stderr_buffer): |
|
exec(code, globals_dict) |
|
|
|
|
|
if plt.get_fignums(): |
|
for i, fig_num in enumerate(plt.get_fignums()): |
|
fig = plt.figure(fig_num) |
|
img_path = os.path.join(tempfile.gettempdir(), f"{exec_id}_plot_{i}.png") |
|
fig.savefig(img_path) |
|
with open(img_path, "rb") as f: |
|
img_data = base64.b64encode(f.read()).decode() |
|
result["plots"].append(img_data) |
|
|
|
|
|
for var_name, var_val in globals_dict.items(): |
|
if isinstance(var_val, pd.DataFrame): |
|
result["dataframes"].append((var_name, var_val.head().to_string())) |
|
|
|
result["stdout"] = stdout_buffer.getvalue() |
|
result["stderr"] = stderr_buffer.getvalue() |
|
result["status"] = "success" |
|
|
|
elif language == "bash": |
|
completed = subprocess.run(code, shell=True, capture_output=True, text=True, timeout=30) |
|
result["stdout"] = completed.stdout |
|
result["stderr"] = completed.stderr |
|
result["status"] = "success" if completed.returncode == 0 else "error" |
|
|
|
elif language == "sql": |
|
conn = sqlite3.connect(":memory:") |
|
cur = conn.cursor() |
|
cur.execute(code) |
|
if code.strip().lower().startswith("select"): |
|
cols = [desc[0] for desc in cur.description] |
|
rows = cur.fetchall() |
|
df = pd.DataFrame(rows, columns=cols) |
|
result["dataframes"].append(("query_result", df.head().to_string())) |
|
conn.commit() |
|
conn.close() |
|
result["status"] = "success" |
|
result["stdout"] = "SQL executed successfully." |
|
|
|
elif language == "c": |
|
with tempfile.TemporaryDirectory() as tmp: |
|
src = os.path.join(tmp, "main.c") |
|
bin_path = os.path.join(tmp, "main") |
|
with open(src, "w") as f: |
|
f.write(code) |
|
comp = subprocess.run(["gcc", src, "-o", bin_path], capture_output=True, text=True) |
|
if comp.returncode != 0: |
|
result["stderr"] = comp.stderr |
|
else: |
|
run = subprocess.run([bin_path], capture_output=True, text=True, timeout=30) |
|
result["stdout"] = run.stdout |
|
result["stderr"] = run.stderr |
|
result["status"] = "success" if run.returncode == 0 else "error" |
|
|
|
elif language == "java": |
|
with tempfile.TemporaryDirectory() as tmp: |
|
src = os.path.join(tmp, "Main.java") |
|
with open(src, "w") as f: |
|
f.write(code) |
|
comp = subprocess.run(["javac", src], capture_output=True, text=True) |
|
if comp.returncode != 0: |
|
result["stderr"] = comp.stderr |
|
else: |
|
run = subprocess.run(["java", "-cp", tmp, "Main"], capture_output=True, text=True, timeout=30) |
|
result["stdout"] = run.stdout |
|
result["stderr"] = run.stderr |
|
result["status"] = "success" if run.returncode == 0 else "error" |
|
|
|
else: |
|
return f"❌ Unsupported language: {language}." |
|
|
|
except Exception as e: |
|
result["stderr"] = traceback.format_exc() |
|
|
|
|
|
summary = [] |
|
if result["status"] == "success": |
|
summary.append(f"✅ Code executed successfully in **{language.upper()}**") |
|
if result["stdout"]: |
|
summary.append(f"\n**Output:**\n```\n{result['stdout'].strip()}\n```") |
|
if result["stderr"]: |
|
summary.append(f"\n**Warnings/Errors:**\n```\n{result['stderr'].strip()}\n```") |
|
for name, df in result["dataframes"]: |
|
summary.append(f"\n**DataFrame `{name}` Preview:**\n```\n{df}\n```") |
|
if result["plots"]: |
|
summary.append(f"\n📊 {len(result['plots'])} plot(s) generated (base64-encoded).") |
|
else: |
|
summary.append(f"❌ Execution failed for **{language.upper()}**") |
|
if result["stderr"]: |
|
summary.append(f"\n**Error:**\n```\n{result['stderr'].strip()}\n```") |
|
|
|
return "\n".join(summary) |
|
|
|
|
|
tools = [multiply, add, subtract, divide, modulus, |
|
wiki_search, web_search, arvix_search, read_excel_file, extract_text_from_pdf, |
|
blip_image_caption, save_and_read_file, download_file_from_url, analyze_csv_file, |
|
execute_code_multilang] |
|
|
|
|
|
system_prompt_path = "system_prompt.txt" |
|
if os.path.exists(system_prompt_path): |
|
with open(system_prompt_path, "r", encoding="utf-8") as f: |
|
system_prompt = f.read() |
|
else: |
|
system_prompt = ( |
|
"You are an intelligent AI agent who can solve math, science, factual, and research-based problems. " |
|
"You can use tools like Wikipedia, Web search, or Arxiv when needed. Always give precise and helpful answers." |
|
) |
|
sys_msg = SystemMessage(content=system_prompt) |
|
|
|
|
|
def build_graph(provider: str = "groq"): |
|
if provider == "google": |
|
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) |
|
elif provider == "groq": |
|
groq_key = os.getenv("GROQ_API_KEY") |
|
if not groq_key: |
|
raise ValueError("GROQ_API_KEY is not set.") |
|
llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=groq_key) |
|
elif provider == "huggingface": |
|
llm = ChatHuggingFace( |
|
llm=HuggingFaceEndpoint( |
|
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", |
|
temperature=0 |
|
) |
|
) |
|
elif provider == "openai": |
|
openai_key = os.getenv("OPENAI_API_KEY") |
|
if not openai_key: |
|
raise ValueError("OPENAI_API_KEY is not set.") |
|
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=openai_key) |
|
else: |
|
raise ValueError("Invalid provider") |
|
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
def assistant(state: MessagesState): |
|
return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]} |
|
|
|
SUPABASE_URL = os.getenv("SUPABASE_URL") |
|
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY") |
|
supabase = create_client(SUPABASE_URL, SUPABASE_KEY) |
|
|
|
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
|
vectorstore = SupabaseVectorStore( |
|
client=supabase, |
|
embedding=embedding_model, |
|
table_name="QA_db" |
|
) |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 1}) |
|
|
|
|
|
|
|
original_fn = vectorstore.similarity_search_by_vector_with_relevance_scores |
|
|
|
|
|
def patched_fn(embedding, k=4, filter=None, **kwargs): |
|
response = supabase.rpc( |
|
"match_documents", |
|
{ |
|
"query_embedding": embedding, |
|
"match_count": k |
|
} |
|
).execute() |
|
|
|
documents = [] |
|
for r in response.data: |
|
metadata = r["metadata"] |
|
if isinstance(metadata, str): |
|
try: |
|
metadata = json.loads(metadata) |
|
except Exception: |
|
metadata = {} |
|
doc = Document( |
|
page_content=r["content"], |
|
metadata=metadata |
|
) |
|
documents.append((doc, r["similarity"])) |
|
return documents |
|
|
|
|
|
vectorstore.similarity_search_by_vector_with_relevance_scores = patched_fn |
|
|
|
def qa_retriever_node(state: MessagesState): |
|
user_question = state["messages"][-1].content |
|
docs = retriever.invoke(user_question) |
|
if docs: |
|
return { |
|
"messages": state["messages"] + [AIMessage(content=docs[0].page_content)], |
|
"__condition__": "complete" |
|
} |
|
return {"messages": state["messages"], "__condition__": "default"} |
|
|
|
builder = StateGraph(MessagesState) |
|
builder.add_node("retriever", qa_retriever_node) |
|
builder.add_node("assistant", assistant) |
|
builder.add_node("tools", ToolNode(tools)) |
|
|
|
builder.add_edge(START, "retriever") |
|
builder.add_conditional_edges("retriever", { |
|
"default": lambda x: "assistant", |
|
"complete": lambda x: END, |
|
}) |
|
builder.add_conditional_edges("assistant", tools_condition) |
|
builder.add_edge("tools", "assistant") |
|
|
|
return builder.compile() |
|
|
|
|
|
if __name__ == "__main__": |
|
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" |
|
graph = build_graph(provider="openai") |
|
messages = graph.invoke({"messages": [HumanMessage(content=question)]}) |
|
print("=== AI Agent Response ===") |
|
for m in messages["messages"]: |
|
m.pretty_print() |
|
|