Spaces:
Sleeping
Sleeping
# 檔案路徑: backend/app/services/weight_estimation_service.py | |
import logging | |
import numpy as np | |
from PIL import Image | |
import io | |
from typing import Dict, Any, List, Optional, Tuple | |
import torch | |
from ultralytics import YOLO | |
# 設置日誌 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# 食物密度表 (g/cm³) - 常見食物的平均密度 | |
FOOD_DENSITY_TABLE = { | |
"rice": 0.8, # 米飯 | |
"fried_rice": 0.7, # 炒飯 | |
"noodles": 0.6, # 麵條 | |
"bread": 0.3, # 麵包 | |
"meat": 1.0, # 肉類 | |
"fish": 1.1, # 魚類 | |
"vegetables": 0.4, # 蔬菜 | |
"fruits": 0.8, # 水果 | |
"soup": 1.0, # 湯類 | |
"default": 0.8 # 預設密度 | |
} | |
# 參考物尺寸表 (cm) | |
REFERENCE_OBJECTS = { | |
"plate": {"diameter": 24.0}, # 標準餐盤直徑 | |
"bowl": {"diameter": 15.0}, # 標準碗直徑 | |
"spoon": {"length": 15.0}, # 湯匙長度 | |
"fork": {"length": 20.0}, # 叉子長度 | |
"default": {"diameter": 24.0} # 預設參考物 | |
} | |
class WeightEstimationService: | |
def __init__(self): | |
"""初始化重量估算服務""" | |
self.sam_model = None | |
self.dpt_model = None | |
self.detection_model = None | |
self._load_models() | |
def _load_models(self): | |
"""載入所需的 AI 模型""" | |
try: | |
# 載入 SAM 分割模型 | |
from transformers import SamModel, SamProcessor | |
logger.info("正在載入 SAM 分割模型...") | |
self.sam_model = SamModel.from_pretrained("facebook/sam-vit-base") | |
self.sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
# 載入 DPT 深度估計模型 | |
from transformers import pipeline | |
logger.info("正在載入 DPT 深度估計模型...") | |
self.dpt_model = pipeline("depth-estimation", model="Intel/dpt-large") | |
# 載入 YOLOv8 物件偵測模型(用於偵測參考物) | |
logger.info("正在載入 YOLOv8 物件偵測模型...") | |
self.detection_model = YOLO("yolov8n.pt") # 你可以改成 yolov5s.pt 或自訂模型 | |
logger.info("所有模型載入完成!") | |
except Exception as e: | |
logger.error(f"模型載入失敗: {str(e)}") | |
raise | |
def detect_objects(self, image: Image.Image) -> List[Dict[str, Any]]: | |
"""使用 YOLOv8 偵測圖片中的所有物體""" | |
try: | |
results = self.detection_model(image) | |
detected_objects = [] | |
for result in results[0].boxes.data.tolist(): | |
x1, y1, x2, y2, conf, class_id = result | |
label = self.detection_model.model.names[int(class_id)].lower() | |
# 我們對所有高信度的物體都感興趣,除了明確的餐具 | |
if conf > 0.4 and label not in ["spoon", "fork", "knife", "scissors"]: | |
detected_objects.append({ | |
"label": label, | |
"bbox": [x1, y1, x2, y2], | |
"confidence": conf | |
}) | |
return detected_objects | |
except Exception as e: | |
logger.warning(f"物件偵測失敗: {str(e)}") | |
return [] | |
def segment_food(self, image: Image.Image, input_boxes: List[List[float]]) -> List[np.ndarray]: | |
"""使用 SAM 根據提供的邊界框分割食物區域""" | |
if not input_boxes: | |
return [] | |
try: | |
# 使用 SAM 進行分割,並提供邊界框作為提示 | |
inputs = self.sam_processor(image, input_boxes=[input_boxes], return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.sam_model(**inputs) | |
# 取得分割遮罩 | |
masks_tensor = self.sam_processor.image_processor.post_process_masks( | |
outputs.pred_masks.sigmoid(), | |
inputs["original_sizes"], | |
inputs["reshaped_input_sizes"] | |
)[0] | |
# 將 Tensor 轉換為 list of numpy arrays | |
masks = [m.squeeze().cpu().numpy().astype(bool) for m in masks_tensor] | |
return masks | |
except Exception as e: | |
logger.error(f"食物分割失敗: {str(e)}") | |
return [] | |
def estimate_depth(self, image: Image.Image) -> np.ndarray: | |
"""使用 DPT 進行深度估計""" | |
try: | |
# 使用 DPT 進行深度估計 | |
depth_result = self.dpt_model(image) | |
depth_map = depth_result["depth"] | |
return np.array(depth_map) | |
except Exception as e: | |
logger.error(f"深度估計失敗: {str(e)}") | |
# 回傳一個預設的深度圖 | |
return np.ones((image.height, image.width)) | |
def calculate_volume_and_weight(self, | |
mask: np.ndarray, | |
depth_map: np.ndarray, | |
food_type: str, | |
reference_object: Optional[Dict[str, Any]] = None) -> Tuple[float, float, float]: | |
"""計算體積和重量""" | |
try: | |
# 計算食物區域的像素數量 | |
food_pixels = np.sum(mask) | |
# 計算食物區域的平均深度 | |
food_depth = np.mean(depth_map[mask]) | |
# 估算體積(相對體積) | |
relative_volume = food_pixels * food_depth | |
# 如果有參考物,進行尺寸校正 | |
if reference_object: | |
ref_type = reference_object["label"] # Changed from "type" to "label" | |
if ref_type in REFERENCE_OBJECTS: | |
ref_size = REFERENCE_OBJECTS[ref_type] | |
# 根據參考物尺寸校正體積 | |
if "diameter" in ref_size: | |
# 圓形參考物(如餐盤) | |
pixel_to_cm_ratio = ref_size["diameter"] / np.sqrt(food_pixels / np.pi) | |
else: | |
# 線性參考物(如餐具) | |
pixel_to_cm_ratio = ref_size["length"] / np.sqrt(food_pixels) | |
# 校正體積 | |
actual_volume = relative_volume * (pixel_to_cm_ratio ** 3) | |
confidence = 0.85 # 有參考物時信心度較高 | |
error_range = 0.15 # ±15% 誤差 | |
else: | |
actual_volume = relative_volume * 0.1 # 預設校正係數 | |
confidence = 0.6 | |
error_range = 0.3 | |
else: | |
# 無參考物,使用預設值 | |
actual_volume = relative_volume * 0.1 # 預設校正係數 | |
confidence = 0.5 # 無參考物時信心度較低 | |
error_range = 0.4 # ±40% 誤差 | |
# 根據食物類型取得密度 | |
density = self.get_food_density(food_type) | |
# 計算重量 (g) | |
weight = actual_volume * density | |
# 對單一物件的重量做一個合理性檢查 | |
if weight > 1500: # > 1.5kg | |
logger.warning(f"單一物件預估重量 {weight:.2f}g 過高,可能不準確。") | |
return weight, confidence, error_range | |
except Exception as e: | |
logger.error(f"體積重量計算失敗: {str(e)}") | |
return 150.0, 0.3, 0.5 # 預設值 | |
def get_food_density(self, food_name: str) -> float: | |
"""根據食物名稱取得密度""" | |
food_name_lower = food_name.lower() | |
# 簡單的關鍵字匹配 | |
if any(keyword in food_name_lower for keyword in ["rice", "飯"]): | |
return FOOD_DENSITY_TABLE["rice"] | |
elif any(keyword in food_name_lower for keyword in ["noodle", "麵"]): | |
return FOOD_DENSITY_TABLE["noodles"] | |
elif any(keyword in food_name_lower for keyword in ["meat", "肉", "chicken", "pork", "beef", "lamb"]): | |
return FOOD_DENSITY_TABLE["meat"] | |
elif any(keyword in food_name_lower for keyword in ["vegetable", "菜"]): | |
return FOOD_DENSITY_TABLE["vegetables"] | |
else: | |
return FOOD_DENSITY_TABLE["default"] | |
# 全域服務實例 | |
weight_service = WeightEstimationService() | |
async def estimate_food_weight(image_bytes: bytes, debug: bool = False) -> Dict[str, Any]: | |
""" | |
整合食物辨識、重量估算與營養分析的主函數 (YOLO + SAM 引導模式) | |
""" | |
debug_dir = None | |
try: | |
if debug: | |
import os | |
from datetime import datetime | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
debug_dir = os.path.join("debug_output", timestamp) | |
os.makedirs(debug_dir, exist_ok=True) | |
# 將 bytes 轉換為 PIL Image | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
if debug: | |
image.save(os.path.join(debug_dir, "00_original.jpg")) | |
# 1. 物件偵測 (YOLO),取得所有物件的邊界框 | |
all_objects = weight_service.detect_objects(image) | |
if not all_objects: | |
note = "無法從圖片中偵測到任何物體。" | |
result = {"detected_foods": [], "total_estimated_weight": 0, "total_nutrition": {}, "note": note} | |
if debug: result["debug_output_path"] = debug_dir | |
return result | |
if debug: | |
from PIL import ImageDraw | |
debug_image = image.copy() | |
draw = ImageDraw.Draw(debug_image) | |
for obj in all_objects: | |
bbox = obj.get("bbox") | |
label = obj.get("label", "unknown") | |
draw.rectangle(bbox, outline="red", width=3) | |
draw.text((bbox[0], bbox[1]), label, fill="red") | |
debug_image.save(os.path.join(debug_dir, "01_detected_objects.jpg")) | |
# 2. 尋找參考物 (如餐盤、碗) | |
reference_objects = [obj for obj in all_objects if obj["label"] in ["plate", "bowl"]] | |
reference_object = max(reference_objects, key=lambda x: x["confidence"]) if reference_objects else None | |
# 3. 深度估計 (DPT),只需執行一次 | |
depth_map = weight_service.estimate_depth(image) | |
if debug: | |
depth_for_save = (depth_map - np.min(depth_map)) / (np.max(depth_map) - np.min(depth_map) + 1e-6) * 255.0 | |
Image.fromarray(depth_for_save.astype(np.uint8)).convert("L").save(os.path.join(debug_dir, "03_depth_map.png")) | |
# 載入相關服務 | |
from .ai_service import classify_food_image | |
from .nutrition_api_service import fetch_nutrition_data | |
detected_foods = [] | |
total_nutrition = {"calories": 0, "protein": 0, "carbs": 0, "fat": 0, "fiber": 0} | |
# 4. 遍歷每個偵測到的物件 (YOLO Box) | |
food_objects = [obj for obj in all_objects if obj["label"] not in ["plate", "bowl"]] | |
for i, food_obj in enumerate(food_objects): | |
try: | |
# a. 使用物件的邊界框提示 SAM 進行精準分割 | |
input_box = [food_obj["bbox"]] | |
masks = weight_service.segment_food(image, input_boxes=input_box) | |
if not masks: continue | |
# SAM 對於一個 prompt 可能回傳多個 mask,我們選最大的一個 | |
mask = max(masks, key=lambda m: np.sum(m)) | |
# b. 根據遮罩裁切出單一食物的圖片 (辨識用) | |
# (此部分邏輯與先前版本相同) | |
rows, cols = np.any(mask, axis=1), np.any(mask, axis=0) | |
if not np.any(rows) or not np.any(cols): continue | |
rmin, rmax = np.where(rows)[0][[0, -1]] | |
cmin, cmax = np.where(cols)[0][[0, -1]] | |
item_array = np.array(image); item_rgba = np.zeros((*item_array.shape[:2], 4), dtype=np.uint8) | |
item_rgba[:,:,:3] = item_array; item_rgba[:,:,3] = mask * 255 | |
cropped_pil = Image.fromarray(item_rgba[rmin:rmax+1, cmin:cmax+1, :], 'RGBA') | |
buffer = io.BytesIO(); cropped_pil.save(buffer, format="PNG"); item_image_bytes = buffer.getvalue() | |
if debug: | |
cropped_pil.save(os.path.join(debug_dir, f"item_{i}_{food_obj['label']}_cropped.png")) | |
# c. 辨識食物種類 (使用更精準的食物辨識模型) | |
food_name = classify_food_image(item_image_bytes) | |
# d. 計算體積和重量 | |
weight, confidence, error_range = weight_service.calculate_volume_and_weight( | |
mask, depth_map, food_name, reference_object | |
) | |
# e. 查詢營養資訊 | |
nutrition_info = fetch_nutrition_data(food_name) | |
if nutrition_info is None: | |
nutrition_info = {"calories": 0, "protein": 0, "carbs": 0, "fat": 0, "fiber": 0} | |
# f. 根據重量調整營養素 | |
weight_ratio = weight / 100 | |
adjusted_nutrition = {k: v * weight_ratio for k, v in nutrition_info.items()} | |
# g. 累加總營養 | |
for key in total_nutrition: total_nutrition[key] += adjusted_nutrition.get(key, 0) | |
# h. 儲存單項食物結果 | |
detected_foods.append({ | |
"food_name": food_name, | |
"estimated_weight": round(weight, 1), | |
"nutrition": {k: round(v, 1) for k, v in adjusted_nutrition.items()} | |
}) | |
except Exception as item_e: | |
logger.error(f"處理物件 '{food_obj['label']}' 時失敗: {str(item_e)}") | |
continue | |
# 5. 生成備註 | |
note = f"已使用 YOLO+SAM 模型成功分析 {len(detected_foods)} 項食物。" | |
if reference_object: | |
note += f" 檢測到參考物:{reference_object['label']},準確度較高。" | |
else: | |
note += " 未檢測到參考物,重量為估算值,結果僅供參考。" | |
result = { | |
"detected_foods": detected_foods, | |
"total_estimated_weight": round(sum(item['estimated_weight'] for item in detected_foods), 1), | |
"total_nutrition": {k: round(v, 1) for k, v in total_nutrition.items()}, | |
"reference_object": reference_object["label"] if reference_object else None, | |
"note": note | |
} | |
if debug: | |
# 儲存最終分割圖 | |
overlay_img = image.copy() | |
overlay_array = np.array(overlay_img) | |
# Find all masks again to draw | |
all_food_boxes = [obj['bbox'] for obj in food_objects] | |
all_masks = weight_service.segment_food(image, input_boxes=all_food_boxes) | |
for mask in all_masks: | |
color = np.random.randint(0, 255, size=3, dtype=np.uint8) | |
overlay_array[mask] = (overlay_array[mask] * 0.5 + color * 0.5).astype(np.uint8) | |
Image.fromarray(overlay_array).save(os.path.join(debug_dir, "02_final_segmentation.jpg")) | |
result["debug_output_path"] = debug_dir | |
return result | |
except Exception as e: | |
logger.error(f"多食物重量估算主流程失敗: {str(e)}") | |
# 回傳包含錯誤訊息的標準結構 | |
result = { | |
"detected_foods": [], | |
"total_estimated_weight": 0, | |
"total_nutrition": {}, | |
"reference_object": None, | |
"note": f"分析失敗: {str(e)}" | |
} | |
if debug and debug_dir: | |
result["debug_output_path"] = debug_dir | |
return result |