Spaces:
Runtime error
Runtime error
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
# app.py | |
# | |
# ่ฏดๆ๏ผ่ฟไธช่ๆฌๅฎ็ฐไบ | |
# 1. ็จๆทไธไผ ็ ็ถๅพๅ๏ผJPEG/PNG๏ผ | |
# 2. ่ฐ็จ Hugging Face Inference API ็้ข่ฎญ็ปโmedical-samโๅๅฒๆจกๅ๏ผ่ฟๅ mask | |
# 3. ไฝฟ็จ LangChain ๅฐ โknowledgeโ ็ฎๅฝไธๆๆ PDF ๆๆฌๅๅ้ๅ๏ผๆๅปบ FAISS ็ดขๅผ | |
# 4. ๅฐๅๅฒ็ปๆ็นๅพ๏ผ้ข็งฏใไฝ็ฝฎ๏ผไธ็จๆทๅพๅๅบๆฌไฟกๆฏไธ่ตท๏ผๆผๅจไธ่ตท็ๆๆฃ็ดข+็ๆๆ็คบ๏ผRAG+Agent๏ผ | |
# 5. ็จ OpenAI GPT-3.5/GPT-4 ็ๆ โๅ ๅฅ่ฏโ ๅๅฒ็ป่ๆ่ฟฐ | |
# 6. Gradio ๅ็ซฏๅฑ็คบ๏ผๅนถๆๅฑ็คบๅๅพ + ๅๅฒ mask + ๆๅญๆ่ฟฐ | |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
import os | |
import io | |
import tempfile | |
import numpy as np | |
from PIL import Image | |
import torch | |
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor | |
import gradio as gr | |
from langchain.document_loaders import PDFPlumberLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import SentenceTransformerEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import RetrievalQA | |
# โโโ ไธใๅ ่ฝฝ็ฏๅขๅ้ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
HF_API_TOKEN = os.getenv("HF_API_TOKEN") # ไป Space Settings โ Secrets ๅกซๅ ฅ | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # ไป Space Settings โ Secrets ๅกซๅ ฅ | |
MODEL_CHECKPOINT = "salesforce/segformer-b0-finetuned-ade-512-512" # ็คบไพ้็จๅๅฒ๏ผๅฏๆ นๆฎ้่ฆๆขๆๅปๅญฆไธ็จๆจกๅ | |
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY | |
# โโโ ไบใๅๅงๅๅๅฒๆจกๅ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_CHECKPOINT, use_auth_token=HF_API_TOKEN) | |
seg_model = SegformerForSemanticSegmentation.from_pretrained(MODEL_CHECKPOINT, use_auth_token=HF_API_TOKEN) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
seg_model.to(device) | |
# โโโ ไธใๆๅปบ RAG ๅ้็ดขๅผ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
def build_knowledge_vectorstore(pdf_folder="knowledge"): | |
"""ๆ knowledge/ ไธๆๆ PDF ๅ ่ฝฝใๆๅ้กตใๅ Embeddingใ็จ FAISS ๅปบ็ดขๅผ""" | |
docs = [] | |
for fn in os.listdir(pdf_folder): | |
if fn.lower().endswith(".pdf"): | |
loader = PDFPlumberLoader(os.path.join(pdf_folder, fn)) | |
for page in loader.load(): | |
docs.append(page) | |
# ๆ้กตๆๅ๏ผไฟ่ฏๆฏไธช chunk ไธ่ถ ่ฟ 1000 ๅญ | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
splits = text_splitter.split_documents(docs) | |
# ็จ Sentence-Transformers ๅ Embedding | |
embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL) | |
vs = FAISS.from_documents(splits, embeddings) | |
return vs | |
# ๅฝ Space ๅฏๅจๆถ๏ผๅ ๆๅปบไธๆฌก็ดขๅผ | |
vectorstore = build_knowledge_vectorstore(pdf_folder="knowledge") | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0), | |
chain_type="stuff", | |
retriever=vectorstore.as_retriever(search_kwargs={"k": 3}), # ๆฃ็ดข Top-3 ็ธๅ ณๆฎต่ฝ | |
) | |
# โโโ ๅใ่พ ๅฉๅฝๆฐ๏ผๅๅฒ + ็นๅพๆๅ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
def run_segmentation(image_pil: Image.Image): | |
"""่ฐ็จ SegFormer ๆจกๅๅ่ฏญไนๅๅฒ๏ผ่ฟๅ mask numpy array""" | |
image_rgb = image_pil.convert("RGB") | |
inputs = feature_extractor(images=image_rgb, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = seg_model(**inputs) | |
logits = outputs.logits # shape: (1, num_classes, H, W) | |
# ๅๆฏๅ็ด ๆๅคงๆฆ็็็ฑปๅซ | |
preds = torch.argmax(logits, dim=1)[0].cpu().numpy() # shape: (H, W) | |
return preds | |
def extract_mask_stats(mask_np: np.ndarray): | |
""" | |
็ฎๅ็คบไพ๏ผๅ่ฎพๅๅฒๆจกๅๅชๅไธค็ฑป๏ผ่ๆฏ=0๏ผ็ ็ถ=1๏ผ | |
1. ็ป่ฎก็ ็ถๅ็ด ไธชๆฐๅ ๆฏ | |
2. ่ฎก็ฎ็ ็ถๅ ๅด็๏ผxmin,ymin,xmax,ymax๏ผ | |
""" | |
lesion_pixels = (mask_np == 1) | |
total_pixels = mask_np.size | |
lesion_count = lesion_pixels.sum() | |
area_ratio = float(lesion_count) / total_pixels | |
ys, xs = np.where(lesion_pixels) | |
if len(xs) > 0 and len(ys) > 0: | |
xmin, xmax = int(xs.min()), int(xs.max()) | |
ymin, ymax = int(ys.min()), int(ys.max()) | |
else: | |
xmin = ymin = xmax = ymax = 0 | |
return { | |
"area_ratio": round(area_ratio, 4), | |
"bbox": (xmin, ymin, xmax, ymax), | |
"lesion_pixels": int(lesion_count), | |
} | |
# โโโ ไบใGradio ๅ่ฐๅฝๆฐ๏ผไธไผ โๅๅฒโRAG+Agent ็ๆๆ่ฟฐ โ ่ฟๅๅ็ซฏ โโโโโโโโโโโโโโโโโโ | |
def segment_and_describe(image_file): | |
# 1. ๆไธไผ ๆไปถ่ฝฌๆ PIL Image | |
image_pil = Image.open(image_file).convert("RGB") | |
# 2. ่ฟ่กๅๅฒๆจกๅ | |
mask_np = run_segmentation(image_pil) # shape: (H, W) | |
# 3. ๆๅๅๅฒ็นๅพ | |
stats = extract_mask_stats(mask_np) | |
area_pct = stats["area_ratio"] * 100 # ๆขๆ็พๅๆฏ | |
bbox = stats["bbox"] | |
# 4. ๅๅค็ป LLM ็ๆ็คบ่ฏญ๏ผPrompt๏ผ๏ผๅ ๅไธๆฌก็ฅ่ฏๆฃ็ดข | |
# ๆฃ็ดข็จๆทๅ ณๅฟ็ๅ ณ้ฎไธไธๆ๏ผๆฏๅฆโ็ ็ถๅฝขๆโโ็ ็ถๅคงๅฐโ็ญ | |
query_text = ( | |
f"่ฏทๆ นๆฎไปฅไธไฟกๆฏๆฐๅๅปๅญฆๅๅฒ่งฃๆ๏ผ\n" | |
f"- ็ ็ถๅ็ด ๅ ๆฏ็บฆ {area_pct:.2f}%ใ\n" | |
f"- ็ ็ถๅ ๅด็ๅๆ (xmin, ymin, xmax, ymax) = {bbox}ใ\n" | |
f"- ็จๆทไธไผ ็ๅพๅ็ฑปๅไธบ็ ็ถๅพๅใ\n" | |
f"่ฏท็ปๅๅปๅญฆ็ฅ่ฏ๏ผๅฏนๆญค็ ็ถๅๅฒ็ปๆๅ 2โ3 ๅฅ่ฏฆ็ปไธไธๆ่ฟฐใ" | |
) | |
# ไฝฟ็จ RAG ๆฃ็ดข็ธๅ ณ็ๆฎต | |
rag_answer = qa_chain.run(query_text) | |
# 5. ๅฐๅๅง Prompt + RAG ๆฃ็ดขๅ ๅฎน + Query ไธ่ตทไผ ็ป ChatOpenAI๏ผ่ฎฉๅฎโๆดๅโๆๆ็ปๆ่ฟฐ | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.2) | |
full_prompt = ( | |
f"ไธ้ขๆฏๆฃ็ดขๅฐ็ๅปๅญฆๆๆกฃ็ๆฎต๏ผ\n{rag_answer}\n\n" | |
f"่ฏท็ปๅไธ่ฟฐ็ๆฎตๅ็ป่ฎกไฟกๆฏ๏ผๅฏน็ ็ถๅๅฒ็ปๆๆฐๅ 2โ3 ๅฅไธไธๆ่ฟฐ๏ผ\n" | |
f"- ็ ็ถๅ็ด ๅ ๆฏ๏ผ{area_pct:.2f}%\n" | |
f"- ็ ็ถๅ ๅด็๏ผ{bbox}\n" | |
) | |
description = llm(full_prompt).content.strip() | |
# 6. ๆ mask ๅๆ RGBA ๅพ๏ผๅ ๅ ๅฐๅๅพไธ | |
mask_rgba = Image.fromarray((mask_np * 255).astype(np.uint8)).convert("L") | |
mask_rgba = mask_rgba.resize(image_pil.size) | |
# ็ๆ็บข่ฒ่็ | |
red_mask = Image.new("RGBA", image_pil.size, color=(255, 0, 0, 0)) | |
red_mask.putalpha(mask_rgba) | |
overlay = Image.alpha_composite(image_pil.convert("RGBA"), red_mask) | |
return overlay, description | |
# โโโ ๅ ญใGradio ๅ็ซฏ็้ขๆๅปบ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
with gr.Blocks() as demo: | |
gr.Markdown("## ๐ฅ ๅปๅญฆ็ ็ถๅๅฒ + RAG ไธไธ่งฃ่ฏป Demo\n\n" | |
"1. ไธไผ ไฝ ็็ ็ถๅพๅ๏ผJPEG/PNG๏ผ\n" | |
"2. ็นๅปไธๆนๆ้ฎ๏ผ่ชๅจ่ฟๅ **ๅๅฒ็ปๆ** + **ๅ ๅฅ่ฏ่ฏฆ็ปๆ่ฟฐ**\n") | |
with gr.Row(): | |
img_in = gr.Image(type="file", label="ไธไผ ็ ็ถๅพๅ") | |
with gr.Column(): | |
btn = gr.Button("ๅผๅงๅๅฒๅนถ็ๆๆ่ฟฐ") | |
text_out = gr.Textbox(label="ๅๅฒ+่งฃ่ฏป็ปๆ", lines=4) | |
img_out = gr.Image(label="ๅ ๅ ๅๅฒ็ปๆ") | |
btn.click(fn=segment_and_describe, inputs=img_in, outputs=[img_out, text_out]) | |
if __name__ == "__main__": | |
demo.launch() | |