xTwin / knowledge_engine.py
aamirhameed's picture
Update knowledge_engine.py
5fce8b9 verified
raw
history blame
2.59 kB
import os
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
class KnowledgeManager:
def __init__(self, root_dir="."):
self.root_dir = root_dir
self.docsearch = None
self.qa_chain = None
self.llm = None
self._initialize_llm()
self._initialize_embeddings()
self._load_knowledge_base()
def _initialize_llm(self):
model_id = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype="auto", # Will use float16 on GPU, float32 on CPU
device_map="auto"
)
falcon_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1
)
self.llm = HuggingFacePipeline(pipeline=falcon_pipeline)
def _initialize_embeddings(self):
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
def _load_knowledge_base(self):
txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
if not txt_files:
raise FileNotFoundError("No .txt files found in root directory.")
all_texts = []
for filename in txt_files:
path = os.path.join(self.root_dir, filename)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
all_texts.append(content)
full_text = "\n\n".join(all_texts)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.create_documents([full_text])
self.docsearch = FAISS.from_documents(docs, self.embeddings)
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.docsearch.as_retriever(),
return_source_documents=True,
)
def ask(self, query):
if not self.qa_chain:
raise ValueError("Knowledge base not initialized.")
result = self.qa_chain(query)
return result["result"]