VisionScout / object_group_processor.py
DawnC's picture
Upload 14 files
12d9ea9 verified
import logging
from typing import Dict, List, Tuple, Optional, Any
class ObjectGroupProcessor:
"""
物件組處理器 - 專門處理物件分組、排序和子句生成的邏輯
負責物件按類別分組、重複物件檢測移除、物件組優先級排序以及描述子句的生成
"""
def __init__(self, confidence_threshold_for_description: float = 0.25,
spatial_handler: Optional[Any] = None,
text_optimizer: Optional[Any] = None):
"""
初始化物件組處理器
Args:
confidence_threshold_for_description: 用於描述的置信度閾值
spatial_handler: 空間位置處理器實例
text_optimizer: 文本優化器實例
"""
self.logger = logging.getLogger(self.__class__.__name__)
self.confidence_threshold_for_description = confidence_threshold_for_description
self.spatial_handler = spatial_handler
self.text_optimizer = text_optimizer
def group_objects_by_class(self, confident_objects: List[Dict],
object_statistics: Optional[Dict]) -> Dict[str, List[Dict]]:
"""
按類別分組物件
Args:
confident_objects: 置信度過濾後的物件
object_statistics: 物件統計信息
Returns:
Dict[str, List[Dict]]: 按類別分組的物件
"""
objects_by_class = {}
if object_statistics:
# 使用預計算的統計信息,採用動態的信心度
for class_name, stats in object_statistics.items():
count = stats.get("count", 0)
avg_confidence = stats.get("avg_confidence", 0)
# 動態調整置信度閾值
dynamic_threshold = self.confidence_threshold_for_description
if class_name in ["potted plant", "vase", "clock", "book"]:
dynamic_threshold = max(0.15, self.confidence_threshold_for_description * 0.6)
elif count >= 3:
dynamic_threshold = max(0.2, self.confidence_threshold_for_description * 0.8)
if count > 0 and avg_confidence >= dynamic_threshold:
matching_objects = [obj for obj in confident_objects if obj.get("class_name") == class_name]
if not matching_objects:
matching_objects = [obj for obj in confident_objects
if obj.get("class_name") == class_name and obj.get("confidence", 0) >= dynamic_threshold]
if matching_objects:
actual_count = min(stats["count"], len(matching_objects))
objects_by_class[class_name] = matching_objects[:actual_count]
# Debug logging for specific classes
if class_name in ["car", "traffic light", "person", "handbag"]:
print(f"DEBUG: Before spatial deduplication:")
print(f"DEBUG: {class_name}: {len(objects_by_class[class_name])} objects before dedup")
else:
# 備用邏輯,同樣使用動態閾值
for obj in confident_objects:
name = obj.get("class_name", "unknown object")
if name == "unknown object" or not name:
continue
if name not in objects_by_class:
objects_by_class[name] = []
objects_by_class[name].append(obj)
return objects_by_class
def remove_duplicate_objects(self, objects_by_class: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]:
"""
移除重複物件
Args:
objects_by_class: 按類別分組的物件
Returns:
Dict[str, List[Dict]]: 去重後的物件
"""
deduplicated_objects_by_class = {}
processed_positions = []
for class_name, group_of_objects in objects_by_class.items():
unique_objects = []
for obj in group_of_objects:
obj_position = obj.get("normalized_center", [0.5, 0.5])
is_duplicate = False
for processed_pos in processed_positions:
position_distance = abs(obj_position[0] - processed_pos[0]) + abs(obj_position[1] - processed_pos[1])
if position_distance < 0.15:
is_duplicate = True
break
if not is_duplicate:
unique_objects.append(obj)
processed_positions.append(obj_position)
if unique_objects:
deduplicated_objects_by_class[class_name] = unique_objects
# Debug logging after deduplication
for class_name in ["car", "traffic light", "person", "handbag"]:
if class_name in deduplicated_objects_by_class:
print(f"DEBUG: After spatial deduplication:")
print(f"DEBUG: {class_name}: {len(deduplicated_objects_by_class[class_name])} objects after dedup")
return deduplicated_objects_by_class
def sort_object_groups(self, objects_by_class: Dict[str, List[Dict]]) -> List[Tuple[str, List[Dict]]]:
"""
排序物件組
Args:
objects_by_class: 按類別分組的物件
Returns:
List[Tuple[str, List[Dict]]]: 排序後的物件組
"""
def sort_key_object_groups(item_tuple: Tuple[str, List[Dict]]):
class_name_key, obj_group_list = item_tuple
priority = 3
count = len(obj_group_list)
# 確保類別名稱已標準化
normalized_class_name = self._normalize_object_class_name(class_name_key)
# 動態優先級
if normalized_class_name == "person":
priority = 0
elif normalized_class_name in ["dining table", "chair", "sofa", "bed"]:
priority = 1
elif normalized_class_name in ["car", "bus", "truck", "traffic light"]:
priority = 2
elif count >= 3:
priority = max(1, priority - 1)
elif normalized_class_name in ["potted plant", "vase", "clock", "book"] and count >= 2:
priority = 2
avg_area = sum(o.get("normalized_area", 0.0) for o in obj_group_list) / len(obj_group_list) if obj_group_list else 0
quantity_bonus = min(count / 5.0, 1.0)
return (priority, -len(obj_group_list), -avg_area, -quantity_bonus)
return sorted(objects_by_class.items(), key=sort_key_object_groups)
def generate_object_clauses(self, sorted_object_groups: List[Tuple[str, List[Dict]]],
object_statistics: Optional[Dict],
scene_type: str,
image_width: Optional[int],
image_height: Optional[int],
region_analyzer: Optional[Any] = None) -> List[str]:
"""
生成物件描述子句
Args:
sorted_object_groups: 排序後的物件組
object_statistics: 物件統計信息
scene_type: 場景類型
image_width: 圖像寬度
image_height: 圖像高度
region_analyzer: 區域分析器實例
Returns:
List[str]: 物件描述子句列表
"""
object_clauses = []
for class_name, group_of_objects in sorted_object_groups:
count = len(group_of_objects)
# Debug logging for final count
if class_name in ["car", "traffic light", "person", "handbag"]:
print(f"DEBUG: Final count for {class_name}: {count}")
if count == 0:
continue
# 標準化class name
normalized_class_name = self._normalize_object_class_name(class_name)
# 使用統計信息確保準確的數量描述
if object_statistics and class_name in object_statistics:
actual_count = object_statistics[class_name]["count"]
formatted_name_with_exact_count = self._format_object_count_description(
normalized_class_name,
actual_count,
scene_type=scene_type
)
else:
formatted_name_with_exact_count = self._format_object_count_description(
normalized_class_name,
count,
scene_type=scene_type
)
if formatted_name_with_exact_count == "no specific objects clearly identified" or not formatted_name_with_exact_count:
continue
# 確定群組的集體位置
location_description_suffix = self._generate_location_description(
group_of_objects, count, image_width, image_height, region_analyzer
)
# 首字母大寫
formatted_name_capitalized = formatted_name_with_exact_count[0].upper() + formatted_name_with_exact_count[1:]
object_clauses.append(f"{formatted_name_capitalized} {location_description_suffix}")
return object_clauses
def format_object_clauses(self, object_clauses: List[str]) -> str:
"""
格式化物件描述子句
Args:
object_clauses: 物件描述子句列表
Returns:
str: 格式化後的描述
"""
if not object_clauses:
return "No common objects were confidently identified for detailed description."
# 處理第一個子句
first_clause = object_clauses.pop(0)
result = first_clause + "."
# 處理剩餘子句
if object_clauses:
result += " The scene features:"
joined_object_clauses = ". ".join(object_clauses)
if joined_object_clauses and not joined_object_clauses.endswith("."):
joined_object_clauses += "."
result += " " + joined_object_clauses
return result
def _generate_location_description(self, group_of_objects: List[Dict], count: int,
image_width: Optional[int], image_height: Optional[int],
region_analyzer: Optional[Any] = None) -> str:
"""
生成位置描述
Args:
group_of_objects: 物件組
count: 物件數量
image_width: 圖像寬度
image_height: 圖像高度
region_analyzer: 區域分析器實例
Returns:
str: 位置描述
"""
if count == 1:
if self.spatial_handler:
spatial_desc = self.spatial_handler.generate_spatial_description(
group_of_objects[0], image_width, image_height, region_analyzer
)
else:
spatial_desc = self._get_spatial_description_phrase(group_of_objects[0].get("region", ""))
if spatial_desc:
return f"is {spatial_desc}"
else:
distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region"))))
valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()]
if not valid_regions:
return "is positioned in the scene"
elif len(valid_regions) == 1:
spatial_desc = self._get_spatial_description_phrase(valid_regions[0])
return f"is primarily {spatial_desc}" if spatial_desc else "is positioned in the scene"
elif len(valid_regions) == 2:
clean_region1 = valid_regions[0].replace('_', ' ')
clean_region2 = valid_regions[1].replace('_', ' ')
return f"is mainly across the {clean_region1} and {clean_region2} areas"
else:
return "is distributed in various parts of the scene"
else:
distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region"))))
valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()]
if not valid_regions:
return "are visible in the scene"
elif len(valid_regions) == 1:
clean_region = valid_regions[0].replace('_', ' ')
return f"are primarily in the {clean_region} area"
elif len(valid_regions) == 2:
clean_region1 = valid_regions[0].replace('_', ' ')
clean_region2 = valid_regions[1].replace('_', ' ')
return f"are mainly across the {clean_region1} and {clean_region2} areas"
else:
return "are distributed in various parts of the scene"
def _get_spatial_description_phrase(self, region: str) -> str:
"""
獲取空間描述短語的備用方法
Args:
region: 區域字符串
Returns:
str: 空間描述短語
"""
if not region or region == "unknown":
return ""
clean_region = region.replace('_', ' ').strip().lower()
region_map = {
"top left": "in the upper left area",
"top center": "in the upper area",
"top right": "in the upper right area",
"middle left": "on the left side",
"middle center": "in the center",
"center": "in the center",
"middle right": "on the right side",
"bottom left": "in the lower left area",
"bottom center": "in the lower area",
"bottom right": "in the lower right area"
}
return region_map.get(clean_region, "")
def _normalize_object_class_name(self, class_name: str) -> str:
"""
標準化物件類別名稱
Args:
class_name: 原始類別名稱
Returns:
str: 標準化後的類別名稱
"""
if self.text_optimizer:
return self.text_optimizer.normalize_object_class_name(class_name)
else:
# 備用標準化邏輯
if not class_name or not isinstance(class_name, str):
return "object"
# 簡單的標準化處理
normalized = class_name.replace('_', ' ').strip().lower()
return normalized
def _format_object_count_description(self, class_name: str, count: int,
scene_type: Optional[str] = None,
detected_objects: Optional[List[Dict]] = None,
avg_confidence: float = 0.0) -> str:
"""
格式化物件數量描述
Args:
class_name: 標準化後的類別名稱
count: 物件數量
scene_type: 場景類型
detected_objects: 該類型的所有檢測物件
avg_confidence: 平均檢測置信度
Returns:
str: 完整的格式化數量描述
"""
if self.text_optimizer:
return self.text_optimizer.format_object_count_description(
class_name, count, scene_type, detected_objects, avg_confidence
)
else:
# 備用格式化邏輯
if count <= 0:
return ""
elif count == 1:
article = "an" if class_name[0].lower() in 'aeiou' else "a"
return f"{article} {class_name}"
else:
# 簡單的複數處理
plural_form = class_name + "s" if not class_name.endswith("s") else class_name
number_words = {
2: "two", 3: "three", 4: "four", 5: "five", 6: "six",
7: "seven", 8: "eight", 9: "nine", 10: "ten",
11: "eleven", 12: "twelve"
}
if count in number_words:
return f"{number_words[count]} {plural_form}"
elif count <= 20:
return f"several {plural_form}"
else:
return f"numerous {plural_form}"