Spaces:
Running
on
Zero
Running
on
Zero
File size: 15,469 Bytes
e6a18b7 e3868ba e6a18b7 e3868ba e6a18b7 be82503 ac7b808 be82503 e6a18b7 95b3ba7 e6a18b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
import logging
import traceback
from typing import Dict, List, Any, Optional
logger = logging.getLogger(__name__)
class ObjectExtractor:
"""
專門處理物件檢測結果的提取和預處理
負責從YOLO檢測結果提取物件資訊、物件分類和核心物件的辨識
"""
def __init__(self, class_names: Dict[int, str] = None, object_categories: Dict[str, List[int]] = None):
"""
初始化物件提取器
Args:
class_names: 類別ID到類別名稱的映射字典
object_categories: 物件類別分組字典
"""
try:
self.class_names = class_names or {}
self.object_categories = object_categories or {}
# 1. 讀取並設定基本信心度門檻(如果外部沒傳,就預設 0.25)
self.base_conf_threshold = 0.25
# 2. 動態信心度調整映射表 (key: 小寫 class_name, value: 調整係數)
# 最終的門檻 = base_conf_threshold * factor
# 如果某個 class_name 沒在這裡,就直接用 base_conf_threshold(相當於 factor=1.0)
self.dynamic_conf_map = {
"traffic light": 0.6,
"car": 0.8,
"person": 0.7,
}
logger.info(f"ObjectExtractor initialized with {len(self.class_names)} class names and {len(self.object_categories)} object categories")
except Exception as e:
logger.error(f"Failed to initialize ObjectExtractor: {str(e)}")
logger.error(traceback.format_exc())
raise
def _get_dynamic_threshold(self, class_name: str) -> float:
"""
根據 class_name 從 dynamic_conf_map 拿到 factor,計算最終的信心度門檻:
threshold = base_conf_threshold * factor
如果 class_name 不在映射表裡,就回傳 base_conf_threshold。
"""
# 使用小寫做匹配,確保在 dynamic_conf_map 裡的 key 也都用小寫
key = class_name.lower()
factor = self.dynamic_conf_map.get(key, 1.0)
return self.base_conf_threshold * factor
def extract_detected_objects(
self,
detection_result: Any,
confidence_threshold: float = 0.25,
region_analyzer=None
) -> List[Dict]:
"""
從檢測結果中提取物件資訊,包含位置資訊
Args:
detection_result: YOLO檢測結果
confidence_threshold: 改由動態門檻決定
region_analyzer: 區域分析器實例,用於判斷物件所屬區域
Returns:
包含檢測物件資訊的字典列表
"""
try:
# 調試信息:記錄當前類別映射狀態
logger.info(f"ObjectExtractor.extract_detected_objects called")
logger.info(f"Current class_names keys: {list(self.class_names.keys()) if self.class_names else 'None'}")
if detection_result is None:
logger.warning("Detection result is None")
return []
if not hasattr(detection_result, 'boxes'):
logger.error("Detection result does not have boxes attribute")
return []
boxes = detection_result.boxes.xyxy.cpu().numpy()
classes = detection_result.boxes.cls.cpu().numpy().astype(int)
confidences = detection_result.boxes.conf.cpu().numpy()
# 獲取圖像尺寸
img_height, img_width = detection_result.orig_shape[:2]
detected_objects = []
for box, class_id, confidence in zip(boxes, classes, confidences):
try:
# 1. 先拿到這筆偵測物件的 class_name
class_name = self.class_names.get(int(class_id), f"unknown_class_{class_id}")
# 2. 計算這個 class 應該採用的動態 threshold
dyn_thr = self._get_dynamic_threshold(class_name) # e.g. 0.25 * factor
# 3. 如果 confidence < dyn_thr,就跳過這一筆
if confidence < dyn_thr:
continue
# 後面維持原本的座標、中心、大小、區域等資訊計算
x1, y1, x2, y2 = box
width = x2 - x1
height = y2 - y1
# 中心點計算
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
# 標準化位置 (0-1)
norm_x = center_x / img_width
norm_y = center_y / img_height
norm_width = width / img_width
norm_height = height / img_height
# 面積計算
area = width * height
norm_area = area / (img_width * img_height)
# 區域判斷
object_region = "unknown"
if region_analyzer:
object_region = region_analyzer.determine_region(norm_x, norm_y)
# 調試信息:記錄映射過程
if class_name.startswith("unknown_class_"):
logger.warning(
f"Class ID {class_id} not found in class_names. "
f"Available keys: {list(self.class_names.keys())}"
)
else:
logger.debug(f"Successfully mapped class ID {class_id} to '{class_name}'")
detected_objects.append({
"class_id": int(class_id),
"class_name": class_name,
"confidence": float(confidence),
"box": [float(x1), float(y1), float(x2), float(y2)],
"center": [float(center_x), float(center_y)],
"normalized_center": [float(norm_x), float(norm_y)],
"size": [float(width), float(height)],
"normalized_size": [float(norm_width), float(norm_height)],
"area": float(area),
"normalized_area": float(norm_area),
"region": object_region
})
except Exception as e:
logger.error(f"Error processing object with class_id {class_id}: {str(e)}")
continue
logger.info(f"Extracted {len(detected_objects)} objects from detection result")
# print(f"DEBUG: ObjectExtractor filtered objects by class:")
# for class_name in ["car", "traffic light", "person", "handbag"]:
# class_objects = [obj for obj in detected_objects if obj.get("class_name") == class_name]
# if class_objects:
# confidences = [obj.get("confidence", 0) for obj in class_objects]
# print(f"DEBUG: {class_name}: {len(class_objects)} objects, confidences: {confidences}")
# print(f"DEBUG: base_conf_threshold: {self.base_conf_threshold}")
# print(f"DEBUG: dynamic_conf_map: {self.dynamic_conf_map}")
return detected_objects
except Exception as e:
logger.error(f"Error extracting detected objects: {str(e)}")
logger.error(traceback.format_exc())
return []
def update_class_names(self, class_names: Dict[int, str]):
"""
動態更新類別名稱映射
Args:
class_names: 新的類別名稱映射字典
"""
try:
self.class_names = class_names or {}
logger.info(f"Class names updated: {len(self.class_names)} classes")
logger.debug(f"Updated class names: {self.class_names}")
except Exception as e:
logger.error(f"Failed to update class names: {str(e)}")
def categorize_object(self, obj: Dict) -> str:
"""
將檢測到的物件分類到功能類別中,用於區域識別
Args:
obj: 物件字典
Returns:
物件功能類別字串
"""
try:
class_id = obj.get("class_id", -1)
class_name = obj.get("class_name", "").lower()
# 使用現有的類別映射(如果可用)
if self.object_categories:
for category, ids in self.object_categories.items():
if class_id in ids:
return category
# 基於COCO類別名稱的後備分類
furniture_items = ["chair", "couch", "bed", "dining table", "toilet"]
plant_items = ["potted plant"]
electronic_items = ["tv", "laptop", "mouse", "remote", "keyboard", "cell phone"]
vehicle_items = ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"]
person_items = ["person"]
kitchen_items = ["bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog",
"pizza", "donut", "cake", "refrigerator", "oven", "toaster", "sink", "microwave"]
sports_items = ["frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket"]
personal_items = ["handbag", "tie", "suitcase", "umbrella", "backpack"]
if any(item in class_name for item in furniture_items):
return "furniture"
elif any(item in class_name for item in plant_items):
return "plant"
elif any(item in class_name for item in electronic_items):
return "electronics"
elif any(item in class_name for item in vehicle_items):
return "vehicle"
elif any(item in class_name for item in person_items):
return "person"
elif any(item in class_name for item in kitchen_items):
return "kitchen_items"
elif any(item in class_name for item in sports_items):
return "sports"
elif any(item in class_name for item in personal_items):
return "personal_items"
else:
return "misc"
except Exception as e:
logger.error(f"Error categorizing object: {str(e)}")
logger.error(traceback.format_exc())
return "misc"
def get_object_categories(self, detected_objects: List[Dict]) -> set:
"""
從檢測到的物件中取得唯一的物件類別
Args:
detected_objects: 檢測到的物件列表
Returns:
唯一物件類別的集合
"""
try:
object_categories = set()
for obj in detected_objects:
category = self.categorize_object(obj)
if category:
object_categories.add(category)
logger.info(f"Found {len(object_categories)} unique object categories")
return object_categories
except Exception as e:
logger.error(f"Error getting object categories: {str(e)}")
logger.error(traceback.format_exc())
return set()
def identify_core_objects_for_scene(self, detected_objects: List[Dict], scene_type: str) -> List[Dict]:
"""
識別定義特定場景類型的核心物件
Args:
detected_objects: 檢測到的物件列表
scene_type: 場景類型
Returns:
場景的核心物件列表
"""
try:
core_objects = []
# 場景核心物件映射
scene_core_mapping = {
"bedroom": [59], # bed
"kitchen": [68, 69, 71, 72], # microwave, oven, sink, refrigerator
"living_room": [57, 58, 62], # sofa, chair, tv
"dining_area": [60, 42, 43], # dining table, fork, knife
"office_workspace": [63, 64, 66, 73] # laptop, mouse, keyboard, book
}
if scene_type in scene_core_mapping:
core_class_ids = scene_core_mapping[scene_type]
for obj in detected_objects:
if obj.get("class_id") in core_class_ids and obj.get("confidence", 0) >= 0.4:
core_objects.append(obj)
logger.info(f"Identified {len(core_objects)} core objects for scene type '{scene_type}'")
return core_objects
except Exception as e:
logger.error(f"Error identifying core objects for scene '{scene_type}': {str(e)}")
logger.error(traceback.format_exc())
return []
def group_objects_by_category_and_region(self, detected_objects: List[Dict]) -> Dict:
"""
將物件按類別和區域分組
Args:
detected_objects: 檢測到的物件列表
Returns:
按類別和區域分組的物件字典
"""
try:
category_regions = {}
for obj in detected_objects:
category = self.categorize_object(obj)
if not category:
continue
if category not in category_regions:
category_regions[category] = {}
region = obj.get("region", "center")
if region not in category_regions[category]:
category_regions[category][region] = []
category_regions[category][region].append(obj)
logger.info(f"Grouped objects into {len(category_regions)} categories across regions")
return category_regions
except Exception as e:
logger.error(f"Error grouping objects by category and region: {str(e)}")
logger.error(traceback.format_exc())
return {}
def filter_objects_by_confidence(self, detected_objects: List[Dict], min_confidence: float) -> List[Dict]:
"""
根據信心度過濾物件
Args:
detected_objects: 檢測到的物件列表
min_confidence: 最小信心度閾值
Returns:
過濾後的物件列表
"""
try:
filtered_objects = [
obj for obj in detected_objects
if obj.get("confidence", 0) >= min_confidence
]
logger.info(f"Filtered {len(detected_objects)} objects to {len(filtered_objects)} objects with confidence >= {min_confidence}")
return filtered_objects
except Exception as e:
logger.error(f"Error filtering objects by confidence: {str(e)}")
logger.error(traceback.format_exc())
return detected_objects # 發生錯誤時返回原始列表
|