Spaces:
Runtime error
Runtime error
import os | |
from dotenv import load_dotenv | |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
# Load environment variables from .env | |
load_dotenv() | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
if not OPENAI_API_KEY: | |
raise ValueError("OPENAI_API_KEY is not set. Please add it via Hugging Face Secrets.") | |
# Define the persistent directory | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
persistent_directory = os.path.join(current_dir, ".") # Use root directory | |
# Define the embedding model | |
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") | |
# Load the existing vector store with the embedding function | |
db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings) | |
# Create a retriever for querying the vector store | |
# `search_type` specifies the type of search (e.g., similarity) | |
# `search_kwargs` contains additional arguments for the search (e.g., number of results to return) | |
'''retriever = db.as_retriever( | |
search_type="similarity", | |
search_kwargs={"k": 4}, | |
)''' | |
retriever = db.as_retriever( | |
search_type="mmr", # Maximal Marginal Relevance (MMR) for diversity | |
search_kwargs={"k": 4, "fetch_k": 10} # Fetch more results for better selection | |
) | |
# Create a ChatOpenAI model | |
llm = ChatOpenAI(model="gpt-4o",temperature=0.2) | |
# Contextualize question prompt | |
# This system prompt helps the AI understand that it should reformulate the question | |
# based on the chat history to make it a standalone question | |
contextualize_q_system_prompt = ( | |
"Given a chat history and the latest user question " | |
"which might reference context in the chat history, " | |
"formulate a standalone question which can be understood " | |
"without the chat history. Do NOT answer the question, just " | |
"reformulate it if needed and otherwise return it as is." | |
) | |
# Create a prompt template for contextualizing questions | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
# Create a history-aware retriever | |
# This uses the LLM to help reformulate the question based on chat history | |
history_aware_retriever = create_history_aware_retriever( | |
llm, retriever, contextualize_q_prompt | |
) | |
# Answer question prompt | |
# This system prompt helps the AI understand that it should provide concise answers | |
# based on the retrieved context and indicates what to do if the answer is unknown | |
qa_system_prompt = ( | |
"You are an assistant for answering questions at Binghamton University." | |
"Use the retrieved context to generate a structured response with bullet points where appropriate." | |
"\n\n{context}" | |
"\n\nIf you don't know the answer, simply state that fact." | |
) | |
# Create a prompt template for answering questions | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", qa_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
# Create a chain to combine documents for question answering | |
# `create_stuff_documents_chain` feeds all retrieved context into the LLM | |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
# Create a retrieval chain that combines the history-aware retriever and the question answering chain | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
# Function to simulate a continual chat | |
def continual_chat(): | |
print("Start chatting with the AI! Type 'exit' to end the conversation.") | |
chat_history = [] # Collect chat history here (a sequence of messages) | |
while True: | |
query = input("You: ") | |
if query.lower() == "exit": | |
break | |
# Process the user's query through the retrieval chain | |
result = rag_chain.invoke({"input": query, "chat_history": chat_history}) | |
# Display the AI's response | |
print(f"AI: {result['answer']}") | |
# Update the chat history | |
chat_history.append(HumanMessage(content=query)) | |
chat_history.append(SystemMessage(content=result["answer"])) | |
# Main function to start the continual chat | |
if __name__ == "__main__": | |
continual_chat() |