CrimsonElephant's picture
updated RAG and APP to clean up previous errors
263b899
# rag.py
import os
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFaceHub
import torch
# This function is responsible for loading all text files from the directory
def load_documents():
docs = []
file_list = [
"milindgatha.txt",
"bhaktas.txt",
"apologetics.txt",
"poc_questions.txt",
"satire_offerings.txt"
]
for filename in file_list:
if os.path.exists(filename):
loader = TextLoader(filename)
docs.extend(loader.load())
return docs
# This function initializes and returns the RAG chain
def get_rag_chain():
docs = load_documents()
if not docs:
print("No documents found. Please check your data files.")
return None
# Use a CPU-only device to prevent memory errors on the free tier
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": device}
)
db = FAISS.from_documents(docs, embeddings)
llm = HuggingFaceHub(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
model_kwargs={"temperature": 0.2, "max_new_tokens": 500}
)
retriever = db.as_retriever(search_kwargs={"k": 2})
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
return qa_chain