File size: 1,936 Bytes
69374eb
 
 
34a9313
69374eb
34a9313
69374eb
db7ceef
 
 
69374eb
d5a33e6
69374eb
34a9313
db7ceef
 
34a9313
 
 
 
 
 
 
 
 
 
 
 
 
 
d5a33e6
34a9313
 
 
 
 
 
 
 
 
db7ceef
 
34a9313
 
 
 
 
69374eb
34a9313
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import hashlib

# Load model once
embedder = SentenceTransformer('all-MiniLM-L6-v2')

class VectorStore:
    def __init__(self):
        self.texts = []
        self.embeddings = []
        self.index = None
        self.text_hashes = set()

    def add_texts(self, texts):
        """Add list of texts to the store, avoiding duplicates"""
        new_texts = []
        for text in texts:
            text_hash = hashlib.md5(text.encode()).hexdigest()
            if text_hash not in self.text_hashes:
                new_texts.append(text)
                self.text_hashes.add(text_hash)
        
        if not new_texts:
            return
        
        # Encode new texts
        new_embeds = embedder.encode(new_texts)
        self.texts.extend(new_texts)
        self.embeddings.extend(new_embeds)
        
        # Update FAISS index
        if self.index is None:
            self.index = faiss.IndexFlatL2(new_embeds[0].shape[0])
        
        # Convert to numpy array and add to index
        embeds_array = np.array(self.embeddings).astype('float32')
        self.index.reset()
        self.index.add(embeds_array)

    def retrieve(self, query, top_k=3):
        """Return top-k relevant texts and their indices"""
        if not self.index or not self.texts:
            return [], []
        
        # Encode query
        query_embed = embedder.encode([query])
        query_array = np.array(query_embed).astype('float32')
        
        # Search
        distances, indices = self.index.search(query_array, k=min(top_k, len(self.texts)))
        
        # Return texts and indices
        return [self.texts[i] for i in indices[0]], indices[0].tolist()

    def clear(self):
        """Clear the vector store"""
        self.texts = []
        self.embeddings = []
        self.index = None
        self.text_hashes = set()