Spaces:
Running
on
Zero
Running
on
Zero
Upload 3 files
Browse files- clip_analyzer.py +27 -30
- clip_model_manager.py +18 -24
- clip_zero_shot_classifier.py +4 -4
clip_analyzer.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
import
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
from typing import Dict, List, Tuple, Any, Optional, Union
|
@@ -20,14 +20,13 @@ class CLIPAnalyzer:
|
|
20 |
Use Clip to intergrate scene understanding function
|
21 |
"""
|
22 |
|
23 |
-
def __init__(self, model_name: str = "ViT-B
|
24 |
"""
|
25 |
-
初始化 CLIP
|
26 |
|
27 |
Args:
|
28 |
-
model_name:
|
29 |
-
device:
|
30 |
-
pretrained: 預訓練權重,使用 "laion2b_s34b_b79k"
|
31 |
"""
|
32 |
# 自動選擇設備
|
33 |
if device is None:
|
@@ -35,23 +34,12 @@ class CLIPAnalyzer:
|
|
35 |
else:
|
36 |
self.device = device
|
37 |
|
38 |
-
print(f"Loading
|
39 |
try:
|
40 |
-
self.model,
|
41 |
-
|
42 |
-
pretrained=pretrained,
|
43 |
-
device=self.device
|
44 |
-
)
|
45 |
-
self.tokenizer = open_clip.get_tokenizer(model_name)
|
46 |
-
print(f"OpenCLIP model loaded successfully.")
|
47 |
-
import gc
|
48 |
-
gc.collect()
|
49 |
-
if torch.cuda.is_available():
|
50 |
-
torch.cuda.empty_cache()
|
51 |
-
torch.cuda.synchronize()
|
52 |
-
print("Memory cleanup completed after OpenCLIP loading.")
|
53 |
except Exception as e:
|
54 |
-
print(f"Error loading
|
55 |
raise
|
56 |
|
57 |
self.scene_type_prompts = SCENE_TYPE_PROMPTS
|
@@ -76,7 +64,7 @@ class CLIPAnalyzer:
|
|
76 |
if scene_texts:
|
77 |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
|
78 |
try:
|
79 |
-
self.text_features_cache["scene_type_tokens"] =
|
80 |
except Exception as e:
|
81 |
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
|
82 |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
|
@@ -94,7 +82,7 @@ class CLIPAnalyzer:
|
|
94 |
for scene_type, prompts in self.cultural_scene_prompts.items():
|
95 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
96 |
try:
|
97 |
-
cultural_tokens_dict_val[scene_type] =
|
98 |
except Exception as e:
|
99 |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
|
100 |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
|
@@ -108,7 +96,7 @@ class CLIPAnalyzer:
|
|
108 |
if lighting_texts:
|
109 |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
|
110 |
try:
|
111 |
-
self.text_features_cache["lighting_tokens"] =
|
112 |
except Exception as e:
|
113 |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
|
114 |
self.text_features_cache["lighting_tokens"] = None
|
@@ -125,7 +113,7 @@ class CLIPAnalyzer:
|
|
125 |
for scene_type, prompts in self.specialized_scene_prompts.items():
|
126 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
127 |
try:
|
128 |
-
specialized_tokens_dict_val[scene_type] =
|
129 |
except Exception as e:
|
130 |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
|
131 |
specialized_tokens_dict_val[scene_type] = None
|
@@ -139,7 +127,7 @@ class CLIPAnalyzer:
|
|
139 |
if viewpoint_texts:
|
140 |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
|
141 |
try:
|
142 |
-
self.text_features_cache["viewpoint_tokens"] =
|
143 |
except Exception as e:
|
144 |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
|
145 |
self.text_features_cache["viewpoint_tokens"] = None
|
@@ -156,7 +144,7 @@ class CLIPAnalyzer:
|
|
156 |
if object_combination_texts:
|
157 |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
|
158 |
try:
|
159 |
-
self.text_features_cache["object_combination_tokens"] =
|
160 |
except Exception as e:
|
161 |
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
|
162 |
self.text_features_cache["object_combination_tokens"] = None
|
@@ -173,7 +161,7 @@ class CLIPAnalyzer:
|
|
173 |
if activity_texts:
|
174 |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
|
175 |
try:
|
176 |
-
self.text_features_cache["activity_tokens"] =
|
177 |
except Exception as e:
|
178 |
print(f"Warning: Error tokenizing activity_prompts: {e}")
|
179 |
self.text_features_cache["activity_tokens"] = None
|
@@ -192,7 +180,7 @@ class CLIPAnalyzer:
|
|
192 |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
|
193 |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
|
194 |
|
195 |
-
print("
|
196 |
|
197 |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
|
198 |
"""
|
@@ -593,7 +581,16 @@ class CLIPAnalyzer:
|
|
593 |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
|
594 |
|
595 |
def text_to_embedding(self, text: str) -> np.ndarray:
|
596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
|
598 |
with torch.no_grad():
|
599 |
text_features = self.model.encode_text(text_token)
|
|
|
1 |
import torch
|
2 |
+
import clip
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
from typing import Dict, List, Tuple, Any, Optional, Union
|
|
|
20 |
Use Clip to intergrate scene understanding function
|
21 |
"""
|
22 |
|
23 |
+
def __init__(self, model_name: str = "ViT-B/16", device: str = None):
|
24 |
"""
|
25 |
+
初始化 CLIP 分析器。
|
26 |
|
27 |
Args:
|
28 |
+
model_name: CLIP Model name, 默認 "ViT-B/16"
|
29 |
+
device: Use GPU if it can use
|
|
|
30 |
"""
|
31 |
# 自動選擇設備
|
32 |
if device is None:
|
|
|
34 |
else:
|
35 |
self.device = device
|
36 |
|
37 |
+
print(f"Loading CLIP model {model_name} on {self.device}...")
|
38 |
try:
|
39 |
+
self.model, self.preprocess = clip.load(model_name, device=self.device)
|
40 |
+
print(f"CLIP model loaded successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
except Exception as e:
|
42 |
+
print(f"Error loading CLIP model: {e}")
|
43 |
raise
|
44 |
|
45 |
self.scene_type_prompts = SCENE_TYPE_PROMPTS
|
|
|
64 |
if scene_texts:
|
65 |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
|
66 |
try:
|
67 |
+
self.text_features_cache["scene_type_tokens"] = clip.tokenize(scene_texts).to(self.device)
|
68 |
except Exception as e:
|
69 |
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
|
70 |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
|
|
|
82 |
for scene_type, prompts in self.cultural_scene_prompts.items():
|
83 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
84 |
try:
|
85 |
+
cultural_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
|
86 |
except Exception as e:
|
87 |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
|
88 |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
|
|
|
96 |
if lighting_texts:
|
97 |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
|
98 |
try:
|
99 |
+
self.text_features_cache["lighting_tokens"] = clip.tokenize(lighting_texts).to(self.device)
|
100 |
except Exception as e:
|
101 |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
|
102 |
self.text_features_cache["lighting_tokens"] = None
|
|
|
113 |
for scene_type, prompts in self.specialized_scene_prompts.items():
|
114 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
115 |
try:
|
116 |
+
specialized_tokens_dict_val[scene_type] = clip.tokenize(prompts).to(self.device)
|
117 |
except Exception as e:
|
118 |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
|
119 |
specialized_tokens_dict_val[scene_type] = None
|
|
|
127 |
if viewpoint_texts:
|
128 |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
|
129 |
try:
|
130 |
+
self.text_features_cache["viewpoint_tokens"] = clip.tokenize(viewpoint_texts).to(self.device)
|
131 |
except Exception as e:
|
132 |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
|
133 |
self.text_features_cache["viewpoint_tokens"] = None
|
|
|
144 |
if object_combination_texts:
|
145 |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
|
146 |
try:
|
147 |
+
self.text_features_cache["object_combination_tokens"] = clip.tokenize(object_combination_texts).to(self.device)
|
148 |
except Exception as e:
|
149 |
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
|
150 |
self.text_features_cache["object_combination_tokens"] = None
|
|
|
161 |
if activity_texts:
|
162 |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
|
163 |
try:
|
164 |
+
self.text_features_cache["activity_tokens"] = clip.tokenize(activity_texts).to(self.device)
|
165 |
except Exception as e:
|
166 |
print(f"Warning: Error tokenizing activity_prompts: {e}")
|
167 |
self.text_features_cache["activity_tokens"] = None
|
|
|
180 |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
|
181 |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
|
182 |
|
183 |
+
print("CLIP text_features_cache prepared.")
|
184 |
|
185 |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
|
186 |
"""
|
|
|
581 |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
|
582 |
|
583 |
def text_to_embedding(self, text: str) -> np.ndarray:
|
584 |
+
"""
|
585 |
+
將文本轉換為 CLIP 嵌入表示
|
586 |
+
|
587 |
+
Args:
|
588 |
+
text: 輸入文本
|
589 |
+
|
590 |
+
Returns:
|
591 |
+
np.ndarray: 文本的 CLIP 特徵向量
|
592 |
+
"""
|
593 |
+
text_token = clip.tokenize([text]).to(self.device)
|
594 |
|
595 |
with torch.no_grad():
|
596 |
text_features = self.model.encode_text(text_token)
|
clip_model_manager.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
|
2 |
import torch
|
3 |
-
import
|
4 |
import numpy as np
|
5 |
import logging
|
6 |
import traceback
|
@@ -12,7 +12,7 @@ class CLIPModelManager:
|
|
12 |
專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
|
13 |
"""
|
14 |
|
15 |
-
def __init__(self, model_name: str = "ViT-B
|
16 |
"""
|
17 |
初始化 CLIP 模型管理器
|
18 |
|
@@ -22,8 +22,6 @@ class CLIPModelManager:
|
|
22 |
"""
|
23 |
self.logger = logging.getLogger(__name__)
|
24 |
self.model_name = model_name
|
25 |
-
self.pretrained = pretrained
|
26 |
-
self.tokenizer = None
|
27 |
|
28 |
# 設置運行設備
|
29 |
if device is None:
|
@@ -31,32 +29,19 @@ class CLIPModelManager:
|
|
31 |
else:
|
32 |
self.device = device
|
33 |
|
|
|
34 |
self.preprocess = None
|
35 |
|
36 |
self._initialize_model()
|
37 |
|
38 |
def _initialize_model(self):
|
39 |
"""
|
40 |
-
初始化
|
41 |
"""
|
42 |
try:
|
43 |
-
self.logger.info(f"Initializing
|
44 |
-
self.model,
|
45 |
-
|
46 |
-
pretrained=self.pretrained,
|
47 |
-
device=self.device
|
48 |
-
)
|
49 |
-
self.tokenizer = open_clip.get_tokenizer(self.model_name)
|
50 |
-
self.logger.info("Successfully loaded OpenCLIP model")
|
51 |
-
|
52 |
-
# 立即清理 OpenCLIP 載入過程中的記憶體碎片
|
53 |
-
import gc
|
54 |
-
gc.collect()
|
55 |
-
if torch.cuda.is_available():
|
56 |
-
torch.cuda.empty_cache()
|
57 |
-
torch.cuda.synchronize()
|
58 |
-
self.logger.info("Memory cleanup completed after OpenCLIP loading in CLIPModelManager")
|
59 |
-
|
60 |
except Exception as e:
|
61 |
self.logger.error(f"Error loading CLIP model: {e}")
|
62 |
self.logger.error(traceback.format_exc())
|
@@ -102,7 +87,7 @@ class CLIPModelManager:
|
|
102 |
|
103 |
for i in range(0, len(text_prompts), batch_size):
|
104 |
batch_prompts = text_prompts[i:i+batch_size]
|
105 |
-
text_tokens =
|
106 |
batch_features = self.model.encode_text(text_tokens)
|
107 |
batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
|
108 |
features_list.append(batch_features)
|
@@ -121,9 +106,18 @@ class CLIPModelManager:
|
|
121 |
raise
|
122 |
|
123 |
def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
try:
|
125 |
with torch.no_grad():
|
126 |
-
text_tokens =
|
127 |
text_features = self.model.encode_text(text_tokens)
|
128 |
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
129 |
return text_features
|
|
|
1 |
|
2 |
import torch
|
3 |
+
import clip
|
4 |
import numpy as np
|
5 |
import logging
|
6 |
import traceback
|
|
|
12 |
專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
|
13 |
"""
|
14 |
|
15 |
+
def __init__(self, model_name: str = "ViT-B/16", device: str = None):
|
16 |
"""
|
17 |
初始化 CLIP 模型管理器
|
18 |
|
|
|
22 |
"""
|
23 |
self.logger = logging.getLogger(__name__)
|
24 |
self.model_name = model_name
|
|
|
|
|
25 |
|
26 |
# 設置運行設備
|
27 |
if device is None:
|
|
|
29 |
else:
|
30 |
self.device = device
|
31 |
|
32 |
+
self.model = None
|
33 |
self.preprocess = None
|
34 |
|
35 |
self._initialize_model()
|
36 |
|
37 |
def _initialize_model(self):
|
38 |
"""
|
39 |
+
初始化CLIP模型
|
40 |
"""
|
41 |
try:
|
42 |
+
self.logger.info(f"Initializing CLIP model ({self.model_name}) on {self.device}")
|
43 |
+
self.model, self.preprocess = clip.load(self.model_name, device=self.device)
|
44 |
+
self.logger.info("Successfully loaded CLIP model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
except Exception as e:
|
46 |
self.logger.error(f"Error loading CLIP model: {e}")
|
47 |
self.logger.error(traceback.format_exc())
|
|
|
87 |
|
88 |
for i in range(0, len(text_prompts), batch_size):
|
89 |
batch_prompts = text_prompts[i:i+batch_size]
|
90 |
+
text_tokens = clip.tokenize(batch_prompts).to(self.device)
|
91 |
batch_features = self.model.encode_text(text_tokens)
|
92 |
batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
|
93 |
features_list.append(batch_features)
|
|
|
106 |
raise
|
107 |
|
108 |
def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
|
109 |
+
"""
|
110 |
+
編碼單個文本批次的特徵
|
111 |
+
|
112 |
+
Args:
|
113 |
+
text_prompts: 文本提示列表
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
torch.Tensor: 標準化後的文本特徵
|
117 |
+
"""
|
118 |
try:
|
119 |
with torch.no_grad():
|
120 |
+
text_tokens = clip.tokenize(text_prompts).to(self.device)
|
121 |
text_features = self.model.encode_text(text_tokens)
|
122 |
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
123 |
return text_features
|
clip_zero_shot_classifier.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
|
2 |
import torch
|
3 |
-
import
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
import logging
|
@@ -21,18 +21,18 @@ class CLIPZeroShotClassifier:
|
|
21 |
這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
|
22 |
"""
|
23 |
|
24 |
-
def __init__(self, model_name: str = "ViT-B
|
25 |
"""
|
26 |
初始化CLIP零樣本分類器
|
27 |
|
28 |
Args:
|
29 |
-
model_name:
|
30 |
device: 運行設備,None則自動選擇
|
31 |
"""
|
32 |
self.logger = logging.getLogger(__name__)
|
33 |
|
34 |
# 初始化各個組件
|
35 |
-
self.clip_model_manager = CLIPModelManager(model_name, device
|
36 |
self.landmark_data_manager = LandmarkDataManager()
|
37 |
self.image_analyzer = ImageAnalyzer()
|
38 |
self.confidence_manager = ConfidenceManager()
|
|
|
1 |
|
2 |
import torch
|
3 |
+
import clip
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
import logging
|
|
|
21 |
這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
|
22 |
"""
|
23 |
|
24 |
+
def __init__(self, model_name: str = "ViT-B/16", device: str = None):
|
25 |
"""
|
26 |
初始化CLIP零樣本分類器
|
27 |
|
28 |
Args:
|
29 |
+
model_name: CLIP模型名稱,默認為"ViT-B/16"
|
30 |
device: 運行設備,None則自動選擇
|
31 |
"""
|
32 |
self.logger = logging.getLogger(__name__)
|
33 |
|
34 |
# 初始化各個組件
|
35 |
+
self.clip_model_manager = CLIPModelManager(model_name, device)
|
36 |
self.landmark_data_manager = LandmarkDataManager()
|
37 |
self.image_analyzer = ImageAnalyzer()
|
38 |
self.confidence_manager = ConfidenceManager()
|