{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "d1e79cc0",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"LangGraph Agent\"\"\"\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from langgraph.graph import START, StateGraph, MessagesState\n",
"from langgraph.prebuilt import tools_condition\n",
"from langgraph.prebuilt import ToolNode\n",
"from langchain_google_genai import ChatGoogleGenerativeAI\n",
"from langchain_groq import ChatGroq\n",
"from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings\n",
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
"from langchain_community.document_loaders import WikipediaLoader\n",
"from langchain_community.document_loaders import ArxivLoader\n",
"from langchain_community.vectorstores import SupabaseVectorStore\n",
"from langchain_core.messages import SystemMessage, HumanMessage\n",
"from langchain_core.tools import tool\n",
"from langchain.tools.retriever import create_retriever_tool\n",
"from supabase.client import Client, create_client\n",
"\n",
"load_dotenv()\n",
"\n",
"@tool\n",
"def multiply(a: int, b: int) -> int:\n",
" \"\"\"Multiply two numbers.\n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" return a * b\n",
"\n",
"@tool\n",
"def add(a: int, b: int) -> int:\n",
" \"\"\"Add two numbers.\n",
" \n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" return a + b\n",
"\n",
"@tool\n",
"def subtract(a: int, b: int) -> int:\n",
" \"\"\"Subtract two numbers.\n",
" \n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" return a - b\n",
"\n",
"@tool\n",
"def divide(a: int, b: int) -> int:\n",
" \"\"\"Divide two numbers.\n",
" \n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" if b == 0:\n",
" raise ValueError(\"Cannot divide by zero.\")\n",
" return a / b\n",
"\n",
"@tool\n",
"def modulus(a: int, b: int) -> int:\n",
" \"\"\"Get the modulus of two numbers.\n",
" \n",
" Args:\n",
" a: first int\n",
" b: second int\n",
" \"\"\"\n",
" return a % b\n",
"\n",
"@tool\n",
"def wiki_search(query: str) -> str:\n",
" \"\"\"Search Wikipedia for a query and return maximum 2 results.\n",
" \n",
" Args:\n",
" query: The search query.\"\"\"\n",
" search_docs = WikipediaLoader(query=query, load_max_docs=2).load()\n",
" formatted_search_docs = \"\\n\\n---\\n\\n\".join(\n",
" [\n",
" f'\\n{doc.page_content}\\n'\n",
" for doc in search_docs\n",
" ])\n",
" return {\"wiki_results\": formatted_search_docs}\n",
"\n",
"@tool\n",
"def web_search(query: str) -> str:\n",
" \"\"\"Search Tavily for a query and return maximum 3 results.\n",
" \n",
" Args:\n",
" query: The search query.\"\"\"\n",
" search_docs = TavilySearchResults(max_results=3).invoke(query=query)\n",
" formatted_search_docs = \"\\n\\n---\\n\\n\".join(\n",
" [\n",
" f'\\n{doc.page_content}\\n'\n",
" for doc in search_docs\n",
" ])\n",
" return {\"web_results\": formatted_search_docs}\n",
"\n",
"@tool\n",
"def arvix_search(query: str) -> str:\n",
" \"\"\"Search Arxiv for a query and return maximum 3 result.\n",
" \n",
" Args:\n",
" query: The search query.\"\"\"\n",
" search_docs = ArxivLoader(query=query, load_max_docs=3).load()\n",
" formatted_search_docs = \"\\n\\n---\\n\\n\".join(\n",
" [\n",
" f'\\n{doc.page_content[:1000]}\\n'\n",
" for doc in search_docs\n",
" ])\n",
" return {\"arvix_results\": formatted_search_docs}\n",
"\n",
"\n",
"\n",
"# load the system prompt from the file\n",
"with open(\"system_prompt.txt\", \"r\", encoding=\"utf-8\") as f:\n",
" system_prompt = f.read()\n",
"\n",
"# System message\n",
"sys_msg = SystemMessage(content=system_prompt)\n",
"\n",
"# build a retriever\n",
"embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-mpnet-base-v2\") # dim=768\n",
"supabase_url = \"https://ajnakgegqblhwltzkzbz.supabase.co\"\n",
"supabase_key = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFqbmFrZ2VncWJsaHdsdHpremJ6Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkyMDgxODgsImV4cCI6MjA2NDc4NDE4OH0.b9RPF-5otedg4yiaQu_uhOgYpXVXd9D_0oR-9cluUjo\"\n",
"\n",
"supabase: Client = create_client(supabase_url, supabase_key)\n",
"vector_store = SupabaseVectorStore(\n",
" client=supabase,\n",
" embedding= embeddings,\n",
" table_name=\"documents\",\n",
" query_name=\"match_documents_langchain\",\n",
")\n",
"create_retriever_tool = create_retriever_tool(\n",
" retriever=vector_store.as_retriever(),\n",
" name=\"Question Search\",\n",
" description=\"A tool to retrieve similar questions from a vector store.\",\n",
")\n",
"\n",
"\n",
"\n",
"tools = [\n",
" multiply,\n",
" add,\n",
" subtract,\n",
" divide,\n",
" modulus,\n",
" wiki_search,\n",
" web_search,\n",
" arvix_search,\n",
"]\n",
"\n",
"# Build graph function\n",
"def build_graph(provider: str = \"google\"):\n",
" \"\"\"Build the graph\"\"\"\n",
" # Load environment variables from .env file\n",
" if provider == \"google\":\n",
" # Google Gemini\n",
" llm = ChatGoogleGenerativeAI(model=\"gemini-2.0-flash\", temperature=0)\n",
" elif provider == \"groq\":\n",
" # Groq https://console.groq.com/docs/models\n",
" llm = ChatGroq(model=\"qwen-qwq-32b\",api_key=\"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\", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it\n",
" elif provider == \"huggingface\":\n",
" # TODO: Add huggingface endpoint\n",
" llm = ChatHuggingFace(\n",
" llm=HuggingFaceEndpoint(\n",
" url=\"https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf\",\n",
" temperature=0,\n",
" ),\n",
" )\n",
" else:\n",
" raise ValueError(\"Invalid provider. Choose 'google', 'groq' or 'huggingface'.\")\n",
" # Bind tools to LLM\n",
" llm_with_tools = llm.bind_tools(tools)\n",
"\n",
" # Node\n",
" def assistant(state: MessagesState):\n",
" \"\"\"Assistant node\"\"\"\n",
" return {\"messages\": [llm_with_tools.invoke(state[\"messages\"])]}\n",
" \n",
" # def retriever(state: MessagesState):\n",
" # \"\"\"Retriever node\"\"\"\n",
" # similar_question = vector_store.similarity_search(state[\"messages\"][0].content)\n",
" #example_msg = HumanMessage(\n",
" # content=f\"Here I provide a similar question and answer for reference: \\n\\n{similar_question[0].page_content}\",\n",
" # )\n",
" # return {\"messages\": [sys_msg] + state[\"messages\"] + [example_msg]}\n",
"\n",
" from langchain_core.messages import AIMessage\n",
"\n",
" def retriever(state: MessagesState):\n",
" query = state[\"messages\"][-1].content\n",
" similar_doc = vector_store.similarity_search(query, k=1)[0]\n",
"\n",
" content = similar_doc.page_content\n",
" if \"Final answer :\" in content:\n",
" answer = content.split(\"Final answer :\")[-1].strip()\n",
" else:\n",
" answer = content.strip()\n",
"\n",
" return {\"messages\": [AIMessage(content=answer)]}\n",
"\n",
" # builder = StateGraph(MessagesState)\n",
" #builder.add_node(\"retriever\", retriever)\n",
" #builder.add_node(\"assistant\", assistant)\n",
" #builder.add_node(\"tools\", ToolNode(tools))\n",
" #builder.add_edge(START, \"retriever\")\n",
" #builder.add_edge(\"retriever\", \"assistant\")\n",
" #builder.add_conditional_edges(\n",
" # \"assistant\",\n",
" # tools_condition,\n",
" #)\n",
" #builder.add_edge(\"tools\", \"assistant\")\n",
"\n",
" builder = StateGraph(MessagesState)\n",
" builder.add_node(\"retriever\", retriever)\n",
"\n",
" # Retriever ist Start und Endpunkt\n",
" builder.set_entry_point(\"retriever\")\n",
" builder.set_finish_point(\"retriever\")\n",
"\n",
" # Compile graph\n",
" return builder.compile()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "abc55916",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}