HexAI_Demo / validate_prescription.py
Ankit Thakur
everything
c57fdf3
import os
import re
import json
import logging
import tempfile
import torch
import streamlit as st
from typing import Dict, Any, List
from PIL import Image
# Suppress verbose backend logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
try:
from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel
from huggingface_hub import login
import google.generativeai as genai
except ImportError as e:
logging.error(f"Failed to import critical AI libraries: {e}")
raise
from config import (
HF_TOKEN, HF_MODELS, GOOGLE_API_KEY,
GOOGLE_APPLICATION_CREDENTIALS, GEMINI_MODEL_NAME, DEVICE, USE_GPU
)
# Configure Logging & Auth
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
if HF_TOKEN: login(token=HF_TOKEN)
class PrescriptionProcessor:
"""Encapsulates the entire hybrid pipeline to resolve the 'self' error."""
def __init__(self):
self.model_cache = {}
self.temp_cred_file = None
self._load_all_models()
def _load_model(self, name: str):
if name in self.model_cache: return
model_id = HF_MODELS.get(name)
if not model_id: return
logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...")
try:
quantization_config = {"load_in_8bit": True} if USE_GPU else {}
if name == "donut":
processor = DonutProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config)
self.model_cache[name] = {"model": model, "processor": processor}
elif name == "phi3":
model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, **quantization_config)
self.model_cache[name] = {"model": model}
logging.info(f"Model '{name}' loaded successfully.")
except Exception as e:
logging.error(f"Failed to load model '{name}': {e}", exc_info=True)
def _load_gemini_client(self):
if "gemini" in self.model_cache: return
if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS:
if not self.temp_cred_file or not os.path.exists(self.temp_cred_file):
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp:
tfp.write(creds_json_str)
self.temp_cred_file = tfp.name
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.temp_cred_file
try:
genai.configure(api_key=GOOGLE_API_KEY)
self.model_cache["gemini"] = genai.GenerativeModel(GEMINI_MODEL_NAME)
except Exception as e:
logging.error(f"Gemini init failed: {e}")
def _load_all_models(self):
self._load_model("donut")
self._load_model("phi3")
self._load_gemini_client()
def _run_donut(self, image: Image.Image) -> Dict[str, Any]:
components = self.model_cache.get("donut")
if not components: return {"error": "Donut model not available."}
model, processor = components["model"].to(DEVICE), components["processor"]
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE)
pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE)
outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True)
sequence = processor.batch_decode(outputs.sequences)[0].replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
return processor.token2json(sequence)
def _run_phi3(self, medication_text: str) -> str:
components = self.model_cache.get("phi3")
if not components: return medication_text
pipe = components["model"]
prompt = f"Normalize the following prescription medication line into its components (drug, dosage, frequency). Raw text: '{medication_text}'"
outputs = pipe(prompt, max_new_tokens=100, do_sample=False)
return outputs[0]['generated_text'].split("Normalized:")[-1].strip()
def _run_gemini_resolver(self, image: Image.Image, donut_result: Dict, phi3_results: List[str]) -> Dict[str, Any]:
gemini_client = self.model_cache.get("gemini")
if not gemini_client: return {"error": "Gemini resolver not available."}
prompt = f"""
You are an expert pharmacist’s assistant...
(Your detailed prompt from the previous turn goes here)
...
**Final JSON Schema**
```json
{{
"Name": "string or null", "Date": "string (MM/DD/YYYY) or null", "Age": "string or null", "PhysicianName": "string or null",
"Medications": [{{"drug_raw": "string", "dosage": "string or null", "frequency": "string or null"}}]
}}
```
"""
try:
response = gemini_client.generate_content([prompt, image], generation_config={"response_mime_type": "application/json"})
return json.loads(response.text)
except Exception as e:
logging.error(f"Gemini resolver failed: {e}")
# This is where your original error was being generated from
return {"error": f"Gemini failed to resolve data: {e}"}
def process(self, image_path: str) -> Dict[str, Any]:
try:
image = Image.open(image_path).convert("RGB")
donut_data = self._run_donut(image)
med_lines = [item.get('text', '') for item in donut_data.get('menu', []) if 'medi' in item.get('category', '').lower()]
phi3_refined_meds = [self._run_phi3(line) for line in med_lines]
final_info = self._run_gemini_resolver(image, donut_data, phi3_refined_meds)
if final_info.get("error"): return final_info
return {"info": final_info, "error": None, "debug_info": {"donut_output": donut_data, "phi3_refinements": phi3_refined_meds}}
except Exception as e:
logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True)
return {"error": f"An unexpected error occurred in the pipeline: {e}"}
@st.cache_resource
def get_processor():
return PrescriptionProcessor()
def extract_prescription_info(image_path: str) -> Dict[str, Any]:
processor = get_processor()
return processor.process(image_path)