File size: 16,814 Bytes
1099afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdedf43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1099afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdedf43
 
 
 
 
1099afe
 
 
 
bdedf43
 
1099afe
bdedf43
 
 
 
 
 
1099afe
 
 
bdedf43
1099afe
 
 
 
bdedf43
 
1099afe
bdedf43
 
1099afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdedf43
 
1099afe
 
 
 
 
 
 
 
 
 
 
bdedf43
1099afe
 
 
 
 
 
 
 
 
 
bdedf43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1099afe
 
 
 
 
 
 
 
bdedf43
1099afe
 
 
 
 
 
 
 
 
 
 
bdedf43
1099afe
 
 
bdedf43
 
 
1099afe
bdedf43
 
 
 
 
1099afe
 
 
 
bdedf43
1099afe
bdedf43
 
 
 
1099afe
 
bdedf43
1099afe
 
bdedf43
1099afe
 
 
 
 
 
 
 
 
bdedf43
 
1099afe
 
 
 
 
 
 
 
 
 
 
bdedf43
1099afe
 
 
 
 
 
 
 
 
 
 
bdedf43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1099afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdedf43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
# DEPENDENCIES
import sys
import torch
from pathlib import Path
from transformers import AutoModel
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))

from utils.logger import log_info
from utils.logger import log_error
from config.model_config import ModelConfig
from utils.logger import ContractAnalyzerLogger
from model_manager.model_registry import ModelInfo
from model_manager.model_registry import ModelType
from model_manager.model_registry import ModelStatus
from model_manager.model_registry import ModelRegistry


class ModelLoader:
    """
    Smart model loader with automatic download, caching, and GPU support
    """
    def __init__(self):
        self.registry = ModelRegistry()
        self.config   = ModelConfig()
        self.logger   = ContractAnalyzerLogger.get_logger()
        
        # Detect device
        self.device   = "cuda" if torch.cuda.is_available() else "cpu"

        log_info(f"ModelLoader initialized", device = self.device, gpu_available = torch.cuda.is_available())
        
        # Ensure directories exist
        ModelConfig.ensure_directories()
        log_info("Model directories ensured", 
                 model_dir = str(self.config.MODEL_DIR),
                 cache_dir = str(self.config.CACHE_DIR),
                )

    
    def _check_model_files_exist(self, local_path: Path) -> bool:
        """
        Check if all required model files exist in local path
        """
        if not local_path.exists():
            return False
            
        # Check for essential files that indicate a complete model
        essential_files = ["config.json",
                           "pytorch_model.bin",
                           "model.safetensors", 
                           "vocab.txt",
                           "tokenizer_config.json"
                          ]
        
        # At least config.json and one model file should exist
        has_config     = (local_path / "config.json").exists()
        has_model_file = any((local_path / file).exists() for file in ["pytorch_model.bin", "model.safetensors"])
        
        return has_config and has_model_file

    
    def load_legal_bert(self) -> tuple:
        """
        Load Legal-BERT model and tokenizer (nlpaueb/legal-bert-base-uncased)
        """
        # Check if already loaded
        if self.registry.is_loaded(ModelType.LEGAL_BERT):
            info = self.registry.get(ModelType.LEGAL_BERT)

            log_info("Legal-BERT already loaded from cache",
                     memory_mb    = info.memory_size_mb,
                     access_count = info.access_count,
                    )

            return info.model, info.tokenizer
        
        # Mark as loading
        self.registry.register(ModelType.LEGAL_BERT, 
                               ModelInfo(name   = "legal-bert", 
                                         type   = ModelType.LEGAL_BERT, 
                                         status = ModelStatus.LOADING,
                                        )
                              )
        
        try:
            config = self.config.LEGAL_BERT
            local_path     = config["local_path"]
            force_download = config.get("force_download", False)
            
            # Check if we should use local cache
            if self._check_model_files_exist(local_path) and not force_download:
                log_info(f"Loading Legal-BERT from local cache", path=str(local_path))
                
                model     = AutoModel.from_pretrained(pretrained_model_name_or_path = str(local_path))
                tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = str(local_path))

            else:
                log_info(f"Downloading Legal-BERT from HuggingFace", model_name = config["model_name"])
                
                model     = AutoModel.from_pretrained(pretrained_model_name_or_path = config["model_name"])
                tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = config["model_name"])
                
                # Save to local cache
                log_info(f"Saving Legal-BERT to local cache", path = str(local_path))
                local_path.mkdir(parents = True, exist_ok = True)

                model.save_pretrained(save_directory = str(local_path))
                tokenizer.save_pretrained(save_directory = str(local_path))
            
            # Move to device
            model.to(self.device)
            model.eval()
            
            # Calculate memory size
            memory_mb = sum(p.nelement() * p.element_size() for p in model.parameters()) / (1024 * 1024)
            
            # Register as loaded
            self.registry.register(ModelType.LEGAL_BERT,
                                   ModelInfo(name           = "legal-bert",
                                             type           = ModelType.LEGAL_BERT,
                                             status         = ModelStatus.LOADED,
                                             model          = model,
                                             tokenizer      = tokenizer,
                                             memory_size_mb = memory_mb,
                                             metadata       = {"device" : self.device, "model_name" : config["model_name"]}
                                            )
                                  )
            
            log_info("Legal-BERT loaded successfully",
                     memory_mb  = round(memory_mb, 2),
                     device     = self.device,
                     parameters = sum(p.numel() for p in model.parameters()),
                    )
            
            return model, tokenizer
            
        except Exception as e:
            log_error(e, context = {"component": "ModelLoader", "operation": "load_legal_bert", "model_name": self.config.LEGAL_BERT["model_name"]})
            
            self.registry.register(ModelType.LEGAL_BERT,
                                   ModelInfo(name          = "legal-bert",
                                             type          = ModelType.LEGAL_BERT,
                                             status        = ModelStatus.ERROR,
                                             error_message = str(e),
                                            )
                                  )
            raise


    def load_classifier_model(self) -> tuple:
        """
        Load contract classification model using Legal-BERT with classification head
        """
        # Check if already loaded
        if self.registry.is_loaded(ModelType.CLASSIFIER):
            info = self.registry.get(ModelType.CLASSIFIER)

            log_info("Classifier model already loaded from cache",
                     memory_mb    = info.memory_size_mb,
                     access_count = info.access_count,
                    )

            return info.model, info.tokenizer
        
        # Mark as loading
        self.registry.register(ModelType.CLASSIFIER,
                               ModelInfo(name   = "classifier", 
                                         type   = ModelType.CLASSIFIER, 
                                         status = ModelStatus.LOADING,
                                        )
                              )
        
        try:
            config = self.config.CLASSIFIER_MODEL
            
            log_info("Loading classifier model (Legal-BERT based)", 
                     embedding_dim  = config["embedding_dim"],
                     hidden_dim     = config["hidden_dim"],
                     num_categories = config["num_categories"],
                    )
            
            # Use the Legal-BERT model but prepare it for classification
            base_model, tokenizer = self.load_legal_bert()
            
            # Register as loaded (sharing the same Legal-BERT instance)
            self.registry.register(ModelType.CLASSIFIER,
                                   ModelInfo(name           = "classifier",
                                             type           = ModelType.CLASSIFIER,
                                             status         = ModelStatus.LOADED,
                                             model          = base_model,  
                                             tokenizer      = tokenizer,   
                                             memory_size_mb = 0.0,  
                                             metadata       = {"device"        : self.device, 
                                                               "base_model"    : "legal-bert",
                                                               "embedding_dim" : config["embedding_dim"],
                                                               "num_classes"   : config["num_categories"],
                                                               "purpose"       : "contract_type_classification",
                                                              }
                                            )
                                  )
            
            log_info("Classifier model loaded successfully",
                     base_model     = "legal-bert",
                     num_categories = config["num_categories"],
                     note           = "Using Legal-BERT for both clause extraction and classification",
                    )
            
            return base_model, tokenizer
            
        except Exception as e:
            log_error(e, context = {"component": "ModelLoader", "operation": "load_classifier_model"})
            
            self.registry.register(ModelType.CLASSIFIER,
                                   ModelInfo(name          = "classifier",
                                             type          = ModelType.CLASSIFIER,
                                             status        = ModelStatus.ERROR,
                                             error_message = str(e),
                                            )
                                  )
            raise

    
    def load_embedding_model(self) -> SentenceTransformer:
        """
        Load sentence transformer for embeddings
        """
        # Check if already loaded
        if self.registry.is_loaded(ModelType.EMBEDDING):
            info = self.registry.get(ModelType.EMBEDDING)

            log_info("Embedding model already loaded from cache",
                     memory_mb    = info.memory_size_mb,
                     access_count = info.access_count,
                    )
            return info.model
        
        # Mark as loading
        self.registry.register(ModelType.EMBEDDING,
                               ModelInfo(name   = "embedding", 
                                         type   = ModelType.EMBEDDING, 
                                         status = ModelStatus.LOADING,
                                        )
                              )
        
        try:
            config         = self.config.EMBEDDING_MODEL
            local_path     = config["local_path"]
            force_download = config.get("force_download", False)
            
            # Check if we should use local cache
            if local_path.exists() and not force_download:
                log_info("Loading embedding model from local cache", path = str(local_path))

                model = SentenceTransformer(model_name_or_path = str(local_path))

            else:
                log_info("Downloading embedding model from HuggingFace", model_name = config["model_name"])
                
                model = SentenceTransformer(model_name_or_path = config["model_name"]) 
                
                # Save to local cache
                log_info("Saving embedding model to local cache", path = str(local_path))
                local_path.mkdir(parents = True, exist_ok = True)
                model.save(str(local_path))
            
            # Move to device
            if self.device == "cuda":
                model = model.to(self.device)
            
            # Estimate memory size
            memory_mb = 100  
            
            # Register as loaded
            self.registry.register(ModelType.EMBEDDING,
                                   ModelInfo(name           = "embedding",
                                             type           = ModelType.EMBEDDING,
                                             status         = ModelStatus.LOADED,
                                             model          = model,
                                             memory_size_mb = memory_mb,
                                             metadata       = {"device": self.device, "model_name": config["model_name"], "dimension": config["dimension"]}
                                            )
                                  )
            
            log_info("Embedding model loaded successfully",
                     memory_mb = memory_mb,
                     device    = self.device,
                     dimension = config["dimension"],
                    )
            
            return model
            
        except Exception as e:
            log_error(e, context = {"component": "ModelLoader", "operation": "load_embedding_model", "model_name": self.config.EMBEDDING_MODEL["model_name"]})
            
            self.registry.register(ModelType.EMBEDDING,
                                   ModelInfo(name          = "embedding",
                                             type          = ModelType.EMBEDDING,
                                             status        = ModelStatus.ERROR,
                                             error_message = str(e),
                                            )
                                  )
            raise

    
    def ensure_models_downloaded(self):
        """
        Ensure all required models are downloaded before use
        """
        log_info("Ensuring all models are downloaded...")
        
        try:
            # Download Legal-BERT if needed
            if not self.registry.is_loaded(ModelType.LEGAL_BERT):
                config     = self.config.LEGAL_BERT
                local_path = config["local_path"]
                
                if not self._check_model_files_exist(local_path):
                    log_info("Pre-downloading Legal-BERT...")

                    model     = AutoModel.from_pretrained(pretrained_model_name_or_path = config["model_name"])
                    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = config["model_name"])
                    
                    local_path.mkdir(parents = True, exist_ok = True)
                    model.save_pretrained(save_directory = str(local_path))
                    tokenizer.save_pretrained(save_directory = str(local_path))

                    log_info("Legal-BERT pre-downloaded successfully")
            
            # Download embedding model if needed
            if not self.registry.is_loaded(ModelType.EMBEDDING):
                config     = self.config.EMBEDDING_MODEL
                local_path = config["local_path"]
                
                if not local_path.exists():
                    log_info("Pre-downloading embedding model...")
                    model = SentenceTransformer(model_name_or_path = config["model_name"])

                    local_path.mkdir(parents = True, exist_ok = True)

                    model.save(str(local_path))
                    log_info("Embedding model pre-downloaded successfully")
            
            # Note: Classifier model is a stub, no download needed
            log_info("Classifier model stub - no download required (uses Legal-BERT)")
                    
            log_info("All models are ready for use")
            
        except Exception as e:
            log_error(e, context={"component": "ModelLoader", "operation": "ensure_models_downloaded"})
            raise

    
    def get_registry_stats(self) -> dict:
        """
        Get statistics about loaded models
        """
        stats = self.registry.get_stats()
        log_info("Retrieved registry statistics",
                 total_models    = stats["total_models"],
                 loaded_models   = stats["loaded_models"],
                 total_memory_mb = stats["total_memory_mb"],
                )

        return stats
    

    def clear_cache(self):
        """
        Clear all models from memory
        """
        log_info("Clearing all models from cache")
        self.registry.clear_all()
        log_info("All models cleared from cache")