answer-grading-app / all_models.py
yamanavijayavardhan's picture
fix memory overlimit issue
8405423
raw
history blame
5.46 kB
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
import tempfile
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelSingleton:
_instance = None
_initialized = False
_models = {}
_reference_counts = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
try:
# Set cache directory to temporary directory
cache_dir = os.getenv('TRANSFORMERS_CACHE', tempfile.gettempdir())
os.environ['TRANSFORMERS_CACHE'] = cache_dir
# Get device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Initialize with None values
self.similarity_tokenizer = None
self.similarity_model = None
self.flan_tokenizer = None
self.flan_model = None
# Initialize reference counts
self._reference_counts['similarity'] = 0
self._reference_counts['flan'] = 0
self._initialized = True
logger.info("Model singleton initialized")
except Exception as e:
logger.error(f"Error during model initialization: {e}")
raise
def get_similarity_model(self):
"""Get sentence transformer model with reference counting"""
try:
if self.similarity_model is None:
logger.info("Loading sentence transformer model...")
SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
self.similarity_tokenizer = AutoTokenizer.from_pretrained(
SENTENCE_MODEL,
cache_dir=os.getenv('TRANSFORMERS_CACHE')
)
self.similarity_model = SentenceTransformer(
SENTENCE_MODEL,
cache_folder=os.getenv('TRANSFORMERS_CACHE')
)
self.similarity_model.to(self.device)
logger.info("Sentence transformer model loaded successfully")
self._reference_counts['similarity'] += 1
return self.similarity_model
except Exception as e:
logger.error(f"Error loading sentence transformer model: {e}")
raise
def get_flan_model(self):
"""Get Flan-T5 model with reference counting"""
try:
if self.flan_model is None:
logger.info("Loading Flan-T5 model...")
FLAN_MODEL = "google/flan-t5-xl"
self.flan_tokenizer = AutoTokenizer.from_pretrained(
FLAN_MODEL,
cache_dir=os.getenv('TRANSFORMERS_CACHE')
)
self.flan_model = AutoModelForSeq2SeqLM.from_pretrained(
FLAN_MODEL,
cache_dir=os.getenv('TRANSFORMERS_CACHE'),
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
low_cpu_mem_usage=True
)
self.flan_model.to(self.device)
logger.info("Flan-T5 model loaded successfully")
self._reference_counts['flan'] += 1
return self.flan_model
except Exception as e:
logger.error(f"Error loading Flan-T5 model: {e}")
raise
def release_similarity_model(self):
"""Release reference to similarity model"""
self._reference_counts['similarity'] -= 1
if self._reference_counts['similarity'] <= 0:
self._cleanup_similarity_model()
def release_flan_model(self):
"""Release reference to Flan-T5 model"""
self._reference_counts['flan'] -= 1
if self._reference_counts['flan'] <= 0:
self._cleanup_flan_model()
def _cleanup_similarity_model(self):
"""Clean up similarity model resources"""
if self.similarity_model is not None:
del self.similarity_model
self.similarity_model = None
self.similarity_tokenizer = None
torch.cuda.empty_cache()
logger.info("Similarity model resources cleaned up")
def _cleanup_flan_model(self):
"""Clean up Flan-T5 model resources"""
if self.flan_model is not None:
del self.flan_model
self.flan_model = None
self.flan_tokenizer = None
torch.cuda.empty_cache()
logger.info("Flan-T5 model resources cleaned up")
def cleanup(self):
"""Clean up all model resources"""
try:
self._cleanup_similarity_model()
self._cleanup_flan_model()
self._reference_counts['similarity'] = 0
self._reference_counts['flan'] = 0
logger.info("All model resources cleaned up successfully")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
# Create global instance
models = ModelSingleton()
# Add cleanup function to the global instance
def cleanup_models():
models.cleanup()