DawnC commited on
Commit
2fa80e1
·
verified ·
1 Parent(s): ddc065a

Upload 3 files

Browse files
clip_analyzer.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- import open_clip
3
  import numpy as np
4
  from PIL import Image
5
  from typing import Dict, List, Tuple, Any, Optional, Union
@@ -20,14 +20,13 @@ class CLIPAnalyzer:
20
  Use Clip to intergrate scene understanding function
21
  """
22
 
23
- def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
24
  """
25
- 初始化 CLIP 分析器,使用 OpenCLIP 實現
26
 
27
  Args:
28
- model_name: OpenCLIP 模型名稱,默認 "ViT-B-16"
29
- device: 運行設備
30
- pretrained: 預訓練權重,使用 "laion2b_s34b_b79k"
31
  """
32
  # 自動選擇設備
33
  if device is None:
@@ -35,23 +34,12 @@ class CLIPAnalyzer:
35
  else:
36
  self.device = device
37
 
38
- print(f"Loading OpenCLIP model {model_name} with {pretrained} on {self.device}...")
39
  try:
40
- self.model, _, self.preprocess = open_clip.create_model_and_transforms(
41
- model_name,
42
- pretrained=pretrained,
43
- device=self.device
44
- )
45
- self.tokenizer = open_clip.get_tokenizer(model_name)
46
- print(f"OpenCLIP model loaded successfully.")
47
- import gc
48
- gc.collect()
49
- if torch.cuda.is_available():
50
- torch.cuda.empty_cache()
51
- torch.cuda.synchronize()
52
- print("Memory cleanup completed after OpenCLIP loading.")
53
  except Exception as e:
54
- print(f"Error loading OpenCLIP model: {e}")
55
  raise
56
 
57
  self.scene_type_prompts = SCENE_TYPE_PROMPTS
@@ -76,7 +64,7 @@ class CLIPAnalyzer:
76
  if scene_texts:
77
  self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
78
  try:
79
- self.text_features_cache["scene_type_tokens"] = self.tokenizer(scene_texts).to(self.device)
80
  except Exception as e:
81
  print(f"Warning: Error tokenizing scene_type_prompts: {e}")
82
  self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
@@ -94,7 +82,7 @@ class CLIPAnalyzer:
94
  for scene_type, prompts in self.cultural_scene_prompts.items():
95
  if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
96
  try:
97
- cultural_tokens_dict_val[scene_type] = self.tokenizer(prompts).to(self.device)
98
  except Exception as e:
99
  print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
100
  cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
@@ -108,7 +96,7 @@ class CLIPAnalyzer:
108
  if lighting_texts:
109
  self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
110
  try:
111
- self.text_features_cache["lighting_tokens"] = self.tokenizer(lighting_texts).to(self.device)
112
  except Exception as e:
113
  print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
114
  self.text_features_cache["lighting_tokens"] = None
@@ -125,7 +113,7 @@ class CLIPAnalyzer:
125
  for scene_type, prompts in self.specialized_scene_prompts.items():
126
  if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
127
  try:
128
- specialized_tokens_dict_val[scene_type] = self.tokenizer(prompts).to(self.device)
129
  except Exception as e:
130
  print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
131
  specialized_tokens_dict_val[scene_type] = None
@@ -139,7 +127,7 @@ class CLIPAnalyzer:
139
  if viewpoint_texts:
140
  self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
141
  try:
142
- self.text_features_cache["viewpoint_tokens"] = self.tokenizer(viewpoint_texts).to(self.device)
143
  except Exception as e:
144
  print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
145
  self.text_features_cache["viewpoint_tokens"] = None
@@ -156,7 +144,7 @@ class CLIPAnalyzer:
156
  if object_combination_texts:
157
  self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
158
  try:
159
- self.text_features_cache["object_combination_tokens"] = self.tokenizer(object_combination_texts).to(self.device)
160
  except Exception as e:
161
  print(f"Warning: Error tokenizing object_combination_prompts: {e}")
162
  self.text_features_cache["object_combination_tokens"] = None
@@ -173,7 +161,7 @@ class CLIPAnalyzer:
173
  if activity_texts:
174
  self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
175
  try:
176
- self.text_features_cache["activity_tokens"] = self.tokenizer(activity_texts).to(self.device)
177
  except Exception as e:
178
  print(f"Warning: Error tokenizing activity_prompts: {e}")
179
  self.text_features_cache["activity_tokens"] = None
@@ -192,7 +180,7 @@ class CLIPAnalyzer:
192
  self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
193
  self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
194
 
195
- print("OpenCLIP text_features_cache prepared.")
196
 
197
  def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
198
  """
@@ -593,7 +581,16 @@ class CLIPAnalyzer:
593
  return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
594
 
595
  def text_to_embedding(self, text: str) -> np.ndarray:
596
- text_token = self.tokenizer([text]).to(self.device)
 
 
 
 
 
 
 
 
 
597
 
598
  with torch.no_grad():
599
  text_features = self.model.encode_text(text_token)
 
1
  import torch
2
+ import clip
3
  import numpy as np
4
  from PIL import Image
5
  from typing import Dict, List, Tuple, Any, Optional, Union
 
20
  Use Clip to intergrate scene understanding function
21
  """
22
 
23
+ def __init__(self, model_name: str = "ViT-B/16", device: str = None):
24
  """
25
+ 初始化 CLIP 分析器。
26
 
27
  Args:
28
+ model_name: CLIP Model name, 默認 "ViT-B/16"
29
+ device: Use GPU if it can use
 
30
  """
31
  # 自動選擇設備
32
  if device is None:
 
34
  else:
35
  self.device = device
36
 
37
+ print(f"Loading CLIP model {model_name} on {self.device}...")
38
  try:
39
+ self.model, self.preprocess = clip.load(model_name, device=self.device)
40
+ print(f"CLIP model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
+ print(f"Error loading CLIP model: {e}")
43
  raise
44
 
45
  self.scene_type_prompts = SCENE_TYPE_PROMPTS
 
64
  if scene_texts:
65
  self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
66
  try:
67
+ self.text_features_cache["scene_type_tokens"] = clip.tokenize(scene_texts).to(self.device)
68
  except Exception as e:
69
  print(f"Warning: Error tokenizing scene_type_prompts: {e}")
70
  self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
 
82
  for scene_type, prompts in self.cultural_scene_prompts.items():
83
  if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
84
  try:
85
+ cultural_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
86
  except Exception as e:
87
  print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
88
  cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
 
96
  if lighting_texts:
97
  self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
98
  try:
99
+ self.text_features_cache["lighting_tokens"] = clip.tokenize(lighting_texts).to(self.device)
100
  except Exception as e:
101
  print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
102
  self.text_features_cache["lighting_tokens"] = None
 
113
  for scene_type, prompts in self.specialized_scene_prompts.items():
114
  if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
115
  try:
116
+ specialized_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
117
  except Exception as e:
118
  print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
119
  specialized_tokens_dict_val[scene_type] = None
 
127
  if viewpoint_texts:
128
  self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
129
  try:
130
+ self.text_features_cache["viewpoint_tokens"] = clip.tokenize(viewpoint_texts).to(self.device)
131
  except Exception as e:
132
  print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
133
  self.text_features_cache["viewpoint_tokens"] = None
 
144
  if object_combination_texts:
145
  self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
146
  try:
147
+ self.text_features_cache["object_combination_tokens"] = clip.tokenize(object_combination_texts).to(self.device)
148
  except Exception as e:
149
  print(f"Warning: Error tokenizing object_combination_prompts: {e}")
150
  self.text_features_cache["object_combination_tokens"] = None
 
161
  if activity_texts:
162
  self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
163
  try:
164
+ self.text_features_cache["activity_tokens"] = clip.tokenize(activity_texts).to(self.device)
165
  except Exception as e:
166
  print(f"Warning: Error tokenizing activity_prompts: {e}")
167
  self.text_features_cache["activity_tokens"] = None
 
180
  self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
181
  self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
182
 
183
+ print("CLIP text_features_cache prepared.")
184
 
185
  def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
186
  """
 
581
  return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
582
 
583
  def text_to_embedding(self, text: str) -> np.ndarray:
584
+ """
585
+ 將文本轉換為 CLIP 嵌入表示
586
+
587
+ Args:
588
+ text: 輸入文本
589
+
590
+ Returns:
591
+ np.ndarray: 文本的 CLIP 特徵向量
592
+ """
593
+ text_token = clip.tokenize([text]).to(self.device)
594
 
595
  with torch.no_grad():
596
  text_features = self.model.encode_text(text_token)
clip_model_manager.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  import torch
3
- import open_clip
4
  import numpy as np
5
  import logging
6
  import traceback
@@ -12,7 +12,7 @@ class CLIPModelManager:
12
  專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
13
  """
14
 
15
- def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
16
  """
17
  初始化 CLIP 模型管理器
18
 
@@ -22,8 +22,6 @@ class CLIPModelManager:
22
  """
23
  self.logger = logging.getLogger(__name__)
24
  self.model_name = model_name
25
- self.pretrained = pretrained
26
- self.tokenizer = None
27
 
28
  # 設置運行設備
29
  if device is None:
@@ -31,32 +29,19 @@ class CLIPModelManager:
31
  else:
32
  self.device = device
33
 
 
34
  self.preprocess = None
35
 
36
  self._initialize_model()
37
 
38
  def _initialize_model(self):
39
  """
40
- 初始化OpenCLIP模型
41
  """
42
  try:
43
- self.logger.info(f"Initializing OpenCLIP model ({self.model_name}) with {self.pretrained} on {self.device}")
44
- self.model, _, self.preprocess = open_clip.create_model_and_transforms(
45
- self.model_name,
46
- pretrained=self.pretrained,
47
- device=self.device
48
- )
49
- self.tokenizer = open_clip.get_tokenizer(self.model_name)
50
- self.logger.info("Successfully loaded OpenCLIP model")
51
-
52
- # 立即清理 OpenCLIP 載入過程中的記憶體碎片
53
- import gc
54
- gc.collect()
55
- if torch.cuda.is_available():
56
- torch.cuda.empty_cache()
57
- torch.cuda.synchronize()
58
- self.logger.info("Memory cleanup completed after OpenCLIP loading in CLIPModelManager")
59
-
60
  except Exception as e:
61
  self.logger.error(f"Error loading CLIP model: {e}")
62
  self.logger.error(traceback.format_exc())
@@ -102,7 +87,7 @@ class CLIPModelManager:
102
 
103
  for i in range(0, len(text_prompts), batch_size):
104
  batch_prompts = text_prompts[i:i+batch_size]
105
- text_tokens = self.tokenizer(batch_prompts).to(self.device)
106
  batch_features = self.model.encode_text(text_tokens)
107
  batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
108
  features_list.append(batch_features)
@@ -121,9 +106,18 @@ class CLIPModelManager:
121
  raise
122
 
123
  def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
124
  try:
125
  with torch.no_grad():
126
- text_tokens = self.tokenizer(text_prompts).to(self.device)
127
  text_features = self.model.encode_text(text_tokens)
128
  text_features = text_features / text_features.norm(dim=-1, keepdim=True)
129
  return text_features
 
1
 
2
  import torch
3
+ import clip
4
  import numpy as np
5
  import logging
6
  import traceback
 
12
  專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
13
  """
14
 
15
+ def __init__(self, model_name: str = "ViT-B/16", device: str = None):
16
  """
17
  初始化 CLIP 模型管理器
18
 
 
22
  """
23
  self.logger = logging.getLogger(__name__)
24
  self.model_name = model_name
 
 
25
 
26
  # 設置運行設備
27
  if device is None:
 
29
  else:
30
  self.device = device
31
 
32
+ self.model = None
33
  self.preprocess = None
34
 
35
  self._initialize_model()
36
 
37
  def _initialize_model(self):
38
  """
39
+ 初始化CLIP模型
40
  """
41
  try:
42
+ self.logger.info(f"Initializing CLIP model ({self.model_name}) on {self.device}")
43
+ self.model, self.preprocess = clip.load(self.model_name, device=self.device)
44
+ self.logger.info("Successfully loaded CLIP model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  except Exception as e:
46
  self.logger.error(f"Error loading CLIP model: {e}")
47
  self.logger.error(traceback.format_exc())
 
87
 
88
  for i in range(0, len(text_prompts), batch_size):
89
  batch_prompts = text_prompts[i:i+batch_size]
90
+ text_tokens = clip.tokenize(batch_prompts).to(self.device)
91
  batch_features = self.model.encode_text(text_tokens)
92
  batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
93
  features_list.append(batch_features)
 
106
  raise
107
 
108
  def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
109
+ """
110
+ 編碼單個文本批次的特徵
111
+
112
+ Args:
113
+ text_prompts: 文本提示列表
114
+
115
+ Returns:
116
+ torch.Tensor: 標準化後的文本特徵
117
+ """
118
  try:
119
  with torch.no_grad():
120
+ text_tokens = clip.tokenize(text_prompts).to(self.device)
121
  text_features = self.model.encode_text(text_tokens)
122
  text_features = text_features / text_features.norm(dim=-1, keepdim=True)
123
  return text_features
clip_zero_shot_classifier.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  import torch
3
- import open_clip
4
  from PIL import Image
5
  import numpy as np
6
  import logging
@@ -21,18 +21,18 @@ class CLIPZeroShotClassifier:
21
  這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
22
  """
23
 
24
- def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
25
  """
26
  初始化CLIP零樣本分類器
27
 
28
  Args:
29
- model_name: OpenCLIP模型名稱,默認為"ViT-B-16"
30
  device: 運行設備,None則自動選擇
31
  """
32
  self.logger = logging.getLogger(__name__)
33
 
34
  # 初始化各個組件
35
- self.clip_model_manager = CLIPModelManager(model_name, device, pretrained)
36
  self.landmark_data_manager = LandmarkDataManager()
37
  self.image_analyzer = ImageAnalyzer()
38
  self.confidence_manager = ConfidenceManager()
 
1
 
2
  import torch
3
+ import clip
4
  from PIL import Image
5
  import numpy as np
6
  import logging
 
21
  這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
22
  """
23
 
24
+ def __init__(self, model_name: str = "ViT-B/16", device: str = None):
25
  """
26
  初始化CLIP零樣本分類器
27
 
28
  Args:
29
+ model_name: CLIP模型名稱,默認為"ViT-B/16"
30
  device: 運行設備,None則自動選擇
31
  """
32
  self.logger = logging.getLogger(__name__)
33
 
34
  # 初始化各個組件
35
+ self.clip_model_manager = CLIPModelManager(model_name, device)
36
  self.landmark_data_manager = LandmarkDataManager()
37
  self.image_analyzer = ImageAnalyzer()
38
  self.confidence_manager = ConfidenceManager()