Spaces:
Sleeping
Sleeping
| import atexit | |
| import functools | |
| import os | |
| import re | |
| import tempfile | |
| from queue import Queue | |
| from threading import Event, Thread | |
| import threading # Import threading | |
| from flask import Flask, request, jsonify | |
| from paddleocr import PaddleOCR | |
| from PIL import Image | |
| # --- NEW: Import the NLP analysis function --- | |
| from nlp_service import analyze_text # Corrected import | |
| # --- Configuration --- | |
| LANG = 'en' # Default language, can be overridden if needed | |
| NUM_WORKERS = 2 # Number of OCR worker threads | |
| # --- PaddleOCR Model Manager --- | |
| class PaddleOCRModelManager(object): | |
| def __init__(self, | |
| num_workers, | |
| model_factory): | |
| super().__init__() | |
| self._model_factory = model_factory | |
| self._queue = Queue() | |
| self._workers = [] | |
| self._model_initialized_event = Event() | |
| print(f"Initializing {num_workers} OCR worker(s)...") | |
| for i in range(num_workers): | |
| print(f"Starting worker {i+1}...") | |
| worker = Thread(target=self._worker, daemon=True) | |
| worker.start() | |
| self._model_initialized_event.wait() # Wait for this worker's model | |
| self._model_initialized_event.clear() | |
| self._workers.append(worker) | |
| print("All OCR workers initialized.") | |
| def infer(self, *args, **kwargs): | |
| result_queue = Queue(maxsize=1) | |
| self._queue.put((args, kwargs, result_queue)) | |
| success, payload = result_queue.get() | |
| if success: | |
| return payload | |
| else: | |
| print(f"Error during OCR inference: {payload}") | |
| raise payload | |
| def close(self): | |
| print("Shutting down OCR workers...") | |
| for _ in self._workers: | |
| self._queue.put(None) | |
| print("OCR worker shutdown signaled.") | |
| def _worker(self): | |
| print(f"Worker thread {threading.current_thread().name}: Loading PaddleOCR model ({LANG})...") | |
| try: | |
| model = self._model_factory() | |
| print(f"Worker thread {threading.current_thread().name}: Model loaded.") | |
| self._model_initialized_event.set() | |
| except Exception as e: | |
| print(f"FATAL: Worker thread {threading.current_thread().name} failed to load model: {e}") | |
| self._model_initialized_event.set() | |
| return | |
| while True: | |
| item = self._queue.get() | |
| if item is None: | |
| print(f"Worker thread {threading.current_thread().name}: Exiting.") | |
| break | |
| args, kwargs, result_queue = item | |
| try: | |
| result = model.ocr(*args, **kwargs) | |
| if result and result[0]: | |
| result_queue.put((True, result[0])) | |
| else: | |
| result_queue.put((True, [])) | |
| except Exception as e: | |
| print(f"Worker thread {threading.current_thread().name}: Error processing request: {e}") | |
| result_queue.put((False, e)) | |
| finally: | |
| self._queue.task_done() | |
| # --- Amount Extraction Logic --- | |
| def find_main_amount(ocr_results): | |
| if not ocr_results: | |
| return None | |
| amount_regex = re.compile(r'(?<!%)\b\d{1,3}(?:,?\d{3})*(?:\.\d{2})\b|\b\d+\.\d{2}\b|\b\d+\b(?!\.\d{1})') | |
| # Prioritized keywords | |
| priority_keywords = ['grand total', 'total amount', 'amount due', 'to pay', 'bill total', 'total payable'] | |
| secondary_keywords = ['total', 'balance', 'net amount', 'paid', 'charge', 'net total'] # Added 'net total' | |
| lower_priority_keywords = ['subtotal', 'sub total'] # Added 'sub total' | |
| parsed_lines = [] | |
| for i, line_info in enumerate(ocr_results): | |
| if not line_info or len(line_info) < 2 or len(line_info[1]) < 1: | |
| continue | |
| text = line_info[1][0].lower().strip() | |
| confidence = line_info[1][1] | |
| numbers_in_line = amount_regex.findall(text) | |
| float_numbers = [] | |
| for num_str in numbers_in_line: | |
| try: | |
| # Avoid converting year-like numbers if they stand alone on short lines | |
| if len(text) < 7 and '.' not in num_str and 1900 < int(num_str.replace(',', '')) < 2100: | |
| # More robust check: avoid if it's the only thing and looks like a year | |
| if len(numbers_in_line) == 1 and len(num_str) == 4: | |
| continue | |
| float_numbers.append(float(num_str.replace(',', ''))) | |
| except ValueError: | |
| continue | |
| # Check for keywords | |
| has_priority_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in priority_keywords) | |
| has_secondary_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in secondary_keywords) | |
| has_lower_priority_keyword = any(re.search(r'\b' + re.escape(kw) + r'\b', text) for kw in lower_priority_keywords) | |
| parsed_lines.append({ | |
| "index": i, | |
| "text": text, | |
| "numbers": float_numbers, | |
| "has_priority_keyword": has_priority_keyword, | |
| "has_secondary_keyword": has_secondary_keyword, | |
| "has_lower_priority_keyword": has_lower_priority_keyword, | |
| "confidence": confidence | |
| }) | |
| # --- Strategy to find the best candidate --- | |
| # 1. Look for numbers on the SAME line as PRIORITY keywords OR the line IMMEDIATELY AFTER | |
| priority_candidates = [] | |
| for i, line in enumerate(parsed_lines): | |
| if line["has_priority_keyword"]: | |
| if line["numbers"]: | |
| priority_candidates.extend(line["numbers"]) | |
| # Check next line if current line has no numbers and next line exists | |
| elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]: | |
| priority_candidates.extend(parsed_lines[i+1]["numbers"]) | |
| if priority_candidates: | |
| # Often the largest number on/near these lines is the final total | |
| return max(priority_candidates) | |
| # 2. Look for numbers on the SAME line as SECONDARY keywords OR the line IMMEDIATELY AFTER | |
| secondary_candidates = [] | |
| for i, line in enumerate(parsed_lines): | |
| if line["has_secondary_keyword"]: | |
| if line["numbers"]: | |
| secondary_candidates.extend(line["numbers"]) | |
| # Check next line if current line has no numbers and next line exists | |
| elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]: | |
| secondary_candidates.extend(parsed_lines[i+1]["numbers"]) | |
| if secondary_candidates: | |
| # If we only found secondary keywords, return the largest number found on/near those lines | |
| return max(secondary_candidates) | |
| # 3. Look near priority/secondary keywords (REMOVED - less reliable, covered by step 1 & 2) | |
| # 4. Look for numbers on the SAME line as LOWER PRIORITY keywords (Subtotal) OR the line IMMEDIATELY AFTER | |
| lower_priority_candidates = [] | |
| for i, line in enumerate(parsed_lines): | |
| if line["has_lower_priority_keyword"]: | |
| if line["numbers"]: | |
| lower_priority_candidates.extend(line["numbers"]) | |
| # Check next line if current line has no numbers and next line exists | |
| elif i + 1 < len(parsed_lines) and parsed_lines[i+1]["numbers"]: | |
| lower_priority_candidates.extend(parsed_lines[i+1]["numbers"]) | |
| # Don't return subtotal directly unless it's the only thing found later | |
| # 5. Fallback: Largest plausible number overall (excluding subtotals if other numbers exist) | |
| print("Warning: No numbers found on/near priority/secondary keyword lines. Using fallback.") | |
| all_numbers = [] | |
| # Use set comprehension for efficiency | |
| subtotal_numbers = {num for line in parsed_lines if line["has_lower_priority_keyword"] for num in line["numbers"]} | |
| # Also add numbers from the line after lower priority keywords to subtotals | |
| for i, line in enumerate(parsed_lines): | |
| if line["has_lower_priority_keyword"] and not line["numbers"] and i + 1 < len(parsed_lines): | |
| subtotal_numbers.update(parsed_lines[i+1]["numbers"]) | |
| for line in parsed_lines: | |
| all_numbers.extend(line["numbers"]) | |
| if all_numbers: | |
| unique_numbers = list(set(all_numbers)) | |
| # Filter out potential quantities/years/small irrelevant numbers | |
| plausible_numbers = [n for n in unique_numbers if n >= 0.01] # Keep small decimals too | |
| # Stricter filter for large numbers: exclude large integers (likely IDs, phone numbers) | |
| # Keep numbers < 50000 OR numbers that have a non-zero decimal part | |
| plausible_numbers = [n for n in plausible_numbers if n < 50000 or (n != int(n))] | |
| # If we have plausible numbers other than subtotals, prefer them | |
| non_subtotal_plausible = [n for n in plausible_numbers if n not in subtotal_numbers] | |
| if non_subtotal_plausible: | |
| return max(non_subtotal_plausible) | |
| elif plausible_numbers: # Only subtotals (or nothing else plausible) were found | |
| return max(plausible_numbers) # Return the largest subtotal/plausible as last resort | |
| # 6. If still nothing, return None | |
| print("Warning: Could not determine main amount.") | |
| return None | |
| # --- Flask App Setup --- | |
| app = Flask(__name__) | |
| # --- REMOVED: Register the NLP Blueprint --- | |
| # app.register_blueprint(nlp_bp) # No longer needed as we call the function directly | |
| # --- Initialize OCR Manager --- | |
| ocr_model_factory = functools.partial(PaddleOCR, lang=LANG, use_angle_cls=True, use_gpu=False, show_log=False) | |
| ocr_manager = PaddleOCRModelManager(num_workers=NUM_WORKERS, model_factory=ocr_model_factory) | |
| # Register cleanup function | |
| atexit.register(ocr_manager.close) | |
| # --- API Endpoint --- | |
| def extract_expense(): | |
| if 'file' not in request.files: | |
| return jsonify({"error": "No file part in the request"}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({"error": "No selected file"}), 400 | |
| if file: | |
| temp_file_path = None # Initialize variable | |
| try: | |
| # Save to a temporary file | |
| _, file_extension = os.path.splitext(file.filename) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: | |
| file.save(temp_file.name) | |
| temp_file_path = temp_file.name | |
| # Perform OCR | |
| ocr_result = ocr_manager.infer(temp_file_path, cls=True) | |
| # Process OCR results | |
| extracted_text = "" | |
| main_amount_ocr = None | |
| if ocr_result: | |
| extracted_lines = [line[1][0] for line in ocr_result if line and len(line) > 1 and len(line[1]) > 0] | |
| extracted_text = "\n".join(extracted_lines) | |
| main_amount_ocr = find_main_amount(ocr_result) # Keep OCR amount extraction | |
| # --- REMOVED: NLP Call --- | |
| # nlp_analysis_result = None | |
| # nlp_error = None | |
| # ... (removed NLP call logic) ... | |
| # --- End Removed NLP Call --- | |
| # Construct the response (only OCR results) | |
| response_data = { | |
| "type": "photo", | |
| "extracted_text": extracted_text, | |
| "main_amount_ocr": main_amount_ocr, # Amount found by OCR regex logic | |
| } | |
| return jsonify(response_data) | |
| except Exception as e: | |
| print(f"Error processing file: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500 | |
| finally: | |
| if temp_file_path and os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| return jsonify({"error": "File processing failed"}), 500 | |
| # --- NEW: NLP Message Endpoint --- | |
| def process_message(): | |
| data = request.get_json() | |
| if not data or 'text' not in data: | |
| return jsonify({"error": "Missing 'text' field in JSON payload"}), 400 | |
| text_message = data['text'] | |
| if not text_message: | |
| return jsonify({"error": "'text' field cannot be empty"}), 400 | |
| nlp_analysis_result = None | |
| nlp_error = None | |
| try: | |
| # Call the imported analysis function | |
| nlp_analysis_result = analyze_text(text_message) # Corrected function call | |
| print(f"NLP Service Analysis Result: {nlp_analysis_result}") | |
| # Check if the NLP analysis itself reported an error/failure or requires fallback | |
| status = nlp_analysis_result.get("status") | |
| if status == "failed": | |
| nlp_error = nlp_analysis_result.get("message", "NLP processing failed") | |
| # Return the failure result from NLP service | |
| return jsonify(nlp_analysis_result), 400 # Use 400 for client-side errors like empty text | |
| elif status == "fallback_required": | |
| # Return the fallback result (e.g., for queries) | |
| return jsonify(nlp_analysis_result), 200 # Return 200, but indicate fallback needed | |
| # Return the successful analysis result | |
| return jsonify(nlp_analysis_result) | |
| except Exception as nlp_e: | |
| nlp_error = f"Error calling NLP analysis function: {nlp_e}" | |
| print(f"Error calling NLP function: {nlp_error}") | |
| return jsonify({"error": "An internal error occurred during NLP processing", "details": nlp_error}), 500 | |
| # --- NEW: Health Check Endpoint --- | |
| def health_check(): | |
| # You could add more checks here (e.g., if OCR workers are alive) | |
| return jsonify({"status": "ok"}), 200 | |
| # --- Run the App --- | |
| if __name__ == '__main__': | |
| # Use port 7860 as expected by Hugging Face Spaces | |
| # Use host='0.0.0.0' for accessibility within Docker/Spaces | |
| app.run(host='0.0.0.0', port=7860, debug=False) |