answer-grading-app / all_models.py
yamanavijayavardhan's picture
Initial upload of answer grading application
51c49bc
raw
history blame
990 Bytes
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class ModelSingleton:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
# Sentence transformer model
SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
self.similarity_tokenizer = AutoTokenizer.from_pretrained(SENTENCE_MODEL)
self.similarity_model = SentenceTransformer(SENTENCE_MODEL)
# Flan-T5-xl model only
FLAN_MODEL = "google/flan-t5-xl"
self.flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL)
self.flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL)
self._initialized = True
# Create a global instance
models = ModelSingleton()