File size: 1,628 Bytes
7ee5751
 
 
 
 
 
 
263b899
86f2022
263b899
86f2022
 
7ee5751
 
 
 
 
263b899
7ee5751
 
86f2022
7ee5751
 
86f2022
 
263b899
 
7ee5751
 
263b899
7ee5751
86f2022
263b899
 
 
 
86f2022
263b899
 
 
 
 
7ee5751
86f2022
7ee5751
 
 
 
263b899
7ee5751
 
263b899
7ee5751
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# rag.py
import os
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFaceHub
import torch

# This function is responsible for loading all text files from the directory
def load_documents():
    docs = []
    file_list = [
        "milindgatha.txt",
        "bhaktas.txt",
        "apologetics.txt",
        "poc_questions.txt",
        "satire_offerings.txt"
    ]
    for filename in file_list:
        if os.path.exists(filename):
            loader = TextLoader(filename)
            docs.extend(loader.load())
    return docs

# This function initializes and returns the RAG chain
def get_rag_chain():
    docs = load_documents()
    if not docs:
        print("No documents found. Please check your data files.")
        return None

    # Use a CPU-only device to prevent memory errors on the free tier
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        model_kwargs={"device": device}
    )
    
    db = FAISS.from_documents(docs, embeddings)

    llm = HuggingFaceHub(
        repo_id="mistralai/Mistral-7B-Instruct-v0.2",
        model_kwargs={"temperature": 0.2, "max_new_tokens": 500}
    )
    
    retriever = db.as_retriever(search_kwargs={"k": 2})
    qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
    
    return qa_chain