import os import re import json import logging import tempfile import torch import streamlit as st from typing import Dict, Any, List from PIL import Image # Suppress verbose backend logs os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" try: from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel from huggingface_hub import login import google.generativeai as genai except ImportError as e: logging.error(f"Failed to import critical AI libraries: {e}") raise from config import ( HF_TOKEN, HF_MODELS, GOOGLE_API_KEY, GOOGLE_APPLICATION_CREDENTIALS, GEMINI_MODEL_NAME, DEVICE, USE_GPU ) # Configure Logging & Auth logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") if HF_TOKEN: login(token=HF_TOKEN) class PrescriptionProcessor: """Encapsulates the entire hybrid pipeline to resolve the 'self' error.""" def __init__(self): self.model_cache = {} self.temp_cred_file = None self._load_all_models() def _load_model(self, name: str): if name in self.model_cache: return model_id = HF_MODELS.get(name) if not model_id: return logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...") try: quantization_config = {"load_in_8bit": True} if USE_GPU else {} if name == "donut": processor = DonutProcessor.from_pretrained(model_id) model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config) self.model_cache[name] = {"model": model, "processor": processor} elif name == "phi3": model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, **quantization_config) self.model_cache[name] = {"model": model} logging.info(f"Model '{name}' loaded successfully.") except Exception as e: logging.error(f"Failed to load model '{name}': {e}", exc_info=True) def _load_gemini_client(self): if "gemini" in self.model_cache: return if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS: if not self.temp_cred_file or not os.path.exists(self.temp_cred_file): with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp: tfp.write(creds_json_str) self.temp_cred_file = tfp.name os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.temp_cred_file try: genai.configure(api_key=GOOGLE_API_KEY) self.model_cache["gemini"] = genai.GenerativeModel(GEMINI_MODEL_NAME) except Exception as e: logging.error(f"Gemini init failed: {e}") def _load_all_models(self): self._load_model("donut") self._load_model("phi3") self._load_gemini_client() def _run_donut(self, image: Image.Image) -> Dict[str, Any]: components = self.model_cache.get("donut") if not components: return {"error": "Donut model not available."} model, processor = components["model"].to(DEVICE), components["processor"] task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE) pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE) outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True) sequence = processor.batch_decode(outputs.sequences)[0].replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") return processor.token2json(sequence) def _run_phi3(self, medication_text: str) -> str: components = self.model_cache.get("phi3") if not components: return medication_text pipe = components["model"] prompt = f"Normalize the following prescription medication line into its components (drug, dosage, frequency). Raw text: '{medication_text}'" outputs = pipe(prompt, max_new_tokens=100, do_sample=False) return outputs[0]['generated_text'].split("Normalized:")[-1].strip() def _run_gemini_resolver(self, image: Image.Image, donut_result: Dict, phi3_results: List[str]) -> Dict[str, Any]: gemini_client = self.model_cache.get("gemini") if not gemini_client: return {"error": "Gemini resolver not available."} prompt = f""" You are an expert pharmacist’s assistant... (Your detailed prompt from the previous turn goes here) ... **Final JSON Schema** ```json {{ "Name": "string or null", "Date": "string (MM/DD/YYYY) or null", "Age": "string or null", "PhysicianName": "string or null", "Medications": [{{"drug_raw": "string", "dosage": "string or null", "frequency": "string or null"}}] }} ``` """ try: response = gemini_client.generate_content([prompt, image], generation_config={"response_mime_type": "application/json"}) return json.loads(response.text) except Exception as e: logging.error(f"Gemini resolver failed: {e}") # This is where your original error was being generated from return {"error": f"Gemini failed to resolve data: {e}"} def process(self, image_path: str) -> Dict[str, Any]: try: image = Image.open(image_path).convert("RGB") donut_data = self._run_donut(image) med_lines = [item.get('text', '') for item in donut_data.get('menu', []) if 'medi' in item.get('category', '').lower()] phi3_refined_meds = [self._run_phi3(line) for line in med_lines] final_info = self._run_gemini_resolver(image, donut_data, phi3_refined_meds) if final_info.get("error"): return final_info return {"info": final_info, "error": None, "debug_info": {"donut_output": donut_data, "phi3_refinements": phi3_refined_meds}} except Exception as e: logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True) return {"error": f"An unexpected error occurred in the pipeline: {e}"} @st.cache_resource def get_processor(): return PrescriptionProcessor() def extract_prescription_info(image_path: str) -> Dict[str, Any]: processor = get_processor() return processor.process(image_path)