MaryamKarimi080's picture
Update scripts/rag_chat.py
8a603a3 verified
import os
from pathlib import Path
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain.prompts import PromptTemplate
BASE_DIR = Path(__file__).resolve().parent.parent
DB_DIR = BASE_DIR / "db"
def build_general_qa_chain(model_name=None):
if not DB_DIR.exists():
print("📦 No DB found. Building vectorstore...")
from scripts import load_documents, chunk_and_embed, setup_vectorstore
load_documents.main()
chunk_and_embed.main()
setup_vectorstore.main()
embedding = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = Chroma(persist_directory=str(DB_DIR), embedding_function=embedding)
template = """Use the following context to answer the question.
If the answer isn't found in the context, use your general knowledge but say so.
Do not answer questions that are completely irrelevant to the main points of the context.
Always cite your sources at the end with 'Source: <filename>' when using course materials.
Context: {context}
Question: {question}
Helpful Answer:"""
QA_PROMPT = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
chain_type_kwargs={"prompt": QA_PROMPT},
return_source_documents=True
)
return qa_chain