from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
class ModelSingleton: | |
_instance = None | |
_initialized = False | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super().__new__(cls) | |
return cls._instance | |
def __init__(self): | |
if not self._initialized: | |
# Sentence transformer model | |
SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
self.similarity_tokenizer = AutoTokenizer.from_pretrained(SENTENCE_MODEL) | |
self.similarity_model = SentenceTransformer(SENTENCE_MODEL) | |
# Flan-T5-xl model only | |
FLAN_MODEL = "google/flan-t5-xl" | |
self.flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL) | |
self.flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL) | |
self._initialized = True | |
# Create a global instance | |
models = ModelSingleton() |