|
|
from fastapi import FastAPI, File, UploadFile, Form, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import torch |
|
|
from transformers import DonutProcessor, VisionEncoderDecoderModel, pipeline |
|
|
from PIL import Image, ImageOps, ImageEnhance |
|
|
import easyocr |
|
|
import pytesseract |
|
|
import numpy as np |
|
|
import os, re |
|
|
import requests |
|
|
|
|
|
app = FastAPI() |
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
donut_processor = DonutProcessor.from_pretrained("chinmays18/medical-prescription-ocr") |
|
|
donut_model = VisionEncoderDecoderModel.from_pretrained("chinmays18/medical-prescription-ocr").to(device) |
|
|
reader = easyocr.Reader(['en']) |
|
|
humadex_pipe = pipeline("token-classification", model="HUMADEX/english_medical_ner", aggregation_strategy="simple") |
|
|
medner_pipe = pipeline("token-classification", model="blaze999/Medical-NER", aggregation_strategy="simple") |
|
|
biogpt_pipe = pipeline("text-generation", model="microsoft/BioGPT-Large-PubMedQA") |
|
|
|
|
|
def advanced_preprocess(image_path): |
|
|
img = Image.open(image_path).convert('L') |
|
|
img = img.resize((960,1280)) |
|
|
img = ImageOps.autocontrast(img) |
|
|
img = ImageEnhance.Contrast(img).enhance(2) |
|
|
npimg = np.array(img) |
|
|
npimg = np.where(npimg < 128, 0, 255).astype(np.uint8) |
|
|
bin_img = Image.fromarray(npimg) |
|
|
return bin_img.convert("RGB") |
|
|
|
|
|
def clean_text(text): |
|
|
|
|
|
text = re.sub(r'<[^>]+>', '', text) |
|
|
|
|
|
text = re.sub(r'[▃▅▇█━—]+', '', text) |
|
|
|
|
|
text = re.sub(r'[•◆■★●]', '', text) |
|
|
|
|
|
text = text.replace('\n', ' ').replace('\f', ' ') |
|
|
|
|
|
text = re.sub(r'[^A-Za-z0-9\s\-/\(\)\.,:%]', '', text) |
|
|
text = re.sub(' +', ' ', text) |
|
|
|
|
|
text = re.sub(r'(\d)\s+(\d)', r'\1\2', text) |
|
|
return text.strip() |
|
|
|
|
|
def extract_drugs_and_dose(text): |
|
|
drugs = re.findall(r'(SYP|TAB|CAP|SYRUP|INJECTION|DROPS|INHALER|MEFTAL[- ]?P|CALPOL|DELCON|LEVOLIN)[\w\-\/\(\)]*', text, re.I) |
|
|
doses = re.findall(r'\d+(\.\d+)?\s*(ml|mg|g|mcg|tablet|cap|puff|dose|drops)', text, re.I) |
|
|
frequency = re.findall(r'(qc[h]?|q6h|tds|t[.]?d[.]?s[.]?|qds|b[.]?d[.]?|bd|sos|daily|once|twice|x\s*\d+d)', text, re.I) |
|
|
doses = set([d[0]+d[1] if d[0] else d[1] for d in doses]) |
|
|
return set(drugs), doses, set(frequency) |
|
|
|
|
|
@app.post("/api/prescription") |
|
|
async def prescription(file: UploadFile = File(...)): |
|
|
filepath = f"temp_{file.filename}" |
|
|
with open(filepath, "wb") as f: |
|
|
f.write(await file.read()) |
|
|
img = advanced_preprocess(filepath) |
|
|
|
|
|
pixel_values = donut_processor(images=img, return_tensors="pt").pixel_values.to(device) |
|
|
task_prompt = "<s_ocr>" |
|
|
decoder_input_ids = donut_processor.tokenizer(task_prompt, return_tensors="pt").input_ids.to(device) |
|
|
generated_ids = donut_model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=512) |
|
|
donut_text = donut_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
|
easy_text = "\n".join([t[1] for t in reader.readtext(filepath)]) |
|
|
tess_text = pytesseract.image_to_string(img) |
|
|
texts = [donut_text, easy_text, tess_text] |
|
|
best_text = max(texts, key=lambda t: len(set(t.strip().split()))) |
|
|
cleaned = clean_text(best_text) |
|
|
humadex_ents = humadex_pipe(cleaned) |
|
|
medner_ents = medner_pipe(cleaned) |
|
|
regex_drugs, regex_doses, regex_freqs = extract_drugs_and_dose(cleaned) |
|
|
out_drugs = set([ent.get('word','') for ent in humadex_ents if 'DRUG' in ent.get('entity_group','').upper()]) | \ |
|
|
set([ent.get('word','') for ent in medner_ents if 'DRUG' in ent.get('entity_group','').upper()]) | regex_drugs |
|
|
out_doses = set([ent.get('word','') for ent in humadex_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | \ |
|
|
set([ent.get('word','') for ent in medner_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | regex_doses |
|
|
out_freqs = set([ent.get('word','') for ent in humadex_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | \ |
|
|
set([ent.get('word','') for ent in medner_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | regex_freqs |
|
|
|
|
|
os.remove(filepath) |
|
|
|
|
|
return { |
|
|
"ocr_text": cleaned, |
|
|
"drugs": list(out_drugs), |
|
|
"doses": list(out_doses), |
|
|
"frequencies": list(out_freqs), |
|
|
} |
|
|
|
|
|
@app.post("/api/chat") |
|
|
async def chat(message: str = Form(...)): |
|
|
|
|
|
result = biogpt_pipe(message, max_new_tokens=400)[0]["generated_text"] |
|
|
|
|
|
|
|
|
result = clean_text(result) |
|
|
|
|
|
return {"response": result} |
|
|
|
|
|
D_ID_API_KEY = os.environ.get("D_ID_API_KEY") |
|
|
|
|
|
@app.post("/api/did_talk") |
|
|
async def did_talk(request: Request): |
|
|
body = await request.json() |
|
|
text = body["text"] |
|
|
image_url = body["image_url"] |
|
|
|
|
|
|
|
|
text = clean_text(text) |
|
|
text = text[:500] |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Basic {D_ID_API_KEY}:", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
payload = { |
|
|
"script": { |
|
|
"type": "text", |
|
|
"input": text, |
|
|
"provider": {"type": "microsoft", "voice_id": "en-US-GuyNeural"} |
|
|
}, |
|
|
"config": {"result_format": "mp4"}, |
|
|
"source_url": image_url |
|
|
} |
|
|
|
|
|
resp = requests.post("https://api.d-id.com/talks", headers=headers, json=payload) |
|
|
print(f"D-ID create response: {resp.status_code} - {resp.text}") |
|
|
if not resp.ok: |
|
|
return {"error": "Failed to create D-ID talk", "details": resp.text} |
|
|
talk_id = resp.json()["id"] |
|
|
print(f"D-ID talk_id: {talk_id}") |
|
|
|
|
|
for i in range(20): |
|
|
import time; time.sleep(3) |
|
|
status = requests.get(f"https://api.d-id.com/talks/{talk_id}", headers=headers) |
|
|
status_data = status.json() |
|
|
print(f"Poll #{i+1}: {status_data}") |
|
|
if "result_url" in status_data and status_data["result_url"]: |
|
|
print(f"Video ready: {status_data['result_url']}") |
|
|
return {"video_url": status_data["result_url"]} |
|
|
|
|
|
print("Video generation timed out") |
|
|
return {"error": "Timed out or video not ready"} |
|
|
|