import torch import logging from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering from sentence_transformers import SentenceTransformer, util import numpy as np from typing import List, Dict, Any, Optional import re from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import json import os class EnhancedModelManager: """ Enhanced model manager with ensemble methods, better prompting, and multiple models for improved accuracy in legal document analysis. """ def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.models = {} self.embedders = {} self.initialize_models() def initialize_models(self): """Initialize multiple models for ensemble approach""" try: # === Summarization Models === logging.info("Loading summarization models...") # Only the legal-specific summarizer self.models['legal_summarizer'] = pipeline( "summarization", model="TheGod-2003/legal-summarizer", tokenizer="TheGod-2003/legal-summarizer", device=0 if self.device == "cuda" else -1 ) logging.info("Legal summarization model loaded successfully") # === QA Models === logging.info("Loading QA models...") # Primary legal QA model self.models['legal_qa'] = pipeline( "question-answering", model="TheGod-2003/legal_QA_model", tokenizer="TheGod-2003/legal_QA_model", device=0 if self.device == "cuda" else -1 ) # Alternative QA models try: self.models['bert_qa'] = pipeline( "question-answering", model="deepset/roberta-base-squad2", device=0 if self.device == "cuda" else -1 ) except Exception as e: logging.warning(f"Could not load RoBERTa QA model: {e}") try: self.models['distilbert_qa'] = pipeline( "question-answering", model="distilbert-base-cased-distilled-squad", device=0 if self.device == "cuda" else -1 ) except Exception as e: logging.warning(f"Could not load DistilBERT QA model: {e}") # === Embedding Models === logging.info("Loading embedding models...") # Primary embedding model self.embedders['mpnet'] = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') # Alternative embedding models for ensemble try: self.embedders['all_minilm'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') except Exception as e: logging.warning(f"Could not load all-MiniLM embedder: {e}") try: self.embedders['paraphrase'] = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') except Exception as e: logging.warning(f"Could not load paraphrase embedder: {e}") logging.info("All models loaded successfully") except Exception as e: logging.error(f"Error initializing models: {e}") raise def generate_enhanced_summary(self, text: str, max_length: int = 4096, min_length: int = 200) -> Dict[str, Any]: """ Generate enhanced summary using ensemble approach with multiple models """ try: summaries = [] weights = [] cleaned_text = self._preprocess_text(text) # Handle long documents with improved chunking cleaned_text = self._handle_long_documents(cleaned_text) # Only legal summarizer if 'legal_summarizer' in self.models: try: # Improved parameters for LED-16384 model summary = self.models['legal_summarizer']( cleaned_text, max_length=max_length, min_length=min_length, num_beams=5, # Increased for better quality length_penalty=1.2, # Slightly favor longer summaries repetition_penalty=1.5, # Reduced to avoid over-penalization no_repeat_ngram_size=2, # Reduced for legal text early_stopping=False, # Disabled to prevent premature stopping do_sample=True, # Enable sampling for better diversity temperature=0.7, # Add some randomness top_p=0.9, # Nucleus sampling pad_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id, eos_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id )[0]['summary_text'] # Ensure summary is complete summary = self._ensure_complete_summary(summary, cleaned_text) # Retry if summary is too short or incomplete if len(summary.split()) < min_length or not summary.strip().endswith(('.', '!', '?')): logging.info("Summary too short or incomplete, retrying with different parameters...") retry_summary = self.models['legal_summarizer']( cleaned_text, max_length=max_length * 2, # Double the max length min_length=min_length, num_beams=3, # Reduce beams for faster generation length_penalty=1.5, # Favor longer summaries repetition_penalty=1.2, no_repeat_ngram_size=1, early_stopping=False, do_sample=False, # Disable sampling for more deterministic output pad_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id, eos_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id )[0]['summary_text'] retry_summary = self._ensure_complete_summary(retry_summary, cleaned_text) if len(retry_summary.split()) > len(summary.split()): summary = retry_summary summaries.append(summary) weights.append(1.0) except Exception as e: logging.warning(f"Legal summarizer failed: {e}") # Fallback to extractive summarization fallback_summary = self._extractive_summarization(cleaned_text, max_length) if fallback_summary: summaries.append(fallback_summary) weights.append(1.0) if not summaries: raise Exception("No models could generate summaries") final_summary = self._ensemble_summaries(summaries, weights) final_summary = self._postprocess_summary(final_summary, summaries, min_sentences=8) return { 'summary': final_summary, 'model_summaries': summaries, 'weights': weights, 'confidence': self._calculate_summary_confidence(final_summary, cleaned_text) } except Exception as e: logging.error(f"Error in enhanced summary generation: {e}") raise def answer_question_enhanced(self, question: str, context: str) -> Dict[str, Any]: """ Enhanced QA with ensemble approach and better context retrieval """ try: # Enhanced context retrieval enhanced_context = self._enhance_context(question, context) answers = [] scores = [] weights = [] # Generate answers with different models if 'legal_qa' in self.models: try: result = self.models['legal_qa']( question=question, context=enhanced_context ) answers.append(result['answer']) scores.append(result['score']) weights.append(0.5) # Higher weight for legal-specific model except Exception as e: logging.warning(f"Legal QA model failed: {e}") if 'bert_qa' in self.models: try: result = self.models['bert_qa']( question=question, context=enhanced_context ) answers.append(result['answer']) scores.append(result['score']) weights.append(0.3) except Exception as e: logging.warning(f"RoBERTa QA model failed: {e}") if 'distilbert_qa' in self.models: try: result = self.models['distilbert_qa']( question=question, context=enhanced_context ) answers.append(result['answer']) scores.append(result['score']) weights.append(0.2) except Exception as e: logging.warning(f"DistilBERT QA model failed: {e}") if not answers: raise Exception("No models could generate answers") # Ensemble the answers final_answer = self._ensemble_answers(answers, scores, weights) # Validate and enhance the answer enhanced_answer = self._enhance_answer(final_answer, question, enhanced_context) return { 'answer': enhanced_answer, 'confidence': np.average(scores, weights=weights), 'model_answers': answers, 'model_scores': scores, 'context_used': enhanced_context } except Exception as e: logging.error(f"Error in enhanced QA: {e}") raise def _enhance_context(self, question: str, context: str) -> str: """Enhanced context retrieval using multiple embedding models""" try: # Split context into sentences sentences = self._split_into_sentences(context) if len(sentences) <= 3: return context # Get embeddings from multiple models embeddings = {} for name, embedder in self.embedders.items(): try: sentence_embeddings = embedder.encode(sentences, convert_to_tensor=True) question_embedding = embedder.encode(question, convert_to_tensor=True) similarities = util.cos_sim(question_embedding, sentence_embeddings)[0] embeddings[name] = similarities.cpu().numpy() except Exception as e: logging.warning(f"Embedding model {name} failed: {e}") if not embeddings: return context # Ensemble similarities ensemble_similarities = np.mean(list(embeddings.values()), axis=0) # Get top sentences top_indices = np.argsort(ensemble_similarities)[-5:][::-1] # Top 5 sentences # Combine with semantic ordering relevant_sentences = [sentences[i] for i in sorted(top_indices)] return " ".join(relevant_sentences) except Exception as e: logging.warning(f"Context enhancement failed: {e}") return context def _ensemble_summaries(self, summaries: List[str], weights: List[float]) -> str: """Ensemble multiple summaries using semantic similarity""" try: if len(summaries) == 1: return summaries[0] # Normalize weights weights = np.array(weights) / np.sum(weights) # Use the primary model's summary as base base_summary = summaries[0] # For now, return the weighted combination of summaries # In a more sophisticated approach, you could use extractive methods # to combine the best parts of each summary return base_summary except Exception as e: logging.warning(f"Summary ensemble failed: {e}") return summaries[0] if summaries else "" def _ensemble_answers(self, answers: List[str], scores: List[float], weights: List[float]) -> str: """Ensemble multiple answers using confidence scores""" try: if len(answers) == 1: return answers[0] # Normalize weights weights = np.array(weights) / np.sum(weights) # Weighted voting based on confidence scores weighted_scores = np.array(scores) * weights best_index = np.argmax(weighted_scores) return answers[best_index] except Exception as e: logging.warning(f"Answer ensemble failed: {e}") return answers[0] if answers else "" def _enhance_answer(self, answer: str, question: str, context: str) -> str: """Enhance answer with post-processing and validation""" try: # Clean the answer answer = answer.strip() # Apply legal-specific post-processing answer = self._apply_legal_postprocessing(answer, question) # Validate answer against context if not self._validate_answer_context(answer, context): # Try to extract a better answer from context extracted_answer = self._extract_answer_from_context(question, context) if extracted_answer: answer = extracted_answer return answer except Exception as e: logging.warning(f"Answer enhancement failed: {e}") return answer def _apply_legal_postprocessing(self, answer: str, question: str) -> str: """Apply legal-specific post-processing rules""" try: # Remove common legal document artifacts answer = re.sub(r'\b(SEC\.|Section|Article)\s*\d+\.?', '', answer, flags=re.IGNORECASE) answer = re.sub(r'\s+', ' ', answer) # Handle specific question types question_lower = question.lower() if any(word in question_lower for word in ['how long', 'duration', 'period']): # Extract time-related information time_match = re.search(r'\d+\s*(years?|months?|days?|weeks?)', answer, re.IGNORECASE) if time_match: return time_match.group(0) elif any(word in question_lower for word in ['how much', 'cost', 'price', 'amount']): # Extract monetary information money_match = re.search(r'\$\d{1,3}(,\d{3})*(\.\d{2})?', answer) if money_match: return money_match.group(0) elif any(word in question_lower for word in ['when', 'date']): # Extract date information date_match = re.search(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', answer) if date_match: return date_match.group(0) return answer.strip() except Exception as e: logging.warning(f"Legal post-processing failed: {e}") return answer def _validate_answer_context(self, answer: str, context: str) -> bool: """Validate if answer is present in context""" try: # Simple validation - check if key terms from answer are in context answer_terms = set(word.lower() for word in answer.split() if len(word) > 3) context_terms = set(word.lower() for word in context.split()) # Check if at least 50% of answer terms are in context if answer_terms: overlap = len(answer_terms.intersection(context_terms)) / len(answer_terms) return overlap >= 0.5 return True except Exception as e: logging.warning(f"Answer validation failed: {e}") return True def _extract_answer_from_context(self, question: str, context: str) -> Optional[str]: """Extract answer directly from context using patterns""" try: question_lower = question.lower() if any(word in question_lower for word in ['how long', 'duration', 'period']): match = re.search(r'\d+\s*(years?|months?|days?|weeks?)', context, re.IGNORECASE) return match.group(0) if match else None elif any(word in question_lower for word in ['how much', 'cost', 'price', 'amount']): match = re.search(r'\$\d{1,3}(,\d{3})*(\.\d{2})?', context) return match.group(0) if match else None elif any(word in question_lower for word in ['when', 'date']): match = re.search(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', context) return match.group(0) if match else None return None except Exception as e: logging.warning(f"Answer extraction failed: {e}") return None def _preprocess_text(self, text: str) -> str: """Preprocess text for better model performance""" try: # Remove common artifacts but preserve legal structure text = re.sub(r'[\\\n\r\u200b\u2022\u00a0_=]+', ' ', text) text = re.sub(r'<.*?>', ' ', text) # Preserve legal citations and numbers (don't remove them completely) # Instead of removing section numbers, normalize them text = re.sub(r'\b(SEC\.|Section|Article)\s*(\d+)\.?', r'Section \2', text, flags=re.IGNORECASE) # Clean up excessive whitespace text = re.sub(r'\s{2,}', ' ', text) # Preserve important legal punctuation and formatting text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text) # Ensure proper sentence spacing # Remove non-printable characters but keep legal symbols text = re.sub(r'[^\x00-\x7F]+', ' ', text) # Ensure proper spacing around legal terms text = re.sub(r'\b(Lessee|Lessor|Party|Parties)\b', r' \1 ', text, flags=re.IGNORECASE) return text.strip() except Exception as e: logging.warning(f"Text preprocessing failed: {e}") return text def _chunk_text_for_summarization(self, text: str, max_words: int = 8000) -> str: """Chunk long text for summarization while preserving legal document structure""" try: words = text.split() if len(words) <= max_words: return text # Split into sentences first sentences = self._split_into_sentences(text) # Take the most important sentences (first and last portions) total_sentences = len(sentences) if total_sentences <= 50: return text # Take first 60% and last 20% of sentences first_portion = int(total_sentences * 0.6) last_portion = int(total_sentences * 0.2) selected_sentences = sentences[:first_portion] + sentences[-last_portion:] chunked_text = " ".join(selected_sentences) # Ensure we don't exceed token limit if len(chunked_text.split()) > max_words: chunked_text = " ".join(chunked_text.split()[:max_words]) return chunked_text except Exception as e: logging.warning(f"Text chunking failed: {e}") return text def _handle_long_documents(self, text: str) -> str: """Handle very long documents by using a sliding window approach""" try: # LED-16384 has a context window of ~16k tokens # Conservative estimate: ~12k tokens for input to leave room for generation max_tokens = 12000 # Approximate tokens (roughly 1.3 words per token for English) words = text.split() if len(words) <= max_tokens * 0.8: # Conservative limit return text # Use sliding window approach for very long documents sentences = self._split_into_sentences(text) if len(sentences) < 10: return text # Take key sections: beginning, middle, and end total_sentences = len(sentences) # Take first 40%, middle 20%, and last 40% first_end = int(total_sentences * 0.4) middle_start = int(total_sentences * 0.4) middle_end = int(total_sentences * 0.6) last_start = int(total_sentences * 0.6) key_sentences = ( sentences[:first_end] + sentences[middle_start:middle_end] + sentences[last_start:] ) # Ensure we don't exceed token limit combined_text = " ".join(key_sentences) words = combined_text.split() if len(words) > max_tokens * 0.8: # Truncate to safe limit combined_text = " ".join(words[:int(max_tokens * 0.8)]) return combined_text except Exception as e: logging.warning(f"Long document handling failed: {e}") return text def _ensure_complete_summary(self, summary: str, original_text: str) -> str: """Ensure the summary is complete and not truncated mid-sentence""" try: if not summary: return summary # Check if summary ends with complete sentence if not summary.rstrip().endswith(('.', '!', '?')): # Find the last complete sentence sentences = summary.split('. ') if len(sentences) > 1: # Remove the incomplete last sentence summary = '. '.join(sentences[:-1]) + '.' # Ensure minimum length if len(summary.split()) < 50: # Try to extract more content from original text additional_content = self._extract_key_sentences(original_text, 100) if additional_content: summary = summary + " " + additional_content return summary.strip() except Exception as e: logging.warning(f"Summary completion check failed: {e}") return summary def _extract_key_sentences(self, text: str, max_words: int = 100) -> str: """Extract key sentences from text for summary completion""" try: sentences = self._split_into_sentences(text) # Simple heuristic: take sentences with legal keywords legal_keywords = ['lease', 'rent', 'payment', 'term', 'agreement', 'lessor', 'lessee', 'covenant', 'obligation', 'right', 'duty', 'termination', 'renewal'] key_sentences = [] word_count = 0 for sentence in sentences: sentence_lower = sentence.lower() if any(keyword in sentence_lower for keyword in legal_keywords): sentence_words = len(sentence.split()) if word_count + sentence_words <= max_words: key_sentences.append(sentence) word_count += sentence_words else: break return " ".join(key_sentences) except Exception as e: logging.warning(f"Key sentence extraction failed: {e}") return "" def _extractive_summarization(self, text: str, max_length: int) -> str: """Fallback extractive summarization using TF-IDF""" try: from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity sentences = self._split_into_sentences(text) if len(sentences) < 3: return text # Create TF-IDF vectors vectorizer = TfidfVectorizer(stop_words='english', max_features=1000) tfidf_matrix = vectorizer.fit_transform(sentences) # Calculate sentence importance based on TF-IDF scores sentence_scores = [] for i in range(len(sentences)): score = tfidf_matrix[i].sum() sentence_scores.append((score, i)) # Sort by score and take top sentences sentence_scores.sort(reverse=True) # Select sentences up to max_length selected_indices = [] total_words = 0 for score, idx in sentence_scores: sentence_words = len(sentences[idx].split()) if total_words + sentence_words <= max_length // 2: # Conservative estimate selected_indices.append(idx) total_words += sentence_words else: break # Sort by original order selected_indices.sort() summary_sentences = [sentences[i] for i in selected_indices] return " ".join(summary_sentences) except Exception as e: logging.warning(f"Extractive summarization failed: {e}") return text[:max_length] if len(text) > max_length else text def _postprocess_summary(self, summary: str, all_summaries: Optional[List[str]] = None, min_sentences: int = 10) -> str: """Post-process summary for better readability""" try: summary = re.sub(r'[\\\n\r\u200b\u2022\u00a0_=]+', ' ', summary) summary = re.sub(r'[^\x00-\x7F]+', ' ', summary) summary = re.sub(r'\s{2,}', ' ', summary) # Remove redundant sentences sentences = summary.split('. ') unique_sentences = [] for sentence in sentences: s = sentence.strip() if s and s not in unique_sentences: unique_sentences.append(s) # If too short, add more unique sentences from other model outputs if all_summaries is not None and len(unique_sentences) < min_sentences: all_sentences = [] for summ in all_summaries: all_sentences.extend([s.strip() for s in summ.split('. ') if s.strip()]) for s in all_sentences: if s not in unique_sentences: unique_sentences.append(s) if len(unique_sentences) >= min_sentences: break return '. '.join(unique_sentences) except Exception as e: logging.warning(f"Summary post-processing failed: {e}") return summary def _split_into_sentences(self, text: str) -> List[str]: """Split text into sentences with improved handling for legal documents""" try: # More sophisticated sentence splitting for legal documents # Handle legal abbreviations and citations properly text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text) # Split on sentence endings, but be careful with legal citations sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text) # Clean up sentences cleaned_sentences = [] for sentence in sentences: sentence = sentence.strip() if sentence and len(sentence) > 10: # Filter out very short fragments # Handle legal abbreviations that might have been split if sentence.startswith(('Sec', 'Art', 'Clause', 'Para')): # This might be a continuation, try to merge with previous if cleaned_sentences: cleaned_sentences[-1] = cleaned_sentences[-1] + " " + sentence else: cleaned_sentences.append(sentence) else: cleaned_sentences.append(sentence) return cleaned_sentences if cleaned_sentences else [text] except Exception as e: logging.warning(f"Sentence splitting failed: {e}") return [text] def _calculate_summary_confidence(self, summary: str, original_text: str) -> float: """Calculate confidence score for summary""" try: # Simple confidence based on summary length and content if not summary or len(summary) < 10: return 0.0 # Check if summary contains key terms from original text summary_terms = set(word.lower() for word in summary.split() if len(word) > 3) original_terms = set(word.lower() for word in original_text.split() if len(word) > 3) if original_terms: overlap = len(summary_terms.intersection(original_terms)) / len(original_terms) return min(overlap * 2, 1.0) # Scale overlap to 0-1 range return 0.5 # Default confidence except Exception as e: logging.warning(f"Confidence calculation failed: {e}") return 0.5 # Global instance enhanced_model_manager = EnhancedModelManager()