File size: 8,344 Bytes
d94b4b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# app.py
#
# ่ฏดๆ˜Ž๏ผš่ฟ™ไธช่„šๆœฌๅฎž็Žฐไบ†
#  1. ็”จๆˆทไธŠไผ ็—…็ถๅ›พๅƒ๏ผˆJPEG/PNG๏ผ‰
#  2. ่ฐƒ็”จ Hugging Face Inference API ็š„้ข„่ฎญ็ปƒโ€œmedical-samโ€ๅˆ†ๅ‰ฒๆจกๅž‹๏ผŒ่ฟ”ๅ›ž mask
#  3. ไฝฟ็”จ LangChain ๅฐ† โ€œknowledgeโ€ ็›ฎๅฝ•ไธ‹ๆ‰€ๆœ‰ PDF ๆ–‡ๆœฌๅšๅ‘้‡ๅŒ–๏ผŒๆž„ๅปบ FAISS ็ดขๅผ•
#  4. ๅฐ†ๅˆ†ๅ‰ฒ็ป“ๆžœ็‰นๅพ๏ผˆ้ข็งฏใ€ไฝ็ฝฎ๏ผ‰ไธŽ็”จๆˆทๅ›พๅƒๅŸบๆœฌไฟกๆฏไธ€่ตท๏ผŒๆ‹ผๅœจไธ€่ตท็”Ÿๆˆๆฃ€็ดข+็”Ÿๆˆๆ็คบ๏ผˆRAG+Agent๏ผ‰
#  5. ็”จ OpenAI GPT-3.5/GPT-4 ็”Ÿๆˆ โ€œๅ‡ ๅฅ่ฏโ€ ๅˆ†ๅ‰ฒ็ป†่Š‚ๆ่ฟฐ
#  6. Gradio ๅ‰็ซฏๅฑ•็คบ๏ผšๅนถๆŽ’ๅฑ•็คบๅŽŸๅ›พ + ๅˆ†ๅ‰ฒ mask + ๆ–‡ๅญ—ๆ่ฟฐ
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”

import os
import io
import tempfile
import numpy as np
from PIL import Image

import torch
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
import gradio as gr

from langchain.document_loaders import PDFPlumberLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA

# โ”€โ”€โ”€ ไธ€ใ€ๅŠ ่ฝฝ็Žฏๅขƒๅ˜้‡ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
HF_API_TOKEN       = os.getenv("HF_API_TOKEN")       # ไปŽ Space Settings โ†’ Secrets ๅกซๅ…ฅ
OPENAI_API_KEY     = os.getenv("OPENAI_API_KEY")     # ไปŽ Space Settings โ†’ Secrets ๅกซๅ…ฅ
MODEL_CHECKPOINT   = "salesforce/segformer-b0-finetuned-ade-512-512"  # ็คบไพ‹้€š็”จๅˆ†ๅ‰ฒ๏ผŒๅฏๆ นๆฎ้œ€่ฆๆขๆˆๅŒปๅญฆไธ“็”จๆจกๅž‹
EMBEDDING_MODEL    = "sentence-transformers/all-MiniLM-L6-v2"

os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

# โ”€โ”€โ”€ ไบŒใ€ๅˆๅง‹ๅŒ–ๅˆ†ๅ‰ฒๆจกๅž‹ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_CHECKPOINT, use_auth_token=HF_API_TOKEN)
seg_model = SegformerForSemanticSegmentation.from_pretrained(MODEL_CHECKPOINT, use_auth_token=HF_API_TOKEN)

device = "cuda" if torch.cuda.is_available() else "cpu"
seg_model.to(device)

# โ”€โ”€โ”€ ไธ‰ใ€ๆž„ๅปบ RAG ๅ‘้‡็ดขๅผ• โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def build_knowledge_vectorstore(pdf_folder="knowledge"):
    """ๆŠŠ knowledge/ ไธ‹ๆ‰€ๆœ‰ PDF ๅŠ ่ฝฝใ€ๆ‹†ๅˆ†้กตใ€ๅš Embeddingใ€็”จ FAISS ๅปบ็ดขๅผ•"""
    docs = []
    for fn in os.listdir(pdf_folder):
        if fn.lower().endswith(".pdf"):
            loader = PDFPlumberLoader(os.path.join(pdf_folder, fn))
            for page in loader.load():
                docs.append(page)
    # ๆŒ‰้กตๆ‹†ๅˆ†๏ผŒไฟ่ฏๆฏไธช chunk ไธ่ถ…่ฟ‡ 1000 ๅญ—
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    splits = text_splitter.split_documents(docs)
    # ็”จ Sentence-Transformers ๅš Embedding
    embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)
    vs = FAISS.from_documents(splits, embeddings)
    return vs

# ๅฝ“ Space ๅฏๅŠจๆ—ถ๏ผŒๅ…ˆๆž„ๅปบไธ€ๆฌก็ดขๅผ•
vectorstore = build_knowledge_vectorstore(pdf_folder="knowledge")
qa_chain = RetrievalQA.from_chain_type(
    llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0),
    chain_type="stuff",
    retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),  # ๆฃ€็ดข Top-3 ็›ธๅ…ณๆฎต่ฝ
)

# โ”€โ”€โ”€ ๅ››ใ€่พ…ๅŠฉๅ‡ฝๆ•ฐ๏ผšๅˆ†ๅ‰ฒ + ็‰นๅพๆๅ– โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def run_segmentation(image_pil: Image.Image):
    """่ฐƒ็”จ SegFormer ๆจกๅž‹ๅš่ฏญไน‰ๅˆ†ๅ‰ฒ๏ผŒ่ฟ”ๅ›ž mask numpy array"""
    image_rgb = image_pil.convert("RGB")
    inputs = feature_extractor(images=image_rgb, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = seg_model(**inputs)
    logits = outputs.logits  # shape: (1, num_classes, H, W)
    # ๅ–ๆฏๅƒ็ด ๆœ€ๅคงๆฆ‚็އ็š„็ฑปๅˆซ
    preds = torch.argmax(logits, dim=1)[0].cpu().numpy()  # shape: (H, W)
    return preds

def extract_mask_stats(mask_np: np.ndarray):
    """
    ็ฎ€ๅ•็คบไพ‹๏ผšๅ‡่ฎพๅˆ†ๅ‰ฒๆจกๅž‹ๅชๅˆ†ไธค็ฑป๏ผˆ่ƒŒๆ™ฏ=0๏ผŒ็—…็ถ=1๏ผ‰
    1. ็ปŸ่ฎก็—…็ถๅƒ็ด ไธชๆ•ฐๅ ๆฏ”
    2. ่ฎก็ฎ—็—…็ถๅŒ…ๅ›ด็›’๏ผˆxmin,ymin,xmax,ymax๏ผ‰
    """
    lesion_pixels = (mask_np == 1)
    total_pixels = mask_np.size
    lesion_count = lesion_pixels.sum()
    area_ratio = float(lesion_count) / total_pixels

    ys, xs = np.where(lesion_pixels)
    if len(xs) > 0 and len(ys) > 0:
        xmin, xmax = int(xs.min()), int(xs.max())
        ymin, ymax = int(ys.min()), int(ys.max())
    else:
        xmin = ymin = xmax = ymax = 0

    return {
        "area_ratio": round(area_ratio, 4),
        "bbox": (xmin, ymin, xmax, ymax),
        "lesion_pixels": int(lesion_count),
    }

# โ”€โ”€โ”€ ไบ”ใ€Gradio ๅ›ž่ฐƒๅ‡ฝๆ•ฐ๏ผšไธŠไผ โ†’ๅˆ†ๅ‰ฒโ†’RAG+Agent ็”Ÿๆˆๆ่ฟฐ โ†’ ่ฟ”ๅ›žๅ‰็ซฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def segment_and_describe(image_file):
    # 1. ๆŠŠไธŠไผ ๆ–‡ไปถ่ฝฌๆˆ PIL Image
    image_pil = Image.open(image_file).convert("RGB")
    # 2. ่ฟ่กŒๅˆ†ๅ‰ฒๆจกๅž‹
    mask_np = run_segmentation(image_pil)  # shape: (H, W)
    # 3. ๆๅ–ๅˆ†ๅ‰ฒ็‰นๅพ
    stats = extract_mask_stats(mask_np)
    area_pct = stats["area_ratio"] * 100  # ๆขๆˆ็™พๅˆ†ๆฏ”
    bbox = stats["bbox"]

    # 4. ๅ‡†ๅค‡็ป™ LLM ็š„ๆ็คบ่ฏญ๏ผˆPrompt๏ผ‰๏ผ›ๅ…ˆๅšไธ€ๆฌก็Ÿฅ่ฏ†ๆฃ€็ดข
    #    ๆฃ€็ดข็”จๆˆทๅ…ณๅฟƒ็š„ๅ…ณ้”ฎไธŠไธ‹ๆ–‡๏ผŒๆฏ”ๅฆ‚โ€œ็—…็ถๅฝขๆ€โ€โ€œ็—…็ถๅคงๅฐโ€็ญ‰
    query_text = (
        f"่ฏทๆ นๆฎไปฅไธ‹ไฟกๆฏๆ’ฐๅ†™ๅŒปๅญฆๅˆ†ๅ‰ฒ่งฃๆž๏ผš\n"
        f"- ็—…็ถๅƒ็ด ๅ ๆฏ”็บฆ {area_pct:.2f}%ใ€‚\n"
        f"- ็—…็ถๅŒ…ๅ›ด็›’ๅๆ ‡ (xmin, ymin, xmax, ymax) = {bbox}ใ€‚\n"
        f"- ็”จๆˆทไธŠไผ ็š„ๅ›พๅƒ็ฑปๅž‹ไธบ็—…็ถๅ›พๅƒใ€‚\n"
        f"่ฏท็ป“ๅˆๅŒปๅญฆ็Ÿฅ่ฏ†๏ผŒๅฏนๆญค็—…็ถๅˆ†ๅ‰ฒ็ป“ๆžœๅš 2โ€“3 ๅฅ่ฏฆ็ป†ไธ“ไธšๆ่ฟฐใ€‚"
    )
    # ไฝฟ็”จ RAG ๆฃ€็ดข็›ธๅ…ณ็‰‡ๆฎต
    rag_answer = qa_chain.run(query_text)

    # 5. ๅฐ†ๅŽŸๅง‹ Prompt + RAG ๆฃ€็ดขๅ†…ๅฎน + Query ไธ€่ตทไผ ็ป™ ChatOpenAI๏ผŒ่ฎฉๅฎƒโ€œๆ•ดๅˆโ€ๆˆๆœ€็ปˆๆ่ฟฐ
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.2)
    full_prompt = (
        f"ไธ‹้ขๆ˜ฏๆฃ€็ดขๅˆฐ็š„ๅŒปๅญฆๆ–‡ๆกฃ็‰‡ๆฎต๏ผš\n{rag_answer}\n\n"
        f"่ฏท็ป“ๅˆไธŠ่ฟฐ็‰‡ๆฎตๅ’Œ็ปŸ่ฎกไฟกๆฏ๏ผŒๅฏน็—…็ถๅˆ†ๅ‰ฒ็ป“ๆžœๆ’ฐๅ†™ 2โ€“3 ๅฅไธ“ไธšๆ่ฟฐ๏ผš\n"
        f"- ็—…็ถๅƒ็ด ๅ ๆฏ”๏ผš{area_pct:.2f}%\n"
        f"- ็—…็ถๅŒ…ๅ›ด็›’๏ผš{bbox}\n"
    )
    description = llm(full_prompt).content.strip()

    # 6. ๆŠŠ mask ๅ˜ๆˆ RGBA ๅ›พ๏ผŒๅ ๅŠ ๅˆฐๅŽŸๅ›พไธŠ
    mask_rgba = Image.fromarray((mask_np * 255).astype(np.uint8)).convert("L")
    mask_rgba = mask_rgba.resize(image_pil.size)
    # ็”Ÿๆˆ็บข่‰ฒ่’™็‰ˆ
    red_mask = Image.new("RGBA", image_pil.size, color=(255, 0, 0, 0))
    red_mask.putalpha(mask_rgba)
    overlay = Image.alpha_composite(image_pil.convert("RGBA"), red_mask)

    return overlay, description

# โ”€โ”€โ”€ ๅ…ญใ€Gradio ๅ‰็ซฏ็•Œ้ขๆž„ๅปบ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Blocks() as demo:
    gr.Markdown("## ๐Ÿฅ ๅŒปๅญฆ็—…็ถๅˆ†ๅ‰ฒ + RAG ไธ“ไธš่งฃ่ฏป Demo\n\n"
                "1. ไธŠไผ ไฝ ็š„็—…็ถๅ›พๅƒ๏ผˆJPEG/PNG๏ผ‰\n"
                "2. ็‚นๅ‡ปไธ‹ๆ–นๆŒ‰้’ฎ๏ผŒ่‡ชๅŠจ่ฟ”ๅ›ž **ๅˆ†ๅ‰ฒ็ป“ๆžœ** + **ๅ‡ ๅฅ่ฏ่ฏฆ็ป†ๆ่ฟฐ**\n")
    with gr.Row():
        img_in = gr.Image(type="file", label="ไธŠไผ ็—…็ถๅ›พๅƒ")
        with gr.Column():
            btn = gr.Button("ๅผ€ๅง‹ๅˆ†ๅ‰ฒๅนถ็”Ÿๆˆๆ่ฟฐ")
            text_out = gr.Textbox(label="ๅˆ†ๅ‰ฒ+่งฃ่ฏป็ป“ๆžœ", lines=4)
    img_out = gr.Image(label="ๅ ๅŠ ๅˆ†ๅ‰ฒ็ป“ๆžœ")

    btn.click(fn=segment_and_describe, inputs=img_in, outputs=[img_out, text_out])

if __name__ == "__main__":
    demo.launch()