Spaces:
Runtime error
Runtime error
File size: 8,344 Bytes
d94b4b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# app.py
#
# ่ฏดๆ๏ผ่ฟไธช่ๆฌๅฎ็ฐไบ
# 1. ็จๆทไธไผ ็
็ถๅพๅ๏ผJPEG/PNG๏ผ
# 2. ่ฐ็จ Hugging Face Inference API ็้ข่ฎญ็ปโmedical-samโๅๅฒๆจกๅ๏ผ่ฟๅ mask
# 3. ไฝฟ็จ LangChain ๅฐ โknowledgeโ ็ฎๅฝไธๆๆ PDF ๆๆฌๅๅ้ๅ๏ผๆๅปบ FAISS ็ดขๅผ
# 4. ๅฐๅๅฒ็ปๆ็นๅพ๏ผ้ข็งฏใไฝ็ฝฎ๏ผไธ็จๆทๅพๅๅบๆฌไฟกๆฏไธ่ตท๏ผๆผๅจไธ่ตท็ๆๆฃ็ดข+็ๆๆ็คบ๏ผRAG+Agent๏ผ
# 5. ็จ OpenAI GPT-3.5/GPT-4 ็ๆ โๅ ๅฅ่ฏโ ๅๅฒ็ป่ๆ่ฟฐ
# 6. Gradio ๅ็ซฏๅฑ็คบ๏ผๅนถๆๅฑ็คบๅๅพ + ๅๅฒ mask + ๆๅญๆ่ฟฐ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
import os
import io
import tempfile
import numpy as np
from PIL import Image
import torch
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
import gradio as gr
from langchain.document_loaders import PDFPlumberLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
# โโโ ไธใๅ ่ฝฝ็ฏๅขๅ้ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
HF_API_TOKEN = os.getenv("HF_API_TOKEN") # ไป Space Settings โ Secrets ๅกซๅ
ฅ
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # ไป Space Settings โ Secrets ๅกซๅ
ฅ
MODEL_CHECKPOINT = "salesforce/segformer-b0-finetuned-ade-512-512" # ็คบไพ้็จๅๅฒ๏ผๅฏๆ นๆฎ้่ฆๆขๆๅปๅญฆไธ็จๆจกๅ
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
# โโโ ไบใๅๅงๅๅๅฒๆจกๅ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_CHECKPOINT, use_auth_token=HF_API_TOKEN)
seg_model = SegformerForSemanticSegmentation.from_pretrained(MODEL_CHECKPOINT, use_auth_token=HF_API_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
seg_model.to(device)
# โโโ ไธใๆๅปบ RAG ๅ้็ดขๅผ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def build_knowledge_vectorstore(pdf_folder="knowledge"):
"""ๆ knowledge/ ไธๆๆ PDF ๅ ่ฝฝใๆๅ้กตใๅ Embeddingใ็จ FAISS ๅปบ็ดขๅผ"""
docs = []
for fn in os.listdir(pdf_folder):
if fn.lower().endswith(".pdf"):
loader = PDFPlumberLoader(os.path.join(pdf_folder, fn))
for page in loader.load():
docs.append(page)
# ๆ้กตๆๅ๏ผไฟ่ฏๆฏไธช chunk ไธ่ถ
่ฟ 1000 ๅญ
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
splits = text_splitter.split_documents(docs)
# ็จ Sentence-Transformers ๅ Embedding
embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)
vs = FAISS.from_documents(splits, embeddings)
return vs
# ๅฝ Space ๅฏๅจๆถ๏ผๅ
ๆๅปบไธๆฌก็ดขๅผ
vectorstore = build_knowledge_vectorstore(pdf_folder="knowledge")
qa_chain = RetrievalQA.from_chain_type(
llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0),
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={"k": 3}), # ๆฃ็ดข Top-3 ็ธๅ
ณๆฎต่ฝ
)
# โโโ ๅใ่พ
ๅฉๅฝๆฐ๏ผๅๅฒ + ็นๅพๆๅ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def run_segmentation(image_pil: Image.Image):
"""่ฐ็จ SegFormer ๆจกๅๅ่ฏญไนๅๅฒ๏ผ่ฟๅ mask numpy array"""
image_rgb = image_pil.convert("RGB")
inputs = feature_extractor(images=image_rgb, return_tensors="pt").to(device)
with torch.no_grad():
outputs = seg_model(**inputs)
logits = outputs.logits # shape: (1, num_classes, H, W)
# ๅๆฏๅ็ด ๆๅคงๆฆ็็็ฑปๅซ
preds = torch.argmax(logits, dim=1)[0].cpu().numpy() # shape: (H, W)
return preds
def extract_mask_stats(mask_np: np.ndarray):
"""
็ฎๅ็คบไพ๏ผๅ่ฎพๅๅฒๆจกๅๅชๅไธค็ฑป๏ผ่ๆฏ=0๏ผ็
็ถ=1๏ผ
1. ็ป่ฎก็
็ถๅ็ด ไธชๆฐๅ ๆฏ
2. ่ฎก็ฎ็
็ถๅ
ๅด็๏ผxmin,ymin,xmax,ymax๏ผ
"""
lesion_pixels = (mask_np == 1)
total_pixels = mask_np.size
lesion_count = lesion_pixels.sum()
area_ratio = float(lesion_count) / total_pixels
ys, xs = np.where(lesion_pixels)
if len(xs) > 0 and len(ys) > 0:
xmin, xmax = int(xs.min()), int(xs.max())
ymin, ymax = int(ys.min()), int(ys.max())
else:
xmin = ymin = xmax = ymax = 0
return {
"area_ratio": round(area_ratio, 4),
"bbox": (xmin, ymin, xmax, ymax),
"lesion_pixels": int(lesion_count),
}
# โโโ ไบใGradio ๅ่ฐๅฝๆฐ๏ผไธไผ โๅๅฒโRAG+Agent ็ๆๆ่ฟฐ โ ่ฟๅๅ็ซฏ โโโโโโโโโโโโโโโโโโ
def segment_and_describe(image_file):
# 1. ๆไธไผ ๆไปถ่ฝฌๆ PIL Image
image_pil = Image.open(image_file).convert("RGB")
# 2. ่ฟ่กๅๅฒๆจกๅ
mask_np = run_segmentation(image_pil) # shape: (H, W)
# 3. ๆๅๅๅฒ็นๅพ
stats = extract_mask_stats(mask_np)
area_pct = stats["area_ratio"] * 100 # ๆขๆ็พๅๆฏ
bbox = stats["bbox"]
# 4. ๅๅค็ป LLM ็ๆ็คบ่ฏญ๏ผPrompt๏ผ๏ผๅ
ๅไธๆฌก็ฅ่ฏๆฃ็ดข
# ๆฃ็ดข็จๆทๅ
ณๅฟ็ๅ
ณ้ฎไธไธๆ๏ผๆฏๅฆโ็
็ถๅฝขๆโโ็
็ถๅคงๅฐโ็ญ
query_text = (
f"่ฏทๆ นๆฎไปฅไธไฟกๆฏๆฐๅๅปๅญฆๅๅฒ่งฃๆ๏ผ\n"
f"- ็
็ถๅ็ด ๅ ๆฏ็บฆ {area_pct:.2f}%ใ\n"
f"- ็
็ถๅ
ๅด็ๅๆ (xmin, ymin, xmax, ymax) = {bbox}ใ\n"
f"- ็จๆทไธไผ ็ๅพๅ็ฑปๅไธบ็
็ถๅพๅใ\n"
f"่ฏท็ปๅๅปๅญฆ็ฅ่ฏ๏ผๅฏนๆญค็
็ถๅๅฒ็ปๆๅ 2โ3 ๅฅ่ฏฆ็ปไธไธๆ่ฟฐใ"
)
# ไฝฟ็จ RAG ๆฃ็ดข็ธๅ
ณ็ๆฎต
rag_answer = qa_chain.run(query_text)
# 5. ๅฐๅๅง Prompt + RAG ๆฃ็ดขๅ
ๅฎน + Query ไธ่ตทไผ ็ป ChatOpenAI๏ผ่ฎฉๅฎโๆดๅโๆๆ็ปๆ่ฟฐ
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.2)
full_prompt = (
f"ไธ้ขๆฏๆฃ็ดขๅฐ็ๅปๅญฆๆๆกฃ็ๆฎต๏ผ\n{rag_answer}\n\n"
f"่ฏท็ปๅไธ่ฟฐ็ๆฎตๅ็ป่ฎกไฟกๆฏ๏ผๅฏน็
็ถๅๅฒ็ปๆๆฐๅ 2โ3 ๅฅไธไธๆ่ฟฐ๏ผ\n"
f"- ็
็ถๅ็ด ๅ ๆฏ๏ผ{area_pct:.2f}%\n"
f"- ็
็ถๅ
ๅด็๏ผ{bbox}\n"
)
description = llm(full_prompt).content.strip()
# 6. ๆ mask ๅๆ RGBA ๅพ๏ผๅ ๅ ๅฐๅๅพไธ
mask_rgba = Image.fromarray((mask_np * 255).astype(np.uint8)).convert("L")
mask_rgba = mask_rgba.resize(image_pil.size)
# ็ๆ็บข่ฒ่็
red_mask = Image.new("RGBA", image_pil.size, color=(255, 0, 0, 0))
red_mask.putalpha(mask_rgba)
overlay = Image.alpha_composite(image_pil.convert("RGBA"), red_mask)
return overlay, description
# โโโ ๅ
ญใGradio ๅ็ซฏ็้ขๆๅปบ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
with gr.Blocks() as demo:
gr.Markdown("## ๐ฅ ๅปๅญฆ็
็ถๅๅฒ + RAG ไธไธ่งฃ่ฏป Demo\n\n"
"1. ไธไผ ไฝ ็็
็ถๅพๅ๏ผJPEG/PNG๏ผ\n"
"2. ็นๅปไธๆนๆ้ฎ๏ผ่ชๅจ่ฟๅ **ๅๅฒ็ปๆ** + **ๅ ๅฅ่ฏ่ฏฆ็ปๆ่ฟฐ**\n")
with gr.Row():
img_in = gr.Image(type="file", label="ไธไผ ็
็ถๅพๅ")
with gr.Column():
btn = gr.Button("ๅผๅงๅๅฒๅนถ็ๆๆ่ฟฐ")
text_out = gr.Textbox(label="ๅๅฒ+่งฃ่ฏป็ปๆ", lines=4)
img_out = gr.Image(label="ๅ ๅ ๅๅฒ็ปๆ")
btn.click(fn=segment_and_describe, inputs=img_in, outputs=[img_out, text_out])
if __name__ == "__main__":
demo.launch()
|