import torch import clip import numpy as np import logging import traceback from typing import List, Dict, Tuple, Optional, Union, Any from PIL import Image class CLIPModelManager: """ 專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能 """ def __init__(self, model_name: str = "ViT-B/16", device: str = None): """ 初始化 CLIP 模型管理器 Args: model_name: CLIP模型名稱,默認為"ViT-B/16" device: 運行設備,None則自動選擇 """ self.logger = logging.getLogger(__name__) self.model_name = model_name # 設置運行設備 if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device self.model = None self.preprocess = None self._initialize_model() def _initialize_model(self): """ 初始化CLIP模型 """ try: self.logger.info(f"Initializing CLIP model ({self.model_name}) on {self.device}") self.model, self.preprocess = clip.load(self.model_name, device=self.device) self.logger.info("Successfully loaded CLIP model") except Exception as e: self.logger.error(f"Error loading CLIP model: {e}") self.logger.error(traceback.format_exc()) raise def encode_image(self, image_input: torch.Tensor) -> torch.Tensor: """ 編碼圖像特徵 Args: image_input: 預處理後的圖像張量 Returns: torch.Tensor: 標準化後的圖像特徵 """ try: with torch.no_grad(): image_features = self.model.encode_image(image_input) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features except Exception as e: self.logger.error(f"Error encoding image features: {e}") self.logger.error(traceback.format_exc()) raise def encode_text_batch(self, text_prompts: List[str], batch_size: int = 128) -> torch.Tensor: """ 批量編碼文本特徵,避免CUDA內存問題 Args: text_prompts: 文本提示列表 batch_size: 批處理大小 Returns: torch.Tensor: 標準化後的文本特徵 """ if not text_prompts: return None try: with torch.no_grad(): features_list = [] for i in range(0, len(text_prompts), batch_size): batch_prompts = text_prompts[i:i+batch_size] text_tokens = clip.tokenize(batch_prompts).to(self.device) batch_features = self.model.encode_text(text_tokens) batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True) features_list.append(batch_features) # 連接所有批次 if len(features_list) > 1: text_features = torch.cat(features_list, dim=0) else: text_features = features_list[0] return text_features except Exception as e: self.logger.error(f"Error encoding text features: {e}") self.logger.error(traceback.format_exc()) raise def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor: """ 編碼單個文本批次的特徵 Args: text_prompts: 文本提示列表 Returns: torch.Tensor: 標準化後的文本特徵 """ try: with torch.no_grad(): text_tokens = clip.tokenize(text_prompts).to(self.device) text_features = self.model.encode_text(text_tokens) text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features except Exception as e: self.logger.error(f"Error encoding single text batch: {e}") self.logger.error(traceback.format_exc()) raise def calculate_similarity(self, image_features: torch.Tensor, text_features: torch.Tensor) -> np.ndarray: """ 計算圖像和文本特徵之間的相似度 Args: image_features: 圖像特徵張量 text_features: 文本特徵張量 Returns: np.ndarray: 相似度分數數組 """ try: with torch.no_grad(): similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) similarity = similarity.cpu().numpy() if self.device == "cuda" else similarity.numpy() return similarity except Exception as e: self.logger.error(f"Error calculating similarity: {e}") self.logger.error(traceback.format_exc()) raise def preprocess_image(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor: """ 預處理圖像以供CLIP模型使用 Args: image: PIL圖像或numpy數組 Returns: torch.Tensor: 預處理後的圖像張量 """ try: if not isinstance(image, Image.Image): if isinstance(image, np.ndarray): image = Image.fromarray(image) else: raise ValueError("Unsupported image format. Expected PIL Image or numpy array.") image_input = self.preprocess(image).unsqueeze(0).to(self.device) return image_input except Exception as e: self.logger.error(f"Error preprocessing image: {e}") self.logger.error(traceback.format_exc()) raise def process_image_region(self, image: Union[Image.Image, np.ndarray], box: List[float]) -> torch.Tensor: """ 處理圖像的特定區域 Args: image: 原始圖像 box: 邊界框 [x1, y1, x2, y2] Returns: torch.Tensor: 區域圖像的特徵 """ try: # 確保圖像是PIL格式 if not isinstance(image, Image.Image): if isinstance(image, np.ndarray): image = Image.fromarray(image) else: raise ValueError("Unsupported image format. Expected PIL Image or numpy array.") # 裁剪區域 x1, y1, x2, y2 = map(int, box) cropped_image = image.crop((x1, y1, x2, y2)) # 預處理並編碼 image_input = self.preprocess_image(cropped_image) image_features = self.encode_image(image_input) return image_features except Exception as e: self.logger.error(f"Error processing image region: {e}") self.logger.error(traceback.format_exc()) raise def batch_process_regions(self, image: Union[Image.Image, np.ndarray], boxes: List[List[float]]) -> torch.Tensor: """ 批量處理多個圖像區域 Args: image: 原始圖像 boxes: 邊界框列表 Returns: torch.Tensor: 所有區域的圖像特徵 """ try: # ensure PIL format if not isinstance(image, Image.Image): if isinstance(image, np.ndarray): image = Image.fromarray(image) else: raise ValueError("Unsupported image format. Expected PIL Image or numpy array.") if not boxes: return torch.empty(0) # 裁剪並預處理所有區域 cropped_inputs = [] for box in boxes: x1, y1, x2, y2 = map(int, box) cropped_image = image.crop((x1, y1, x2, y2)) processed_image = self.preprocess(cropped_image).unsqueeze(0) cropped_inputs.append(processed_image) # 批量處理 batch_tensor = torch.cat(cropped_inputs).to(self.device) image_features = self.encode_image(batch_tensor) return image_features except Exception as e: self.logger.error(f"Error batch processing regions: {e}") self.logger.error(traceback.format_exc()) raise def is_model_loaded(self) -> bool: """ 檢查模型是否已成功載入 Returns: bool: 模型載入狀態 """ return self.model is not None and self.preprocess is not None def get_device(self) -> str: """ 獲取當前設備 Returns: str: 設備名稱 """ return self.device def get_model_name(self) -> str: """ 獲取模型名稱 Returns: str: 模型名稱 """ return self.model_name