File size: 6,523 Bytes
3a9903e adaceb3 2902711 adaceb3 9116864 adaceb3 3a9903e e6ccaa8 adaceb3 e6ccaa8 adaceb3 e6ccaa8 adaceb3 a1a1b35 3a9903e adaceb3 3a9903e a5ff066 2902711 3a9903e 2902711 f8a17f6 2902711 3a9903e 2902711 3a9903e 2902711 a5ff066 2902711 3a9903e 2902711 3a9903e 2902711 a5ff066 3a9903e 2902711 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.middleware.cors import CORSMiddleware
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel, pipeline
from PIL import Image, ImageOps, ImageEnhance
import easyocr
import pytesseract
import numpy as np
import os, re
import requests
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
donut_processor = DonutProcessor.from_pretrained("chinmays18/medical-prescription-ocr")
donut_model = VisionEncoderDecoderModel.from_pretrained("chinmays18/medical-prescription-ocr").to(device)
reader = easyocr.Reader(['en'])
humadex_pipe = pipeline("token-classification", model="HUMADEX/english_medical_ner", aggregation_strategy="simple")
medner_pipe = pipeline("token-classification", model="blaze999/Medical-NER", aggregation_strategy="simple")
biogpt_pipe = pipeline("text-generation", model="microsoft/BioGPT-Large-PubMedQA")
def advanced_preprocess(image_path):
img = Image.open(image_path).convert('L')
img = img.resize((960,1280))
img = ImageOps.autocontrast(img)
img = ImageEnhance.Contrast(img).enhance(2)
npimg = np.array(img)
npimg = np.where(npimg < 128, 0, 255).astype(np.uint8)
bin_img = Image.fromarray(npimg)
return bin_img.convert("RGB")
def clean_text(text):
# Remove <...> tags, e.g., <TITLE>, < /FREETEXT >, < / ABSTRACT >
text = re.sub(r'<[^>]+>', '', text)
# Remove common bar/block Unicode, including standalone ones
text = re.sub(r'[▃▅▇█━—]+', '', text)
# Remove common markdown and odd punctuation
text = re.sub(r'[•◆■★●]', '', text)
# Collapse weird newlines/extra whitespace
text = text.replace('\n', ' ').replace('\f', ' ')
# Only keep useful punctuation and collapse spaces
text = re.sub(r'[^A-Za-z0-9\s\-/\(\)\.,:%]', '', text)
text = re.sub(' +', ' ', text)
# Remove spaces between digits
text = re.sub(r'(\d)\s+(\d)', r'\1\2', text)
return text.strip()
def extract_drugs_and_dose(text):
drugs = re.findall(r'(SYP|TAB|CAP|SYRUP|INJECTION|DROPS|INHALER|MEFTAL[- ]?P|CALPOL|DELCON|LEVOLIN)[\w\-\/\(\)]*', text, re.I)
doses = re.findall(r'\d+(\.\d+)?\s*(ml|mg|g|mcg|tablet|cap|puff|dose|drops)', text, re.I)
frequency = re.findall(r'(qc[h]?|q6h|tds|t[.]?d[.]?s[.]?|qds|b[.]?d[.]?|bd|sos|daily|once|twice|x\s*\d+d)', text, re.I)
doses = set([d[0]+d[1] if d[0] else d[1] for d in doses])
return set(drugs), doses, set(frequency)
@app.post("/api/prescription")
async def prescription(file: UploadFile = File(...)):
filepath = f"temp_{file.filename}"
with open(filepath, "wb") as f:
f.write(await file.read())
img = advanced_preprocess(filepath)
pixel_values = donut_processor(images=img, return_tensors="pt").pixel_values.to(device)
task_prompt = "<s_ocr>"
decoder_input_ids = donut_processor.tokenizer(task_prompt, return_tensors="pt").input_ids.to(device)
generated_ids = donut_model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=512)
donut_text = donut_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
easy_text = "\n".join([t[1] for t in reader.readtext(filepath)])
tess_text = pytesseract.image_to_string(img)
texts = [donut_text, easy_text, tess_text]
best_text = max(texts, key=lambda t: len(set(t.strip().split())))
cleaned = clean_text(best_text)
humadex_ents = humadex_pipe(cleaned)
medner_ents = medner_pipe(cleaned)
regex_drugs, regex_doses, regex_freqs = extract_drugs_and_dose(cleaned)
out_drugs = set([ent.get('word','') for ent in humadex_ents if 'DRUG' in ent.get('entity_group','').upper()]) | \
set([ent.get('word','') for ent in medner_ents if 'DRUG' in ent.get('entity_group','').upper()]) | regex_drugs
out_doses = set([ent.get('word','') for ent in humadex_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | \
set([ent.get('word','') for ent in medner_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | regex_doses
out_freqs = set([ent.get('word','') for ent in humadex_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | \
set([ent.get('word','') for ent in medner_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | regex_freqs
os.remove(filepath)
return {
"ocr_text": cleaned,
"drugs": list(out_drugs),
"doses": list(out_doses),
"frequencies": list(out_freqs),
}
@app.post("/api/chat")
async def chat(message: str = Form(...)):
# Query BioGPT
result = biogpt_pipe(message, max_new_tokens=400)[0]["generated_text"]
# ✅ CLEAN THE BIOGPT OUTPUT
result = clean_text(result)
return {"response": result}
D_ID_API_KEY = os.environ.get("D_ID_API_KEY")
@app.post("/api/did_talk")
async def did_talk(request: Request):
body = await request.json()
text = body["text"]
image_url = body["image_url"]
# ✅ CLEAN TEXT FOR D-ID TTS
text = clean_text(text)
text = text[:500] # Limit to 500 chars for TTS stability
headers = {
"Authorization": f"Basic {D_ID_API_KEY}:",
"Content-Type": "application/json"
}
payload = {
"script": {
"type": "text",
"input": text,
"provider": {"type": "microsoft", "voice_id": "en-US-GuyNeural"}
},
"config": {"result_format": "mp4"},
"source_url": image_url
}
resp = requests.post("https://api.d-id.com/talks", headers=headers, json=payload)
print(f"D-ID create response: {resp.status_code} - {resp.text}")
if not resp.ok:
return {"error": "Failed to create D-ID talk", "details": resp.text}
talk_id = resp.json()["id"]
print(f"D-ID talk_id: {talk_id}")
for i in range(20):
import time; time.sleep(3)
status = requests.get(f"https://api.d-id.com/talks/{talk_id}", headers=headers)
status_data = status.json()
print(f"Poll #{i+1}: {status_data}")
if "result_url" in status_data and status_data["result_url"]:
print(f"Video ready: {status_data['result_url']}")
return {"video_url": status_data["result_url"]}
print("Video generation timed out")
return {"error": "Timed out or video not ready"}
|