BinghamtonAI / Rag_conversation.py
ashfaq93's picture
Update Rag_conversation.py
bd3dbe8 verified
import os
from dotenv import load_dotenv
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# Load environment variables from .env
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY is not set. Please add it via Hugging Face Secrets.")
# Define the persistent directory
current_dir = os.path.dirname(os.path.abspath(__file__))
persistent_directory = os.path.join(current_dir, ".") # Use root directory
# Define the embedding model
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
# Load the existing vector store with the embedding function
db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings)
# Create a retriever for querying the vector store
# `search_type` specifies the type of search (e.g., similarity)
# `search_kwargs` contains additional arguments for the search (e.g., number of results to return)
'''retriever = db.as_retriever(
search_type="similarity",
search_kwargs={"k": 4},
)'''
retriever = db.as_retriever(
search_type="mmr", # Maximal Marginal Relevance (MMR) for diversity
search_kwargs={"k": 4, "fetch_k": 10} # Fetch more results for better selection
)
# Create a ChatOpenAI model
llm = ChatOpenAI(model="gpt-4o",temperature=0.2)
# Contextualize question prompt
# This system prompt helps the AI understand that it should reformulate the question
# based on the chat history to make it a standalone question
contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, just "
"reformulate it if needed and otherwise return it as is."
)
# Create a prompt template for contextualizing questions
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
# Create a history-aware retriever
# This uses the LLM to help reformulate the question based on chat history
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
# Answer question prompt
# This system prompt helps the AI understand that it should provide concise answers
# based on the retrieved context and indicates what to do if the answer is unknown
qa_system_prompt = (
"You are an assistant for answering questions at Binghamton University."
"Use the retrieved context to generate a structured response with bullet points where appropriate."
"\n\n{context}"
"\n\nIf you don't know the answer, simply state that fact."
)
# Create a prompt template for answering questions
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
# Create a chain to combine documents for question answering
# `create_stuff_documents_chain` feeds all retrieved context into the LLM
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
# Create a retrieval chain that combines the history-aware retriever and the question answering chain
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Function to simulate a continual chat
def continual_chat():
print("Start chatting with the AI! Type 'exit' to end the conversation.")
chat_history = [] # Collect chat history here (a sequence of messages)
while True:
query = input("You: ")
if query.lower() == "exit":
break
# Process the user's query through the retrieval chain
result = rag_chain.invoke({"input": query, "chat_history": chat_history})
# Display the AI's response
print(f"AI: {result['answer']}")
# Update the chat history
chat_history.append(HumanMessage(content=query))
chat_history.append(SystemMessage(content=result["answer"]))
# Main function to start the continual chat
if __name__ == "__main__":
continual_chat()