Spaces:
Sleeping
Sleeping
import os | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain.llms import HuggingFacePipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
class KnowledgeManager: | |
def __init__(self, root_dir="."): | |
self.root_dir = root_dir | |
self.docsearch = None | |
self.qa_chain = None | |
self.llm = None | |
self._initialize_llm() | |
self._initialize_embeddings() | |
self._load_knowledge_base() | |
def _initialize_llm(self): | |
model_id = "tiiuae/falcon-7b-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
trust_remote_code=True, | |
torch_dtype="auto", # Will use float16 on GPU, float32 on CPU | |
device_map="auto" | |
) | |
falcon_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.1 | |
) | |
self.llm = HuggingFacePipeline(pipeline=falcon_pipeline) | |
def _initialize_embeddings(self): | |
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def _load_knowledge_base(self): | |
txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")] | |
if not txt_files: | |
raise FileNotFoundError("No .txt files found in root directory.") | |
all_texts = [] | |
for filename in txt_files: | |
path = os.path.join(self.root_dir, filename) | |
with open(path, "r", encoding="utf-8") as f: | |
content = f.read() | |
all_texts.append(content) | |
full_text = "\n\n".join(all_texts) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
docs = text_splitter.create_documents([full_text]) | |
self.docsearch = FAISS.from_documents(docs, self.embeddings) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.docsearch.as_retriever(), | |
return_source_documents=True, | |
) | |
def ask(self, query): | |
if not self.qa_chain: | |
raise ValueError("Knowledge base not initialized.") | |
result = self.qa_chain(query) | |
return result["result"] | |