answer-grading-app / all_models.py
yamanavijayavardhan's picture
update_new_new
26f855a
raw
history blame
3.53 kB
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
import tempfile
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelSingleton:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
try:
# Set cache directory to temporary directory
cache_dir = os.getenv('TRANSFORMERS_CACHE', tempfile.gettempdir())
os.environ['TRANSFORMERS_CACHE'] = cache_dir
# Get device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Sentence transformer model
try:
logger.info("Loading sentence transformer model...")
SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
self.similarity_tokenizer = AutoTokenizer.from_pretrained(
SENTENCE_MODEL,
cache_dir=cache_dir
)
self.similarity_model = SentenceTransformer(
SENTENCE_MODEL,
cache_folder=cache_dir
)
self.similarity_model.to(self.device)
logger.info("Sentence transformer model loaded successfully")
except Exception as e:
logger.error(f"Error loading sentence transformer model: {e}")
raise
# Flan-T5-xl model
try:
logger.info("Loading Flan-T5 model...")
FLAN_MODEL = "google/flan-t5-xl"
self.flan_tokenizer = AutoTokenizer.from_pretrained(
FLAN_MODEL,
cache_dir=cache_dir
)
self.flan_model = AutoModelForSeq2SeqLM.from_pretrained(
FLAN_MODEL,
cache_dir=cache_dir,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
self.flan_model.to(self.device)
logger.info("Flan-T5 model loaded successfully")
except Exception as e:
logger.error(f"Error loading Flan-T5 model: {e}")
raise
self._initialized = True
logger.info("All models initialized successfully")
except Exception as e:
logger.error(f"Error during model initialization: {e}")
raise
def cleanup(self):
"""Clean up model resources"""
try:
if hasattr(self, 'similarity_model'):
del self.similarity_model
if hasattr(self, 'flan_model'):
del self.flan_model
torch.cuda.empty_cache()
logger.info("Model resources cleaned up successfully")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
# Create a global instance
models = ModelSingleton()
# Add cleanup function to the global instance
def cleanup_models():
models.cleanup()