File size: 4,017 Bytes
e9ce2a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import faiss
import numpy as np
import pickle
import os
from typing import List, Dict, Any, Tuple
from sentence_transformers import SentenceTransformer
from pdf_processor import DocumentChunk

class VectorStore:
    def __init__(self, model_name: str, vector_db_path: str):
        self.model = SentenceTransformer(model_name)
        self.vector_db_path = vector_db_path
        self.index_path = os.path.join(vector_db_path, 'faiss_index.bin')
        self.metadata_path = os.path.join(vector_db_path, 'metadata.pkl')

        self.index = None
        self.metadata = []
        self.load_index()

    def load_index(self):
        """Load existing FAISS index and metadata."""
        try:
            if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
                self.index = faiss.read_index(self.index_path)
                with open(self.metadata_path, 'rb') as f:
                    self.metadata = pickle.load(f)

                print(f"Loaded existing index with {len(self.metadata)} documents")
            else:
                print("No existing index found. Will create new one.")
        except Exception as e:
            print(f"Error loading index: {e}")
            self.index = None
            self.metadata = []

    def add_documents(self, chunks: List[DocumentChunk]):
        """Add document chunks to the vector store."""
        if not chunks:
            return

        texts = [chunk.content for chunk in chunks]
        embeddings = self.model.encode(texts, convert_to_tensor=False)
        embeddings = np.array(embeddings).astype('float32')

        if self.index is None:
            dimension = embeddings.shape[1]
            self.index = faiss.IndexFlatIP(dimension)
            faiss.normalize_L2(embeddings)

        self.index.add(embeddings) # type: ignore


        for chunk in chunks:
            self.metadata.append({
                'content': chunk.content,
                'metadata': chunk.metadata,
                'page_number': chunk.page_number,
                'source_file': chunk.source_file
            })


        self.save_index()
        print(f"Added {len(chunks)} chunks to vector store")

    def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """Search for similar documents."""
        if self.index is None or len(self.metadata) == 0:
            return []

        query_embedding = self.model.encode([query], convert_to_tensor=False)
        query_embedding = np.array(query_embedding).astype('float32')
        faiss.normalize_L2(query_embedding)
        scores, indices = self.index.search(query_embedding, min(top_k, len(self.metadata))) # type: ignore
        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx != -1:
                result = self.metadata[idx].copy()
                result['similarity_score'] = float(score)
                results.append(result)

        return results

    def save_index(self):
        """Save FAISS index and metadata to disk."""
        try:
            if self.index is not None:
                faiss.write_index(self.index, self.index_path)

            with open(self.metadata_path, 'wb') as f:
                pickle.dump(self.metadata, f)

        except Exception as e:
            print(f"Error saving index: {e}")

    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the vector store."""
        if self.index is None:
            return {'total_documents': 0, 'index_size': 0}

        return {
            'total_documents': len(self.metadata),
            'index_size': self.index.ntotal,
            'dimension': self.index.d
        }

    def clear_index(self):
        """Clear the entire index."""
        self.index = None
        self.metadata = []
        if os.path.exists(self.index_path):
            os.remove(self.index_path)
        if os.path.exists(self.metadata_path):
            os.remove(self.metadata_path)

        print("Index cleared successfully")