Spaces:
Sleeping
Sleeping
import os | |
import time | |
import logging | |
import re | |
from datetime import datetime, timedelta | |
from dotenv import load_dotenv | |
from cryptography.fernet import Fernet | |
from simple_salesforce import Salesforce | |
from transformers import pipeline | |
from PIL import Image | |
import pytesseract | |
import pandas as pd | |
from docx import Document | |
import PyPDF2 | |
import gradio as gr | |
from pdf2image import convert_from_path | |
import tempfile | |
from pytz import timezone | |
import shutil | |
import unicodedata | |
import asyncio | |
import torch | |
# Global variables for caching | |
_sf = None | |
_summarizer = None | |
_fernet = None | |
_lock = asyncio.Lock() | |
# Setup logging | |
log_file = os.path.join(tempfile.gettempdir(), 'app.log') | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[logging.FileHandler(log_file)] | |
) | |
logger = logging.getLogger(__name__) | |
# Preload models and dependencies | |
def init_globals(): | |
global _summarizer, _fernet | |
load_dotenv() | |
required_env_vars = [ | |
'ENCRYPTION_KEY', 'SALESFORCE_USERNAME', 'SALESFORCE_PASSWORD', | |
'SALESFORCE_SECURITY_TOKEN', 'SALESFORCE_DOMAIN' | |
] | |
env = {var: os.getenv(var) for var in required_env_vars} | |
if missing := [k for k in required_env_vars if not env[k]]: | |
logger.error(f"Missing env vars: {', '.join(missing)}") | |
return False | |
try: | |
_fernet = Fernet(env['ENCRYPTION_KEY'].encode()) | |
except Exception as e: | |
logger.error(f"Invalid encryption key: {str(e)}") | |
return False | |
try: | |
_summarizer = pipeline( | |
"summarization", | |
model="t5-small", | |
tokenizer="t5-small", | |
framework="pt", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
logger.info("Summarizer initialized successfully") | |
except Exception as e: | |
logger.error(f"Summarizer init failed: {str(e)}") | |
return False | |
return True | |
# Check critical dependencies | |
def check_dependencies(): | |
try: | |
tesseract_path = shutil.which('tesseract') | |
if not tesseract_path: | |
logger.warning("Tesseract not found. OCR unavailable.") | |
return ["Tesseract"], [] | |
pytesseract.pytesseract.tesseract_cmd = tesseract_path | |
poppler_path = shutil.which('pdfinfo') | |
if not poppler_path: | |
logger.warning("Poppler not found.") | |
return ["Poppler"], [] | |
return [], [] | |
except Exception as e: | |
logger.error(f"Dependency check failed: {str(e)}") | |
return ["Tesseract", "Poppler"], [] | |
if not init_globals(): | |
raise RuntimeError("Failed to initialize global dependencies") | |
missing_deps, _ = check_dependencies() | |
if missing_deps: | |
logger.warning(f"Missing dependencies: {', '.join(missing_deps)}") | |
# Salesforce connection (async) | |
async def init_salesforce(max_retries=2, initial_delay=1): | |
global _sf | |
async with _lock: | |
if _sf is not None: | |
return _sf | |
for attempt in range(max_retries): | |
try: | |
_sf = await asyncio.get_event_loop().run_in_executor( | |
None, | |
lambda: Salesforce( | |
username=os.getenv('SALESFORCE_USERNAME'), | |
password=os.getenv('SALESFORCE_PASSWORD'), | |
security_token=os.getenv('SALESFORCE_SECURITY_TOKEN'), | |
domain=os.getenv('SALESFORCE_DOMAIN'), | |
version='58.0' | |
) | |
) | |
logger.info("Salesforce connection established") | |
return _sf | |
except Exception as e: | |
logger.error(f"Salesforce connection attempt {attempt + 1} failed: {str(e)}") | |
if attempt < max_retries - 1: | |
await asyncio.sleep(initial_delay * (2 ** attempt)) | |
raise ValueError("Salesforce connection failed after retries") | |
# Preprocess image for OCR (optimized) | |
def preprocess_image(image): | |
try: | |
return image.convert('L').resize((image.width, image.height), Image.BILINEAR) | |
except Exception as e: | |
logger.error(f"Image preprocess failed: {str(e)}") | |
return image.convert('L') | |
# Clean text (optimized) | |
def clean_text(text): | |
try: | |
if not text: | |
return "" | |
text = unicodedata.normalize('NFKC', text) | |
text = re.sub(r'\s+', ' ', text.strip()) | |
return text[:512] | |
except Exception as e: | |
logger.error(f"Text cleaning failed: {str(e)}") | |
return "" | |
# Validate file | |
def validate_file(file_path): | |
ext = os.path.splitext(file_path)[1].lower() | |
if ext not in ['.pdf', '.docx', '.png', '.jpg', '.jpeg', '.csv', '.xls', '.xlsx']: | |
return False, f"Unsupported file type: {ext}" | |
if not os.path.exists(file_path) or os.path.getsize(file_path) == 0: | |
return False, f"File not found or empty: {file_path}" | |
return True, None | |
# Extract text (async) | |
async def extract_text_async(file_path): | |
is_valid, error = validate_file(file_path) | |
if not is_valid: | |
return None, error | |
ext = os.path.splitext(file_path)[1].lower() | |
try: | |
if ext == '.pdf': | |
with open(file_path, 'rb') as f: | |
pdf_reader = PyPDF2.PdfReader(f) | |
text = "".join([p.extract_text() or "" for p in pdf_reader.pages[:1]]) | |
if not text or len(text.strip()) < 50: | |
images = convert_from_path(file_path, dpi=100, first_page=1, last_page=1, thread_count=2) | |
text = pytesseract.image_to_string(preprocess_image(images[0]), config='--psm 6') | |
logger.info(f"Extracted text: {text[:100]}...") | |
elif ext == '.docx': | |
doc = Document(file_path) | |
text = "\n".join([p.text for p in doc.paragraphs if p.text.strip()][:25]) | |
elif ext in ['.png', '.jpg', '.jpeg']: | |
img = Image.open(file_path) | |
img = preprocess_image(img) | |
text = pytesseract.image_to_string(img, config='--psm 6') | |
elif ext in ['.csv', '.xls', '.xlsx']: | |
df = pd.read_csv(file_path, encoding='utf-8') if ext == '.csv' else pd.read_excel(file_path) | |
text = " ".join(df.astype(str).values.flatten())[:500] | |
text = clean_text(text) | |
if not text or len(text) < 50: | |
return None, f"No valid text extracted from {file_path}" | |
return text, None | |
except Exception as e: | |
logger.error(f"Text extraction failed: {str(e)} with file {file_path}") | |
return None, f"Text extraction failed: {str(e)}" | |
# Parse dates (enhanced for better end date detection) | |
def parse_dates(text): | |
ist = timezone('Asia/Kolkata') | |
current_date = datetime.now(ist).strftime('%Y-%m-%d') | |
try: | |
date_patterns = [r'\b\d{4}-\d{2}-\d{2}\b'] | |
term_patterns = [r'(?:term|duration)\s*(?:of|for)\s*(\d+)\s*(?:year|years)'] | |
dates = re.findall(date_patterns[0], text, re.IGNORECASE) | |
parsed_dates = [datetime.strptime(date, '%Y-%m-%d').strftime('%Y-%m-%d') for date in dates if '-' in date] | |
term_match = re.search(term_patterns[0], text, re.IGNORECASE) | |
start_date = parsed_dates[0] if parsed_dates else current_date | |
end_date = (datetime.strptime(start_date, '%Y-%m-%d') + timedelta(days=(int(term_match.group(1)) * 365 if term_match else 1) * 365)).strftime('%Y-%m-%d') if parsed_dates else current_date | |
logger.info(f"Parsed dates - Start: {start_date}, End: {end_date}") | |
return start_date, end_date | |
except Exception as e: | |
logger.error(f"Date parsing failed: {str(e)} with text {text[:50]}...") | |
return current_date, current_date | |
# Summarize contract (async) | |
async def summarize_contract_async(text, summarizer, file_name): | |
aspects = ["parties", "payment terms", "obligations", "termination clauses"] | |
try: | |
if not text or len(text.strip()) < 50: | |
ist = timezone('Asia/Kolkata') | |
current_date = datetime.now(ist).strftime('%Y-%m-%d') | |
return { | |
"full_summary": "No summary due to insufficient text", | |
"aspect_summaries": {asp: "Not extracted" for asp in aspects}, | |
"start_date": current_date, | |
"end_date": current_date | |
}, None | |
text = clean_text(text)[:512] | |
aspect_summaries = {} | |
for asp in aspects: | |
if asp == "parties": | |
match = re.search(r'(?:parties|between)\s+([A-Za-z\s&]+?)(?:\sand|\,|\.)', text, re.IGNORECASE) | |
aspect_summaries[asp] = match.group(1).strip()[:100] if match else "Not extracted" | |
elif asp == "payment terms": | |
match = re.search(r'(?:payment|terms)\s+([\d,.]+\s*(?:EUR|USD|INR)\s*(?:monthly|annually|quarterly))', text, re.IGNORECASE) | |
aspect_summaries[asp] = match.group(1)[:100] if match else "Not extracted" | |
elif asp == "obligations": | |
match = re.search(r'(?:obligations|services|duties)\s+(.+?)(?:\by|\,|\.)', text, re.IGNORECASE) | |
aspect_summaries[asp] = match.group(1).strip()[:100] if match else "Not extracted" | |
elif asp == "termination clauses": | |
match = re.search(r'(?:termination|notice)\s+(\d+\s*days\'?\s*notice)', text, re.IGNORECASE) | |
aspect_summaries[asp] = match.group(1)[:100] if match else "Not extracted" | |
# Custom summary template | |
parties = aspect_summaries.get("parties", "Not extracted") | |
obligations = aspect_summaries.get("obligations", "Not extracted") | |
full_summary = f"Logistics agreement between {parties} for {obligations}..." if parties != "Not extracted" and obligations != "Not extracted" else text[:60] + "..." | |
logger.info(f"Final summary: {full_summary}") | |
start_date, end_date = parse_dates(text) | |
return { | |
"full_summary": full_summary, | |
"aspect_summaries": aspect_summaries, | |
"start_date": start_date, | |
"end_date": end_date | |
}, None | |
except Exception as e: | |
logger.error(f"Summarization failed: {str(e)} with text {text[:50]}...") | |
ist = timezone('Asia/Kolkata') | |
current_date = datetime.now(ist).strftime('%Y-%m-%d') | |
return { | |
"full_summary": text[:60] + "..." if len(text) > 60 else text, | |
"aspect_summaries": {asp: "Not extracted" for asp in aspects}, | |
"start_date": current_date, | |
"end_date": current_date | |
}, f"Summarization error: {str(e)}" | |
# Create Contract Document (async) | |
async def create_contract_document(sf, file_name): | |
ist = timezone('Asia/Kolkata') | |
current_time = datetime.now(ist).strftime('%Y-%m-%dT%H:%M:%SZ') | |
try: | |
escaped_file_name = file_name.replace("'", "\\'") | |
query = f"SELECT Id FROM Contract_Document__c WHERE Name = '{escaped_file_name}' LIMIT 1" | |
result = await asyncio.get_event_loop().run_in_executor(None, sf.query, query) | |
if result['totalSize'] > 0: | |
return result['records'][0]['Id'], None | |
record = { | |
'Name': file_name, | |
'Document_URL__c': '', | |
'Upload_Date__c': current_time, | |
'Status__c': 'Uploaded' | |
} | |
result = await asyncio.get_event_loop().run_in_executor(None, sf.Contract_Document__c.create, record) | |
return result['id'], None | |
except Exception as e: | |
logger.error(f"Contract document creation failed: {str(e)}") | |
return None, f"Contract document creation failed: {str(e)}" | |
# Store summary in Salesforce (async) | |
async def store_in_salesforce(sf, summary_data, file_name, contract_doc_id): | |
try: | |
if not contract_doc_id: | |
return None, "Contract document ID is missing" | |
query = f"SELECT Id FROM Contract_Summary__c WHERE Contract_Document__c = '{contract_doc_id}' LIMIT 1" | |
result = await asyncio.get_event_loop().run_in_executor(None, sf.query, query) | |
if result['totalSize'] > 0: | |
return {'id': result['records'][0]['Id']}, None | |
encrypted_summary = _fernet.encrypt(summary_data['full_summary'].encode()).decode() | |
def truncate(text, length=100): | |
return text[:length] if text else 'Not extracted' | |
record = { | |
'Name': file_name, | |
'Contract_Document__c': contract_doc_id, | |
'Parties__c': truncate(summary_data['aspect_summaries'].get('parties', 'Not extracted')), | |
'Payment_Terms__c': truncate(summary_data['aspect_summaries'].get('payment terms', 'Not extracted')), | |
'Obligations__c': truncate(summary_data['aspect_summaries'].get('obligations', 'Not extracted')), | |
'Termination_Clause__c': truncate(summary_data['aspect_summaries'].get('termination clauses', 'Not extracted')), | |
'Custom_Field_1__c': encrypted_summary, | |
'Validation_Status__c': 'Pending', | |
'Start_Date__c': summary_data['start_date'][:10], | |
'End_Date__c': summary_data['end_date'][:10], | |
} | |
result = await asyncio.get_event_loop().run_in_executor(None, sf.Contract_Summary__c.create, record) | |
return result, None | |
except Exception as e: | |
logger.error(f"Store summary failed: {str(e)}") | |
return None, f"Store summary failed: {str(e)}" | |
# Generate CSV report (async) | |
async def generate_report(sf, output_file, contract_doc_id): | |
try: | |
if not contract_doc_id: | |
return pd.DataFrame(columns=['Field', 'Value']), "Contract document ID is missing" | |
query = ( | |
f"SELECT Id, Name, Parties__c, Payment_Terms__c, Obligations__c, Termination_Clause__c, Custom_Field_1__c, " | |
f"Validation_Status__c, Start_Date__c, End_Date__c " | |
f"FROM Contract_Summary__c WHERE Contract_Document__c = '{contract_doc_id}' LIMIT 1" | |
) | |
results = (await asyncio.get_event_loop().run_in_executor(None, sf.query, query))['records'] | |
rows = [] | |
for r in results: | |
decrypted_summary = _fernet.decrypt(r.get('Custom_Field_1__c', '').encode()).decode() if r.get('Custom_Field_1__c') else 'Not extracted' | |
fields = [ | |
('Contract Name', r.get('Name', 'Not extracted')), | |
('Parties', r.get('Parties__c', 'Not extracted')[:100]), | |
('Payment Terms', r.get('Payment_Terms__c', 'Not extracted')[:100]), | |
('Obligations', r.get('Obligations__c', 'Not extracted')[:100]), | |
('Termination Clause', r.get('Termination_Clause__c', 'Not extracted')[:100]), | |
('Full Summary', decrypted_summary[:100]), | |
('Validation Status', r.get('Validation_Status__c', 'Not extracted')), | |
('Start Date', r.get('Start_Date__c', 'Not extracted')), | |
('End Date', r.get('End_Date__c', 'Not extracted')), | |
] | |
rows.extend(fields) | |
# Create DataFrame without the "Summary Report" header row | |
df = pd.DataFrame(rows, columns=['Field', 'Value']) if rows else pd.DataFrame(columns=['Field', 'Value']) | |
df.to_csv(output_file, index=False, encoding='utf-8') | |
return df, output_file | |
except Exception as e: | |
logger.error(f"Report generation failed: {str(e)}") | |
return pd.DataFrame(columns=['Field', 'Value']), f"Report generation failed: {str(e)}" | |
# Gradio interface function (async) | |
async def gradio_process_async(file, progress=gr.Progress()): | |
try: | |
if not file: | |
return pd.DataFrame(columns=['Field', 'Value']), None | |
file_path = file.name if hasattr(file, 'name') else file | |
file_name = os.path.basename(file_path) | |
progress(0.1, desc="Validating...") | |
is_valid, error = validate_file(file_path) | |
if not is_valid: | |
return pd.DataFrame(columns=['Field', 'Value']), None | |
progress(0.2, desc="Extracting text...") | |
text, error = await extract_text_async(file_path) | |
if error: | |
return pd.DataFrame(columns['Field', 'Value']), None | |
progress(0.4, desc="Initializing...") | |
sf = await init_salesforce() | |
progress(0.5, desc="Summarizing...") | |
summary_data, err = await summarize_contract_async(text, _summarizer, file_name) | |
if err: | |
return pd.DataFrame(columns['Field', 'Value']), None | |
progress(0.7, desc="Storing in Salesforce...") | |
contract_doc_id, err = await create_contract_document(sf, file_name) | |
if err or not contract_doc_id: | |
return pd.DataFrame(columns['Field', 'Value']), None | |
store_result, err = await store_in_salesforce(sf, summary_data, file_name, contract_doc_id) | |
if err: | |
return pd.DataFrame(columns['Field', 'Value']), None | |
progress(0.9, desc="Generating report...") | |
csv_path = os.path.join(tempfile.gettempdir(), f"contract_summary_{file_name}.csv") | |
report_df, csv_path = await generate_report(sf, csv_path, contract_doc_id) | |
if not csv_path: | |
return pd.DataFrame(columns['Field', 'Value']), None | |
progress(1.0, desc="Complete!") | |
return report_df, csv_path | |
except Exception as e: | |
logger.error(f"Processing error for {file_name if 'file_name' in locals() else 'file'}: {str(e)} at {datetime.now(timezone('Asia/Kolkata')).strftime('%H:%M:%S %Y-%m-%d')}") | |
return pd.DataFrame(columns['Field', 'Value']), None | |
# Gradio UI setup | |
with gr.Blocks(theme="soft", css=""" | |
.gr-button { | |
background-color: #6A5ACD; | |
color: #6A5ACD; | |
font-weight: bold; | |
font-size: 16px; | |
border: none; | |
padding: 5px 20px; | |
} | |
.gr-button:hover { | |
background-color: #5A4ABF; | |
color: #6A5ACD; | |
font-weight: bold; | |
font-size: 16px; | |
} | |
.gr-label { | |
color: #6A5ACD; | |
font-weight: bold; | |
font-size: 16px; | |
background-color: #F0F0FF; | |
padding: 5px; | |
} | |
.gr-textbox { border: 1px solid #6A5ACD; background-color: white; } | |
.gr-file { border: 1px solid #6A5ACD; background-color: white; } | |
.gr-dataframe { | |
border: 1px solid #6A5ACD; | |
background-color: white; | |
} | |
.gr-dataframe td, .gr-dataframe th { | |
color: #6A5ACD; | |
font-weight: bold; | |
font-size: 16px; | |
} | |
#summary-report-label { | |
color: #6A5ACD; | |
font-weight: bold; | |
font-size: 16px; | |
background-color: #F0F0FF; | |
padding: 5px; | |
} | |
.gr-dataframe tr:first-child td { | |
background-color: #F0F0FF; | |
color: #6A5ACD; | |
font-weight: bold; | |
font-size: 16px; | |
padding: 5px; | |
} | |
""") as iface: | |
file_input = gr.File(label="Upload Contract (PDF, DOCX, CSV)") | |
submit_btn = gr.Button("Submit") | |
report_output = gr.DataFrame(label="Summary Report", headers=['Field', 'Value'], interactive=False, elem_id="summary-report") | |
csv_output = gr.File(label="Download CSV") | |
submit_btn.click( | |
fn=gradio_process_async, | |
inputs=[file_input], | |
outputs=[report_output, csv_output] | |
) | |
if __name__ == "__main__": | |
logger.info("Application startup at %s", datetime.now(timezone('Asia/Kolkata')).strftime('%H:%M:%S %Y-%m-%d')) | |
iface.launch() |