|
|
|
import os |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain.chains import RetrievalQA |
|
from langchain_community.llms import HuggingFaceHub |
|
import torch |
|
|
|
|
|
def load_documents(): |
|
docs = [] |
|
file_list = [ |
|
"milindgatha.txt", |
|
"bhaktas.txt", |
|
"apologetics.txt", |
|
"poc_questions.txt", |
|
"satire_offerings.txt" |
|
] |
|
for filename in file_list: |
|
if os.path.exists(filename): |
|
loader = TextLoader(filename) |
|
docs.extend(loader.load()) |
|
return docs |
|
|
|
|
|
def get_rag_chain(): |
|
docs = load_documents() |
|
if not docs: |
|
print("No documents found. Please check your data files.") |
|
return None |
|
|
|
|
|
device = "cpu" |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/all-MiniLM-L6-v2", |
|
model_kwargs={"device": device} |
|
) |
|
|
|
db = FAISS.from_documents(docs, embeddings) |
|
|
|
llm = HuggingFaceHub( |
|
repo_id="mistralai/Mistral-7B-Instruct-v0.2", |
|
model_kwargs={"temperature": 0.2, "max_new_tokens": 500} |
|
) |
|
|
|
retriever = db.as_retriever(search_kwargs={"k": 2}) |
|
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever) |
|
|
|
return qa_chain |