File size: 6,523 Bytes
3a9903e
adaceb3
 
 
 
 
 
 
 
2902711
 
adaceb3
 
 
 
 
 
 
 
 
9116864
adaceb3
 
 
 
 
 
 
 
 
 
 
 
3a9903e
e6ccaa8
 
 
 
 
 
adaceb3
e6ccaa8
adaceb3
 
e6ccaa8
 
adaceb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1a1b35
3a9903e
 
 
 
adaceb3
 
3a9903e
a5ff066
2902711
 
 
 
 
 
3a9903e
 
 
 
2902711
 
 
 
 
 
 
 
f8a17f6
2902711
 
 
 
 
 
3a9903e
2902711
 
 
3a9903e
2902711
a5ff066
2902711
 
 
3a9903e
2902711
3a9903e
2902711
a5ff066
3a9903e
2902711
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.middleware.cors import CORSMiddleware
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel, pipeline
from PIL import Image, ImageOps, ImageEnhance
import easyocr
import pytesseract
import numpy as np
import os, re
import requests

app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
donut_processor = DonutProcessor.from_pretrained("chinmays18/medical-prescription-ocr")
donut_model = VisionEncoderDecoderModel.from_pretrained("chinmays18/medical-prescription-ocr").to(device)
reader = easyocr.Reader(['en'])
humadex_pipe = pipeline("token-classification", model="HUMADEX/english_medical_ner", aggregation_strategy="simple")
medner_pipe = pipeline("token-classification", model="blaze999/Medical-NER", aggregation_strategy="simple")
biogpt_pipe = pipeline("text-generation", model="microsoft/BioGPT-Large-PubMedQA")

def advanced_preprocess(image_path):
    img = Image.open(image_path).convert('L')
    img = img.resize((960,1280))
    img = ImageOps.autocontrast(img)
    img = ImageEnhance.Contrast(img).enhance(2)
    npimg = np.array(img)
    npimg = np.where(npimg < 128, 0, 255).astype(np.uint8)
    bin_img = Image.fromarray(npimg)
    return bin_img.convert("RGB")

def clean_text(text):
    # Remove <...> tags, e.g., <TITLE>, < /FREETEXT >, < / ABSTRACT >
    text = re.sub(r'<[^>]+>', '', text)
    # Remove common bar/block Unicode, including standalone ones
    text = re.sub(r'[▃▅▇█━—]+', '', text)
    # Remove common markdown and odd punctuation
    text = re.sub(r'[•◆■★●]', '', text)
    # Collapse weird newlines/extra whitespace
    text = text.replace('\n', ' ').replace('\f', ' ')
    # Only keep useful punctuation and collapse spaces
    text = re.sub(r'[^A-Za-z0-9\s\-/\(\)\.,:%]', '', text)
    text = re.sub(' +', ' ', text)
    # Remove spaces between digits
    text = re.sub(r'(\d)\s+(\d)', r'\1\2', text)
    return text.strip()

def extract_drugs_and_dose(text):
    drugs = re.findall(r'(SYP|TAB|CAP|SYRUP|INJECTION|DROPS|INHALER|MEFTAL[- ]?P|CALPOL|DELCON|LEVOLIN)[\w\-\/\(\)]*', text, re.I)
    doses = re.findall(r'\d+(\.\d+)?\s*(ml|mg|g|mcg|tablet|cap|puff|dose|drops)', text, re.I)
    frequency = re.findall(r'(qc[h]?|q6h|tds|t[.]?d[.]?s[.]?|qds|b[.]?d[.]?|bd|sos|daily|once|twice|x\s*\d+d)', text, re.I)
    doses = set([d[0]+d[1] if d[0] else d[1] for d in doses])
    return set(drugs), doses, set(frequency)

@app.post("/api/prescription")
async def prescription(file: UploadFile = File(...)):
    filepath = f"temp_{file.filename}"
    with open(filepath, "wb") as f:
        f.write(await file.read())
    img = advanced_preprocess(filepath)

    pixel_values = donut_processor(images=img, return_tensors="pt").pixel_values.to(device)
    task_prompt = "<s_ocr>"
    decoder_input_ids = donut_processor.tokenizer(task_prompt, return_tensors="pt").input_ids.to(device)
    generated_ids = donut_model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=512)
    donut_text = donut_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    easy_text = "\n".join([t[1] for t in reader.readtext(filepath)])
    tess_text = pytesseract.image_to_string(img)
    texts = [donut_text, easy_text, tess_text]
    best_text = max(texts, key=lambda t: len(set(t.strip().split())))
    cleaned = clean_text(best_text)
    humadex_ents = humadex_pipe(cleaned)
    medner_ents = medner_pipe(cleaned)
    regex_drugs, regex_doses, regex_freqs = extract_drugs_and_dose(cleaned)
    out_drugs = set([ent.get('word','') for ent in humadex_ents if 'DRUG' in ent.get('entity_group','').upper()]) | \
        set([ent.get('word','') for ent in medner_ents if 'DRUG' in ent.get('entity_group','').upper()]) | regex_drugs
    out_doses = set([ent.get('word','') for ent in humadex_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | \
        set([ent.get('word','') for ent in medner_ents if 'DOSE' in ent.get('entity_group','').upper() or 'DOSAGE' in ent.get('entity_group','').upper()]) | regex_doses
    out_freqs = set([ent.get('word','') for ent in humadex_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | \
        set([ent.get('word','') for ent in medner_ents if 'FREQUENCY' in ent.get('entity_group','').upper()]) | regex_freqs

    os.remove(filepath)

    return {
        "ocr_text": cleaned,
        "drugs": list(out_drugs),
        "doses": list(out_doses),
        "frequencies": list(out_freqs),
    }

@app.post("/api/chat")
async def chat(message: str = Form(...)):
    # Query BioGPT
    result = biogpt_pipe(message, max_new_tokens=400)[0]["generated_text"]
    
    # ✅ CLEAN THE BIOGPT OUTPUT
    result = clean_text(result)
    
    return {"response": result}

D_ID_API_KEY = os.environ.get("D_ID_API_KEY")

@app.post("/api/did_talk")
async def did_talk(request: Request):
    body = await request.json()
    text = body["text"]
    image_url = body["image_url"]

    # ✅ CLEAN TEXT FOR D-ID TTS
    text = clean_text(text)
    text = text[:500]  # Limit to 500 chars for TTS stability

    headers = {
        "Authorization": f"Basic {D_ID_API_KEY}:",
        "Content-Type": "application/json"
    }
    payload = {
        "script": {
            "type": "text",
            "input": text,
            "provider": {"type": "microsoft", "voice_id": "en-US-GuyNeural"}
        },
        "config": {"result_format": "mp4"},
        "source_url": image_url
    }

    resp = requests.post("https://api.d-id.com/talks", headers=headers, json=payload)
    print(f"D-ID create response: {resp.status_code} - {resp.text}")
    if not resp.ok:
        return {"error": "Failed to create D-ID talk", "details": resp.text}
    talk_id = resp.json()["id"]
    print(f"D-ID talk_id: {talk_id}")

    for i in range(20):
        import time; time.sleep(3)
        status = requests.get(f"https://api.d-id.com/talks/{talk_id}", headers=headers)
        status_data = status.json()
        print(f"Poll #{i+1}: {status_data}")
        if "result_url" in status_data and status_data["result_url"]:
            print(f"Video ready: {status_data['result_url']}")
            return {"video_url": status_data["result_url"]}
    
    print("Video generation timed out")
    return {"error": "Timed out or video not ready"}