yusir4200's picture
Update app.py
d94b4b9 verified
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# app.py
#
# ่ฏดๆ˜Ž๏ผš่ฟ™ไธช่„šๆœฌๅฎž็Žฐไบ†
# 1. ็”จๆˆทไธŠไผ ็—…็ถๅ›พๅƒ๏ผˆJPEG/PNG๏ผ‰
# 2. ่ฐƒ็”จ Hugging Face Inference API ็š„้ข„่ฎญ็ปƒโ€œmedical-samโ€ๅˆ†ๅ‰ฒๆจกๅž‹๏ผŒ่ฟ”ๅ›ž mask
# 3. ไฝฟ็”จ LangChain ๅฐ† โ€œknowledgeโ€ ็›ฎๅฝ•ไธ‹ๆ‰€ๆœ‰ PDF ๆ–‡ๆœฌๅšๅ‘้‡ๅŒ–๏ผŒๆž„ๅปบ FAISS ็ดขๅผ•
# 4. ๅฐ†ๅˆ†ๅ‰ฒ็ป“ๆžœ็‰นๅพ๏ผˆ้ข็งฏใ€ไฝ็ฝฎ๏ผ‰ไธŽ็”จๆˆทๅ›พๅƒๅŸบๆœฌไฟกๆฏไธ€่ตท๏ผŒๆ‹ผๅœจไธ€่ตท็”Ÿๆˆๆฃ€็ดข+็”Ÿๆˆๆ็คบ๏ผˆRAG+Agent๏ผ‰
# 5. ็”จ OpenAI GPT-3.5/GPT-4 ็”Ÿๆˆ โ€œๅ‡ ๅฅ่ฏโ€ ๅˆ†ๅ‰ฒ็ป†่Š‚ๆ่ฟฐ
# 6. Gradio ๅ‰็ซฏๅฑ•็คบ๏ผšๅนถๆŽ’ๅฑ•็คบๅŽŸๅ›พ + ๅˆ†ๅ‰ฒ mask + ๆ–‡ๅญ—ๆ่ฟฐ
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
import os
import io
import tempfile
import numpy as np
from PIL import Image
import torch
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
import gradio as gr
from langchain.document_loaders import PDFPlumberLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
# โ”€โ”€โ”€ ไธ€ใ€ๅŠ ่ฝฝ็Žฏๅขƒๅ˜้‡ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
HF_API_TOKEN = os.getenv("HF_API_TOKEN") # ไปŽ Space Settings โ†’ Secrets ๅกซๅ…ฅ
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # ไปŽ Space Settings โ†’ Secrets ๅกซๅ…ฅ
MODEL_CHECKPOINT = "salesforce/segformer-b0-finetuned-ade-512-512" # ็คบไพ‹้€š็”จๅˆ†ๅ‰ฒ๏ผŒๅฏๆ นๆฎ้œ€่ฆๆขๆˆๅŒปๅญฆไธ“็”จๆจกๅž‹
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
# โ”€โ”€โ”€ ไบŒใ€ๅˆๅง‹ๅŒ–ๅˆ†ๅ‰ฒๆจกๅž‹ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_CHECKPOINT, use_auth_token=HF_API_TOKEN)
seg_model = SegformerForSemanticSegmentation.from_pretrained(MODEL_CHECKPOINT, use_auth_token=HF_API_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
seg_model.to(device)
# โ”€โ”€โ”€ ไธ‰ใ€ๆž„ๅปบ RAG ๅ‘้‡็ดขๅผ• โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def build_knowledge_vectorstore(pdf_folder="knowledge"):
"""ๆŠŠ knowledge/ ไธ‹ๆ‰€ๆœ‰ PDF ๅŠ ่ฝฝใ€ๆ‹†ๅˆ†้กตใ€ๅš Embeddingใ€็”จ FAISS ๅปบ็ดขๅผ•"""
docs = []
for fn in os.listdir(pdf_folder):
if fn.lower().endswith(".pdf"):
loader = PDFPlumberLoader(os.path.join(pdf_folder, fn))
for page in loader.load():
docs.append(page)
# ๆŒ‰้กตๆ‹†ๅˆ†๏ผŒไฟ่ฏๆฏไธช chunk ไธ่ถ…่ฟ‡ 1000 ๅญ—
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
splits = text_splitter.split_documents(docs)
# ็”จ Sentence-Transformers ๅš Embedding
embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)
vs = FAISS.from_documents(splits, embeddings)
return vs
# ๅฝ“ Space ๅฏๅŠจๆ—ถ๏ผŒๅ…ˆๆž„ๅปบไธ€ๆฌก็ดขๅผ•
vectorstore = build_knowledge_vectorstore(pdf_folder="knowledge")
qa_chain = RetrievalQA.from_chain_type(
llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0),
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={"k": 3}), # ๆฃ€็ดข Top-3 ็›ธๅ…ณๆฎต่ฝ
)
# โ”€โ”€โ”€ ๅ››ใ€่พ…ๅŠฉๅ‡ฝๆ•ฐ๏ผšๅˆ†ๅ‰ฒ + ็‰นๅพๆๅ– โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def run_segmentation(image_pil: Image.Image):
"""่ฐƒ็”จ SegFormer ๆจกๅž‹ๅš่ฏญไน‰ๅˆ†ๅ‰ฒ๏ผŒ่ฟ”ๅ›ž mask numpy array"""
image_rgb = image_pil.convert("RGB")
inputs = feature_extractor(images=image_rgb, return_tensors="pt").to(device)
with torch.no_grad():
outputs = seg_model(**inputs)
logits = outputs.logits # shape: (1, num_classes, H, W)
# ๅ–ๆฏๅƒ็ด ๆœ€ๅคงๆฆ‚็އ็š„็ฑปๅˆซ
preds = torch.argmax(logits, dim=1)[0].cpu().numpy() # shape: (H, W)
return preds
def extract_mask_stats(mask_np: np.ndarray):
"""
็ฎ€ๅ•็คบไพ‹๏ผšๅ‡่ฎพๅˆ†ๅ‰ฒๆจกๅž‹ๅชๅˆ†ไธค็ฑป๏ผˆ่ƒŒๆ™ฏ=0๏ผŒ็—…็ถ=1๏ผ‰
1. ็ปŸ่ฎก็—…็ถๅƒ็ด ไธชๆ•ฐๅ ๆฏ”
2. ่ฎก็ฎ—็—…็ถๅŒ…ๅ›ด็›’๏ผˆxmin,ymin,xmax,ymax๏ผ‰
"""
lesion_pixels = (mask_np == 1)
total_pixels = mask_np.size
lesion_count = lesion_pixels.sum()
area_ratio = float(lesion_count) / total_pixels
ys, xs = np.where(lesion_pixels)
if len(xs) > 0 and len(ys) > 0:
xmin, xmax = int(xs.min()), int(xs.max())
ymin, ymax = int(ys.min()), int(ys.max())
else:
xmin = ymin = xmax = ymax = 0
return {
"area_ratio": round(area_ratio, 4),
"bbox": (xmin, ymin, xmax, ymax),
"lesion_pixels": int(lesion_count),
}
# โ”€โ”€โ”€ ไบ”ใ€Gradio ๅ›ž่ฐƒๅ‡ฝๆ•ฐ๏ผšไธŠไผ โ†’ๅˆ†ๅ‰ฒโ†’RAG+Agent ็”Ÿๆˆๆ่ฟฐ โ†’ ่ฟ”ๅ›žๅ‰็ซฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def segment_and_describe(image_file):
# 1. ๆŠŠไธŠไผ ๆ–‡ไปถ่ฝฌๆˆ PIL Image
image_pil = Image.open(image_file).convert("RGB")
# 2. ่ฟ่กŒๅˆ†ๅ‰ฒๆจกๅž‹
mask_np = run_segmentation(image_pil) # shape: (H, W)
# 3. ๆๅ–ๅˆ†ๅ‰ฒ็‰นๅพ
stats = extract_mask_stats(mask_np)
area_pct = stats["area_ratio"] * 100 # ๆขๆˆ็™พๅˆ†ๆฏ”
bbox = stats["bbox"]
# 4. ๅ‡†ๅค‡็ป™ LLM ็š„ๆ็คบ่ฏญ๏ผˆPrompt๏ผ‰๏ผ›ๅ…ˆๅšไธ€ๆฌก็Ÿฅ่ฏ†ๆฃ€็ดข
# ๆฃ€็ดข็”จๆˆทๅ…ณๅฟƒ็š„ๅ…ณ้”ฎไธŠไธ‹ๆ–‡๏ผŒๆฏ”ๅฆ‚โ€œ็—…็ถๅฝขๆ€โ€โ€œ็—…็ถๅคงๅฐโ€็ญ‰
query_text = (
f"่ฏทๆ นๆฎไปฅไธ‹ไฟกๆฏๆ’ฐๅ†™ๅŒปๅญฆๅˆ†ๅ‰ฒ่งฃๆž๏ผš\n"
f"- ็—…็ถๅƒ็ด ๅ ๆฏ”็บฆ {area_pct:.2f}%ใ€‚\n"
f"- ็—…็ถๅŒ…ๅ›ด็›’ๅๆ ‡ (xmin, ymin, xmax, ymax) = {bbox}ใ€‚\n"
f"- ็”จๆˆทไธŠไผ ็š„ๅ›พๅƒ็ฑปๅž‹ไธบ็—…็ถๅ›พๅƒใ€‚\n"
f"่ฏท็ป“ๅˆๅŒปๅญฆ็Ÿฅ่ฏ†๏ผŒๅฏนๆญค็—…็ถๅˆ†ๅ‰ฒ็ป“ๆžœๅš 2โ€“3 ๅฅ่ฏฆ็ป†ไธ“ไธšๆ่ฟฐใ€‚"
)
# ไฝฟ็”จ RAG ๆฃ€็ดข็›ธๅ…ณ็‰‡ๆฎต
rag_answer = qa_chain.run(query_text)
# 5. ๅฐ†ๅŽŸๅง‹ Prompt + RAG ๆฃ€็ดขๅ†…ๅฎน + Query ไธ€่ตทไผ ็ป™ ChatOpenAI๏ผŒ่ฎฉๅฎƒโ€œๆ•ดๅˆโ€ๆˆๆœ€็ปˆๆ่ฟฐ
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.2)
full_prompt = (
f"ไธ‹้ขๆ˜ฏๆฃ€็ดขๅˆฐ็š„ๅŒปๅญฆๆ–‡ๆกฃ็‰‡ๆฎต๏ผš\n{rag_answer}\n\n"
f"่ฏท็ป“ๅˆไธŠ่ฟฐ็‰‡ๆฎตๅ’Œ็ปŸ่ฎกไฟกๆฏ๏ผŒๅฏน็—…็ถๅˆ†ๅ‰ฒ็ป“ๆžœๆ’ฐๅ†™ 2โ€“3 ๅฅไธ“ไธšๆ่ฟฐ๏ผš\n"
f"- ็—…็ถๅƒ็ด ๅ ๆฏ”๏ผš{area_pct:.2f}%\n"
f"- ็—…็ถๅŒ…ๅ›ด็›’๏ผš{bbox}\n"
)
description = llm(full_prompt).content.strip()
# 6. ๆŠŠ mask ๅ˜ๆˆ RGBA ๅ›พ๏ผŒๅ ๅŠ ๅˆฐๅŽŸๅ›พไธŠ
mask_rgba = Image.fromarray((mask_np * 255).astype(np.uint8)).convert("L")
mask_rgba = mask_rgba.resize(image_pil.size)
# ็”Ÿๆˆ็บข่‰ฒ่’™็‰ˆ
red_mask = Image.new("RGBA", image_pil.size, color=(255, 0, 0, 0))
red_mask.putalpha(mask_rgba)
overlay = Image.alpha_composite(image_pil.convert("RGBA"), red_mask)
return overlay, description
# โ”€โ”€โ”€ ๅ…ญใ€Gradio ๅ‰็ซฏ็•Œ้ขๆž„ๅปบ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Blocks() as demo:
gr.Markdown("## ๐Ÿฅ ๅŒปๅญฆ็—…็ถๅˆ†ๅ‰ฒ + RAG ไธ“ไธš่งฃ่ฏป Demo\n\n"
"1. ไธŠไผ ไฝ ็š„็—…็ถๅ›พๅƒ๏ผˆJPEG/PNG๏ผ‰\n"
"2. ็‚นๅ‡ปไธ‹ๆ–นๆŒ‰้’ฎ๏ผŒ่‡ชๅŠจ่ฟ”ๅ›ž **ๅˆ†ๅ‰ฒ็ป“ๆžœ** + **ๅ‡ ๅฅ่ฏ่ฏฆ็ป†ๆ่ฟฐ**\n")
with gr.Row():
img_in = gr.Image(type="file", label="ไธŠไผ ็—…็ถๅ›พๅƒ")
with gr.Column():
btn = gr.Button("ๅผ€ๅง‹ๅˆ†ๅ‰ฒๅนถ็”Ÿๆˆๆ่ฟฐ")
text_out = gr.Textbox(label="ๅˆ†ๅ‰ฒ+่งฃ่ฏป็ป“ๆžœ", lines=4)
img_out = gr.Image(label="ๅ ๅŠ ๅˆ†ๅ‰ฒ็ป“ๆžœ")
btn.click(fn=segment_and_describe, inputs=img_in, outputs=[img_out, text_out])
if __name__ == "__main__":
demo.launch()