Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from typing import Dict, List, Tuple, Optional, Any | |
class ObjectGroupProcessor: | |
""" | |
物件組處理器 - 專門處理物件分組、排序和子句生成的邏輯 | |
負責物件按類別分組、重複物件檢測移除、物件組優先級排序以及描述子句的生成 | |
""" | |
def __init__(self, confidence_threshold_for_description: float = 0.25, | |
spatial_handler: Optional[Any] = None, | |
text_optimizer: Optional[Any] = None): | |
""" | |
初始化物件組處理器 | |
Args: | |
confidence_threshold_for_description: 用於描述的置信度閾值 | |
spatial_handler: 空間位置處理器實例 | |
text_optimizer: 文本優化器實例 | |
""" | |
self.logger = logging.getLogger(self.__class__.__name__) | |
self.confidence_threshold_for_description = confidence_threshold_for_description | |
self.spatial_handler = spatial_handler | |
self.text_optimizer = text_optimizer | |
def group_objects_by_class(self, confident_objects: List[Dict], | |
object_statistics: Optional[Dict]) -> Dict[str, List[Dict]]: | |
""" | |
按類別分組物件 | |
Args: | |
confident_objects: 置信度過濾後的物件 | |
object_statistics: 物件統計信息 | |
Returns: | |
Dict[str, List[Dict]]: 按類別分組的物件 | |
""" | |
objects_by_class = {} | |
if object_statistics: | |
# 使用預計算的統計信息,採用動態的信心度 | |
for class_name, stats in object_statistics.items(): | |
count = stats.get("count", 0) | |
avg_confidence = stats.get("avg_confidence", 0) | |
# 動態調整置信度閾值 | |
dynamic_threshold = self.confidence_threshold_for_description | |
if class_name in ["potted plant", "vase", "clock", "book"]: | |
dynamic_threshold = max(0.15, self.confidence_threshold_for_description * 0.6) | |
elif count >= 3: | |
dynamic_threshold = max(0.2, self.confidence_threshold_for_description * 0.8) | |
if count > 0 and avg_confidence >= dynamic_threshold: | |
matching_objects = [obj for obj in confident_objects if obj.get("class_name") == class_name] | |
if not matching_objects: | |
matching_objects = [obj for obj in confident_objects | |
if obj.get("class_name") == class_name and obj.get("confidence", 0) >= dynamic_threshold] | |
if matching_objects: | |
actual_count = min(stats["count"], len(matching_objects)) | |
objects_by_class[class_name] = matching_objects[:actual_count] | |
# Debug logging for specific classes | |
if class_name in ["car", "traffic light", "person", "handbag"]: | |
print(f"DEBUG: Before spatial deduplication:") | |
print(f"DEBUG: {class_name}: {len(objects_by_class[class_name])} objects before dedup") | |
else: | |
# 備用邏輯,同樣使用動態閾值 | |
for obj in confident_objects: | |
name = obj.get("class_name", "unknown object") | |
if name == "unknown object" or not name: | |
continue | |
if name not in objects_by_class: | |
objects_by_class[name] = [] | |
objects_by_class[name].append(obj) | |
return objects_by_class | |
def remove_duplicate_objects(self, objects_by_class: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]: | |
""" | |
移除重複物件 | |
Args: | |
objects_by_class: 按類別分組的物件 | |
Returns: | |
Dict[str, List[Dict]]: 去重後的物件 | |
""" | |
deduplicated_objects_by_class = {} | |
processed_positions = [] | |
for class_name, group_of_objects in objects_by_class.items(): | |
unique_objects = [] | |
for obj in group_of_objects: | |
obj_position = obj.get("normalized_center", [0.5, 0.5]) | |
is_duplicate = False | |
for processed_pos in processed_positions: | |
position_distance = abs(obj_position[0] - processed_pos[0]) + abs(obj_position[1] - processed_pos[1]) | |
if position_distance < 0.15: | |
is_duplicate = True | |
break | |
if not is_duplicate: | |
unique_objects.append(obj) | |
processed_positions.append(obj_position) | |
if unique_objects: | |
deduplicated_objects_by_class[class_name] = unique_objects | |
# Debug logging after deduplication | |
for class_name in ["car", "traffic light", "person", "handbag"]: | |
if class_name in deduplicated_objects_by_class: | |
print(f"DEBUG: After spatial deduplication:") | |
print(f"DEBUG: {class_name}: {len(deduplicated_objects_by_class[class_name])} objects after dedup") | |
return deduplicated_objects_by_class | |
def sort_object_groups(self, objects_by_class: Dict[str, List[Dict]]) -> List[Tuple[str, List[Dict]]]: | |
""" | |
排序物件組 | |
Args: | |
objects_by_class: 按類別分組的物件 | |
Returns: | |
List[Tuple[str, List[Dict]]]: 排序後的物件組 | |
""" | |
def sort_key_object_groups(item_tuple: Tuple[str, List[Dict]]): | |
class_name_key, obj_group_list = item_tuple | |
priority = 3 | |
count = len(obj_group_list) | |
# 確保類別名稱已標準化 | |
normalized_class_name = self._normalize_object_class_name(class_name_key) | |
# 動態優先級 | |
if normalized_class_name == "person": | |
priority = 0 | |
elif normalized_class_name in ["dining table", "chair", "sofa", "bed"]: | |
priority = 1 | |
elif normalized_class_name in ["car", "bus", "truck", "traffic light"]: | |
priority = 2 | |
elif count >= 3: | |
priority = max(1, priority - 1) | |
elif normalized_class_name in ["potted plant", "vase", "clock", "book"] and count >= 2: | |
priority = 2 | |
avg_area = sum(o.get("normalized_area", 0.0) for o in obj_group_list) / len(obj_group_list) if obj_group_list else 0 | |
quantity_bonus = min(count / 5.0, 1.0) | |
return (priority, -len(obj_group_list), -avg_area, -quantity_bonus) | |
return sorted(objects_by_class.items(), key=sort_key_object_groups) | |
def generate_object_clauses(self, sorted_object_groups: List[Tuple[str, List[Dict]]], | |
object_statistics: Optional[Dict], | |
scene_type: str, | |
image_width: Optional[int], | |
image_height: Optional[int], | |
region_analyzer: Optional[Any] = None) -> List[str]: | |
""" | |
生成物件描述子句 | |
Args: | |
sorted_object_groups: 排序後的物件組 | |
object_statistics: 物件統計信息 | |
scene_type: 場景類型 | |
image_width: 圖像寬度 | |
image_height: 圖像高度 | |
region_analyzer: 區域分析器實例 | |
Returns: | |
List[str]: 物件描述子句列表 | |
""" | |
object_clauses = [] | |
for class_name, group_of_objects in sorted_object_groups: | |
count = len(group_of_objects) | |
# Debug logging for final count | |
if class_name in ["car", "traffic light", "person", "handbag"]: | |
print(f"DEBUG: Final count for {class_name}: {count}") | |
if count == 0: | |
continue | |
# 標準化class name | |
normalized_class_name = self._normalize_object_class_name(class_name) | |
# 使用統計信息確保準確的數量描述 | |
if object_statistics and class_name in object_statistics: | |
actual_count = object_statistics[class_name]["count"] | |
formatted_name_with_exact_count = self._format_object_count_description( | |
normalized_class_name, | |
actual_count, | |
scene_type=scene_type | |
) | |
else: | |
formatted_name_with_exact_count = self._format_object_count_description( | |
normalized_class_name, | |
count, | |
scene_type=scene_type | |
) | |
if formatted_name_with_exact_count == "no specific objects clearly identified" or not formatted_name_with_exact_count: | |
continue | |
# 確定群組的集體位置 | |
location_description_suffix = self._generate_location_description( | |
group_of_objects, count, image_width, image_height, region_analyzer | |
) | |
# 首字母大寫 | |
formatted_name_capitalized = formatted_name_with_exact_count[0].upper() + formatted_name_with_exact_count[1:] | |
object_clauses.append(f"{formatted_name_capitalized} {location_description_suffix}") | |
return object_clauses | |
def format_object_clauses(self, object_clauses: List[str]) -> str: | |
""" | |
格式化物件描述子句 | |
Args: | |
object_clauses: 物件描述子句列表 | |
Returns: | |
str: 格式化後的描述 | |
""" | |
if not object_clauses: | |
return "No common objects were confidently identified for detailed description." | |
# 處理第一個子句 | |
first_clause = object_clauses.pop(0) | |
result = first_clause + "." | |
# 處理剩餘子句 | |
if object_clauses: | |
result += " The scene features:" | |
joined_object_clauses = ". ".join(object_clauses) | |
if joined_object_clauses and not joined_object_clauses.endswith("."): | |
joined_object_clauses += "." | |
result += " " + joined_object_clauses | |
return result | |
def _generate_location_description(self, group_of_objects: List[Dict], count: int, | |
image_width: Optional[int], image_height: Optional[int], | |
region_analyzer: Optional[Any] = None) -> str: | |
""" | |
生成位置描述 | |
Args: | |
group_of_objects: 物件組 | |
count: 物件數量 | |
image_width: 圖像寬度 | |
image_height: 圖像高度 | |
region_analyzer: 區域分析器實例 | |
Returns: | |
str: 位置描述 | |
""" | |
if count == 1: | |
if self.spatial_handler: | |
spatial_desc = self.spatial_handler.generate_spatial_description( | |
group_of_objects[0], image_width, image_height, region_analyzer | |
) | |
else: | |
spatial_desc = self._get_spatial_description_phrase(group_of_objects[0].get("region", "")) | |
if spatial_desc: | |
return f"is {spatial_desc}" | |
else: | |
distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region")))) | |
valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()] | |
if not valid_regions: | |
return "is positioned in the scene" | |
elif len(valid_regions) == 1: | |
spatial_desc = self._get_spatial_description_phrase(valid_regions[0]) | |
return f"is primarily {spatial_desc}" if spatial_desc else "is positioned in the scene" | |
elif len(valid_regions) == 2: | |
clean_region1 = valid_regions[0].replace('_', ' ') | |
clean_region2 = valid_regions[1].replace('_', ' ') | |
return f"is mainly across the {clean_region1} and {clean_region2} areas" | |
else: | |
return "is distributed in various parts of the scene" | |
else: | |
distinct_regions = sorted(list(set(obj.get("region", "") for obj in group_of_objects if obj.get("region")))) | |
valid_regions = [r for r in distinct_regions if r and r != "unknown" and r.strip()] | |
if not valid_regions: | |
return "are visible in the scene" | |
elif len(valid_regions) == 1: | |
clean_region = valid_regions[0].replace('_', ' ') | |
return f"are primarily in the {clean_region} area" | |
elif len(valid_regions) == 2: | |
clean_region1 = valid_regions[0].replace('_', ' ') | |
clean_region2 = valid_regions[1].replace('_', ' ') | |
return f"are mainly across the {clean_region1} and {clean_region2} areas" | |
else: | |
return "are distributed in various parts of the scene" | |
def _get_spatial_description_phrase(self, region: str) -> str: | |
""" | |
獲取空間描述短語的備用方法 | |
Args: | |
region: 區域字符串 | |
Returns: | |
str: 空間描述短語 | |
""" | |
if not region or region == "unknown": | |
return "" | |
clean_region = region.replace('_', ' ').strip().lower() | |
region_map = { | |
"top left": "in the upper left area", | |
"top center": "in the upper area", | |
"top right": "in the upper right area", | |
"middle left": "on the left side", | |
"middle center": "in the center", | |
"center": "in the center", | |
"middle right": "on the right side", | |
"bottom left": "in the lower left area", | |
"bottom center": "in the lower area", | |
"bottom right": "in the lower right area" | |
} | |
return region_map.get(clean_region, "") | |
def _normalize_object_class_name(self, class_name: str) -> str: | |
""" | |
標準化物件類別名稱 | |
Args: | |
class_name: 原始類別名稱 | |
Returns: | |
str: 標準化後的類別名稱 | |
""" | |
if self.text_optimizer: | |
return self.text_optimizer.normalize_object_class_name(class_name) | |
else: | |
# 備用標準化邏輯 | |
if not class_name or not isinstance(class_name, str): | |
return "object" | |
# 簡單的標準化處理 | |
normalized = class_name.replace('_', ' ').strip().lower() | |
return normalized | |
def _format_object_count_description(self, class_name: str, count: int, | |
scene_type: Optional[str] = None, | |
detected_objects: Optional[List[Dict]] = None, | |
avg_confidence: float = 0.0) -> str: | |
""" | |
格式化物件數量描述 | |
Args: | |
class_name: 標準化後的類別名稱 | |
count: 物件數量 | |
scene_type: 場景類型 | |
detected_objects: 該類型的所有檢測物件 | |
avg_confidence: 平均檢測置信度 | |
Returns: | |
str: 完整的格式化數量描述 | |
""" | |
if self.text_optimizer: | |
return self.text_optimizer.format_object_count_description( | |
class_name, count, scene_type, detected_objects, avg_confidence | |
) | |
else: | |
# 備用格式化邏輯 | |
if count <= 0: | |
return "" | |
elif count == 1: | |
article = "an" if class_name[0].lower() in 'aeiou' else "a" | |
return f"{article} {class_name}" | |
else: | |
# 簡單的複數處理 | |
plural_form = class_name + "s" if not class_name.endswith("s") else class_name | |
number_words = { | |
2: "two", 3: "three", 4: "four", 5: "five", 6: "six", | |
7: "seven", 8: "eight", 9: "nine", 10: "ten", | |
11: "eleven", 12: "twelve" | |
} | |
if count in number_words: | |
return f"{number_words[count]} {plural_form}" | |
elif count <= 20: | |
return f"several {plural_form}" | |
else: | |
return f"numerous {plural_form}" | |