from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import os import tempfile import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelSingleton: _instance = None _initialized = False def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if not self._initialized: try: # Set cache directory to temporary directory cache_dir = os.getenv('TRANSFORMERS_CACHE', tempfile.gettempdir()) os.environ['TRANSFORMERS_CACHE'] = cache_dir # Get device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Sentence transformer model try: logger.info("Loading sentence transformer model...") SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2" self.similarity_tokenizer = AutoTokenizer.from_pretrained( SENTENCE_MODEL, cache_dir=cache_dir ) self.similarity_model = SentenceTransformer( SENTENCE_MODEL, cache_folder=cache_dir ) self.similarity_model.to(self.device) logger.info("Sentence transformer model loaded successfully") except Exception as e: logger.error(f"Error loading sentence transformer model: {e}") raise # Flan-T5-xl model try: logger.info("Loading Flan-T5 model...") FLAN_MODEL = "google/flan-t5-xl" self.flan_tokenizer = AutoTokenizer.from_pretrained( FLAN_MODEL, cache_dir=cache_dir ) self.flan_model = AutoModelForSeq2SeqLM.from_pretrained( FLAN_MODEL, cache_dir=cache_dir, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) self.flan_model.to(self.device) logger.info("Flan-T5 model loaded successfully") except Exception as e: logger.error(f"Error loading Flan-T5 model: {e}") raise self._initialized = True logger.info("All models initialized successfully") except Exception as e: logger.error(f"Error during model initialization: {e}") raise def cleanup(self): """Clean up model resources""" try: if hasattr(self, 'similarity_model'): del self.similarity_model if hasattr(self, 'flan_model'): del self.flan_model torch.cuda.empty_cache() logger.info("Model resources cleaned up successfully") except Exception as e: logger.error(f"Error during cleanup: {e}") # Create a global instance models = ModelSingleton() # Add cleanup function to the global instance def cleanup_models(): models.cleanup()