Spaces:
Sleeping
Sleeping
from typing import List, TypedDict, Annotated, Optional, Dict, Union | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import SystemMessage, HumanMessage, AnyMessage | |
from langgraph.graph.message import add_messages | |
from langchain_community.vectorstores import SupabaseVectorStore | |
from supabase.client import create_client | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFaceEndpoint | |
from serpapi import GoogleSearch | |
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
class AgentState(TypedDict): | |
"""Agent state to be passed to the tool.""" | |
messages: Annotated[List[AnyMessage], add_messages] | |
def add(a: Union[float , int], b: Union[float , int]) -> Union[float , int]: | |
"""Add two numbers.""" | |
return a + b | |
def subtract(a: Union[float , int], b: Union[float , int]) -> Union[float , int]: | |
"""Subtract two numbers.""" | |
return a - b | |
def multiply(a: Union[float , int], b: Union[float , int]) -> Union[float , int]: | |
"""Multiply two numbers.""" | |
return a * b | |
def divide(a: Union[float , int], b: Union[float , int]) -> Union[float , int , None]: | |
"""Divide two numbers.""" | |
if b == 0: | |
return None | |
return a / b | |
def web_search(query: str) -> str: | |
"""Perform a web search using SerpAPI.""" | |
params = { | |
"engine": "google", | |
"q": query, | |
"api_key": os.getenv("SERPAPI_KEY"), | |
"num": 5 | |
} | |
search = GoogleSearch(params) | |
results = search.get_dict()["organic_results"] | |
context = "\n---\n".join([ | |
"Title: " + result['title'] + "\nLink: " + result['link'] + "\nSnippet: " + result.get('snippet', 'No snippet available') | |
for result in results if 'title' in result and 'link' in result | |
] | |
) | |
return context if context else "No results found." | |
# llm = ChatHuggingFace(llm = HuggingFaceEndpoint( | |
# repo_id = "meta-llama/Llama-2-7b-chat-hf", | |
# temperature=0, | |
# huggingfacehub_api_token=os.environ.get("HUGGING_FACE_API_KEY"))) | |
tools = [add, subtract, divide, web_search] | |
llm =ChatGoogleGenerativeAI(model = "gemini-2.0-flash") | |
llm_with_tools = llm.bind_tools(tools) | |
def retriever(state: AgentState) -> Dict: | |
""" | |
Retrieve the answer fom vector database instead of searching if we found a user query similar to which is already found in the dataset | |
""" | |
supabase_url = os.environ.get("SUPABASE_URL") | |
supabase_key = os.environ.get("SUPABASE_KEY") | |
supabase = create_client(supabase_url, supabase_key) | |
embeddings = HuggingFaceEmbeddings(model_name = "sentence-transformers/all-mpnet-base-v2") | |
vector_store = SupabaseVectorStore( | |
embedding=embeddings, | |
client=supabase, | |
table_name="documents", | |
query_name="match_documents", | |
) | |
docs = vector_store.similarity_search(query = state["messages"][-1].content, k = 1) | |
humanmessage = HumanMessage(content = f"Here are some of the questions and answers relevant to user query \n\n {docs[0].page_content}") | |
return {"messages":[humanmessage]} | |
def assistant(state: AgentState) -> Dict: | |
system_message = """ | |
You are a helpful assistant tasked with answering questions using a set of tools. | |
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: | |
FINAL ANSWER: [YOUR FINAL ANSWER]. | |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. | |
""" | |
tools_description = """ | |
You have the following tools available to perform actions | |
websearch(query: str) -> str: | |
Args: | |
query: Search query | |
Returns: | |
A string containing 5 relevant search results | |
add(a: Union[float , int], b: Union[float , int]) -> Union[float , int]: | |
Add two numbers | |
subtract(a: Union[float , int], b: Union[float , int]) -> Union[float , int]: | |
Subtract two numbers | |
multiply(a: Union[float , int], b: Union[float , int]) -> Union[float , int]: | |
Multiply two numbers | |
divide(a: Union[float , int], b: Union[float , int]) -> Union[float , int , None]: | |
Divide two numbers | |
""" | |
sys_msg = SystemMessage(content=system_message + tools_description) | |
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} |