from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import os import tempfile import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelSingleton: _instance = None _initialized = False _models = {} _reference_counts = {} def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if not self._initialized: try: # Set cache directory to temporary directory cache_dir = os.getenv('TRANSFORMERS_CACHE', tempfile.gettempdir()) os.environ['TRANSFORMERS_CACHE'] = cache_dir # Get device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Initialize with None values self.similarity_tokenizer = None self.similarity_model = None self.flan_tokenizer = None self.flan_model = None # Initialize reference counts self._reference_counts['similarity'] = 0 self._reference_counts['flan'] = 0 self._initialized = True logger.info("Model singleton initialized") except Exception as e: logger.error(f"Error during model initialization: {e}") raise def get_similarity_model(self): """Get sentence transformer model with reference counting""" try: if self.similarity_model is None: logger.info("Loading sentence transformer model...") SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2" self.similarity_tokenizer = AutoTokenizer.from_pretrained( SENTENCE_MODEL, cache_dir=os.getenv('TRANSFORMERS_CACHE') ) self.similarity_model = SentenceTransformer( SENTENCE_MODEL, cache_folder=os.getenv('TRANSFORMERS_CACHE') ) self.similarity_model.to(self.device) logger.info("Sentence transformer model loaded successfully") self._reference_counts['similarity'] += 1 return self.similarity_model except Exception as e: logger.error(f"Error loading sentence transformer model: {e}") raise def get_flan_model(self): """Get Flan-T5 model with reference counting""" try: if self.flan_model is None: logger.info("Loading Flan-T5 model...") FLAN_MODEL = "google/flan-t5-xl" self.flan_tokenizer = AutoTokenizer.from_pretrained( FLAN_MODEL, cache_dir=os.getenv('TRANSFORMERS_CACHE') ) self.flan_model = AutoModelForSeq2SeqLM.from_pretrained( FLAN_MODEL, cache_dir=os.getenv('TRANSFORMERS_CACHE'), torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, low_cpu_mem_usage=True ) self.flan_model.to(self.device) logger.info("Flan-T5 model loaded successfully") self._reference_counts['flan'] += 1 return self.flan_model except Exception as e: logger.error(f"Error loading Flan-T5 model: {e}") raise def release_similarity_model(self): """Release reference to similarity model""" self._reference_counts['similarity'] -= 1 if self._reference_counts['similarity'] <= 0: self._cleanup_similarity_model() def release_flan_model(self): """Release reference to Flan-T5 model""" self._reference_counts['flan'] -= 1 if self._reference_counts['flan'] <= 0: self._cleanup_flan_model() def _cleanup_similarity_model(self): """Clean up similarity model resources""" if self.similarity_model is not None: del self.similarity_model self.similarity_model = None self.similarity_tokenizer = None torch.cuda.empty_cache() logger.info("Similarity model resources cleaned up") def _cleanup_flan_model(self): """Clean up Flan-T5 model resources""" if self.flan_model is not None: del self.flan_model self.flan_model = None self.flan_tokenizer = None torch.cuda.empty_cache() logger.info("Flan-T5 model resources cleaned up") def cleanup(self): """Clean up all model resources""" try: self._cleanup_similarity_model() self._cleanup_flan_model() self._reference_counts['similarity'] = 0 self._reference_counts['flan'] = 0 logger.info("All model resources cleaned up successfully") except Exception as e: logger.error(f"Error during cleanup: {e}") # Create global instance models = ModelSingleton() # Add cleanup function to the global instance def cleanup_models(): models.cleanup()