shanis185 commited on
Commit
559e937
·
verified ·
1 Parent(s): 9d4af7d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +91 -0
  2. packages.txt +5 -0
  3. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from typing import List
4
+ import torch
5
+ from transformers import DonutProcessor, VisionEncoderDecoderModel, pipeline
6
+ from PIL import Image, ImageOps, ImageEnhance
7
+ import easyocr
8
+ import pytesseract
9
+ import numpy as np
10
+ import os, re
11
+
12
+ app = FastAPI()
13
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ donut_processor = DonutProcessor.from_pretrained("chinmays18/medical-prescription-ocr")
17
+ donut_model = VisionEncoderDecoderModel.from_pretrained("chinmays18/medical-prescription-ocr").to(device)
18
+ reader = easyocr.Reader(['en'])
19
+ humadex_pipe = pipeline("token-classification", model="HUMADEX/english_medical_ner", aggregation_strategy="simple")
20
+ medner_pipe = pipeline("token-classification", model="blaze999/Medical-NER", aggregation_strategy="simple")
21
+ biogpt_pipe = pipeline("text-generation", model="microsoft/BioGPT-Large-PubMedQA")
22
+
23
+ def advanced_preprocess(image_path):
24
+ img = Image.open(image_path).convert('L')
25
+ img = img.resize((960,1280))
26
+ img = ImageOps.autocontrast(img)
27
+ img = ImageEnhance.Contrast(img).enhance(2)
28
+ npimg = np.array(img)
29
+ npimg = np.where(npimg < 128, 0, 255).astype(np.uint8)
30
+ bin_img = Image.fromarray(npimg)
31
+ return bin_img.convert("RGB")
32
+
33
+ def clean_text(text):
34
+ text = text.replace('\n', ' ').replace('\f', ' ')
35
+ text = re.sub(r'[^A-Za-z0-9\s\-/\(\)\.,:]', '', text)
36
+ text = re.sub(' +', ' ', text)
37
+ return text.strip()
38
+
39
+ def extract_drugs_and_dose(text):
40
+ drugs = re.findall(r'(SYP|TAB|CAP|SYRUP|INJECTION|DROPS|INHALER|MEFTAL[- ]?P|CALPOL|DELCON|LEVOLIN)[\w\-\/\(\)]*', text, re.I)
41
+ doses = re.findall(r'\d+(\.\d+)?\s*(ml|mg|g|mcg|tablet|cap|puff|dose|drops)', text, re.I)
42
+ frequency = re.findall(r'(qc[h]?|q6h|tds|t[.]?d[.]?s[.]?|qds|b[.]?d[.]?|bd|sos|daily|once|twice|x\s*\d+d)', text, re.I)
43
+ doses = set([d[0]+d[1] if d[0] else d[1] for d in doses])
44
+ return set(drugs), doses, set(frequency)
45
+
46
+ @app.post("/api/prescription")
47
+ async def prescription(file: UploadFile = File(...)):
48
+ # Save and preprocess image
49
+ filepath = f"temp_{file.filename}"
50
+ with open(filepath, "wb") as f:
51
+ f.write(await file.read())
52
+ img = advanced_preprocess(filepath)
53
+
54
+ # OCR
55
+ pixel_values = donut_processor(images=img, return_tensors="pt").pixel_values.to(device)
56
+ task_prompt = "<s_ocr>"
57
+ decoder_input_ids = donut_processor.tokenizer(task_prompt, return_tensors="pt").input_ids.to(device)
58
+ generated_ids = donut_model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=512)
59
+ donut_text = donut_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
60
+ easy_text = "\n".join([t[1] for t in reader.readtext(filepath)])
61
+ tess_text = pytesseract.image_to_string(img)
62
+ texts = [donut_text, easy_text, tess_text]
63
+ best_text = max(texts, key=lambda t: len(set(t.strip().split())))
64
+ cleaned = clean_text(best_text)
65
+ humadex_ents = humadex_pipe(cleaned)
66
+ medner_ents = medner_pipe(cleaned)
67
+ regex_drugs, regex_doses, regex_freqs = extract_drugs_and_dose(cleaned)
68
+ out_drugs = set([ent.get('word','') for ent in humadex_ents if 'DRUG' in ent.get('entity_group','').upper()]) | \
69
+ set([ent.get('word','') for ent in medner_ents if 'DRUG' in ent.get('entity_group','').upper()]) | regex_drugs
70
+ out_doses = set([ent.get('word','') for ent in humadex_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | \
71
+ set([ent.get('word','') for ent in medner_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | regex_doses
72
+ out_freqs = set([ent.get('word','') for ent in humadex_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | \
73
+ set([ent.get('word','') for ent in medner_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | regex_freqs
74
+
75
+ # Clean up temp
76
+ os.remove(filepath)
77
+
78
+ return {
79
+ "ocr_text": cleaned,
80
+ "drugs": list(out_drugs),
81
+ "doses": list(out_doses),
82
+ "frequencies": list(out_freqs),
83
+ }
84
+
85
+ @app.post("/api/chat")
86
+ async def chat(message: str = Form(...)):
87
+ # Query BioGPT
88
+ result = biogpt_pipe(message, max_new_tokens=200)[0]["generated_text"]
89
+ return {"response": result}
90
+
91
+ # Optionally, add D-ID . . .
packages.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tesseract-ocr
2
+ libglib2.0-0
3
+ libsm6
4
+ libxext6
5
+ libxrender-dev
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ pillow
6
+ easyocr
7
+ pytesseract
8
+ scipy
9
+ huggingface-hub
10
+ python-multipart
11
+ sacremoses
12
+ protobuf