|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import os |
|
import tempfile |
|
import logging |
|
import shutil |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def ensure_full_permissions(path): |
|
"""Grant full permissions to a file or directory""" |
|
try: |
|
if os.path.isdir(path): |
|
|
|
os.chmod(path, 0o777) |
|
|
|
for root, dirs, files in os.walk(path): |
|
for d in dirs: |
|
os.chmod(os.path.join(root, d), 0o777) |
|
for f in files: |
|
os.chmod(os.path.join(root, f), 0o666) |
|
else: |
|
|
|
os.chmod(path, 0o666) |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error setting permissions for {path}: {e}") |
|
return False |
|
|
|
def check_directory_permissions(path): |
|
"""Check if directory exists and has correct permissions""" |
|
try: |
|
if not os.path.exists(path): |
|
logger.warning(f"Directory does not exist: {path}") |
|
return False |
|
|
|
|
|
ensure_full_permissions(path) |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error checking permissions for {path}: {e}") |
|
return False |
|
|
|
def get_cache_dir(): |
|
"""Get a user-accessible cache directory""" |
|
try: |
|
|
|
home_dir = os.path.expanduser('~') |
|
if not os.path.exists(home_dir): |
|
raise Exception(f"Home directory does not exist: {home_dir}") |
|
|
|
cache_dir = os.path.join(home_dir, '.cache', 'answer_grading_app') |
|
logger.info(f"Attempting to use cache directory: {cache_dir}") |
|
|
|
|
|
os.makedirs(cache_dir, mode=0o777, exist_ok=True) |
|
ensure_full_permissions(cache_dir) |
|
|
|
logger.info(f"Successfully created and verified cache directory: {cache_dir}") |
|
return cache_dir |
|
except Exception as e: |
|
logger.warning(f"Could not use home directory cache: {e}") |
|
|
|
|
|
try: |
|
temp_dir = os.path.join(tempfile.gettempdir(), 'answer_grading_app') |
|
logger.info(f"Attempting to use temporary directory: {temp_dir}") |
|
|
|
os.makedirs(temp_dir, mode=0o777, exist_ok=True) |
|
ensure_full_permissions(temp_dir) |
|
|
|
logger.info(f"Using temporary directory: {temp_dir}") |
|
return temp_dir |
|
except Exception as e: |
|
logger.warning(f"Could not use temp directory: {e}") |
|
|
|
|
|
try: |
|
current_dir = os.path.join(os.getcwd(), '.cache') |
|
logger.info(f"Attempting to use current directory: {current_dir}") |
|
|
|
os.makedirs(current_dir, mode=0o777, exist_ok=True) |
|
ensure_full_permissions(current_dir) |
|
|
|
logger.info(f"Using current directory: {current_dir}") |
|
return current_dir |
|
except Exception as e: |
|
logger.error(f"Could not create any cache directory: {e}") |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
ensure_full_permissions(temp_dir) |
|
logger.info(f"Created temporary directory as last resort: {temp_dir}") |
|
return temp_dir |
|
|
|
class ModelSingleton: |
|
_instance = None |
|
_initialized = False |
|
_models = {} |
|
_reference_counts = {} |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super().__new__(cls) |
|
return cls._instance |
|
|
|
def __init__(self): |
|
if not self._initialized: |
|
try: |
|
logger.info("Initializing ModelSingleton...") |
|
|
|
|
|
self.cache_dir = get_cache_dir() |
|
logger.info(f"Using main cache directory: {self.cache_dir}") |
|
|
|
|
|
self.cache_dirs = { |
|
'transformers': os.path.join(self.cache_dir, 'transformers'), |
|
'huggingface': os.path.join(self.cache_dir, 'huggingface'), |
|
'torch': os.path.join(self.cache_dir, 'torch'), |
|
'cache': os.path.join(self.cache_dir, 'cache'), |
|
'sentence_transformers': os.path.join(self.cache_dir, 'sentence_transformers'), |
|
'fasttext': os.path.join(self.cache_dir, 'fasttext') |
|
} |
|
|
|
|
|
for name, path in self.cache_dirs.items(): |
|
try: |
|
|
|
os.makedirs(path, mode=0o777, exist_ok=True) |
|
ensure_full_permissions(path) |
|
logger.info(f"Successfully created {name} cache directory: {path}") |
|
|
|
|
|
test_file = os.path.join(path, '.write_test') |
|
try: |
|
with open(test_file, 'w') as f: |
|
f.write('test') |
|
os.chmod(test_file, 0o666) |
|
os.remove(test_file) |
|
logger.info(f"Verified write permissions for {name} cache directory") |
|
except Exception as e: |
|
logger.error(f"Failed to verify write permissions for {name} cache directory: {e}") |
|
|
|
ensure_full_permissions(path) |
|
|
|
except Exception as e: |
|
logger.error(f"Error creating {name} cache directory: {e}") |
|
|
|
temp_path = os.path.join(tempfile.gettempdir(), 'answer_grading_app', name) |
|
os.makedirs(temp_path, mode=0o777, exist_ok=True) |
|
ensure_full_permissions(temp_path) |
|
self.cache_dirs[name] = temp_path |
|
logger.info(f"Using fallback directory for {name}: {temp_path}") |
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = self.cache_dirs['transformers'] |
|
os.environ['HF_HOME'] = self.cache_dirs['huggingface'] |
|
os.environ['TORCH_HOME'] = self.cache_dirs['torch'] |
|
os.environ['XDG_CACHE_HOME'] = self.cache_dirs['cache'] |
|
os.environ['SENTENCE_TRANSFORMERS_HOME'] = self.cache_dirs['sentence_transformers'] |
|
|
|
|
|
for env_var, path in [ |
|
('TRANSFORMERS_CACHE', 'transformers'), |
|
('HF_HOME', 'huggingface'), |
|
('TORCH_HOME', 'torch'), |
|
('XDG_CACHE_HOME', 'cache'), |
|
('SENTENCE_TRANSFORMERS_HOME', 'sentence_transformers') |
|
]: |
|
if os.environ.get(env_var) != self.cache_dirs[path]: |
|
logger.warning(f"Environment variable {env_var} does not match expected path") |
|
os.environ[env_var] = self.cache_dirs[path] |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
self.similarity_tokenizer = None |
|
self.similarity_model = None |
|
self.flan_tokenizer = None |
|
self.flan_model = None |
|
self.trocr_processor = None |
|
self.trocr_model = None |
|
self.vit_model = None |
|
self.vit_processor = None |
|
|
|
|
|
self._reference_counts = { |
|
'similarity': 0, |
|
'flan': 0, |
|
'trocr': 0, |
|
'vit': 0 |
|
} |
|
|
|
self._initialized = True |
|
logger.info("ModelSingleton initialization completed successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during ModelSingleton initialization: {e}") |
|
raise |
|
|
|
def get_similarity_model(self): |
|
"""Get sentence transformer model with reference counting""" |
|
try: |
|
if self.similarity_model is None: |
|
logger.info("Loading sentence transformer model...") |
|
SENTENCE_MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
|
self.similarity_tokenizer = AutoTokenizer.from_pretrained( |
|
SENTENCE_MODEL, |
|
cache_dir=os.getenv('TRANSFORMERS_CACHE') |
|
) |
|
self.similarity_model = SentenceTransformer( |
|
SENTENCE_MODEL, |
|
cache_folder=os.getenv('TRANSFORMERS_CACHE') |
|
) |
|
self.similarity_model.to(self.device) |
|
logger.info("Sentence transformer model loaded successfully") |
|
|
|
self._reference_counts['similarity'] += 1 |
|
return self.similarity_model |
|
except Exception as e: |
|
logger.error(f"Error loading sentence transformer model: {e}") |
|
raise |
|
|
|
def get_flan_model(self): |
|
"""Get Flan-T5 model with reference counting""" |
|
try: |
|
if self.flan_model is None: |
|
logger.info("Loading Flan-T5 model...") |
|
FLAN_MODEL = "google/flan-t5-xl" |
|
self.flan_tokenizer = AutoTokenizer.from_pretrained( |
|
FLAN_MODEL, |
|
cache_dir=os.getenv('TRANSFORMERS_CACHE') |
|
) |
|
self.flan_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
FLAN_MODEL, |
|
cache_dir=os.getenv('TRANSFORMERS_CACHE'), |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
low_cpu_mem_usage=True |
|
) |
|
self.flan_model.to(self.device) |
|
logger.info("Flan-T5 model loaded successfully") |
|
|
|
self._reference_counts['flan'] += 1 |
|
return self.flan_model |
|
except Exception as e: |
|
logger.error(f"Error loading Flan-T5 model: {e}") |
|
raise |
|
|
|
def get_trocr_model(self): |
|
"""Get TrOCR model with reference counting""" |
|
try: |
|
if self.trocr_model is None: |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
logger.info("Loading TrOCR model...") |
|
MODEL_NAME = "microsoft/trocr-large-handwritten" |
|
self.trocr_processor = TrOCRProcessor.from_pretrained(MODEL_NAME) |
|
self.trocr_model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) |
|
self.trocr_model.to(self.device) |
|
logger.info("TrOCR model loaded successfully") |
|
|
|
self._reference_counts['trocr'] += 1 |
|
return self.trocr_model, self.trocr_processor |
|
except Exception as e: |
|
logger.error(f"Error loading TrOCR model: {e}") |
|
raise |
|
|
|
def get_vit_model(self): |
|
"""Get ViT model using only local files - no downloads""" |
|
try: |
|
if self.vit_model is None: |
|
from transformers import ViTConfig, ViTImageProcessor, ViTForImageClassification |
|
logger.info("Loading local ViT model from files...") |
|
|
|
|
|
model_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
model_path = os.path.join(model_root, 'models', 'vit-base-beans') |
|
logger.info(f"Using local model directory: {model_path}") |
|
|
|
|
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"Local model directory not found at: {model_path}") |
|
|
|
|
|
model_file = os.path.join(model_path, 'model.safetensors') |
|
config_file = os.path.join(model_path, 'config.json') |
|
|
|
|
|
if not os.path.exists(model_file): |
|
raise FileNotFoundError(f"Local model weights file not found at: {model_file}") |
|
if not os.path.exists(config_file): |
|
raise FileNotFoundError(f"Local model config file not found at: {config_file}") |
|
|
|
logger.info("Found all required local model files:") |
|
logger.info(f"- Using model weights: {model_file}") |
|
logger.info(f"- Using config file: {config_file}") |
|
|
|
|
|
logger.info("Loading model configuration from local file...") |
|
config = ViTConfig.from_json_file(config_file) |
|
|
|
|
|
logger.info("Creating image processor from local config...") |
|
self.vit_processor = ViTImageProcessor( |
|
do_resize=True, |
|
size=config.image_size, |
|
do_normalize=True |
|
) |
|
|
|
|
|
logger.info("Loading model weights from local file...") |
|
self.vit_model = ViTForImageClassification.from_pretrained( |
|
model_path, |
|
config=config, |
|
local_files_only=True, |
|
use_safetensors=True, |
|
trust_remote_code=False, |
|
from_tf=False, |
|
_fast_init=True |
|
) |
|
|
|
logger.info(f"Moving model to {self.device}...") |
|
self.vit_model.to(self.device) |
|
self.vit_model.eval() |
|
logger.info("Local model loaded successfully!") |
|
|
|
self._reference_counts['vit'] += 1 |
|
return self.vit_model, self.vit_processor |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading local ViT model: {str(e)}") |
|
raise |
|
|
|
def release_similarity_model(self): |
|
"""Release reference to similarity model""" |
|
self._reference_counts['similarity'] -= 1 |
|
if self._reference_counts['similarity'] <= 0: |
|
self._cleanup_similarity_model() |
|
|
|
def release_flan_model(self): |
|
"""Release reference to Flan-T5 model""" |
|
self._reference_counts['flan'] -= 1 |
|
if self._reference_counts['flan'] <= 0: |
|
self._cleanup_flan_model() |
|
|
|
def release_trocr_model(self): |
|
"""Release reference to TrOCR model""" |
|
self._reference_counts['trocr'] -= 1 |
|
if self._reference_counts['trocr'] <= 0: |
|
self._cleanup_trocr_model() |
|
|
|
def release_vit_model(self): |
|
"""Release reference to ViT model""" |
|
self._reference_counts['vit'] -= 1 |
|
if self._reference_counts['vit'] <= 0: |
|
self._cleanup_vit_model() |
|
|
|
def _cleanup_similarity_model(self): |
|
"""Clean up similarity model resources""" |
|
if self.similarity_model is not None: |
|
del self.similarity_model |
|
self.similarity_model = None |
|
self.similarity_tokenizer = None |
|
torch.cuda.empty_cache() |
|
logger.info("Similarity model resources cleaned up") |
|
|
|
def _cleanup_flan_model(self): |
|
"""Clean up Flan-T5 model resources""" |
|
if self.flan_model is not None: |
|
del self.flan_model |
|
self.flan_model = None |
|
self.flan_tokenizer = None |
|
torch.cuda.empty_cache() |
|
logger.info("Flan-T5 model resources cleaned up") |
|
|
|
def _cleanup_trocr_model(self): |
|
"""Clean up TrOCR model resources""" |
|
if self.trocr_model is not None: |
|
del self.trocr_model |
|
del self.trocr_processor |
|
self.trocr_model = None |
|
self.trocr_processor = None |
|
torch.cuda.empty_cache() |
|
logger.info("TrOCR model resources cleaned up") |
|
|
|
def _cleanup_vit_model(self): |
|
"""Clean up ViT model resources""" |
|
if self.vit_model is not None: |
|
del self.vit_model |
|
del self.vit_processor |
|
self.vit_model = None |
|
self.vit_processor = None |
|
torch.cuda.empty_cache() |
|
logger.info("ViT model resources cleaned up") |
|
|
|
def cleanup(self): |
|
"""Clean up all model resources""" |
|
try: |
|
logger.info("Starting model cleanup...") |
|
|
|
|
|
if self._reference_counts.get('similarity', 0) > 0: |
|
self._cleanup_similarity_model() |
|
if self._reference_counts.get('flan', 0) > 0: |
|
self._cleanup_flan_model() |
|
if self._reference_counts.get('trocr', 0) > 0: |
|
self._cleanup_trocr_model() |
|
if self._reference_counts.get('vit', 0) > 0: |
|
self._cleanup_vit_model() |
|
|
|
|
|
for model_type in self._reference_counts: |
|
self._reference_counts[model_type] = 0 |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
logger.info("Model cleanup completed successfully") |
|
except Exception as e: |
|
logger.error(f"Error during model cleanup: {e}") |
|
|
|
|
|
|
|
models = ModelSingleton() |
|
|
|
|
|
def cleanup_models(): |
|
models.cleanup() |