{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "d1e79cc0", "metadata": {}, "outputs": [], "source": [ "\"\"\"LangGraph Agent\"\"\"\n", "import os\n", "from dotenv import load_dotenv\n", "from langgraph.graph import START, StateGraph, MessagesState\n", "from langgraph.prebuilt import tools_condition\n", "from langgraph.prebuilt import ToolNode\n", "from langchain_google_genai import ChatGoogleGenerativeAI\n", "from langchain_groq import ChatGroq\n", "from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings\n", "from langchain_community.tools.tavily_search import TavilySearchResults\n", "from langchain_community.document_loaders import WikipediaLoader\n", "from langchain_community.document_loaders import ArxivLoader\n", "from langchain_community.vectorstores import SupabaseVectorStore\n", "from langchain_core.messages import SystemMessage, HumanMessage\n", "from langchain_core.tools import tool\n", "from langchain.tools.retriever import create_retriever_tool\n", "from supabase.client import Client, create_client\n", "\n", "load_dotenv()\n", "\n", "@tool\n", "def multiply(a: int, b: int) -> int:\n", " \"\"\"Multiply two numbers.\n", " Args:\n", " a: first int\n", " b: second int\n", " \"\"\"\n", " return a * b\n", "\n", "@tool\n", "def add(a: int, b: int) -> int:\n", " \"\"\"Add two numbers.\n", " \n", " Args:\n", " a: first int\n", " b: second int\n", " \"\"\"\n", " return a + b\n", "\n", "@tool\n", "def subtract(a: int, b: int) -> int:\n", " \"\"\"Subtract two numbers.\n", " \n", " Args:\n", " a: first int\n", " b: second int\n", " \"\"\"\n", " return a - b\n", "\n", "@tool\n", "def divide(a: int, b: int) -> int:\n", " \"\"\"Divide two numbers.\n", " \n", " Args:\n", " a: first int\n", " b: second int\n", " \"\"\"\n", " if b == 0:\n", " raise ValueError(\"Cannot divide by zero.\")\n", " return a / b\n", "\n", "@tool\n", "def modulus(a: int, b: int) -> int:\n", " \"\"\"Get the modulus of two numbers.\n", " \n", " Args:\n", " a: first int\n", " b: second int\n", " \"\"\"\n", " return a % b\n", "\n", "@tool\n", "def wiki_search(query: str) -> str:\n", " \"\"\"Search Wikipedia for a query and return maximum 2 results.\n", " \n", " Args:\n", " query: The search query.\"\"\"\n", " search_docs = WikipediaLoader(query=query, load_max_docs=2).load()\n", " formatted_search_docs = \"\\n\\n---\\n\\n\".join(\n", " [\n", " f'\\n{doc.page_content}\\n'\n", " for doc in search_docs\n", " ])\n", " return {\"wiki_results\": formatted_search_docs}\n", "\n", "@tool\n", "def web_search(query: str) -> str:\n", " \"\"\"Search Tavily for a query and return maximum 3 results.\n", " \n", " Args:\n", " query: The search query.\"\"\"\n", " search_docs = TavilySearchResults(max_results=3).invoke(query=query)\n", " formatted_search_docs = \"\\n\\n---\\n\\n\".join(\n", " [\n", " f'\\n{doc.page_content}\\n'\n", " for doc in search_docs\n", " ])\n", " return {\"web_results\": formatted_search_docs}\n", "\n", "@tool\n", "def arvix_search(query: str) -> str:\n", " \"\"\"Search Arxiv for a query and return maximum 3 result.\n", " \n", " Args:\n", " query: The search query.\"\"\"\n", " search_docs = ArxivLoader(query=query, load_max_docs=3).load()\n", " formatted_search_docs = \"\\n\\n---\\n\\n\".join(\n", " [\n", " f'\\n{doc.page_content[:1000]}\\n'\n", " for doc in search_docs\n", " ])\n", " return {\"arvix_results\": formatted_search_docs}\n", "\n", "\n", "\n", "# load the system prompt from the file\n", "with open(\"system_prompt.txt\", \"r\", encoding=\"utf-8\") as f:\n", " system_prompt = f.read()\n", "\n", "# System message\n", "sys_msg = SystemMessage(content=system_prompt)\n", "\n", "# build a retriever\n", "embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-mpnet-base-v2\") # dim=768\n", "supabase_url = \"https://ajnakgegqblhwltzkzbz.supabase.co\"\n", "supabase_key = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFqbmFrZ2VncWJsaHdsdHpremJ6Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkyMDgxODgsImV4cCI6MjA2NDc4NDE4OH0.b9RPF-5otedg4yiaQu_uhOgYpXVXd9D_0oR-9cluUjo\"\n", "\n", "supabase: Client = create_client(supabase_url, supabase_key)\n", "vector_store = SupabaseVectorStore(\n", " client=supabase,\n", " embedding= embeddings,\n", " table_name=\"documents\",\n", " query_name=\"match_documents_langchain\",\n", ")\n", "create_retriever_tool = create_retriever_tool(\n", " retriever=vector_store.as_retriever(),\n", " name=\"Question Search\",\n", " description=\"A tool to retrieve similar questions from a vector store.\",\n", ")\n", "\n", "\n", "\n", "tools = [\n", " multiply,\n", " add,\n", " subtract,\n", " divide,\n", " modulus,\n", " wiki_search,\n", " web_search,\n", " arvix_search,\n", "]\n", "\n", "# Build graph function\n", "def build_graph(provider: str = \"google\"):\n", " \"\"\"Build the graph\"\"\"\n", " # Load environment variables from .env file\n", " if provider == \"google\":\n", " # Google Gemini\n", " llm = ChatGoogleGenerativeAI(model=\"gemini-2.0-flash\", temperature=0)\n", " elif provider == \"groq\":\n", " # Groq https://console.groq.com/docs/models\n", " llm = ChatGroq(model=\"qwen-qwq-32b\",api_key=\"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\", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it\n", " elif provider == \"huggingface\":\n", " # TODO: Add huggingface endpoint\n", " llm = ChatHuggingFace(\n", " llm=HuggingFaceEndpoint(\n", " url=\"https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf\",\n", " temperature=0,\n", " ),\n", " )\n", " else:\n", " raise ValueError(\"Invalid provider. Choose 'google', 'groq' or 'huggingface'.\")\n", " # Bind tools to LLM\n", " llm_with_tools = llm.bind_tools(tools)\n", "\n", " # Node\n", " def assistant(state: MessagesState):\n", " \"\"\"Assistant node\"\"\"\n", " return {\"messages\": [llm_with_tools.invoke(state[\"messages\"])]}\n", " \n", " # def retriever(state: MessagesState):\n", " # \"\"\"Retriever node\"\"\"\n", " # similar_question = vector_store.similarity_search(state[\"messages\"][0].content)\n", " #example_msg = HumanMessage(\n", " # content=f\"Here I provide a similar question and answer for reference: \\n\\n{similar_question[0].page_content}\",\n", " # )\n", " # return {\"messages\": [sys_msg] + state[\"messages\"] + [example_msg]}\n", "\n", " from langchain_core.messages import AIMessage\n", "\n", " def retriever(state: MessagesState):\n", " query = state[\"messages\"][-1].content\n", " similar_doc = vector_store.similarity_search(query, k=1)[0]\n", "\n", " content = similar_doc.page_content\n", " if \"Final answer :\" in content:\n", " answer = content.split(\"Final answer :\")[-1].strip()\n", " else:\n", " answer = content.strip()\n", "\n", " return {\"messages\": [AIMessage(content=answer)]}\n", "\n", " # builder = StateGraph(MessagesState)\n", " #builder.add_node(\"retriever\", retriever)\n", " #builder.add_node(\"assistant\", assistant)\n", " #builder.add_node(\"tools\", ToolNode(tools))\n", " #builder.add_edge(START, \"retriever\")\n", " #builder.add_edge(\"retriever\", \"assistant\")\n", " #builder.add_conditional_edges(\n", " # \"assistant\",\n", " # tools_condition,\n", " #)\n", " #builder.add_edge(\"tools\", \"assistant\")\n", "\n", " builder = StateGraph(MessagesState)\n", " builder.add_node(\"retriever\", retriever)\n", "\n", " # Retriever ist Start und Endpunkt\n", " builder.set_entry_point(\"retriever\")\n", " builder.set_finish_point(\"retriever\")\n", "\n", " # Compile graph\n", " return builder.compile()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "abc55916", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }