|
""" |
|
Knowledge base management using FAISS and HuggingFace embeddings |
|
""" |
|
|
|
import os |
|
import json |
|
import pickle |
|
from typing import List, Dict, Tuple, Optional |
|
import numpy as np |
|
import faiss |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer |
|
import hashlib |
|
from datetime import datetime |
|
from pathlib import Path |
|
|
|
class KnowledgeBase: |
|
"""Manages the vector store for knowledge retrieval""" |
|
|
|
def __init__(self, config): |
|
self.config = config |
|
self.embedding_model = SentenceTransformer(config.models.embedding_model) |
|
self.dimension = config.vector_store.dimension |
|
self.index = None |
|
self.metadata = [] |
|
self.chunks = [] |
|
self.index_path = config.INDEX_DIR |
|
self.books_path = config.BOOKS_DIR |
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.models.tinygpt2_model) |
|
|
|
|
|
self._initialize_index() |
|
|
|
def _initialize_index(self): |
|
"""Initialize or load existing FAISS index""" |
|
index_file = os.path.join(self.index_path, "knowledge.index") |
|
metadata_file = os.path.join(self.index_path, "metadata.pkl") |
|
chunks_file = os.path.join(self.index_path, "chunks.pkl") |
|
|
|
if os.path.exists(index_file) and os.path.exists(metadata_file): |
|
|
|
self.index = faiss.read_index(index_file) |
|
with open(metadata_file, 'rb') as f: |
|
self.metadata = pickle.load(f) |
|
with open(chunks_file, 'rb') as f: |
|
self.chunks = pickle.load(f) |
|
print(f"Loaded existing index with {self.index.ntotal} vectors") |
|
else: |
|
|
|
if self.config.vector_store.metric == "cosine": |
|
|
|
self.index = faiss.IndexFlatIP(self.dimension) |
|
else: |
|
|
|
self.index = faiss.IndexFlatL2(self.dimension) |
|
print("Created new index") |
|
|
|
def process_books(self, force_rebuild: bool = False): |
|
"""Process all books in the books directory""" |
|
if self.index.ntotal > 0 and not force_rebuild: |
|
print(f"Index already contains {self.index.ntotal} vectors. Use force_rebuild=True to rebuild.") |
|
return |
|
|
|
|
|
if force_rebuild: |
|
self.index = faiss.IndexFlatIP(self.dimension) if self.config.vector_store.metric == "cosine" else faiss.IndexFlatL2(self.dimension) |
|
self.metadata = [] |
|
self.chunks = [] |
|
|
|
|
|
book_files = list(Path(self.books_path).glob("*.txt")) |
|
print(f"Found {len(book_files)} books to process") |
|
|
|
for book_file in book_files: |
|
print(f"Processing {book_file.name}...") |
|
self._process_single_book(book_file) |
|
|
|
|
|
self._save_index() |
|
print(f"Processing complete. Index contains {self.index.ntotal} vectors") |
|
|
|
def _process_single_book(self, book_path: Path): |
|
"""Process a single book file""" |
|
try: |
|
|
|
with open(book_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
|
|
|
|
book_name = book_path.stem.replace('_', ' ').title() |
|
|
|
|
|
chunks = self._create_chunks(content) |
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
|
|
if not chunk.strip(): |
|
continue |
|
|
|
|
|
embedding = self._create_embedding(chunk) |
|
|
|
|
|
if self.config.vector_store.metric == "cosine": |
|
embedding = embedding / np.linalg.norm(embedding) |
|
|
|
|
|
self.index.add(np.array([embedding])) |
|
|
|
|
|
metadata = { |
|
"book": book_name, |
|
"chunk_id": i, |
|
"timestamp": datetime.now().isoformat(), |
|
"char_count": len(chunk), |
|
"checksum": hashlib.md5(chunk.encode()).hexdigest() |
|
} |
|
self.metadata.append(metadata) |
|
self.chunks.append(chunk) |
|
|
|
except Exception as e: |
|
print(f"Error processing {book_path}: {str(e)}") |
|
|
|
def _create_chunks(self, text: str) -> List[str]: |
|
"""Split text into chunks using sliding window""" |
|
|
|
text = text.strip() |
|
if not text: |
|
return [] |
|
|
|
|
|
tokens = self.tokenizer.encode(text, add_special_tokens=False) |
|
|
|
chunks = [] |
|
chunk_size = self.config.vector_store.chunk_size |
|
overlap = self.config.vector_store.chunk_overlap |
|
|
|
|
|
for i in range(0, len(tokens), chunk_size - overlap): |
|
chunk_tokens = tokens[i:i + chunk_size] |
|
chunk_text = self.tokenizer.decode(chunk_tokens, skip_special_tokens=True) |
|
chunks.append(chunk_text) |
|
|
|
return chunks |
|
|
|
def _create_embedding(self, text: str) -> np.ndarray: |
|
"""Create embedding for text""" |
|
embedding = self.embedding_model.encode(text, convert_to_numpy=True) |
|
return embedding.astype('float32') |
|
|
|
def search(self, query: str, k: int = None, filter_books: List[str] = None) -> List[Dict]: |
|
"""Search for similar chunks in the knowledge base""" |
|
if self.index.ntotal == 0: |
|
return [] |
|
|
|
k = k or self.config.vector_store.n_results |
|
|
|
|
|
query_embedding = self._create_embedding(query) |
|
|
|
|
|
if self.config.vector_store.metric == "cosine": |
|
query_embedding = query_embedding / np.linalg.norm(query_embedding) |
|
|
|
|
|
distances, indices = self.index.search( |
|
np.array([query_embedding]), |
|
min(k, self.index.ntotal) |
|
) |
|
|
|
|
|
results = [] |
|
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])): |
|
if idx < 0: |
|
continue |
|
|
|
metadata = self.metadata[idx] |
|
|
|
|
|
if filter_books and metadata["book"] not in filter_books: |
|
continue |
|
|
|
result = { |
|
"text": self.chunks[idx], |
|
"book": metadata["book"], |
|
"score": float(dist), |
|
"rank": i + 1, |
|
"metadata": metadata |
|
} |
|
results.append(result) |
|
|
|
|
|
results.sort(key=lambda x: x["score"], reverse=True) |
|
|
|
return results[:k] |
|
|
|
def search_with_context(self, query: str, k: int = None, context_window: int = 1) -> List[Dict]: |
|
"""Search and include surrounding context chunks""" |
|
results = self.search(query, k) |
|
|
|
|
|
expanded_results = [] |
|
for result in results: |
|
chunk_idx = result["metadata"]["chunk_id"] |
|
book = result["book"] |
|
|
|
|
|
context_chunks = [] |
|
|
|
|
|
for i in range(context_window, 0, -1): |
|
prev_idx = self._find_chunk_index(book, chunk_idx - i) |
|
if prev_idx is not None: |
|
context_chunks.append(self.chunks[prev_idx]) |
|
|
|
|
|
context_chunks.append(result["text"]) |
|
|
|
|
|
for i in range(1, context_window + 1): |
|
next_idx = self._find_chunk_index(book, chunk_idx + i) |
|
if next_idx is not None: |
|
context_chunks.append(self.chunks[next_idx]) |
|
|
|
|
|
expanded_result = result.copy() |
|
expanded_result["context"] = "\n\n".join(context_chunks) |
|
expanded_result["context_size"] = len(context_chunks) |
|
expanded_results.append(expanded_result) |
|
|
|
return expanded_results |
|
|
|
def _find_chunk_index(self, book: str, chunk_id: int) -> Optional[int]: |
|
"""Find index of a specific chunk""" |
|
for i, metadata in enumerate(self.metadata): |
|
if metadata["book"] == book and metadata["chunk_id"] == chunk_id: |
|
return i |
|
return None |
|
|
|
def add_text(self, text: str, source: str, metadata: Dict = None): |
|
"""Add a single text to the knowledge base""" |
|
|
|
chunks = self._create_chunks(text) |
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
if not chunk.strip(): |
|
continue |
|
|
|
|
|
embedding = self._create_embedding(chunk) |
|
|
|
|
|
if self.config.vector_store.metric == "cosine": |
|
embedding = embedding / np.linalg.norm(embedding) |
|
|
|
|
|
self.index.add(np.array([embedding])) |
|
|
|
|
|
chunk_metadata = { |
|
"book": source, |
|
"chunk_id": i, |
|
"timestamp": datetime.now().isoformat(), |
|
"char_count": len(chunk), |
|
"checksum": hashlib.md5(chunk.encode()).hexdigest() |
|
} |
|
|
|
|
|
if metadata: |
|
chunk_metadata.update(metadata) |
|
|
|
self.metadata.append(chunk_metadata) |
|
self.chunks.append(chunk) |
|
|
|
|
|
self._save_index() |
|
|
|
def _save_index(self): |
|
"""Save index and metadata to disk""" |
|
os.makedirs(self.index_path, exist_ok=True) |
|
|
|
|
|
index_file = os.path.join(self.index_path, "knowledge.index") |
|
faiss.write_index(self.index, index_file) |
|
|
|
|
|
metadata_file = os.path.join(self.index_path, "metadata.pkl") |
|
with open(metadata_file, 'wb') as f: |
|
pickle.dump(self.metadata, f) |
|
|
|
|
|
chunks_file = os.path.join(self.index_path, "chunks.pkl") |
|
with open(chunks_file, 'wb') as f: |
|
pickle.dump(self.chunks, f) |
|
|
|
|
|
config_file = os.path.join(self.index_path, "config.json") |
|
with open(config_file, 'w') as f: |
|
json.dump({ |
|
"dimension": self.dimension, |
|
"metric": self.config.vector_store.metric, |
|
"total_chunks": len(self.chunks), |
|
"books": list(set(m["book"] for m in self.metadata)), |
|
"last_updated": datetime.now().isoformat() |
|
}, f, indent=2) |
|
|
|
def get_stats(self) -> Dict: |
|
"""Get statistics about the knowledge base""" |
|
if not self.metadata: |
|
return {"status": "empty"} |
|
|
|
books = {} |
|
for metadata in self.metadata: |
|
book = metadata["book"] |
|
if book not in books: |
|
books[book] = {"chunks": 0, "chars": 0} |
|
books[book]["chunks"] += 1 |
|
books[book]["chars"] += metadata["char_count"] |
|
|
|
return { |
|
"total_chunks": len(self.chunks), |
|
"total_books": len(books), |
|
"books": books, |
|
"index_size": self.index.ntotal, |
|
"dimension": self.dimension, |
|
"metric": self.config.vector_store.metric |
|
} |
|
|
|
def clear(self): |
|
"""Clear the entire knowledge base""" |
|
self.index = faiss.IndexFlatIP(self.dimension) if self.config.vector_store.metric == "cosine" else faiss.IndexFlatL2(self.dimension) |
|
self.metadata = [] |
|
self.chunks = [] |
|
self._save_index() |
|
print("Knowledge base cleared") |