Spaces:
Paused
Paused
import os | |
import re | |
import json | |
import logging | |
import tempfile | |
import torch | |
import streamlit as st | |
from typing import Dict, Any, List | |
from PIL import Image | |
# Suppress verbose backend logs | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
try: | |
from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel | |
from huggingface_hub import login | |
import google.generativeai as genai | |
except ImportError as e: | |
logging.error(f"Failed to import critical AI libraries: {e}") | |
raise | |
from config import ( | |
HF_TOKEN, HF_MODELS, GOOGLE_API_KEY, | |
GOOGLE_APPLICATION_CREDENTIALS, GEMINI_MODEL_NAME, DEVICE, USE_GPU | |
) | |
# Configure Logging & Auth | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
if HF_TOKEN: login(token=HF_TOKEN) | |
class PrescriptionProcessor: | |
"""Encapsulates the entire hybrid pipeline to resolve the 'self' error.""" | |
def __init__(self): | |
self.model_cache = {} | |
self.temp_cred_file = None | |
self._load_all_models() | |
def _load_model(self, name: str): | |
if name in self.model_cache: return | |
model_id = HF_MODELS.get(name) | |
if not model_id: return | |
logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...") | |
try: | |
quantization_config = {"load_in_8bit": True} if USE_GPU else {} | |
if name == "donut": | |
processor = DonutProcessor.from_pretrained(model_id) | |
model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config) | |
self.model_cache[name] = {"model": model, "processor": processor} | |
elif name == "phi3": | |
model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, **quantization_config) | |
self.model_cache[name] = {"model": model} | |
logging.info(f"Model '{name}' loaded successfully.") | |
except Exception as e: | |
logging.error(f"Failed to load model '{name}': {e}", exc_info=True) | |
def _load_gemini_client(self): | |
if "gemini" in self.model_cache: return | |
if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS: | |
if not self.temp_cred_file or not os.path.exists(self.temp_cred_file): | |
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp: | |
tfp.write(creds_json_str) | |
self.temp_cred_file = tfp.name | |
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.temp_cred_file | |
try: | |
genai.configure(api_key=GOOGLE_API_KEY) | |
self.model_cache["gemini"] = genai.GenerativeModel(GEMINI_MODEL_NAME) | |
except Exception as e: | |
logging.error(f"Gemini init failed: {e}") | |
def _load_all_models(self): | |
self._load_model("donut") | |
self._load_model("phi3") | |
self._load_gemini_client() | |
def _run_donut(self, image: Image.Image) -> Dict[str, Any]: | |
components = self.model_cache.get("donut") | |
if not components: return {"error": "Donut model not available."} | |
model, processor = components["model"].to(DEVICE), components["processor"] | |
task_prompt = "<s_cord-v2>" | |
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE) | |
pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE) | |
outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True) | |
sequence = processor.batch_decode(outputs.sequences)[0].replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") | |
return processor.token2json(sequence) | |
def _run_phi3(self, medication_text: str) -> str: | |
components = self.model_cache.get("phi3") | |
if not components: return medication_text | |
pipe = components["model"] | |
prompt = f"Normalize the following prescription medication line into its components (drug, dosage, frequency). Raw text: '{medication_text}'" | |
outputs = pipe(prompt, max_new_tokens=100, do_sample=False) | |
return outputs[0]['generated_text'].split("Normalized:")[-1].strip() | |
def _run_gemini_resolver(self, image: Image.Image, donut_result: Dict, phi3_results: List[str]) -> Dict[str, Any]: | |
gemini_client = self.model_cache.get("gemini") | |
if not gemini_client: return {"error": "Gemini resolver not available."} | |
prompt = f""" | |
You are an expert pharmacist’s assistant... | |
(Your detailed prompt from the previous turn goes here) | |
... | |
**Final JSON Schema** | |
```json | |
{{ | |
"Name": "string or null", "Date": "string (MM/DD/YYYY) or null", "Age": "string or null", "PhysicianName": "string or null", | |
"Medications": [{{"drug_raw": "string", "dosage": "string or null", "frequency": "string or null"}}] | |
}} | |
``` | |
""" | |
try: | |
response = gemini_client.generate_content([prompt, image], generation_config={"response_mime_type": "application/json"}) | |
return json.loads(response.text) | |
except Exception as e: | |
logging.error(f"Gemini resolver failed: {e}") | |
# This is where your original error was being generated from | |
return {"error": f"Gemini failed to resolve data: {e}"} | |
def process(self, image_path: str) -> Dict[str, Any]: | |
try: | |
image = Image.open(image_path).convert("RGB") | |
donut_data = self._run_donut(image) | |
med_lines = [item.get('text', '') for item in donut_data.get('menu', []) if 'medi' in item.get('category', '').lower()] | |
phi3_refined_meds = [self._run_phi3(line) for line in med_lines] | |
final_info = self._run_gemini_resolver(image, donut_data, phi3_refined_meds) | |
if final_info.get("error"): return final_info | |
return {"info": final_info, "error": None, "debug_info": {"donut_output": donut_data, "phi3_refinements": phi3_refined_meds}} | |
except Exception as e: | |
logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True) | |
return {"error": f"An unexpected error occurred in the pipeline: {e}"} | |
def get_processor(): | |
return PrescriptionProcessor() | |
def extract_prescription_info(image_path: str) -> Dict[str, Any]: | |
processor = get_processor() | |
return processor.process(image_path) | |