Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import pickle | |
from typing import List, Tuple | |
import faiss | |
import numpy as np | |
from tqdm import tqdm | |
class Indexer(object): | |
def __init__(self, vector_sz,device='cpu'): | |
self.index = faiss.IndexFlatIP(vector_sz) | |
self.device = device | |
if self.device == 'cuda': | |
self.index = faiss.index_cpu_to_all_gpus(self.index) | |
self.index_id_to_db_id = [] | |
def index_data(self, ids, embeddings): | |
self._update_id_mapping(ids) | |
embeddings = embeddings.astype('float32') | |
if not self.index.is_trained: | |
self.index.train(embeddings) | |
self.index.add(embeddings) | |
print(f'Total data indexed {self.index.ntotal}') | |
def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 8) -> List[Tuple[List[object], List[float]]]: | |
query_vectors = query_vectors.astype('float32') | |
result = [] | |
nbatch = (len(query_vectors)-1) // index_batch_size + 1 | |
for k in tqdm(range(nbatch)): | |
start_idx = k*index_batch_size | |
end_idx = min((k+1)*index_batch_size, len(query_vectors)) | |
q = query_vectors[start_idx: end_idx] | |
scores, indexes = self.index.search(q, top_docs) | |
# convert to external ids | |
db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes] | |
result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))]) | |
return result | |
def serialize(self, dir_path): | |
index_file = os.path.join(dir_path, 'index.faiss') | |
meta_file = os.path.join(dir_path, 'index_meta.faiss') | |
print(f'Serializing index to {index_file}, meta data to {meta_file}') | |
if self.device == 'cuda': | |
save_index = faiss.index_gpu_to_cpu(self.index) | |
else: | |
save_index = self.index | |
faiss.write_index(save_index, index_file) | |
with open(meta_file, mode='wb') as f: | |
pickle.dump(self.index_id_to_db_id, f) | |
def deserialize_from(self, dir_path): | |
index_file = os.path.join(dir_path, 'index.faiss') | |
meta_file = os.path.join(dir_path, 'index_meta.faiss') | |
print(f'Loading index from {index_file}, meta data from {meta_file}') | |
self.index = faiss.read_index(index_file) | |
if self.device == 'cuda': | |
self.index = faiss.index_cpu_to_all_gpus(self.index) | |
print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) | |
with open(meta_file, "rb") as reader: | |
self.index_id_to_db_id = pickle.load(reader) | |
assert len( | |
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' | |
def _update_id_mapping(self, db_ids: List): | |
self.index_id_to_db_id.extend(db_ids) | |
def reset(self): | |
self.index.reset() | |
self.index_id_to_db_id = [] | |
print(f'Index reset, total data indexed {self.index.ntotal}') | |