File size: 6,640 Bytes
12f2295
 
 
 
 
 
c57fdf3
 
12f2295
 
 
 
 
 
c57fdf3
12f2295
 
 
 
 
 
 
c57fdf3
 
12f2295
 
c57fdf3
12f2295
c57fdf3
12f2295
c57fdf3
 
12f2295
c57fdf3
 
 
12f2295
c57fdf3
 
12f2295
c57fdf3
12f2295
 
 
 
 
 
 
c57fdf3
12f2295
c57fdf3
 
12f2295
 
 
 
c57fdf3
 
12f2295
c57fdf3
12f2295
 
c57fdf3
 
12f2295
 
c57fdf3
12f2295
 
 
c57fdf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f2295
c57fdf3
 
 
 
 
 
 
 
 
 
 
 
12f2295
c57fdf3
 
 
12f2295
 
c57fdf3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import re
import json
import logging
import tempfile
import torch
import streamlit as st
from typing import Dict, Any, List
from PIL import Image

# Suppress verbose backend logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

try:
    from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel
    from huggingface_hub import login
    import google.generativeai as genai
except ImportError as e:
    logging.error(f"Failed to import critical AI libraries: {e}")
    raise

from config import (
    HF_TOKEN, HF_MODELS, GOOGLE_API_KEY,
    GOOGLE_APPLICATION_CREDENTIALS, GEMINI_MODEL_NAME, DEVICE, USE_GPU
)

# Configure Logging & Auth
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
if HF_TOKEN: login(token=HF_TOKEN)

class PrescriptionProcessor:
    """Encapsulates the entire hybrid pipeline to resolve the 'self' error."""
    def __init__(self):
        self.model_cache = {}
        self.temp_cred_file = None
        self._load_all_models()

    def _load_model(self, name: str):
        if name in self.model_cache: return
        model_id = HF_MODELS.get(name)
        if not model_id: return

        logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...")
        try:
            quantization_config = {"load_in_8bit": True} if USE_GPU else {}
            if name == "donut":
                processor = DonutProcessor.from_pretrained(model_id)
                model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config)
                self.model_cache[name] = {"model": model, "processor": processor}
            elif name == "phi3":
                model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, **quantization_config)
                self.model_cache[name] = {"model": model}
            logging.info(f"Model '{name}' loaded successfully.")
        except Exception as e:
            logging.error(f"Failed to load model '{name}': {e}", exc_info=True)

    def _load_gemini_client(self):
        if "gemini" in self.model_cache: return
        if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS:
            if not self.temp_cred_file or not os.path.exists(self.temp_cred_file):
                with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp:
                    tfp.write(creds_json_str)
                    self.temp_cred_file = tfp.name
                os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.temp_cred_file
        try:
            genai.configure(api_key=GOOGLE_API_KEY)
            self.model_cache["gemini"] = genai.GenerativeModel(GEMINI_MODEL_NAME)
        except Exception as e:
            logging.error(f"Gemini init failed: {e}")

    def _load_all_models(self):
        self._load_model("donut")
        self._load_model("phi3")
        self._load_gemini_client()

    def _run_donut(self, image: Image.Image) -> Dict[str, Any]:
        components = self.model_cache.get("donut")
        if not components: return {"error": "Donut model not available."}
        model, processor = components["model"].to(DEVICE), components["processor"]
        task_prompt = "<s_cord-v2>"
        decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE)
        pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE)
        outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True)
        sequence = processor.batch_decode(outputs.sequences)[0].replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        return processor.token2json(sequence)

    def _run_phi3(self, medication_text: str) -> str:
        components = self.model_cache.get("phi3")
        if not components: return medication_text
        pipe = components["model"]
        prompt = f"Normalize the following prescription medication line into its components (drug, dosage, frequency). Raw text: '{medication_text}'"
        outputs = pipe(prompt, max_new_tokens=100, do_sample=False)
        return outputs[0]['generated_text'].split("Normalized:")[-1].strip()

    def _run_gemini_resolver(self, image: Image.Image, donut_result: Dict, phi3_results: List[str]) -> Dict[str, Any]:
        gemini_client = self.model_cache.get("gemini")
        if not gemini_client: return {"error": "Gemini resolver not available."}
        prompt = f"""
        You are an expert pharmacist’s assistant...
        (Your detailed prompt from the previous turn goes here)
        ...
        **Final JSON Schema**
        ```json
        {{
          "Name": "string or null", "Date": "string (MM/DD/YYYY) or null", "Age": "string or null", "PhysicianName": "string or null",
          "Medications": [{{"drug_raw": "string", "dosage": "string or null", "frequency": "string or null"}}]
        }}
        ```
        """
        try:
            response = gemini_client.generate_content([prompt, image], generation_config={"response_mime_type": "application/json"})
            return json.loads(response.text)
        except Exception as e:
            logging.error(f"Gemini resolver failed: {e}")
            # This is where your original error was being generated from
            return {"error": f"Gemini failed to resolve data: {e}"}

    def process(self, image_path: str) -> Dict[str, Any]:
        try:
            image = Image.open(image_path).convert("RGB")
            donut_data = self._run_donut(image)
            med_lines = [item.get('text', '') for item in donut_data.get('menu', []) if 'medi' in item.get('category', '').lower()]
            phi3_refined_meds = [self._run_phi3(line) for line in med_lines]
            final_info = self._run_gemini_resolver(image, donut_data, phi3_refined_meds)
            if final_info.get("error"): return final_info
            return {"info": final_info, "error": None, "debug_info": {"donut_output": donut_data, "phi3_refinements": phi3_refined_meds}}
        except Exception as e:
            logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True)
            return {"error": f"An unexpected error occurred in the pipeline: {e}"}

@st.cache_resource
def get_processor():
    return PrescriptionProcessor()

def extract_prescription_info(image_path: str) -> Dict[str, Any]:
    processor = get_processor()
    return processor.process(image_path)