Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import traceback | |
from typing import Dict, List, Any, Optional | |
from model_manager import ModelManager | |
from prompt_template_manager import PromptTemplateManager | |
from response_processor import ResponseProcessor | |
from text_quality_validator import TextQualityValidator | |
from landmark_data import ALL_LANDMARKS | |
class LLMEnhancer: | |
""" | |
LLM增強器的主要窗口,協調模型管理、提示模板、回應處理和品質驗證等組件。 | |
提供統一的接口來處理場景描述增強、檢測結果驗證和無檢測情況處理。 | |
""" | |
def __init__(self, | |
model_path: Optional[str] = None, | |
tokenizer_path: Optional[str] = None, | |
device: Optional[str] = None, | |
max_length: int = 2048, | |
temperature: float = 0.3, | |
top_p: float = 0.85): | |
""" | |
初始化LLM增強器門面 | |
Args: | |
model_path: LLM模型的路徑或HuggingFace模型名稱,預設使用Llama 3.2 | |
tokenizer_path: tokenizer的路徑,通常與model_path相同 | |
device: 運行設備 ('cpu'或'cuda'),None時自動檢測 | |
max_length: 輸入文本的最大長度 | |
temperature: 生成文本的溫度參數 | |
top_p: 生成文本時的核心採樣機率閾值 | |
""" | |
# 設置專屬logger | |
self.logger = logging.getLogger(self.__class__.__name__) | |
if not self.logger.handlers: | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
self.logger.addHandler(handler) | |
self.logger.setLevel(logging.INFO) | |
try: | |
# 初始化四個核心組件 | |
self.model_manager = ModelManager( | |
model_path=model_path, | |
tokenizer_path=tokenizer_path, | |
device=device, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p | |
) | |
self.prompt_manager = PromptTemplateManager() | |
self.response_processor = ResponseProcessor() | |
self.quality_validator = TextQualityValidator() | |
# 保存模型路徑以供後續使用 | |
self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct" | |
self.logger.info("LLMEnhancer facade initialized successfully") | |
except Exception as e: | |
error_msg = f"Failed to initialize LLMEnhancer facade: {str(e)}" | |
self.logger.error(error_msg) | |
self.logger.error(traceback.format_exc()) | |
raise Exception(error_msg) from e | |
def enhance_description(self, scene_data: Dict[str, Any]) -> str: | |
""" | |
場景描述增強器主要入口方法,整合所有組件來處理場景描述增強 | |
Args: | |
scene_data: 包含場景資訊的字典,包括原始描述、檢測物件 (含 is_landmark)、 | |
場景類型、時間/光線資訊等 | |
Returns: | |
str: 增強後的場景描述 | |
""" | |
try: | |
self.logger.info("Starting scene description enhancement") | |
# 1. 重置模型上下文 | |
self.model_manager.reset_context() | |
# 2. 取出原始描述 | |
original_desc = scene_data.get("original_description", "") | |
if not original_desc: | |
self.logger.warning("No original description provided") | |
return "No original description provided." | |
# 3. 準備物件統計資訊 | |
object_list = self._prepare_object_statistics(scene_data) | |
if not object_list: | |
object_keywords = self.quality_validator.extract_objects_from_description(original_desc) | |
object_list = ", ".join(object_keywords) if object_keywords else "objects visible in the scene" | |
# 4. 檢測地標並準備地標資訊 | |
landmark_info = self._extract_landmark_info(scene_data) | |
# 5. 將地標資訊加入scene_data | |
enhanced_scene_data = scene_data.copy() | |
if landmark_info: | |
enhanced_scene_data["landmark_location_info"] = landmark_info | |
# 6. 生成 prompt | |
prompt = self.prompt_manager.format_enhancement_prompt_with_landmark( | |
scene_data=enhanced_scene_data, | |
object_list=object_list, | |
original_description=original_desc | |
) | |
# 7. 生成 LLM 回應 | |
self.logger.info("Generating LLM response") | |
response = self.model_manager.generate_response(prompt) | |
# 8. 處理不完整回應(重試機制) | |
response = self._handle_incomplete_response(response, prompt, original_desc) | |
# 9. 清理 LLM 回應 | |
model_type = self.model_path | |
raw_cleaned = self.response_processor.clean_response(response, model_type) | |
# 10. 移除解釋性注釋 | |
cleaned_response = self.response_processor.remove_explanatory_notes(raw_cleaned) | |
# 11. 事實準確性驗證 | |
try: | |
cleaned_response = self.quality_validator.verify_factual_accuracy( | |
original_desc, cleaned_response, object_list | |
) | |
except Exception: | |
self.logger.warning("Fact verification failed; using response without verification") | |
# 12. 場景類型一致性確保 | |
scene_type = scene_data.get("scene_type", "unknown scene") | |
word_count = len(cleaned_response.split()) | |
if word_count >= 5 and scene_type.lower() not in cleaned_response.lower(): | |
cleaned_response = self.quality_validator.ensure_scene_type_consistency( | |
cleaned_response, scene_type, original_desc | |
) | |
# 13. 視角一致性處理 | |
perspective = self.quality_validator.extract_perspective_from_description(original_desc) | |
if perspective and perspective.lower() not in cleaned_response.lower(): | |
cleaned_response = f"{perspective}, {cleaned_response[0].lower()}{cleaned_response[1:]}" | |
# 14. 最終驗證:如果結果過短,嘗試fallback | |
final_result = cleaned_response.strip() | |
if not final_result or len(final_result) < 20: | |
self.logger.warning("Enhanced description too short; attempting fallback") | |
# Fallback prompt | |
fallback_scene_data = enhanced_scene_data.copy() | |
fallback_scene_data["is_fallback"] = True | |
fallback_prompt = self.prompt_manager.format_enhancement_prompt_with_landmark( | |
scene_data=fallback_scene_data, | |
object_list=object_list, | |
original_description=original_desc | |
) | |
fallback_resp = self.model_manager.generate_response(fallback_prompt) | |
fallback_cleaned = self.response_processor.clean_response(fallback_resp, model_type) | |
fallback_cleaned = self.response_processor.remove_explanatory_notes(fallback_cleaned) | |
final_result = fallback_cleaned.strip() | |
if not final_result or len(final_result) < 20: | |
self.logger.warning("Fallback also insufficient; returning original") | |
return original_desc | |
# 15. display enhanced description | |
self.logger.info(f"Scene description enhancement completed successfully ({len(final_result)} chars)") | |
return final_result | |
except Exception as e: | |
error_msg = f"Enhancement failed: {str(e)}" | |
self.logger.error(error_msg) | |
self.logger.error(traceback.format_exc()) | |
return scene_data.get("original_description", "Unable to enhance description") | |
def _extract_landmark_info(self, scene_data: Dict[str, Any]) -> Optional[Dict[str, str]]: | |
""" | |
提取地標資訊,但不構建prompt內容 | |
Args: | |
scene_data: 場景資料字典 | |
Returns: | |
Optional[Dict[str, str]]: 地標資訊字典,包含name和location,如果沒有地標則返回None | |
""" | |
try: | |
# 檢查是否有地標 | |
lm_id_in_data = scene_data.get("landmark_id") | |
if not lm_id_in_data: | |
# 從檢測物件中尋找地標 | |
for obj in scene_data.get("detected_objects", []): | |
if obj.get("is_landmark") and obj.get("landmark_id"): | |
lm_id_in_data = obj["landmark_id"] | |
break | |
# 如果沒有檢測到地標,返回None | |
if not lm_id_in_data: | |
return None | |
# 從landmark_data.py提取地標資訊 | |
if lm_id_in_data in ALL_LANDMARKS: | |
lm_info = ALL_LANDMARKS[lm_id_in_data] | |
landmark_name = scene_data.get("scene_name", lm_info.get("name", lm_id_in_data)) | |
landmark_location = lm_info.get("location", "") | |
if landmark_location: | |
return { | |
"name": landmark_name, | |
"location": landmark_location, | |
"landmark_id": lm_id_in_data | |
} | |
return None | |
except Exception as e: | |
self.logger.error(f"Error extracting landmark info: {str(e)}") | |
return None | |
def _prepare_object_statistics(self, scene_data: Dict[str, Any]) -> str: | |
""" | |
準備物件統計資訊用於提示詞生成 | |
Args: | |
scene_data: 場景資料字典 | |
Returns: | |
str: 格式化的物件統計資訊 | |
""" | |
try: | |
# 高信心度閾值 | |
high_confidence_threshold = 0.65 | |
# 優先使用預計算的統計資訊 | |
object_statistics = scene_data.get("object_statistics", {}) | |
object_counts = {} | |
if object_statistics: | |
for class_name, stats in object_statistics.items(): | |
if stats.get("count", 0) > 0 and stats.get("avg_confidence", 0) >= high_confidence_threshold: | |
object_counts[class_name] = stats["count"] | |
else: | |
# 回退到原有的計算方式 | |
detected_objects = scene_data.get("detected_objects", []) | |
filtered_objects = [] | |
for obj in detected_objects: | |
confidence = obj.get("confidence", 0) | |
class_name = obj.get("class_name", "") | |
# 為特殊類別設置更高閾值 | |
special_classes = ["airplane", "helicopter", "boat"] | |
if class_name in special_classes: | |
if confidence < 0.75: | |
continue | |
if confidence >= high_confidence_threshold: | |
filtered_objects.append(obj) | |
for obj in filtered_objects: | |
class_name = obj.get("class_name", "") | |
if class_name not in object_counts: | |
object_counts[class_name] = 0 | |
object_counts[class_name] += 1 | |
# 格式化物件描述 | |
return ", ".join([ | |
f"{count} {obj}{'s' if count > 1 else ''}" | |
for obj, count in object_counts.items() | |
]) | |
except Exception as e: | |
self.logger.error(f"Object statistics preparation failed: {str(e)}") | |
return "objects visible in the scene" | |
def _handle_incomplete_response(self, response: str, prompt: str, original_desc: str) -> str: | |
""" | |
處理不完整的回應,必要時重新生成 | |
Args: | |
response: 原始回應 | |
prompt: 使用的提示詞 | |
original_desc: 原始描述 | |
Returns: | |
str: 處理後的回應 | |
""" | |
try: | |
# 檢查回應完整性 | |
is_complete, issue = self.quality_validator.validate_response_completeness(response) | |
max_retries = 3 | |
attempts = 0 | |
while not is_complete and attempts < max_retries: | |
self.logger.warning(f"Incomplete response detected ({issue}), retrying... Attempt {attempts+1}/{max_retries}") | |
# 重新生成 | |
response = self.model_manager.generate_response(prompt) | |
is_complete, issue = self.quality_validator.validate_response_completeness(response) | |
attempts += 1 | |
if not response or len(response.strip()) < 10: | |
self.logger.warning("Generated response was empty or too short, returning original description") | |
return original_desc | |
return response | |
except Exception as e: | |
self.logger.error(f"Incomplete response handling failed: {str(e)}") | |
return response # 返回原始回應 | |
def verify_detection(self, | |
detected_objects: List[Dict], | |
clip_analysis: Dict[str, Any], | |
scene_type: str, | |
scene_name: str, | |
confidence: float) -> Dict[str, Any]: | |
""" | |
驗證並可能修正YOLO的檢測結果 | |
Args: | |
detected_objects: YOLO檢測到的物體列表 | |
clip_analysis: CLIP分析結果 | |
scene_type: 識別的場景類型 | |
scene_name: 場景名稱 | |
confidence: 場景分類的信心度 | |
Returns: | |
Dict: 包含驗證結果和建議的字典 | |
""" | |
try: | |
self.logger.info("Starting detection verification") | |
# 格式化驗證提示 | |
prompt = self.prompt_manager.format_verification_prompt( | |
detected_objects=detected_objects, | |
clip_analysis=clip_analysis, | |
scene_type=scene_type, | |
scene_name=scene_name, | |
confidence=confidence | |
) | |
# 調用LLM進行驗證 | |
verification_result = self.model_manager.generate_response(prompt) | |
# 清理回應 | |
cleaned_result = self.response_processor.clean_response(verification_result, self.model_path) | |
# 解析驗證結果 | |
result = { | |
"verification_text": cleaned_result, | |
"has_errors": "appear accurate" not in cleaned_result.lower(), | |
"corrected_objects": None | |
} | |
self.logger.info("Detection verification completed") | |
return result | |
except Exception as e: | |
error_msg = f"Detection verification failed: {str(e)}" | |
self.logger.error(error_msg) | |
self.logger.error(traceback.format_exc()) | |
return { | |
"verification_text": "Verification failed", | |
"has_errors": False, | |
"corrected_objects": None | |
} | |
def handle_no_detection(self, clip_analysis: Dict[str, Any]) -> str: | |
""" | |
處理YOLO未檢測到物體的情況 | |
Args: | |
clip_analysis: CLIP分析結果 | |
Returns: | |
str: 生成的場景描述 | |
""" | |
try: | |
self.logger.info("Handling no detection scenario") | |
# 格式化無檢測提示 | |
prompt = self.prompt_manager.format_no_detection_prompt(clip_analysis) | |
# 調用LLM生成描述 | |
description = self.model_manager.generate_response(prompt) | |
# 清理回應 | |
cleaned_description = self.response_processor.clean_response(description, self.model_path) | |
self.logger.info("No detection handling completed") | |
return cleaned_description | |
except Exception as e: | |
error_msg = f"No detection handling failed: {str(e)}" | |
self.logger.error(error_msg) | |
self.logger.error(traceback.format_exc()) | |
return "Unable to generate scene description" | |
def reset_context(self): | |
"""重置LLM模型上下文""" | |
try: | |
self.model_manager.reset_context() | |
self.logger.info("LLM context reset completed") | |
except Exception as e: | |
self.logger.error(f"Context reset failed: {str(e)}") | |
def get_call_count(self) -> int: | |
""" | |
獲取模型調用次數 | |
Returns: | |
int: 調用次數 | |
""" | |
return self.model_manager.get_call_count() | |
def get_model_info(self) -> Dict[str, Any]: | |
""" | |
獲取模型和組件資訊 | |
Returns: | |
Dict[str, Any]: 包含所有組件狀態的綜合資訊 | |
""" | |
try: | |
return { | |
"model_manager": self.model_manager.get_model_info(), | |
"prompt_manager": self.prompt_manager.get_template_info(), | |
"response_processor": self.response_processor.get_processor_info(), | |
"quality_validator": self.quality_validator.get_validator_info(), | |
"facade_status": "initialized" | |
} | |
except Exception as e: | |
self.logger.error(f"Failed to get component info: {str(e)}") | |
return {"facade_status": "error", "error_message": str(e)} | |
def is_model_loaded(self) -> bool: | |
""" | |
檢查模型是否已載入 | |
Returns: | |
bool: 模型載入狀態 | |
""" | |
return self.model_manager.is_model_loaded() | |
def get_current_device(self) -> str: | |
""" | |
獲取當前運行設備 | |
Returns: | |
str: 當前設備名稱 | |
""" | |
return self.model_manager.get_current_device() | |
def _detect_scene_type(self, detected_objects: List[Dict]) -> str: | |
""" | |
基於物件分佈和模式檢測場景類型 | |
Args: | |
detected_objects: 檢測到的物件列表 | |
Returns: | |
str: 檢測到的場景類型 | |
""" | |
try: | |
# 預設場景類型 | |
scene_type = "intersection" | |
# 計算物件數量 | |
object_counts = {} | |
for obj in detected_objects: | |
class_name = obj.get("class_name", "") | |
if class_name not in object_counts: | |
object_counts[class_name] = 0 | |
object_counts[class_name] += 1 | |
# 人數統計 | |
people_count = object_counts.get("person", 0) | |
# 交通工具統計 | |
car_count = object_counts.get("car", 0) | |
bus_count = object_counts.get("bus", 0) | |
truck_count = object_counts.get("truck", 0) | |
total_vehicles = car_count + bus_count + truck_count | |
# 簡單的場景類型檢測邏輯 | |
if people_count > 8 and total_vehicles < 2: | |
scene_type = "pedestrian_crossing" | |
elif people_count > 5 and total_vehicles > 2: | |
scene_type = "busy_intersection" | |
elif people_count < 3 and total_vehicles > 3: | |
scene_type = "traffic_junction" | |
return scene_type | |
except Exception as e: | |
self.logger.error(f"Scene type detection failed: {str(e)}") | |
return "intersection" | |