# DEPENDENCIES import sys import torch from pathlib import Path from transformers import AutoModel from transformers import AutoTokenizer from sentence_transformers import SentenceTransformer # Add parent directory to path for imports sys.path.append(str(Path(__file__).parent.parent)) from utils.logger import log_info from utils.logger import log_error from config.model_config import ModelConfig from utils.logger import ContractAnalyzerLogger from model_manager.model_registry import ModelInfo from model_manager.model_registry import ModelType from model_manager.model_registry import ModelStatus from model_manager.model_registry import ModelRegistry class ModelLoader: """ Smart model loader with automatic download, caching, and GPU support """ def __init__(self): self.registry = ModelRegistry() self.config = ModelConfig() self.logger = ContractAnalyzerLogger.get_logger() # Detect device self.device = "cuda" if torch.cuda.is_available() else "cpu" log_info(f"ModelLoader initialized", device = self.device, gpu_available = torch.cuda.is_available()) # Ensure directories exist ModelConfig.ensure_directories() log_info("Model directories ensured", model_dir = str(self.config.MODEL_DIR), cache_dir = str(self.config.CACHE_DIR), ) def _check_model_files_exist(self, local_path: Path) -> bool: """ Check if all required model files exist in local path """ if not local_path.exists(): return False # Check for essential files that indicate a complete model essential_files = ["config.json", "pytorch_model.bin", "model.safetensors", "vocab.txt", "tokenizer_config.json" ] # At least config.json and one model file should exist has_config = (local_path / "config.json").exists() has_model_file = any((local_path / file).exists() for file in ["pytorch_model.bin", "model.safetensors"]) return has_config and has_model_file def load_legal_bert(self) -> tuple: """ Load Legal-BERT model and tokenizer (nlpaueb/legal-bert-base-uncased) """ # Check if already loaded if self.registry.is_loaded(ModelType.LEGAL_BERT): info = self.registry.get(ModelType.LEGAL_BERT) log_info("Legal-BERT already loaded from cache", memory_mb = info.memory_size_mb, access_count = info.access_count, ) return info.model, info.tokenizer # Mark as loading self.registry.register(ModelType.LEGAL_BERT, ModelInfo(name = "legal-bert", type = ModelType.LEGAL_BERT, status = ModelStatus.LOADING, ) ) try: config = self.config.LEGAL_BERT local_path = config["local_path"] force_download = config.get("force_download", False) # Check if we should use local cache if self._check_model_files_exist(local_path) and not force_download: log_info(f"Loading Legal-BERT from local cache", path=str(local_path)) model = AutoModel.from_pretrained(pretrained_model_name_or_path = str(local_path)) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = str(local_path)) else: log_info(f"Downloading Legal-BERT from HuggingFace", model_name = config["model_name"]) model = AutoModel.from_pretrained(pretrained_model_name_or_path = config["model_name"]) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = config["model_name"]) # Save to local cache log_info(f"Saving Legal-BERT to local cache", path = str(local_path)) local_path.mkdir(parents = True, exist_ok = True) model.save_pretrained(save_directory = str(local_path)) tokenizer.save_pretrained(save_directory = str(local_path)) # Move to device model.to(self.device) model.eval() # Calculate memory size memory_mb = sum(p.nelement() * p.element_size() for p in model.parameters()) / (1024 * 1024) # Register as loaded self.registry.register(ModelType.LEGAL_BERT, ModelInfo(name = "legal-bert", type = ModelType.LEGAL_BERT, status = ModelStatus.LOADED, model = model, tokenizer = tokenizer, memory_size_mb = memory_mb, metadata = {"device" : self.device, "model_name" : config["model_name"]} ) ) log_info("Legal-BERT loaded successfully", memory_mb = round(memory_mb, 2), device = self.device, parameters = sum(p.numel() for p in model.parameters()), ) return model, tokenizer except Exception as e: log_error(e, context = {"component": "ModelLoader", "operation": "load_legal_bert", "model_name": self.config.LEGAL_BERT["model_name"]}) self.registry.register(ModelType.LEGAL_BERT, ModelInfo(name = "legal-bert", type = ModelType.LEGAL_BERT, status = ModelStatus.ERROR, error_message = str(e), ) ) raise def load_classifier_model(self) -> tuple: """ Load contract classification model using Legal-BERT with classification head """ # Check if already loaded if self.registry.is_loaded(ModelType.CLASSIFIER): info = self.registry.get(ModelType.CLASSIFIER) log_info("Classifier model already loaded from cache", memory_mb = info.memory_size_mb, access_count = info.access_count, ) return info.model, info.tokenizer # Mark as loading self.registry.register(ModelType.CLASSIFIER, ModelInfo(name = "classifier", type = ModelType.CLASSIFIER, status = ModelStatus.LOADING, ) ) try: config = self.config.CLASSIFIER_MODEL log_info("Loading classifier model (Legal-BERT based)", embedding_dim = config["embedding_dim"], hidden_dim = config["hidden_dim"], num_categories = config["num_categories"], ) # Use the Legal-BERT model but prepare it for classification base_model, tokenizer = self.load_legal_bert() # Register as loaded (sharing the same Legal-BERT instance) self.registry.register(ModelType.CLASSIFIER, ModelInfo(name = "classifier", type = ModelType.CLASSIFIER, status = ModelStatus.LOADED, model = base_model, tokenizer = tokenizer, memory_size_mb = 0.0, metadata = {"device" : self.device, "base_model" : "legal-bert", "embedding_dim" : config["embedding_dim"], "num_classes" : config["num_categories"], "purpose" : "contract_type_classification", } ) ) log_info("Classifier model loaded successfully", base_model = "legal-bert", num_categories = config["num_categories"], note = "Using Legal-BERT for both clause extraction and classification", ) return base_model, tokenizer except Exception as e: log_error(e, context = {"component": "ModelLoader", "operation": "load_classifier_model"}) self.registry.register(ModelType.CLASSIFIER, ModelInfo(name = "classifier", type = ModelType.CLASSIFIER, status = ModelStatus.ERROR, error_message = str(e), ) ) raise def load_embedding_model(self) -> SentenceTransformer: """ Load sentence transformer for embeddings """ # Check if already loaded if self.registry.is_loaded(ModelType.EMBEDDING): info = self.registry.get(ModelType.EMBEDDING) log_info("Embedding model already loaded from cache", memory_mb = info.memory_size_mb, access_count = info.access_count, ) return info.model # Mark as loading self.registry.register(ModelType.EMBEDDING, ModelInfo(name = "embedding", type = ModelType.EMBEDDING, status = ModelStatus.LOADING, ) ) try: config = self.config.EMBEDDING_MODEL local_path = config["local_path"] force_download = config.get("force_download", False) # Check if we should use local cache if local_path.exists() and not force_download: log_info("Loading embedding model from local cache", path = str(local_path)) model = SentenceTransformer(model_name_or_path = str(local_path)) else: log_info("Downloading embedding model from HuggingFace", model_name = config["model_name"]) model = SentenceTransformer(model_name_or_path = config["model_name"]) # Save to local cache log_info("Saving embedding model to local cache", path = str(local_path)) local_path.mkdir(parents = True, exist_ok = True) model.save(str(local_path)) # Move to device if self.device == "cuda": model = model.to(self.device) # Estimate memory size memory_mb = 100 # Register as loaded self.registry.register(ModelType.EMBEDDING, ModelInfo(name = "embedding", type = ModelType.EMBEDDING, status = ModelStatus.LOADED, model = model, memory_size_mb = memory_mb, metadata = {"device": self.device, "model_name": config["model_name"], "dimension": config["dimension"]} ) ) log_info("Embedding model loaded successfully", memory_mb = memory_mb, device = self.device, dimension = config["dimension"], ) return model except Exception as e: log_error(e, context = {"component": "ModelLoader", "operation": "load_embedding_model", "model_name": self.config.EMBEDDING_MODEL["model_name"]}) self.registry.register(ModelType.EMBEDDING, ModelInfo(name = "embedding", type = ModelType.EMBEDDING, status = ModelStatus.ERROR, error_message = str(e), ) ) raise def ensure_models_downloaded(self): """ Ensure all required models are downloaded before use """ log_info("Ensuring all models are downloaded...") try: # Download Legal-BERT if needed if not self.registry.is_loaded(ModelType.LEGAL_BERT): config = self.config.LEGAL_BERT local_path = config["local_path"] if not self._check_model_files_exist(local_path): log_info("Pre-downloading Legal-BERT...") model = AutoModel.from_pretrained(pretrained_model_name_or_path = config["model_name"]) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = config["model_name"]) local_path.mkdir(parents = True, exist_ok = True) model.save_pretrained(save_directory = str(local_path)) tokenizer.save_pretrained(save_directory = str(local_path)) log_info("Legal-BERT pre-downloaded successfully") # Download embedding model if needed if not self.registry.is_loaded(ModelType.EMBEDDING): config = self.config.EMBEDDING_MODEL local_path = config["local_path"] if not local_path.exists(): log_info("Pre-downloading embedding model...") model = SentenceTransformer(model_name_or_path = config["model_name"]) local_path.mkdir(parents = True, exist_ok = True) model.save(str(local_path)) log_info("Embedding model pre-downloaded successfully") # Note: Classifier model is a stub, no download needed log_info("Classifier model stub - no download required (uses Legal-BERT)") log_info("All models are ready for use") except Exception as e: log_error(e, context={"component": "ModelLoader", "operation": "ensure_models_downloaded"}) raise def get_registry_stats(self) -> dict: """ Get statistics about loaded models """ stats = self.registry.get_stats() log_info("Retrieved registry statistics", total_models = stats["total_models"], loaded_models = stats["loaded_models"], total_memory_mb = stats["total_memory_mb"], ) return stats def clear_cache(self): """ Clear all models from memory """ log_info("Clearing all models from cache") self.registry.clear_all() log_info("All models cleared from cache")