health-assistant / app /services /weight_estimation_service.py
yuting111222's picture
Update health assistant minimal with new services and improvements
a608ddf
# 檔案路徑: backend/app/services/weight_estimation_service.py
import logging
import numpy as np
from PIL import Image
import io
from typing import Dict, Any, List, Optional, Tuple
import torch
from ultralytics import YOLO
# 設置日誌
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 食物密度表 (g/cm³) - 常見食物的平均密度
FOOD_DENSITY_TABLE = {
"rice": 0.8, # 米飯
"fried_rice": 0.7, # 炒飯
"noodles": 0.6, # 麵條
"bread": 0.3, # 麵包
"meat": 1.0, # 肉類
"fish": 1.1, # 魚類
"vegetables": 0.4, # 蔬菜
"fruits": 0.8, # 水果
"soup": 1.0, # 湯類
"default": 0.8 # 預設密度
}
# 參考物尺寸表 (cm)
REFERENCE_OBJECTS = {
"plate": {"diameter": 24.0}, # 標準餐盤直徑
"bowl": {"diameter": 15.0}, # 標準碗直徑
"spoon": {"length": 15.0}, # 湯匙長度
"fork": {"length": 20.0}, # 叉子長度
"default": {"diameter": 24.0} # 預設參考物
}
class WeightEstimationService:
def __init__(self):
"""初始化重量估算服務"""
self.sam_model = None
self.dpt_model = None
self.detection_model = None
self._load_models()
def _load_models(self):
"""載入所需的 AI 模型"""
try:
# 載入 SAM 分割模型
from transformers import SamModel, SamProcessor
logger.info("正在載入 SAM 分割模型...")
self.sam_model = SamModel.from_pretrained("facebook/sam-vit-base")
self.sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# 載入 DPT 深度估計模型
from transformers import pipeline
logger.info("正在載入 DPT 深度估計模型...")
self.dpt_model = pipeline("depth-estimation", model="Intel/dpt-large")
# 載入 YOLOv8 物件偵測模型(用於偵測參考物)
logger.info("正在載入 YOLOv8 物件偵測模型...")
self.detection_model = YOLO("yolov8n.pt") # 你可以改成 yolov5s.pt 或自訂模型
logger.info("所有模型載入完成!")
except Exception as e:
logger.error(f"模型載入失敗: {str(e)}")
raise
def detect_objects(self, image: Image.Image) -> List[Dict[str, Any]]:
"""使用 YOLOv8 偵測圖片中的所有物體"""
try:
results = self.detection_model(image)
detected_objects = []
for result in results[0].boxes.data.tolist():
x1, y1, x2, y2, conf, class_id = result
label = self.detection_model.model.names[int(class_id)].lower()
# 我們對所有高信度的物體都感興趣,除了明確的餐具
if conf > 0.4 and label not in ["spoon", "fork", "knife", "scissors"]:
detected_objects.append({
"label": label,
"bbox": [x1, y1, x2, y2],
"confidence": conf
})
return detected_objects
except Exception as e:
logger.warning(f"物件偵測失敗: {str(e)}")
return []
def segment_food(self, image: Image.Image, input_boxes: List[List[float]]) -> List[np.ndarray]:
"""使用 SAM 根據提供的邊界框分割食物區域"""
if not input_boxes:
return []
try:
# 使用 SAM 進行分割,並提供邊界框作為提示
inputs = self.sam_processor(image, input_boxes=[input_boxes], return_tensors="pt")
with torch.no_grad():
outputs = self.sam_model(**inputs)
# 取得分割遮罩
masks_tensor = self.sam_processor.image_processor.post_process_masks(
outputs.pred_masks.sigmoid(),
inputs["original_sizes"],
inputs["reshaped_input_sizes"]
)[0]
# 將 Tensor 轉換為 list of numpy arrays
masks = [m.squeeze().cpu().numpy().astype(bool) for m in masks_tensor]
return masks
except Exception as e:
logger.error(f"食物分割失敗: {str(e)}")
return []
def estimate_depth(self, image: Image.Image) -> np.ndarray:
"""使用 DPT 進行深度估計"""
try:
# 使用 DPT 進行深度估計
depth_result = self.dpt_model(image)
depth_map = depth_result["depth"]
return np.array(depth_map)
except Exception as e:
logger.error(f"深度估計失敗: {str(e)}")
# 回傳一個預設的深度圖
return np.ones((image.height, image.width))
def calculate_volume_and_weight(self,
mask: np.ndarray,
depth_map: np.ndarray,
food_type: str,
reference_object: Optional[Dict[str, Any]] = None) -> Tuple[float, float, float]:
"""計算體積和重量"""
try:
# 計算食物區域的像素數量
food_pixels = np.sum(mask)
# 計算食物區域的平均深度
food_depth = np.mean(depth_map[mask])
# 估算體積(相對體積)
relative_volume = food_pixels * food_depth
# 如果有參考物,進行尺寸校正
if reference_object:
ref_type = reference_object["label"] # Changed from "type" to "label"
if ref_type in REFERENCE_OBJECTS:
ref_size = REFERENCE_OBJECTS[ref_type]
# 根據參考物尺寸校正體積
if "diameter" in ref_size:
# 圓形參考物(如餐盤)
pixel_to_cm_ratio = ref_size["diameter"] / np.sqrt(food_pixels / np.pi)
else:
# 線性參考物(如餐具)
pixel_to_cm_ratio = ref_size["length"] / np.sqrt(food_pixels)
# 校正體積
actual_volume = relative_volume * (pixel_to_cm_ratio ** 3)
confidence = 0.85 # 有參考物時信心度較高
error_range = 0.15 # ±15% 誤差
else:
actual_volume = relative_volume * 0.1 # 預設校正係數
confidence = 0.6
error_range = 0.3
else:
# 無參考物,使用預設值
actual_volume = relative_volume * 0.1 # 預設校正係數
confidence = 0.5 # 無參考物時信心度較低
error_range = 0.4 # ±40% 誤差
# 根據食物類型取得密度
density = self.get_food_density(food_type)
# 計算重量 (g)
weight = actual_volume * density
# 對單一物件的重量做一個合理性檢查
if weight > 1500: # > 1.5kg
logger.warning(f"單一物件預估重量 {weight:.2f}g 過高,可能不準確。")
return weight, confidence, error_range
except Exception as e:
logger.error(f"體積重量計算失敗: {str(e)}")
return 150.0, 0.3, 0.5 # 預設值
def get_food_density(self, food_name: str) -> float:
"""根據食物名稱取得密度"""
food_name_lower = food_name.lower()
# 簡單的關鍵字匹配
if any(keyword in food_name_lower for keyword in ["rice", "飯"]):
return FOOD_DENSITY_TABLE["rice"]
elif any(keyword in food_name_lower for keyword in ["noodle", "麵"]):
return FOOD_DENSITY_TABLE["noodles"]
elif any(keyword in food_name_lower for keyword in ["meat", "肉", "chicken", "pork", "beef", "lamb"]):
return FOOD_DENSITY_TABLE["meat"]
elif any(keyword in food_name_lower for keyword in ["vegetable", "菜"]):
return FOOD_DENSITY_TABLE["vegetables"]
else:
return FOOD_DENSITY_TABLE["default"]
# 全域服務實例
weight_service = WeightEstimationService()
async def estimate_food_weight(image_bytes: bytes, debug: bool = False) -> Dict[str, Any]:
"""
整合食物辨識、重量估算與營養分析的主函數 (YOLO + SAM 引導模式)
"""
debug_dir = None
try:
if debug:
import os
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
debug_dir = os.path.join("debug_output", timestamp)
os.makedirs(debug_dir, exist_ok=True)
# 將 bytes 轉換為 PIL Image
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
if debug:
image.save(os.path.join(debug_dir, "00_original.jpg"))
# 1. 物件偵測 (YOLO),取得所有物件的邊界框
all_objects = weight_service.detect_objects(image)
if not all_objects:
note = "無法從圖片中偵測到任何物體。"
result = {"detected_foods": [], "total_estimated_weight": 0, "total_nutrition": {}, "note": note}
if debug: result["debug_output_path"] = debug_dir
return result
if debug:
from PIL import ImageDraw
debug_image = image.copy()
draw = ImageDraw.Draw(debug_image)
for obj in all_objects:
bbox = obj.get("bbox")
label = obj.get("label", "unknown")
draw.rectangle(bbox, outline="red", width=3)
draw.text((bbox[0], bbox[1]), label, fill="red")
debug_image.save(os.path.join(debug_dir, "01_detected_objects.jpg"))
# 2. 尋找參考物 (如餐盤、碗)
reference_objects = [obj for obj in all_objects if obj["label"] in ["plate", "bowl"]]
reference_object = max(reference_objects, key=lambda x: x["confidence"]) if reference_objects else None
# 3. 深度估計 (DPT),只需執行一次
depth_map = weight_service.estimate_depth(image)
if debug:
depth_for_save = (depth_map - np.min(depth_map)) / (np.max(depth_map) - np.min(depth_map) + 1e-6) * 255.0
Image.fromarray(depth_for_save.astype(np.uint8)).convert("L").save(os.path.join(debug_dir, "03_depth_map.png"))
# 載入相關服務
from .ai_service import classify_food_image
from .nutrition_api_service import fetch_nutrition_data
detected_foods = []
total_nutrition = {"calories": 0, "protein": 0, "carbs": 0, "fat": 0, "fiber": 0}
# 4. 遍歷每個偵測到的物件 (YOLO Box)
food_objects = [obj for obj in all_objects if obj["label"] not in ["plate", "bowl"]]
for i, food_obj in enumerate(food_objects):
try:
# a. 使用物件的邊界框提示 SAM 進行精準分割
input_box = [food_obj["bbox"]]
masks = weight_service.segment_food(image, input_boxes=input_box)
if not masks: continue
# SAM 對於一個 prompt 可能回傳多個 mask,我們選最大的一個
mask = max(masks, key=lambda m: np.sum(m))
# b. 根據遮罩裁切出單一食物的圖片 (辨識用)
# (此部分邏輯與先前版本相同)
rows, cols = np.any(mask, axis=1), np.any(mask, axis=0)
if not np.any(rows) or not np.any(cols): continue
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
item_array = np.array(image); item_rgba = np.zeros((*item_array.shape[:2], 4), dtype=np.uint8)
item_rgba[:,:,:3] = item_array; item_rgba[:,:,3] = mask * 255
cropped_pil = Image.fromarray(item_rgba[rmin:rmax+1, cmin:cmax+1, :], 'RGBA')
buffer = io.BytesIO(); cropped_pil.save(buffer, format="PNG"); item_image_bytes = buffer.getvalue()
if debug:
cropped_pil.save(os.path.join(debug_dir, f"item_{i}_{food_obj['label']}_cropped.png"))
# c. 辨識食物種類 (使用更精準的食物辨識模型)
food_name = classify_food_image(item_image_bytes)
# d. 計算體積和重量
weight, confidence, error_range = weight_service.calculate_volume_and_weight(
mask, depth_map, food_name, reference_object
)
# e. 查詢營養資訊
nutrition_info = fetch_nutrition_data(food_name)
if nutrition_info is None:
nutrition_info = {"calories": 0, "protein": 0, "carbs": 0, "fat": 0, "fiber": 0}
# f. 根據重量調整營養素
weight_ratio = weight / 100
adjusted_nutrition = {k: v * weight_ratio for k, v in nutrition_info.items()}
# g. 累加總營養
for key in total_nutrition: total_nutrition[key] += adjusted_nutrition.get(key, 0)
# h. 儲存單項食物結果
detected_foods.append({
"food_name": food_name,
"estimated_weight": round(weight, 1),
"nutrition": {k: round(v, 1) for k, v in adjusted_nutrition.items()}
})
except Exception as item_e:
logger.error(f"處理物件 '{food_obj['label']}' 時失敗: {str(item_e)}")
continue
# 5. 生成備註
note = f"已使用 YOLO+SAM 模型成功分析 {len(detected_foods)} 項食物。"
if reference_object:
note += f" 檢測到參考物:{reference_object['label']},準確度較高。"
else:
note += " 未檢測到參考物,重量為估算值,結果僅供參考。"
result = {
"detected_foods": detected_foods,
"total_estimated_weight": round(sum(item['estimated_weight'] for item in detected_foods), 1),
"total_nutrition": {k: round(v, 1) for k, v in total_nutrition.items()},
"reference_object": reference_object["label"] if reference_object else None,
"note": note
}
if debug:
# 儲存最終分割圖
overlay_img = image.copy()
overlay_array = np.array(overlay_img)
# Find all masks again to draw
all_food_boxes = [obj['bbox'] for obj in food_objects]
all_masks = weight_service.segment_food(image, input_boxes=all_food_boxes)
for mask in all_masks:
color = np.random.randint(0, 255, size=3, dtype=np.uint8)
overlay_array[mask] = (overlay_array[mask] * 0.5 + color * 0.5).astype(np.uint8)
Image.fromarray(overlay_array).save(os.path.join(debug_dir, "02_final_segmentation.jpg"))
result["debug_output_path"] = debug_dir
return result
except Exception as e:
logger.error(f"多食物重量估算主流程失敗: {str(e)}")
# 回傳包含錯誤訊息的標準結構
result = {
"detected_foods": [],
"total_estimated_weight": 0,
"total_nutrition": {},
"reference_object": None,
"note": f"分析失敗: {str(e)}"
}
if debug and debug_dir:
result["debug_output_path"] = debug_dir
return result