import logging import traceback from typing import Dict, List, Any, Optional logger = logging.getLogger(__name__) class ObjectExtractor: """ 專門處理物件檢測結果的提取和預處理 負責從YOLO檢測結果提取物件資訊、物件分類和核心物件的辨識 """ def __init__(self, class_names: Dict[int, str] = None, object_categories: Dict[str, List[int]] = None): """ 初始化物件提取器 Args: class_names: 類別ID到類別名稱的映射字典 object_categories: 物件類別分組字典 """ try: self.class_names = class_names or {} self.object_categories = object_categories or {} # 1. 讀取並設定基本信心度門檻(如果外部沒傳,就預設 0.25) self.base_conf_threshold = 0.25 # 2. 動態信心度調整映射表 (key: 小寫 class_name, value: 調整係數) # 最終的門檻 = base_conf_threshold * factor # 如果某個 class_name 沒在這裡,就直接用 base_conf_threshold(相當於 factor=1.0) self.dynamic_conf_map = { "traffic light": 0.6, "car": 0.8, "person": 0.7, } logger.info(f"ObjectExtractor initialized with {len(self.class_names)} class names and {len(self.object_categories)} object categories") except Exception as e: logger.error(f"Failed to initialize ObjectExtractor: {str(e)}") logger.error(traceback.format_exc()) raise def _get_dynamic_threshold(self, class_name: str) -> float: """ 根據 class_name 從 dynamic_conf_map 拿到 factor,計算最終的信心度門檻: threshold = base_conf_threshold * factor 如果 class_name 不在映射表裡,就回傳 base_conf_threshold。 """ # 使用小寫做匹配,確保在 dynamic_conf_map 裡的 key 也都用小寫 key = class_name.lower() factor = self.dynamic_conf_map.get(key, 1.0) return self.base_conf_threshold * factor def extract_detected_objects( self, detection_result: Any, confidence_threshold: float = 0.25, region_analyzer=None ) -> List[Dict]: """ 從檢測結果中提取物件資訊,包含位置資訊 Args: detection_result: YOLO檢測結果 confidence_threshold: 改由動態門檻決定 region_analyzer: 區域分析器實例,用於判斷物件所屬區域 Returns: 包含檢測物件資訊的字典列表 """ try: # 調試信息:記錄當前類別映射狀態 logger.info(f"ObjectExtractor.extract_detected_objects called") logger.info(f"Current class_names keys: {list(self.class_names.keys()) if self.class_names else 'None'}") if detection_result is None: logger.warning("Detection result is None") return [] if not hasattr(detection_result, 'boxes'): logger.error("Detection result does not have boxes attribute") return [] boxes = detection_result.boxes.xyxy.cpu().numpy() classes = detection_result.boxes.cls.cpu().numpy().astype(int) confidences = detection_result.boxes.conf.cpu().numpy() # 獲取圖像尺寸 img_height, img_width = detection_result.orig_shape[:2] detected_objects = [] for box, class_id, confidence in zip(boxes, classes, confidences): try: # 1. 先拿到這筆偵測物件的 class_name class_name = self.class_names.get(int(class_id), f"unknown_class_{class_id}") # 2. 計算這個 class 應該採用的動態 threshold dyn_thr = self._get_dynamic_threshold(class_name) # e.g. 0.25 * factor # 3. 如果 confidence < dyn_thr,就跳過這一筆 if confidence < dyn_thr: continue # 後面維持原本的座標、中心、大小、區域等資訊計算 x1, y1, x2, y2 = box width = x2 - x1 height = y2 - y1 # 中心點計算 center_x = (x1 + x2) / 2 center_y = (y1 + y2) / 2 # 標準化位置 (0-1) norm_x = center_x / img_width norm_y = center_y / img_height norm_width = width / img_width norm_height = height / img_height # 面積計算 area = width * height norm_area = area / (img_width * img_height) # 區域判斷 object_region = "unknown" if region_analyzer: object_region = region_analyzer.determine_region(norm_x, norm_y) # 調試信息:記錄映射過程 if class_name.startswith("unknown_class_"): logger.warning( f"Class ID {class_id} not found in class_names. " f"Available keys: {list(self.class_names.keys())}" ) else: logger.debug(f"Successfully mapped class ID {class_id} to '{class_name}'") detected_objects.append({ "class_id": int(class_id), "class_name": class_name, "confidence": float(confidence), "box": [float(x1), float(y1), float(x2), float(y2)], "center": [float(center_x), float(center_y)], "normalized_center": [float(norm_x), float(norm_y)], "size": [float(width), float(height)], "normalized_size": [float(norm_width), float(norm_height)], "area": float(area), "normalized_area": float(norm_area), "region": object_region }) except Exception as e: logger.error(f"Error processing object with class_id {class_id}: {str(e)}") continue logger.info(f"Extracted {len(detected_objects)} objects from detection result") # print(f"DEBUG: ObjectExtractor filtered objects by class:") # for class_name in ["car", "traffic light", "person", "handbag"]: # class_objects = [obj for obj in detected_objects if obj.get("class_name") == class_name] # if class_objects: # confidences = [obj.get("confidence", 0) for obj in class_objects] # print(f"DEBUG: {class_name}: {len(class_objects)} objects, confidences: {confidences}") # print(f"DEBUG: base_conf_threshold: {self.base_conf_threshold}") # print(f"DEBUG: dynamic_conf_map: {self.dynamic_conf_map}") return detected_objects except Exception as e: logger.error(f"Error extracting detected objects: {str(e)}") logger.error(traceback.format_exc()) return [] def update_class_names(self, class_names: Dict[int, str]): """ 動態更新類別名稱映射 Args: class_names: 新的類別名稱映射字典 """ try: self.class_names = class_names or {} logger.info(f"Class names updated: {len(self.class_names)} classes") logger.debug(f"Updated class names: {self.class_names}") except Exception as e: logger.error(f"Failed to update class names: {str(e)}") def categorize_object(self, obj: Dict) -> str: """ 將檢測到的物件分類到功能類別中,用於區域識別 Args: obj: 物件字典 Returns: 物件功能類別字串 """ try: class_id = obj.get("class_id", -1) class_name = obj.get("class_name", "").lower() # 使用現有的類別映射(如果可用) if self.object_categories: for category, ids in self.object_categories.items(): if class_id in ids: return category # 基於COCO類別名稱的後備分類 furniture_items = ["chair", "couch", "bed", "dining table", "toilet"] plant_items = ["potted plant"] electronic_items = ["tv", "laptop", "mouse", "remote", "keyboard", "cell phone"] vehicle_items = ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"] person_items = ["person"] kitchen_items = ["bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "refrigerator", "oven", "toaster", "sink", "microwave"] sports_items = ["frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket"] personal_items = ["handbag", "tie", "suitcase", "umbrella", "backpack"] if any(item in class_name for item in furniture_items): return "furniture" elif any(item in class_name for item in plant_items): return "plant" elif any(item in class_name for item in electronic_items): return "electronics" elif any(item in class_name for item in vehicle_items): return "vehicle" elif any(item in class_name for item in person_items): return "person" elif any(item in class_name for item in kitchen_items): return "kitchen_items" elif any(item in class_name for item in sports_items): return "sports" elif any(item in class_name for item in personal_items): return "personal_items" else: return "misc" except Exception as e: logger.error(f"Error categorizing object: {str(e)}") logger.error(traceback.format_exc()) return "misc" def get_object_categories(self, detected_objects: List[Dict]) -> set: """ 從檢測到的物件中取得唯一的物件類別 Args: detected_objects: 檢測到的物件列表 Returns: 唯一物件類別的集合 """ try: object_categories = set() for obj in detected_objects: category = self.categorize_object(obj) if category: object_categories.add(category) logger.info(f"Found {len(object_categories)} unique object categories") return object_categories except Exception as e: logger.error(f"Error getting object categories: {str(e)}") logger.error(traceback.format_exc()) return set() def identify_core_objects_for_scene(self, detected_objects: List[Dict], scene_type: str) -> List[Dict]: """ 識別定義特定場景類型的核心物件 Args: detected_objects: 檢測到的物件列表 scene_type: 場景類型 Returns: 場景的核心物件列表 """ try: core_objects = [] # 場景核心物件映射 scene_core_mapping = { "bedroom": [59], # bed "kitchen": [68, 69, 71, 72], # microwave, oven, sink, refrigerator "living_room": [57, 58, 62], # sofa, chair, tv "dining_area": [60, 42, 43], # dining table, fork, knife "office_workspace": [63, 64, 66, 73] # laptop, mouse, keyboard, book } if scene_type in scene_core_mapping: core_class_ids = scene_core_mapping[scene_type] for obj in detected_objects: if obj.get("class_id") in core_class_ids and obj.get("confidence", 0) >= 0.4: core_objects.append(obj) logger.info(f"Identified {len(core_objects)} core objects for scene type '{scene_type}'") return core_objects except Exception as e: logger.error(f"Error identifying core objects for scene '{scene_type}': {str(e)}") logger.error(traceback.format_exc()) return [] def group_objects_by_category_and_region(self, detected_objects: List[Dict]) -> Dict: """ 將物件按類別和區域分組 Args: detected_objects: 檢測到的物件列表 Returns: 按類別和區域分組的物件字典 """ try: category_regions = {} for obj in detected_objects: category = self.categorize_object(obj) if not category: continue if category not in category_regions: category_regions[category] = {} region = obj.get("region", "center") if region not in category_regions[category]: category_regions[category][region] = [] category_regions[category][region].append(obj) logger.info(f"Grouped objects into {len(category_regions)} categories across regions") return category_regions except Exception as e: logger.error(f"Error grouping objects by category and region: {str(e)}") logger.error(traceback.format_exc()) return {} def filter_objects_by_confidence(self, detected_objects: List[Dict], min_confidence: float) -> List[Dict]: """ 根據信心度過濾物件 Args: detected_objects: 檢測到的物件列表 min_confidence: 最小信心度閾值 Returns: 過濾後的物件列表 """ try: filtered_objects = [ obj for obj in detected_objects if obj.get("confidence", 0) >= min_confidence ] logger.info(f"Filtered {len(detected_objects)} objects to {len(filtered_objects)} objects with confidence >= {min_confidence}") return filtered_objects except Exception as e: logger.error(f"Error filtering objects by confidence: {str(e)}") logger.error(traceback.format_exc()) return detected_objects # 發生錯誤時返回原始列表