|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import os |
|
import tempfile |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ModelSingleton: |
|
_instance = None |
|
_initialized = False |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super().__new__(cls) |
|
return cls._instance |
|
|
|
def __init__(self): |
|
if not self._initialized: |
|
try: |
|
|
|
cache_dir = os.getenv('TRANSFORMERS_CACHE', tempfile.gettempdir()) |
|
os.environ['TRANSFORMERS_CACHE'] = cache_dir |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
try: |
|
logger.info("Loading sentence transformer model...") |
|
SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
self.similarity_tokenizer = AutoTokenizer.from_pretrained( |
|
SENTENCE_MODEL, |
|
cache_dir=cache_dir |
|
) |
|
self.similarity_model = SentenceTransformer( |
|
SENTENCE_MODEL, |
|
cache_folder=cache_dir |
|
) |
|
self.similarity_model.to(self.device) |
|
logger.info("Sentence transformer model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading sentence transformer model: {e}") |
|
raise |
|
|
|
|
|
try: |
|
logger.info("Loading Flan-T5 model...") |
|
FLAN_MODEL = "google/flan-t5-xl" |
|
self.flan_tokenizer = AutoTokenizer.from_pretrained( |
|
FLAN_MODEL, |
|
cache_dir=cache_dir |
|
) |
|
self.flan_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
FLAN_MODEL, |
|
cache_dir=cache_dir, |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
) |
|
self.flan_model.to(self.device) |
|
logger.info("Flan-T5 model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading Flan-T5 model: {e}") |
|
raise |
|
|
|
self._initialized = True |
|
logger.info("All models initialized successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during model initialization: {e}") |
|
raise |
|
|
|
def cleanup(self): |
|
"""Clean up model resources""" |
|
try: |
|
if hasattr(self, 'similarity_model'): |
|
del self.similarity_model |
|
if hasattr(self, 'flan_model'): |
|
del self.flan_model |
|
torch.cuda.empty_cache() |
|
logger.info("Model resources cleaned up successfully") |
|
except Exception as e: |
|
logger.error(f"Error during cleanup: {e}") |
|
|
|
|
|
models = ModelSingleton() |
|
|
|
|
|
def cleanup_models(): |
|
models.cleanup() |