|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import os |
|
import tempfile |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ModelSingleton: |
|
_instance = None |
|
_initialized = False |
|
_models = {} |
|
_reference_counts = {} |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super().__new__(cls) |
|
return cls._instance |
|
|
|
def __init__(self): |
|
if not self._initialized: |
|
try: |
|
|
|
cache_dir = os.getenv('TRANSFORMERS_CACHE', tempfile.gettempdir()) |
|
os.environ['TRANSFORMERS_CACHE'] = cache_dir |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
self.similarity_tokenizer = None |
|
self.similarity_model = None |
|
self.flan_tokenizer = None |
|
self.flan_model = None |
|
|
|
|
|
self._reference_counts['similarity'] = 0 |
|
self._reference_counts['flan'] = 0 |
|
|
|
self._initialized = True |
|
logger.info("Model singleton initialized") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during model initialization: {e}") |
|
raise |
|
|
|
def get_similarity_model(self): |
|
"""Get sentence transformer model with reference counting""" |
|
try: |
|
if self.similarity_model is None: |
|
logger.info("Loading sentence transformer model...") |
|
SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
self.similarity_tokenizer = AutoTokenizer.from_pretrained( |
|
SENTENCE_MODEL, |
|
cache_dir=os.getenv('TRANSFORMERS_CACHE') |
|
) |
|
self.similarity_model = SentenceTransformer( |
|
SENTENCE_MODEL, |
|
cache_folder=os.getenv('TRANSFORMERS_CACHE') |
|
) |
|
self.similarity_model.to(self.device) |
|
logger.info("Sentence transformer model loaded successfully") |
|
|
|
self._reference_counts['similarity'] += 1 |
|
return self.similarity_model |
|
except Exception as e: |
|
logger.error(f"Error loading sentence transformer model: {e}") |
|
raise |
|
|
|
def get_flan_model(self): |
|
"""Get Flan-T5 model with reference counting""" |
|
try: |
|
if self.flan_model is None: |
|
logger.info("Loading Flan-T5 model...") |
|
FLAN_MODEL = "google/flan-t5-xl" |
|
self.flan_tokenizer = AutoTokenizer.from_pretrained( |
|
FLAN_MODEL, |
|
cache_dir=os.getenv('TRANSFORMERS_CACHE') |
|
) |
|
self.flan_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
FLAN_MODEL, |
|
cache_dir=os.getenv('TRANSFORMERS_CACHE'), |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
low_cpu_mem_usage=True |
|
) |
|
self.flan_model.to(self.device) |
|
logger.info("Flan-T5 model loaded successfully") |
|
|
|
self._reference_counts['flan'] += 1 |
|
return self.flan_model |
|
except Exception as e: |
|
logger.error(f"Error loading Flan-T5 model: {e}") |
|
raise |
|
|
|
def release_similarity_model(self): |
|
"""Release reference to similarity model""" |
|
self._reference_counts['similarity'] -= 1 |
|
if self._reference_counts['similarity'] <= 0: |
|
self._cleanup_similarity_model() |
|
|
|
def release_flan_model(self): |
|
"""Release reference to Flan-T5 model""" |
|
self._reference_counts['flan'] -= 1 |
|
if self._reference_counts['flan'] <= 0: |
|
self._cleanup_flan_model() |
|
|
|
def _cleanup_similarity_model(self): |
|
"""Clean up similarity model resources""" |
|
if self.similarity_model is not None: |
|
del self.similarity_model |
|
self.similarity_model = None |
|
self.similarity_tokenizer = None |
|
torch.cuda.empty_cache() |
|
logger.info("Similarity model resources cleaned up") |
|
|
|
def _cleanup_flan_model(self): |
|
"""Clean up Flan-T5 model resources""" |
|
if self.flan_model is not None: |
|
del self.flan_model |
|
self.flan_model = None |
|
self.flan_tokenizer = None |
|
torch.cuda.empty_cache() |
|
logger.info("Flan-T5 model resources cleaned up") |
|
|
|
def cleanup(self): |
|
"""Clean up all model resources""" |
|
try: |
|
self._cleanup_similarity_model() |
|
self._cleanup_flan_model() |
|
self._reference_counts['similarity'] = 0 |
|
self._reference_counts['flan'] = 0 |
|
logger.info("All model resources cleaned up successfully") |
|
except Exception as e: |
|
logger.error(f"Error during cleanup: {e}") |
|
|
|
|
|
models = ModelSingleton() |
|
|
|
|
|
def cleanup_models(): |
|
models.cleanup() |