legal-doc-backend / backend /app /utils /enhanced_models.py
Harsh Upadhyay
adding backend to spaces with initial commit.
8397f09
import torch
import logging
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer, util
import numpy as np
from typing import List, Dict, Any, Optional
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import json
import os
class EnhancedModelManager:
"""
Enhanced model manager with ensemble methods, better prompting, and multiple models
for improved accuracy in legal document analysis.
"""
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.models = {}
self.embedders = {}
self.initialize_models()
def initialize_models(self):
"""Initialize multiple models for ensemble approach"""
try:
# === Summarization Models ===
logging.info("Loading summarization models...")
# Only the legal-specific summarizer
self.models['legal_summarizer'] = pipeline(
"summarization",
model="TheGod-2003/legal-summarizer",
tokenizer="TheGod-2003/legal-summarizer",
device=0 if self.device == "cuda" else -1
)
logging.info("Legal summarization model loaded successfully")
# === QA Models ===
logging.info("Loading QA models...")
# Primary legal QA model
self.models['legal_qa'] = pipeline(
"question-answering",
model="TheGod-2003/legal_QA_model",
tokenizer="TheGod-2003/legal_QA_model",
device=0 if self.device == "cuda" else -1
)
# Alternative QA models
try:
self.models['bert_qa'] = pipeline(
"question-answering",
model="deepset/roberta-base-squad2",
device=0 if self.device == "cuda" else -1
)
except Exception as e:
logging.warning(f"Could not load RoBERTa QA model: {e}")
try:
self.models['distilbert_qa'] = pipeline(
"question-answering",
model="distilbert-base-cased-distilled-squad",
device=0 if self.device == "cuda" else -1
)
except Exception as e:
logging.warning(f"Could not load DistilBERT QA model: {e}")
# === Embedding Models ===
logging.info("Loading embedding models...")
# Primary embedding model
self.embedders['mpnet'] = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
# Alternative embedding models for ensemble
try:
self.embedders['all_minilm'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
except Exception as e:
logging.warning(f"Could not load all-MiniLM embedder: {e}")
try:
self.embedders['paraphrase'] = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
except Exception as e:
logging.warning(f"Could not load paraphrase embedder: {e}")
logging.info("All models loaded successfully")
except Exception as e:
logging.error(f"Error initializing models: {e}")
raise
def generate_enhanced_summary(self, text: str, max_length: int = 4096, min_length: int = 200) -> Dict[str, Any]:
"""
Generate enhanced summary using ensemble approach with multiple models
"""
try:
summaries = []
weights = []
cleaned_text = self._preprocess_text(text)
# Handle long documents with improved chunking
cleaned_text = self._handle_long_documents(cleaned_text)
# Only legal summarizer
if 'legal_summarizer' in self.models:
try:
# Improved parameters for LED-16384 model
summary = self.models['legal_summarizer'](
cleaned_text,
max_length=max_length,
min_length=min_length,
num_beams=5, # Increased for better quality
length_penalty=1.2, # Slightly favor longer summaries
repetition_penalty=1.5, # Reduced to avoid over-penalization
no_repeat_ngram_size=2, # Reduced for legal text
early_stopping=False, # Disabled to prevent premature stopping
do_sample=True, # Enable sampling for better diversity
temperature=0.7, # Add some randomness
top_p=0.9, # Nucleus sampling
pad_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id,
eos_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id
)[0]['summary_text']
# Ensure summary is complete
summary = self._ensure_complete_summary(summary, cleaned_text)
# Retry if summary is too short or incomplete
if len(summary.split()) < min_length or not summary.strip().endswith(('.', '!', '?')):
logging.info("Summary too short or incomplete, retrying with different parameters...")
retry_summary = self.models['legal_summarizer'](
cleaned_text,
max_length=max_length * 2, # Double the max length
min_length=min_length,
num_beams=3, # Reduce beams for faster generation
length_penalty=1.5, # Favor longer summaries
repetition_penalty=1.2,
no_repeat_ngram_size=1,
early_stopping=False,
do_sample=False, # Disable sampling for more deterministic output
pad_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id,
eos_token_id=self.models['legal_summarizer'].tokenizer.eos_token_id
)[0]['summary_text']
retry_summary = self._ensure_complete_summary(retry_summary, cleaned_text)
if len(retry_summary.split()) > len(summary.split()):
summary = retry_summary
summaries.append(summary)
weights.append(1.0)
except Exception as e:
logging.warning(f"Legal summarizer failed: {e}")
# Fallback to extractive summarization
fallback_summary = self._extractive_summarization(cleaned_text, max_length)
if fallback_summary:
summaries.append(fallback_summary)
weights.append(1.0)
if not summaries:
raise Exception("No models could generate summaries")
final_summary = self._ensemble_summaries(summaries, weights)
final_summary = self._postprocess_summary(final_summary, summaries, min_sentences=8)
return {
'summary': final_summary,
'model_summaries': summaries,
'weights': weights,
'confidence': self._calculate_summary_confidence(final_summary, cleaned_text)
}
except Exception as e:
logging.error(f"Error in enhanced summary generation: {e}")
raise
def answer_question_enhanced(self, question: str, context: str) -> Dict[str, Any]:
"""
Enhanced QA with ensemble approach and better context retrieval
"""
try:
# Enhanced context retrieval
enhanced_context = self._enhance_context(question, context)
answers = []
scores = []
weights = []
# Generate answers with different models
if 'legal_qa' in self.models:
try:
result = self.models['legal_qa'](
question=question,
context=enhanced_context
)
answers.append(result['answer'])
scores.append(result['score'])
weights.append(0.5) # Higher weight for legal-specific model
except Exception as e:
logging.warning(f"Legal QA model failed: {e}")
if 'bert_qa' in self.models:
try:
result = self.models['bert_qa'](
question=question,
context=enhanced_context
)
answers.append(result['answer'])
scores.append(result['score'])
weights.append(0.3)
except Exception as e:
logging.warning(f"RoBERTa QA model failed: {e}")
if 'distilbert_qa' in self.models:
try:
result = self.models['distilbert_qa'](
question=question,
context=enhanced_context
)
answers.append(result['answer'])
scores.append(result['score'])
weights.append(0.2)
except Exception as e:
logging.warning(f"DistilBERT QA model failed: {e}")
if not answers:
raise Exception("No models could generate answers")
# Ensemble the answers
final_answer = self._ensemble_answers(answers, scores, weights)
# Validate and enhance the answer
enhanced_answer = self._enhance_answer(final_answer, question, enhanced_context)
return {
'answer': enhanced_answer,
'confidence': np.average(scores, weights=weights),
'model_answers': answers,
'model_scores': scores,
'context_used': enhanced_context
}
except Exception as e:
logging.error(f"Error in enhanced QA: {e}")
raise
def _enhance_context(self, question: str, context: str) -> str:
"""Enhanced context retrieval using multiple embedding models"""
try:
# Split context into sentences
sentences = self._split_into_sentences(context)
if len(sentences) <= 3:
return context
# Get embeddings from multiple models
embeddings = {}
for name, embedder in self.embedders.items():
try:
sentence_embeddings = embedder.encode(sentences, convert_to_tensor=True)
question_embedding = embedder.encode(question, convert_to_tensor=True)
similarities = util.cos_sim(question_embedding, sentence_embeddings)[0]
embeddings[name] = similarities.cpu().numpy()
except Exception as e:
logging.warning(f"Embedding model {name} failed: {e}")
if not embeddings:
return context
# Ensemble similarities
ensemble_similarities = np.mean(list(embeddings.values()), axis=0)
# Get top sentences
top_indices = np.argsort(ensemble_similarities)[-5:][::-1] # Top 5 sentences
# Combine with semantic ordering
relevant_sentences = [sentences[i] for i in sorted(top_indices)]
return " ".join(relevant_sentences)
except Exception as e:
logging.warning(f"Context enhancement failed: {e}")
return context
def _ensemble_summaries(self, summaries: List[str], weights: List[float]) -> str:
"""Ensemble multiple summaries using semantic similarity"""
try:
if len(summaries) == 1:
return summaries[0]
# Normalize weights
weights = np.array(weights) / np.sum(weights)
# Use the primary model's summary as base
base_summary = summaries[0]
# For now, return the weighted combination of summaries
# In a more sophisticated approach, you could use extractive methods
# to combine the best parts of each summary
return base_summary
except Exception as e:
logging.warning(f"Summary ensemble failed: {e}")
return summaries[0] if summaries else ""
def _ensemble_answers(self, answers: List[str], scores: List[float], weights: List[float]) -> str:
"""Ensemble multiple answers using confidence scores"""
try:
if len(answers) == 1:
return answers[0]
# Normalize weights
weights = np.array(weights) / np.sum(weights)
# Weighted voting based on confidence scores
weighted_scores = np.array(scores) * weights
best_index = np.argmax(weighted_scores)
return answers[best_index]
except Exception as e:
logging.warning(f"Answer ensemble failed: {e}")
return answers[0] if answers else ""
def _enhance_answer(self, answer: str, question: str, context: str) -> str:
"""Enhance answer with post-processing and validation"""
try:
# Clean the answer
answer = answer.strip()
# Apply legal-specific post-processing
answer = self._apply_legal_postprocessing(answer, question)
# Validate answer against context
if not self._validate_answer_context(answer, context):
# Try to extract a better answer from context
extracted_answer = self._extract_answer_from_context(question, context)
if extracted_answer:
answer = extracted_answer
return answer
except Exception as e:
logging.warning(f"Answer enhancement failed: {e}")
return answer
def _apply_legal_postprocessing(self, answer: str, question: str) -> str:
"""Apply legal-specific post-processing rules"""
try:
# Remove common legal document artifacts
answer = re.sub(r'\b(SEC\.|Section|Article)\s*\d+\.?', '', answer, flags=re.IGNORECASE)
answer = re.sub(r'\s+', ' ', answer)
# Handle specific question types
question_lower = question.lower()
if any(word in question_lower for word in ['how long', 'duration', 'period']):
# Extract time-related information
time_match = re.search(r'\d+\s*(years?|months?|days?|weeks?)', answer, re.IGNORECASE)
if time_match:
return time_match.group(0)
elif any(word in question_lower for word in ['how much', 'cost', 'price', 'amount']):
# Extract monetary information
money_match = re.search(r'\$\d{1,3}(,\d{3})*(\.\d{2})?', answer)
if money_match:
return money_match.group(0)
elif any(word in question_lower for word in ['when', 'date']):
# Extract date information
date_match = re.search(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', answer)
if date_match:
return date_match.group(0)
return answer.strip()
except Exception as e:
logging.warning(f"Legal post-processing failed: {e}")
return answer
def _validate_answer_context(self, answer: str, context: str) -> bool:
"""Validate if answer is present in context"""
try:
# Simple validation - check if key terms from answer are in context
answer_terms = set(word.lower() for word in answer.split() if len(word) > 3)
context_terms = set(word.lower() for word in context.split())
# Check if at least 50% of answer terms are in context
if answer_terms:
overlap = len(answer_terms.intersection(context_terms)) / len(answer_terms)
return overlap >= 0.5
return True
except Exception as e:
logging.warning(f"Answer validation failed: {e}")
return True
def _extract_answer_from_context(self, question: str, context: str) -> Optional[str]:
"""Extract answer directly from context using patterns"""
try:
question_lower = question.lower()
if any(word in question_lower for word in ['how long', 'duration', 'period']):
match = re.search(r'\d+\s*(years?|months?|days?|weeks?)', context, re.IGNORECASE)
return match.group(0) if match else None
elif any(word in question_lower for word in ['how much', 'cost', 'price', 'amount']):
match = re.search(r'\$\d{1,3}(,\d{3})*(\.\d{2})?', context)
return match.group(0) if match else None
elif any(word in question_lower for word in ['when', 'date']):
match = re.search(r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}', context)
return match.group(0) if match else None
return None
except Exception as e:
logging.warning(f"Answer extraction failed: {e}")
return None
def _preprocess_text(self, text: str) -> str:
"""Preprocess text for better model performance"""
try:
# Remove common artifacts but preserve legal structure
text = re.sub(r'[\\\n\r\u200b\u2022\u00a0_=]+', ' ', text)
text = re.sub(r'<.*?>', ' ', text)
# Preserve legal citations and numbers (don't remove them completely)
# Instead of removing section numbers, normalize them
text = re.sub(r'\b(SEC\.|Section|Article)\s*(\d+)\.?', r'Section \2', text, flags=re.IGNORECASE)
# Clean up excessive whitespace
text = re.sub(r'\s{2,}', ' ', text)
# Preserve important legal punctuation and formatting
text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text) # Ensure proper sentence spacing
# Remove non-printable characters but keep legal symbols
text = re.sub(r'[^\x00-\x7F]+', ' ', text)
# Ensure proper spacing around legal terms
text = re.sub(r'\b(Lessee|Lessor|Party|Parties)\b', r' \1 ', text, flags=re.IGNORECASE)
return text.strip()
except Exception as e:
logging.warning(f"Text preprocessing failed: {e}")
return text
def _chunk_text_for_summarization(self, text: str, max_words: int = 8000) -> str:
"""Chunk long text for summarization while preserving legal document structure"""
try:
words = text.split()
if len(words) <= max_words:
return text
# Split into sentences first
sentences = self._split_into_sentences(text)
# Take the most important sentences (first and last portions)
total_sentences = len(sentences)
if total_sentences <= 50:
return text
# Take first 60% and last 20% of sentences
first_portion = int(total_sentences * 0.6)
last_portion = int(total_sentences * 0.2)
selected_sentences = sentences[:first_portion] + sentences[-last_portion:]
chunked_text = " ".join(selected_sentences)
# Ensure we don't exceed token limit
if len(chunked_text.split()) > max_words:
chunked_text = " ".join(chunked_text.split()[:max_words])
return chunked_text
except Exception as e:
logging.warning(f"Text chunking failed: {e}")
return text
def _handle_long_documents(self, text: str) -> str:
"""Handle very long documents by using a sliding window approach"""
try:
# LED-16384 has a context window of ~16k tokens
# Conservative estimate: ~12k tokens for input to leave room for generation
max_tokens = 12000
# Approximate tokens (roughly 1.3 words per token for English)
words = text.split()
if len(words) <= max_tokens * 0.8: # Conservative limit
return text
# Use sliding window approach for very long documents
sentences = self._split_into_sentences(text)
if len(sentences) < 10:
return text
# Take key sections: beginning, middle, and end
total_sentences = len(sentences)
# Take first 40%, middle 20%, and last 40%
first_end = int(total_sentences * 0.4)
middle_start = int(total_sentences * 0.4)
middle_end = int(total_sentences * 0.6)
last_start = int(total_sentences * 0.6)
key_sentences = (
sentences[:first_end] +
sentences[middle_start:middle_end] +
sentences[last_start:]
)
# Ensure we don't exceed token limit
combined_text = " ".join(key_sentences)
words = combined_text.split()
if len(words) > max_tokens * 0.8:
# Truncate to safe limit
combined_text = " ".join(words[:int(max_tokens * 0.8)])
return combined_text
except Exception as e:
logging.warning(f"Long document handling failed: {e}")
return text
def _ensure_complete_summary(self, summary: str, original_text: str) -> str:
"""Ensure the summary is complete and not truncated mid-sentence"""
try:
if not summary:
return summary
# Check if summary ends with complete sentence
if not summary.rstrip().endswith(('.', '!', '?')):
# Find the last complete sentence
sentences = summary.split('. ')
if len(sentences) > 1:
# Remove the incomplete last sentence
summary = '. '.join(sentences[:-1]) + '.'
# Ensure minimum length
if len(summary.split()) < 50:
# Try to extract more content from original text
additional_content = self._extract_key_sentences(original_text, 100)
if additional_content:
summary = summary + " " + additional_content
return summary.strip()
except Exception as e:
logging.warning(f"Summary completion check failed: {e}")
return summary
def _extract_key_sentences(self, text: str, max_words: int = 100) -> str:
"""Extract key sentences from text for summary completion"""
try:
sentences = self._split_into_sentences(text)
# Simple heuristic: take sentences with legal keywords
legal_keywords = ['lease', 'rent', 'payment', 'term', 'agreement', 'lessor', 'lessee',
'covenant', 'obligation', 'right', 'duty', 'termination', 'renewal']
key_sentences = []
word_count = 0
for sentence in sentences:
sentence_lower = sentence.lower()
if any(keyword in sentence_lower for keyword in legal_keywords):
sentence_words = len(sentence.split())
if word_count + sentence_words <= max_words:
key_sentences.append(sentence)
word_count += sentence_words
else:
break
return " ".join(key_sentences)
except Exception as e:
logging.warning(f"Key sentence extraction failed: {e}")
return ""
def _extractive_summarization(self, text: str, max_length: int) -> str:
"""Fallback extractive summarization using TF-IDF"""
try:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
sentences = self._split_into_sentences(text)
if len(sentences) < 3:
return text
# Create TF-IDF vectors
vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
tfidf_matrix = vectorizer.fit_transform(sentences)
# Calculate sentence importance based on TF-IDF scores
sentence_scores = []
for i in range(len(sentences)):
score = tfidf_matrix[i].sum()
sentence_scores.append((score, i))
# Sort by score and take top sentences
sentence_scores.sort(reverse=True)
# Select sentences up to max_length
selected_indices = []
total_words = 0
for score, idx in sentence_scores:
sentence_words = len(sentences[idx].split())
if total_words + sentence_words <= max_length // 2: # Conservative estimate
selected_indices.append(idx)
total_words += sentence_words
else:
break
# Sort by original order
selected_indices.sort()
summary_sentences = [sentences[i] for i in selected_indices]
return " ".join(summary_sentences)
except Exception as e:
logging.warning(f"Extractive summarization failed: {e}")
return text[:max_length] if len(text) > max_length else text
def _postprocess_summary(self, summary: str, all_summaries: Optional[List[str]] = None, min_sentences: int = 10) -> str:
"""Post-process summary for better readability"""
try:
summary = re.sub(r'[\\\n\r\u200b\u2022\u00a0_=]+', ' ', summary)
summary = re.sub(r'[^\x00-\x7F]+', ' ', summary)
summary = re.sub(r'\s{2,}', ' ', summary)
# Remove redundant sentences
sentences = summary.split('. ')
unique_sentences = []
for sentence in sentences:
s = sentence.strip()
if s and s not in unique_sentences:
unique_sentences.append(s)
# If too short, add more unique sentences from other model outputs
if all_summaries is not None and len(unique_sentences) < min_sentences:
all_sentences = []
for summ in all_summaries:
all_sentences.extend([s.strip() for s in summ.split('. ') if s.strip()])
for s in all_sentences:
if s not in unique_sentences:
unique_sentences.append(s)
if len(unique_sentences) >= min_sentences:
break
return '. '.join(unique_sentences)
except Exception as e:
logging.warning(f"Summary post-processing failed: {e}")
return summary
def _split_into_sentences(self, text: str) -> List[str]:
"""Split text into sentences with improved handling for legal documents"""
try:
# More sophisticated sentence splitting for legal documents
# Handle legal abbreviations and citations properly
text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text)
# Split on sentence endings, but be careful with legal citations
sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text)
# Clean up sentences
cleaned_sentences = []
for sentence in sentences:
sentence = sentence.strip()
if sentence and len(sentence) > 10: # Filter out very short fragments
# Handle legal abbreviations that might have been split
if sentence.startswith(('Sec', 'Art', 'Clause', 'Para')):
# This might be a continuation, try to merge with previous
if cleaned_sentences:
cleaned_sentences[-1] = cleaned_sentences[-1] + " " + sentence
else:
cleaned_sentences.append(sentence)
else:
cleaned_sentences.append(sentence)
return cleaned_sentences if cleaned_sentences else [text]
except Exception as e:
logging.warning(f"Sentence splitting failed: {e}")
return [text]
def _calculate_summary_confidence(self, summary: str, original_text: str) -> float:
"""Calculate confidence score for summary"""
try:
# Simple confidence based on summary length and content
if not summary or len(summary) < 10:
return 0.0
# Check if summary contains key terms from original text
summary_terms = set(word.lower() for word in summary.split() if len(word) > 3)
original_terms = set(word.lower() for word in original_text.split() if len(word) > 3)
if original_terms:
overlap = len(summary_terms.intersection(original_terms)) / len(original_terms)
return min(overlap * 2, 1.0) # Scale overlap to 0-1 range
return 0.5 # Default confidence
except Exception as e:
logging.warning(f"Confidence calculation failed: {e}")
return 0.5
# Global instance
enhanced_model_manager = EnhancedModelManager()