Spaces:
Paused
Paused
File size: 6,640 Bytes
12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 12f2295 c57fdf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import re
import json
import logging
import tempfile
import torch
import streamlit as st
from typing import Dict, Any, List
from PIL import Image
# Suppress verbose backend logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
try:
from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel
from huggingface_hub import login
import google.generativeai as genai
except ImportError as e:
logging.error(f"Failed to import critical AI libraries: {e}")
raise
from config import (
HF_TOKEN, HF_MODELS, GOOGLE_API_KEY,
GOOGLE_APPLICATION_CREDENTIALS, GEMINI_MODEL_NAME, DEVICE, USE_GPU
)
# Configure Logging & Auth
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
if HF_TOKEN: login(token=HF_TOKEN)
class PrescriptionProcessor:
"""Encapsulates the entire hybrid pipeline to resolve the 'self' error."""
def __init__(self):
self.model_cache = {}
self.temp_cred_file = None
self._load_all_models()
def _load_model(self, name: str):
if name in self.model_cache: return
model_id = HF_MODELS.get(name)
if not model_id: return
logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...")
try:
quantization_config = {"load_in_8bit": True} if USE_GPU else {}
if name == "donut":
processor = DonutProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config)
self.model_cache[name] = {"model": model, "processor": processor}
elif name == "phi3":
model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, **quantization_config)
self.model_cache[name] = {"model": model}
logging.info(f"Model '{name}' loaded successfully.")
except Exception as e:
logging.error(f"Failed to load model '{name}': {e}", exc_info=True)
def _load_gemini_client(self):
if "gemini" in self.model_cache: return
if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS:
if not self.temp_cred_file or not os.path.exists(self.temp_cred_file):
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp:
tfp.write(creds_json_str)
self.temp_cred_file = tfp.name
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.temp_cred_file
try:
genai.configure(api_key=GOOGLE_API_KEY)
self.model_cache["gemini"] = genai.GenerativeModel(GEMINI_MODEL_NAME)
except Exception as e:
logging.error(f"Gemini init failed: {e}")
def _load_all_models(self):
self._load_model("donut")
self._load_model("phi3")
self._load_gemini_client()
def _run_donut(self, image: Image.Image) -> Dict[str, Any]:
components = self.model_cache.get("donut")
if not components: return {"error": "Donut model not available."}
model, processor = components["model"].to(DEVICE), components["processor"]
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE)
pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE)
outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True)
sequence = processor.batch_decode(outputs.sequences)[0].replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
return processor.token2json(sequence)
def _run_phi3(self, medication_text: str) -> str:
components = self.model_cache.get("phi3")
if not components: return medication_text
pipe = components["model"]
prompt = f"Normalize the following prescription medication line into its components (drug, dosage, frequency). Raw text: '{medication_text}'"
outputs = pipe(prompt, max_new_tokens=100, do_sample=False)
return outputs[0]['generated_text'].split("Normalized:")[-1].strip()
def _run_gemini_resolver(self, image: Image.Image, donut_result: Dict, phi3_results: List[str]) -> Dict[str, Any]:
gemini_client = self.model_cache.get("gemini")
if not gemini_client: return {"error": "Gemini resolver not available."}
prompt = f"""
You are an expert pharmacist’s assistant...
(Your detailed prompt from the previous turn goes here)
...
**Final JSON Schema**
```json
{{
"Name": "string or null", "Date": "string (MM/DD/YYYY) or null", "Age": "string or null", "PhysicianName": "string or null",
"Medications": [{{"drug_raw": "string", "dosage": "string or null", "frequency": "string or null"}}]
}}
```
"""
try:
response = gemini_client.generate_content([prompt, image], generation_config={"response_mime_type": "application/json"})
return json.loads(response.text)
except Exception as e:
logging.error(f"Gemini resolver failed: {e}")
# This is where your original error was being generated from
return {"error": f"Gemini failed to resolve data: {e}"}
def process(self, image_path: str) -> Dict[str, Any]:
try:
image = Image.open(image_path).convert("RGB")
donut_data = self._run_donut(image)
med_lines = [item.get('text', '') for item in donut_data.get('menu', []) if 'medi' in item.get('category', '').lower()]
phi3_refined_meds = [self._run_phi3(line) for line in med_lines]
final_info = self._run_gemini_resolver(image, donut_data, phi3_refined_meds)
if final_info.get("error"): return final_info
return {"info": final_info, "error": None, "debug_info": {"donut_output": donut_data, "phi3_refinements": phi3_refined_meds}}
except Exception as e:
logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True)
return {"error": f"An unexpected error occurred in the pipeline: {e}"}
@st.cache_resource
def get_processor():
return PrescriptionProcessor()
def extract_prescription_info(image_path: str) -> Dict[str, Any]:
processor = get_processor()
return processor.process(image_path)
|