import os import sys import types import importlib.machinery from typing import List, Dict import gradio as gr import torch from PIL import Image # ========== 1) 偽裝 flash_attn,避免 remote code 硬性檢查 ========== def _make_pkg_stub(fullname: str): m = types.ModuleType(fullname) m.__file__ = f"" m.__package__ = fullname.rpartition('.')[0] m.__path__ = [] # 標記為 package m.__spec__ = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True) sys.modules[fullname] = m return m for name in [ "flash_attn", "flash_attn.ops", "flash_attn.layers", "flash_attn.functional", "flash_attn.bert_padding", "flash_attn.flash_attn_interface", ]: if name not in sys.modules: _make_pkg_stub(name) # ========== 2) Florence-2 載入(eager + dtype 對齊) ========== from transformers import AutoProcessor, AutoModelForCausalLM MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base") device = "cuda" if torch.cuda.is_available() else "cpu" TASK_TOKENS = { "caption": "", "object_detection": "", } _processor = None _model = None def get_florence2(): global _processor, _model if _processor is None or _model is None: _processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, attn_implementation="eager", # 關鍵:不依賴 flash_attn torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device).eval() _model.config.use_cache = False return _processor, _model @torch.inference_mode() def florence2_text(image: Image.Image, task: str = "caption"): proc, mdl = get_florence2() token = TASK_TOKENS.get(task, "") text = token # 這兩個任務都是「不帶輸入」的格式 # 先在 CPU 做處理,再手動搬到正確 device / 對齊 dtype batch = proc(text=text, images=image, return_tensors="pt") inputs = {} for k, v in batch.items(): if isinstance(v, torch.Tensor): if v.is_floating_point(): inputs[k] = v.to(device=device, dtype=mdl.dtype) else: inputs[k] = v.to(device=device) else: inputs[k] = v ids = mdl.generate( **inputs, max_new_tokens=128, do_sample=False, num_beams=1, use_cache=False, # ← 關掉 KV-cache(關鍵) early_stopping=False, # ← 與 num_beams=1 時無效,但設 False 更乾淨 eos_token_id=getattr(getattr(proc, "tokenizer", None), "eos_token_id", None), ) out = proc.batch_decode(ids, skip_special_tokens=True)[0].strip() if ">" in out: out = out.split(">", 1)[-1].strip() return out # ========== 3) 後端邏輯(食物 DB / 同義詞 / 估重 / 規則) ========== FOOD_DB = { "rice": {"kcal":130, "carb_g":28, "protein_g":2.4, "fat_g":0.3, "sodium_mg":0, "cat":"全榖雜糧類", "base_g":150, "tip":"主食可改糙米/全穀增加膳食纖維"}, "noodles":{"kcal":138, "carb_g":25, "protein_g":4.5, "fat_g":1.9, "sodium_mg":170, "cat":"全榖雜糧類", "base_g":180, "tip":"盡量選清湯少油,避免重鹹湯底"}, "bread": {"kcal":265, "carb_g":49, "protein_g":9.0, "fat_g":3.2, "sodium_mg":490, "cat":"全榖雜糧類", "base_g":60, "tip":"可選全麥減少抹醬、甜餡"}, "broccoli":{"kcal":35, "carb_g":7, "protein_g":2.4, "fat_g":0.4, "sodium_mg":33, "cat":"蔬菜類", "base_g":80, "tip":"川燙/清炒保留口感與維生素"}, "spinach":{"kcal":23, "carb_g":3.6,"protein_g":2.9,"fat_g":0.4,"sodium_mg":70, "cat":"蔬菜類", "base_g":80, "tip":"川燙後快炒,少鹽少油"}, "chicken":{"kcal":215,"carb_g":0, "protein_g":27, "fat_g":12, "sodium_mg":90, "cat":"豆魚蛋肉類", "base_g":120, "tip":"去皮烹調、烤/氣炸取代油炸"}, "soy_braised_chicken_leg":{"kcal":220,"carb_g":0,"protein_g":24,"fat_g":12,"sodium_mg":550,"cat":"豆魚蛋肉類","base_g":130,"tip":"減醬油與滷汁、可先汆燙再滷"}, "salmon":{"kcal":208,"carb_g":0, "protein_g":20, "fat_g":13, "sodium_mg":60, "cat":"豆魚蛋肉類", "base_g":120, "tip":"烤/蒸保留 Omega-3,少鹽少醬"}, "pork_chop":{"kcal":242,"carb_g":0,"protein_g":27,"fat_g":14,"sodium_mg":75, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少裹粉油炸,改煎烤並瀝油"}, "tofu": {"kcal":76, "carb_g":1.9,"protein_g":8.1,"fat_g":4.8,"sodium_mg":7, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少勾芡、少滷汁,清蒸清爽"}, "egg": {"kcal":155,"carb_g":1.1,"protein_g":13, "fat_g":11, "sodium_mg":124, "cat":"豆魚蛋肉類", "base_g":60, "tip":"水煮/荷包少油,避免重鹹醬料"}, "banana":{"kcal":89, "carb_g":23, "protein_g":1.1,"fat_g":0.3,"sodium_mg":1, "cat":"水果類", "base_g":100, "tip":"控制份量,避免一次過量"}, "miso_soup":{"kcal":36,"carb_g":4.3,"protein_g":2.0,"fat_g":1.3,"sodium_mg":550, "cat":"湯品/飲品", "base_g":200, "tip":"味噌湯偏鹹,建議少量品嚐"}, "salad": {"kcal":30,"carb_g":5,"protein_g":1.5,"fat_g":0.5,"sodium_mg":40,"cat":"蔬菜類","base_g":100,"tip":"少醬少油,優先清爽調味"}, "fish": {"kcal":170,"carb_g":0,"protein_g":22,"fat_g":8,"sodium_mg":70,"cat":"豆魚蛋肉類","base_g":120,"tip":"蒸/烤/煎少油,避免重鹹醬汁"}, } ALIASES = { "white rice":"rice","steamed rice":"rice","飯":"rice","白飯":"rice", "麵":"noodles","拉麵":"noodles","麵條":"noodles","義大利麵":"noodles", "麵包":"bread","吐司":"bread", "雞肉":"chicken","雞胸":"chicken","烤雞":"chicken", "滷雞腿":"soy_braised_chicken_leg","醬油雞腿":"soy_braised_chicken_leg", "鮭魚":"salmon","三文魚":"salmon", "豬排":"pork_chop", "豆腐":"tofu", "蛋":"egg","水煮蛋":"egg","荷包蛋":"egg", "花椰菜":"broccoli","青花菜":"broccoli","菠菜":"spinach", "香蕉":"banana","味噌湯":"miso_soup", } RULES = {"T2DM": {"carb_g_per_meal_max": 60}, "HTN": {"sodium_mg_per_meal_max": 600}} PORTION_MUL = {"小":0.8, "中":1.0, "大":1.2} def detect_foods_from_text(text: str) -> List[str]: lower = text.lower() labels = set() for k in FOOD_DB.keys(): if k in lower: labels.add(k) for alias, key in ALIASES.items(): if alias in text or alias.lower() in lower: labels.add(key) return list(labels) # 自由抽詞(允許未知) import re DEFAULT_BASE_G = 100 STOPWORDS = { # 英文 "a","an","the","with","and","of","on","in","to","served","over","side","sides", "set","dish","meal","mixed","assorted","fresh","hot","cold","topped","style","seasoned", # 中文 "便當","套餐","一盤","一碗","配菜","附餐","湯","沙拉","醬","佐","搭配","附","拌","炒","滷","炸","烤","蒸","煮" } COLOR_WORDS = {"white","black","red","green","yellow","orange","brown","purple","pink","golden"} UTENSILS = {"plate","bowl","tray","box","cup","glass","plateful","bento"} ADJ_MISC = {"piece","slice","fillet","serving","topped","mixed","assorted"} # 常見食物名詞(沒有就先列為候選) FOOD_LIKE = { "salad","fish","chicken","beef","pork","shrimp","tofu","egg", "rice","noodles","bread","soup","vegetables","veggies","fruit" } import re def extract_food_terms_free(text: str): """ 從 caption 中抽食物詞(允許未知): - 解析片語:piece/slice/fillet/serving of X → X - 切片(逗號/and/with),去掉顏色、器皿、形容詞停用詞 - 取片尾名詞;若無,掃描整句抓常見食物名詞 - Alias → 主鍵;沒對到就保留原字(當未知) """ t = text.strip().lower() hits = set() # 1) 特例:「X of Y」→ 直接抓 Y for pat in [r"(?:piece|slice|fillet|serving)\s+of\s+([a-z\u4e00-\u9fff]+)"]: for m in re.findall(pat, t, flags=re.I): y = m.strip() if y in COLOR_WORDS or y in UTENSILS or y in ADJ_MISC or y in STOPWORDS: continue hits.add(ALIASES.get(y, y)) # 2) 片段切分(逗號、分號、and、with、換行) parts = re.split(r"(?:,|;|\.|\band\b|\bwith\b|\n)+", t, flags=re.I) for p in parts: if not p: continue # 擷取英/中文字 toks = re.findall(r"[a-z\u4e00-\u9fff]+", p) # 過濾顏色/器皿/形容詞/停用詞 toks = [ w for w in toks if w not in COLOR_WORDS and w not in UTENSILS and w not in ADJ_MISC and w not in STOPWORDS and len(w) >= 2 ] if not toks: continue head = toks[-1] # 片尾通常是名詞,如 "salad"/"fish" hits.add(ALIASES.get(head, head)) # 3) 萬一片段沒抓到,再從整句補常見食物名詞 for w in FOOD_LIKE: if re.search(rf"\b{re.escape(w)}\b", t): hits.add(ALIASES.get(w, w)) return list(hits) def estimate_weight(name: str, plate_cm: int, portion: str) -> int: base = FOOD_DB.get(name, {}).get("base_g", DEFAULT_BASE_G) mul = PORTION_MUL.get(portion, 1.0) grams = int(base * mul * (plate_cm / 24)) return max(10, grams) def grams_to_nutrition(name: str, grams: int) -> Dict: info = FOOD_DB[name] ratio = grams / 100.0 out = {"name": name, "cat": info["cat"], "weight_g": grams, "tip": info.get("tip","")} for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"): out[k] = round(info[k] * ratio, 1) return out def make_placeholder_item(name: str, plate_cm: int, portion: str): grams = int(DEFAULT_BASE_G * (plate_cm / 24) * PORTION_MUL.get(portion, 1.0)) return { "name": name, "cat": "未分類", "weight_g": grams, "kcal": "待新增資訊", "carb_g": "待新增資訊", "protein_g": "待新增資訊", "fat_g": "待新增資訊", "sodium_mg": "待新增資訊", "tip": "待新增資訊" } def eval_rules(items: List[Dict], conditions: List[str]): totals = {} for it in items: # 只累加可數值的項 if isinstance(it.get("kcal"), (int, float)): for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"): totals[k] = round(totals.get(k,0) + float(it[k]), 1) advice = [] if "T2DM" in conditions and totals.get("carb_g",0) > RULES["T2DM"]["carb_g_per_meal_max"]: advice.append("【糖尿病】碳水偏高,建議主食減量或改全穀。") if "HTN" in conditions and totals.get("sodium_mg",0) > RULES["HTN"]["sodium_mg_per_meal_max"]: advice.append("【高血壓】鈉含量偏高,少鹽、避免重口味與滷味/湯品。") cats = {} for it in items: cats[it["cat"]] = cats.get(it["cat"], 0) + 1 return totals, advice, cats # ========== 4) Gradio 介面 ========== def run_pipeline(image, plate_cm, portion, conditions, task_mode, dev_mode): if image is None: return "請先上傳一張照片。", "", [], {} # 1) 文字輸出(Dev 模式跳過模型) if dev_mode: txt = "A bento with white rice, broccoli and grilled chicken thigh." else: t = "caption" if task_mode == "描述 (Caption)" else "object_detection" txt = florence2_text(image, task=t) # 2) 合併偵測 labels_known = detect_foods_from_text(txt) labels_free = extract_food_terms_free(txt) labels_all = [] seen = set() for term in labels_free + labels_known: key = ALIASES.get(term, term) if key not in seen: labels_all.append(key) seen.add(key) # 3) 產生 items(未知也照列) items = [] for name in labels_all[:6]: if name in FOOD_DB: g = estimate_weight(name, plate_cm, portion) items.append(grams_to_nutrition(name, g)) else: items.append(make_placeholder_item(name, plate_cm, portion)) totals, advice, cats = eval_rules(items, conditions) # 4) 組合輸出 lines = [f"模型輸出:{txt}", ""] if labels_all: lines.append("偵測到: " + ", ".join(labels_all)) else: lines.append("偵測到: (無)") lines.append("") for it in items: kcal = it['kcal'] if isinstance(it['kcal'], (int, float)) else it['kcal'] carb = it['carb_g'] if isinstance(it['carb_g'], (int, float)) else it['carb_g'] prot = it['protein_g'] if isinstance(it['protein_g'], (int, float)) else it['protein_g'] fat = it['fat_g'] if isinstance(it['fat_g'], (int, float)) else it['fat_g'] na = it['sodium_mg'] if isinstance(it['sodium_mg'], (int, float)) else it['sodium_mg'] lines.append(f"- {it['name']} ({it['cat']}) {it['weight_g']} g → " f"{kcal} kcal, C{carb} g, P{prot} g, F{fat} g, Na{na} mg") if totals: lines.append("") lines.append(f"總計:{totals.get('kcal',0)} kcal,碳水 {totals.get('carb_g',0)} g,蛋白 {totals.get('protein_g',0)} g,脂肪 {totals.get('fat_g',0)} g,鈉 {totals.get('sodium_mg',0)} mg") if advice: lines.append("建議:" + " ".join(advice)) return "\n".join(lines), txt, items, totals with gr.Blocks(title="FoodAI · Florence-2 Demo") as demo: gr.Markdown("# 🍱 FoodAI · Florence-2 Demo\n上傳餐點 → 產生描述/偵測 → 估營養/建議\n\n> 開發模式:不跑模型,固定假字串方便測試 UI/流程。") with gr.Row(): with gr.Column(scale=1): img = gr.Image(type="pil", label="上傳圖片") plate = gr.Slider(18, 28, value=24, step=1, label="盤子直徑 (cm)") portion = gr.Radio(["小", "中", "大"], value="中", label="份量") cond = gr.CheckboxGroup(["T2DM", "HTN"], label="狀況") task_mode = gr.Radio(["描述 (Caption)", "偵測 (Object Detection)"], value="描述 (Caption)", label="任務") dev_mode = gr.Checkbox(label="開發模式(不跑模型)", value=False) btn = gr.Button("開始分析", variant="primary") with gr.Column(scale=1): out_md = gr.Markdown(label="結果") raw = gr.Textbox(label="模型原始輸出", lines=4) js = gr.JSON(label="逐項結果") total = gr.JSON(label="總計") btn.click(run_pipeline, inputs=[img, plate, portion, cond, task_mode, dev_mode], outputs=[out_md, raw, js, total]) if __name__ == "__main__": # 在本地/Spaces 都可直接執行 PORT = int(os.getenv("PORT", "7860")) demo.launch(server_name="0.0.0.0", server_port=PORT)