from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import ArxivLoader
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_ollama import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFaceEndpoint
from langgraph.graph import START, StateGraph, MessagesState
# from langchain_chroma import Chroma
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
import os
from dotenv import load_dotenv
load_dotenv()
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers and return the result.
Args:
a (int): The first number.
b (int): The second number.
Returns:
int: The product of the two numbers.
"""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers and return the result.
Args:
a (int): The first number.
b (int): The second number.
Returns:
int: The sum of the two numbers.
"""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers and return the result.
Args:
a (int): The first number.
b (int): The second number.
Returns:
int: The difference between the two numbers.
"""
return a - b
@tool
def divide(a: int, b: int) -> int:
"""Divide two numbers and return the result.
Args:
a (int): The first number.
b (int): The second number.
Returns:
int: The quotient of the two numbers.
"""
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Calculate the modulus of two numbers and return the result.
Args:
a (int): The first number.
b (int): The second number.
Returns:
int: The modulus of the two numbers.
"""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a given query and return the top result.
Args:
query (str): The search query.
"""
search_docs = WikipediaLoader(query, load_max_docs=2).load()
formatted_search_docs = '\n\n---\n\n'.join(
[
f'\n{doc.page_content}\n' for doc in search_docs
]
)
return {'wiki_results': formatted_search_docs}
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return maximum 3 results
Args:
query (str): The search query.
"""
search_docs = TavilySearchResults(max_results=3).invoke(query)
formatted_search_docs = '\n\n---\n\n'.join(
[
f'\n{doc.get("content", "")}\n' for doc in search_docs
]
)
return {'web_results': formatted_search_docs}
@tool
def arvix_search(query: str) -> str:
"""Search Arvix for a query and return maximum 3 results
Args:
query (str): The search query.
"""
search_docs = ArxivLoader(query, load_max_docs=3).load()
formatted_search_docs = '\n\n---\n\n'.join(
[
f'\n{doc.page_content}\n' for doc in search_docs
]
)
return {'arvix_results': formatted_search_docs}
# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
# System message
sys_msg = SystemMessage(content=system_prompt)
# Retriever
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
# vector_store = Chroma(
# collection_name="demo_collection",
# embedding_function=embeddings,
# persist_directory="./chroma_langchain_db",
# )
embedding_dim = len(embeddings.embed_query("hello world"))
index = faiss.IndexFlatL2(embedding_dim)
vector_store = FAISS(
embedding_function=embeddings,
index=index,
docstore=InMemoryDocstore(),
index_to_docstore_id={},
)
create_retriever_tool = create_retriever_tool(
retriever= vector_store.as_retriever(),
name='Question Search',
description='A tool to retrieve similar question from vector store.'
)
tools = [
multiply,
add,
subtract,
modulus,
wiki_search,
web_search,
arvix_search
]
# build graph function
def build_graph(tag: str='huggingface'):
"""Build the graph"""
if tag == 'local':
llm = ChatOllama(model="qwen3")
elif tag == 'google':
# Google Gemini
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
elif tag == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/Qwen/Qwen3-14B"),
temperature=0,
)
else:
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
# bind tools to llm
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
return {'messages': [llm_with_tools.invoke(state['messages'])]}
def retriever(state: MessagesState):
similar_question = vector_store.similarity_search(state['messages'][0].content)
example_msg = HumanMessage(
content=f''
)
return {'messages': [sys_msg] + state['messages'] + [example_msg]}
builder = StateGraph(MessagesState)
builder.add_node('retriever', retriever)
builder.add_node('assistant', assistant)
builder.add_node('tools', ToolNode(tools))
builder.add_edge(START, 'retriever')
builder.add_edge('retriever', 'assistant')
builder.add_conditional_edges(
'assistant',
tools_condition
)
builder.add_edge('tools', 'assistant')
# builder.set_entry_point("retriever")
# builder.set_finish_point("retriever")
return builder.compile()
# test
if __name__ == "__main__":
question = 'When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?'
# build the graph
graph = build_graph('local')
# run the graph
messages = [HumanMessage(content=question)]
messages = graph.invoke({'messages': messages})
for m in messages['messages']:
m.pretty_print()