#!/usr/bin/env python # -*- coding: utf-8 -*- import json import torch from sentence_transformers import SentenceTransformer from collections import defaultdict from typing import List, Dict, Tuple, Union import torch from PIL import Image import pickle from openai import OpenAI import os import torch import time import yaml class MemoryIndex: def __init__(self,number_of_neighbours,use_openai=False): self.documents = {} self.document_vectors = {} self.use_openai=use_openai if use_openai: api_key = os.getenv("OPENAI_API_KEY") self.client = OpenAI(api_key=api_key) self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') # self.model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2') with open('test_configs/llama2_test_config.yaml') as file: config = yaml.load(file, Loader=yaml.FullLoader) embedding_gpu_id=config['model']['minigpt4_gpu_id'] self.device = f"cuda:{embedding_gpu_id}" if torch.cuda.is_available() else "cpu" self.number_of_neighbours=int(number_of_neighbours) def load_documents_from_json(self, file_path,emdedding_path=""): with open(file_path, 'r') as file: data = json.load(file) for doc_id, doc_data in data.items(): self.documents[doc_id] = doc_data self.document_vectors[doc_id] = self._compute_sentence_embedding(doc_data) # save self.documents and self.document_vectors to pkl file m=[self.documents,self.document_vectors] with open(emdedding_path, 'wb') as file: pickle.dump(m, file) return emdedding_path def load_embeddings_from_pkl(self, pkl_file_path): #read the pkl file with open(pkl_file_path, 'rb') as file: data = pickle.load(file) self.documents=data[0] self.document_vectors=data[1] def load_data_from_pkl(self, pkl_file_path): with open(pkl_file_path, 'rb') as file: data = pickle.load(file) for doc_id, doc_data in data.items(): self.documents[doc_id] = doc_data self.document_vectors[doc_id] = doc_data def _compute_sentence_embedding(self, text: str) -> torch.Tensor: if self.use_openai: done=False while not done: try: embedding=self.client.embeddings.create(input = [text], model="text-embedding-3-small").data[0].embedding # Convert the list to a PyTorch tensor embedding = torch.tensor(embedding) done=True except Exception as e: print("error",e) print("text",text) # sleep for 5 seconds and try again time.sleep(5) continue else: return self.model.encode(text, convert_to_tensor=True).to(self.device) return embedding def search_by_similarity(self, query: str) -> List[str]: query_vector = self._compute_sentence_embedding(query) scores = {doc_id: torch.nn.functional.cosine_similarity(query_vector, doc_vector, dim=0).item() for doc_id, doc_vector in self.document_vectors.items()} sorted_doc_ids = sorted(scores, key=scores.get, reverse=True) sorted_documents=[self.documents[doc_id] for doc_id in sorted_doc_ids] if self.number_of_neighbours == -1: return list(self.documents.values()), list(self.documents.keys()) if self.number_of_neighbours > len(sorted_documents): return sorted_documents, sorted_doc_ids # if the retrieved document is the summary, return the summary and the next document to grauntee that always retieve clip name. if self.number_of_neighbours==1 and sorted_doc_ids[0]=='summary': return sorted_documents[0:2], sorted_doc_ids[:2] print("Number of neighbours",self.number_of_neighbours) return sorted_documents[:self.number_of_neighbours], sorted_doc_ids[:self.number_of_neighbours] # # main function # if __name__ == "__main__": # memory_index = MemoryIndex(-1,use_openai=True) # memory_index.load_documents_from_json('workspace/results/llama_vid/tt0035423.json') # print(memory_index.documents.keys()) # docs,keys=memory_index.search_by_similarity('kerolos')