Spaces:
Running
on
Zero
Running
on
Zero
Upload 8 files
Browse files- clip_analyzer.py +24 -27
- clip_model_manager.py +15 -18
- clip_prompts.py +128 -5
- clip_zero_shot_classifier.py +4 -4
- llm_enhancer.py +2 -2
- llm_model_manager.py +358 -0
- requirements.txt +2 -2
- scene_scoring_engine.py +9 -8
clip_analyzer.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
import
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
from typing import Dict, List, Tuple, Any, Optional, Union
|
@@ -20,13 +20,14 @@ class CLIPAnalyzer:
|
|
20 |
Use Clip to intergrate scene understanding function
|
21 |
"""
|
22 |
|
23 |
-
def __init__(self, model_name: str = "ViT-B
|
24 |
"""
|
25 |
-
初始化 CLIP
|
26 |
|
27 |
Args:
|
28 |
-
model_name:
|
29 |
-
device:
|
|
|
30 |
"""
|
31 |
# 自動選擇設備
|
32 |
if device is None:
|
@@ -34,12 +35,17 @@ class CLIPAnalyzer:
|
|
34 |
else:
|
35 |
self.device = device
|
36 |
|
37 |
-
print(f"Loading
|
38 |
try:
|
39 |
-
self.model, self.preprocess =
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
except Exception as e:
|
42 |
-
print(f"Error loading
|
43 |
raise
|
44 |
|
45 |
self.scene_type_prompts = SCENE_TYPE_PROMPTS
|
@@ -64,7 +70,7 @@ class CLIPAnalyzer:
|
|
64 |
if scene_texts:
|
65 |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
|
66 |
try:
|
67 |
-
self.text_features_cache["scene_type_tokens"] =
|
68 |
except Exception as e:
|
69 |
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
|
70 |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
|
@@ -82,7 +88,7 @@ class CLIPAnalyzer:
|
|
82 |
for scene_type, prompts in self.cultural_scene_prompts.items():
|
83 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
84 |
try:
|
85 |
-
cultural_tokens_dict_val[scene_type] =
|
86 |
except Exception as e:
|
87 |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
|
88 |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
|
@@ -96,7 +102,7 @@ class CLIPAnalyzer:
|
|
96 |
if lighting_texts:
|
97 |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
|
98 |
try:
|
99 |
-
self.text_features_cache["lighting_tokens"] =
|
100 |
except Exception as e:
|
101 |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
|
102 |
self.text_features_cache["lighting_tokens"] = None
|
@@ -113,7 +119,7 @@ class CLIPAnalyzer:
|
|
113 |
for scene_type, prompts in self.specialized_scene_prompts.items():
|
114 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
115 |
try:
|
116 |
-
specialized_tokens_dict_val[scene_type] =
|
117 |
except Exception as e:
|
118 |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
|
119 |
specialized_tokens_dict_val[scene_type] = None
|
@@ -127,7 +133,7 @@ class CLIPAnalyzer:
|
|
127 |
if viewpoint_texts:
|
128 |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
|
129 |
try:
|
130 |
-
self.text_features_cache["viewpoint_tokens"] =
|
131 |
except Exception as e:
|
132 |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
|
133 |
self.text_features_cache["viewpoint_tokens"] = None
|
@@ -144,7 +150,7 @@ class CLIPAnalyzer:
|
|
144 |
if object_combination_texts:
|
145 |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
|
146 |
try:
|
147 |
-
self.text_features_cache["object_combination_tokens"] =
|
148 |
except Exception as e:
|
149 |
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
|
150 |
self.text_features_cache["object_combination_tokens"] = None
|
@@ -161,7 +167,7 @@ class CLIPAnalyzer:
|
|
161 |
if activity_texts:
|
162 |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
|
163 |
try:
|
164 |
-
self.text_features_cache["activity_tokens"] =
|
165 |
except Exception as e:
|
166 |
print(f"Warning: Error tokenizing activity_prompts: {e}")
|
167 |
self.text_features_cache["activity_tokens"] = None
|
@@ -180,7 +186,7 @@ class CLIPAnalyzer:
|
|
180 |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
|
181 |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
|
182 |
|
183 |
-
print("
|
184 |
|
185 |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
|
186 |
"""
|
@@ -581,16 +587,7 @@ class CLIPAnalyzer:
|
|
581 |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
|
582 |
|
583 |
def text_to_embedding(self, text: str) -> np.ndarray:
|
584 |
-
|
585 |
-
將文本轉換為 CLIP 嵌入表示
|
586 |
-
|
587 |
-
Args:
|
588 |
-
text: 輸入文本
|
589 |
-
|
590 |
-
Returns:
|
591 |
-
np.ndarray: 文本的 CLIP 特徵向量
|
592 |
-
"""
|
593 |
-
text_token = clip.tokenize([text]).to(self.device)
|
594 |
|
595 |
with torch.no_grad():
|
596 |
text_features = self.model.encode_text(text_token)
|
|
|
1 |
import torch
|
2 |
+
import open_clip
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
from typing import Dict, List, Tuple, Any, Optional, Union
|
|
|
20 |
Use Clip to intergrate scene understanding function
|
21 |
"""
|
22 |
|
23 |
+
def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
|
24 |
"""
|
25 |
+
初始化 CLIP 分析器,使用 OpenCLIP 實現
|
26 |
|
27 |
Args:
|
28 |
+
model_name: OpenCLIP 模型名稱,默認 "ViT-B-16"
|
29 |
+
device: 運行設備
|
30 |
+
pretrained: 預訓練權重,使用 "laion2b_s34b_b79k"
|
31 |
"""
|
32 |
# 自動選擇設備
|
33 |
if device is None:
|
|
|
35 |
else:
|
36 |
self.device = device
|
37 |
|
38 |
+
print(f"Loading OpenCLIP model {model_name} with {pretrained} on {self.device}...")
|
39 |
try:
|
40 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
41 |
+
model_name,
|
42 |
+
pretrained=pretrained,
|
43 |
+
device=self.device
|
44 |
+
)
|
45 |
+
self.tokenizer = open_clip.get_tokenizer(model_name)
|
46 |
+
print(f"OpenCLIP model loaded successfully.")
|
47 |
except Exception as e:
|
48 |
+
print(f"Error loading OpenCLIP model: {e}")
|
49 |
raise
|
50 |
|
51 |
self.scene_type_prompts = SCENE_TYPE_PROMPTS
|
|
|
70 |
if scene_texts:
|
71 |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
|
72 |
try:
|
73 |
+
self.text_features_cache["scene_type_tokens"] = self.tokenizer(scene_texts).to(self.device)
|
74 |
except Exception as e:
|
75 |
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
|
76 |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
|
|
|
88 |
for scene_type, prompts in self.cultural_scene_prompts.items():
|
89 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
90 |
try:
|
91 |
+
cultural_tokens_dict_val[scene_type] = self.tokenizer(prompts).to(self.device)
|
92 |
except Exception as e:
|
93 |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
|
94 |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
|
|
|
102 |
if lighting_texts:
|
103 |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
|
104 |
try:
|
105 |
+
self.text_features_cache["lighting_tokens"] = self.tokenizer(lighting_texts).to(self.device)
|
106 |
except Exception as e:
|
107 |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
|
108 |
self.text_features_cache["lighting_tokens"] = None
|
|
|
119 |
for scene_type, prompts in self.specialized_scene_prompts.items():
|
120 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
121 |
try:
|
122 |
+
specialized_tokens_dict_val[scene_type] = self.tokenizer(prompts).to(self.device)
|
123 |
except Exception as e:
|
124 |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
|
125 |
specialized_tokens_dict_val[scene_type] = None
|
|
|
133 |
if viewpoint_texts:
|
134 |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
|
135 |
try:
|
136 |
+
self.text_features_cache["viewpoint_tokens"] = self.tokenizer(viewpoint_texts).to(self.device)
|
137 |
except Exception as e:
|
138 |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
|
139 |
self.text_features_cache["viewpoint_tokens"] = None
|
|
|
150 |
if object_combination_texts:
|
151 |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
|
152 |
try:
|
153 |
+
self.text_features_cache["object_combination_tokens"] = self.tokenizer(object_combination_texts).to(self.device)
|
154 |
except Exception as e:
|
155 |
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
|
156 |
self.text_features_cache["object_combination_tokens"] = None
|
|
|
167 |
if activity_texts:
|
168 |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
|
169 |
try:
|
170 |
+
self.text_features_cache["activity_tokens"] = self.tokenizer(activity_texts).to(self.device)
|
171 |
except Exception as e:
|
172 |
print(f"Warning: Error tokenizing activity_prompts: {e}")
|
173 |
self.text_features_cache["activity_tokens"] = None
|
|
|
186 |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
|
187 |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
|
188 |
|
189 |
+
print("OpenCLIP text_features_cache prepared.")
|
190 |
|
191 |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
|
192 |
"""
|
|
|
587 |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
|
588 |
|
589 |
def text_to_embedding(self, text: str) -> np.ndarray:
|
590 |
+
text_token = self.tokenizer([text]).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
|
592 |
with torch.no_grad():
|
593 |
text_features = self.model.encode_text(text_token)
|
clip_model_manager.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
|
2 |
import torch
|
3 |
-
import
|
4 |
import numpy as np
|
5 |
import logging
|
6 |
import traceback
|
@@ -12,7 +12,7 @@ class CLIPModelManager:
|
|
12 |
專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
|
13 |
"""
|
14 |
|
15 |
-
def __init__(self, model_name: str = "ViT-B
|
16 |
"""
|
17 |
初始化 CLIP 模型管理器
|
18 |
|
@@ -22,6 +22,8 @@ class CLIPModelManager:
|
|
22 |
"""
|
23 |
self.logger = logging.getLogger(__name__)
|
24 |
self.model_name = model_name
|
|
|
|
|
25 |
|
26 |
# 設置運行設備
|
27 |
if device is None:
|
@@ -29,19 +31,23 @@ class CLIPModelManager:
|
|
29 |
else:
|
30 |
self.device = device
|
31 |
|
32 |
-
self.model = None
|
33 |
self.preprocess = None
|
34 |
|
35 |
self._initialize_model()
|
36 |
|
37 |
def _initialize_model(self):
|
38 |
"""
|
39 |
-
初始化
|
40 |
"""
|
41 |
try:
|
42 |
-
self.logger.info(f"Initializing
|
43 |
-
self.model, self.preprocess =
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
except Exception as e:
|
46 |
self.logger.error(f"Error loading CLIP model: {e}")
|
47 |
self.logger.error(traceback.format_exc())
|
@@ -87,7 +93,7 @@ class CLIPModelManager:
|
|
87 |
|
88 |
for i in range(0, len(text_prompts), batch_size):
|
89 |
batch_prompts = text_prompts[i:i+batch_size]
|
90 |
-
text_tokens =
|
91 |
batch_features = self.model.encode_text(text_tokens)
|
92 |
batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
|
93 |
features_list.append(batch_features)
|
@@ -106,18 +112,9 @@ class CLIPModelManager:
|
|
106 |
raise
|
107 |
|
108 |
def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
|
109 |
-
"""
|
110 |
-
編碼單個文本批次的特徵
|
111 |
-
|
112 |
-
Args:
|
113 |
-
text_prompts: 文本提示列表
|
114 |
-
|
115 |
-
Returns:
|
116 |
-
torch.Tensor: 標準化後的文本特徵
|
117 |
-
"""
|
118 |
try:
|
119 |
with torch.no_grad():
|
120 |
-
text_tokens =
|
121 |
text_features = self.model.encode_text(text_tokens)
|
122 |
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
123 |
return text_features
|
|
|
1 |
|
2 |
import torch
|
3 |
+
import open_clip
|
4 |
import numpy as np
|
5 |
import logging
|
6 |
import traceback
|
|
|
12 |
專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
|
13 |
"""
|
14 |
|
15 |
+
def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
|
16 |
"""
|
17 |
初始化 CLIP 模型管理器
|
18 |
|
|
|
22 |
"""
|
23 |
self.logger = logging.getLogger(__name__)
|
24 |
self.model_name = model_name
|
25 |
+
self.pretrained = pretrained
|
26 |
+
self.tokenizer = None
|
27 |
|
28 |
# 設置運行設備
|
29 |
if device is None:
|
|
|
31 |
else:
|
32 |
self.device = device
|
33 |
|
|
|
34 |
self.preprocess = None
|
35 |
|
36 |
self._initialize_model()
|
37 |
|
38 |
def _initialize_model(self):
|
39 |
"""
|
40 |
+
初始化OpenCLIP模型
|
41 |
"""
|
42 |
try:
|
43 |
+
self.logger.info(f"Initializing OpenCLIP model ({self.model_name}) with {self.pretrained} on {self.device}")
|
44 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
45 |
+
self.model_name,
|
46 |
+
pretrained=self.pretrained,
|
47 |
+
device=self.device
|
48 |
+
)
|
49 |
+
self.tokenizer = open_clip.get_tokenizer(self.model_name)
|
50 |
+
self.logger.info("Successfully loaded OpenCLIP model")
|
51 |
except Exception as e:
|
52 |
self.logger.error(f"Error loading CLIP model: {e}")
|
53 |
self.logger.error(traceback.format_exc())
|
|
|
93 |
|
94 |
for i in range(0, len(text_prompts), batch_size):
|
95 |
batch_prompts = text_prompts[i:i+batch_size]
|
96 |
+
text_tokens = self.tokenizer(batch_prompts).to(self.device)
|
97 |
batch_features = self.model.encode_text(text_tokens)
|
98 |
batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
|
99 |
features_list.append(batch_features)
|
|
|
112 |
raise
|
113 |
|
114 |
def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
try:
|
116 |
with torch.no_grad():
|
117 |
+
text_tokens = self.tokenizer(text_prompts).to(self.device)
|
118 |
text_features = self.model.encode_text(text_tokens)
|
119 |
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
120 |
return text_features
|
clip_prompts.py
CHANGED
@@ -69,7 +69,49 @@ SCENE_TYPE_PROMPTS = {
|
|
69 |
"construction_site": "A photo of a construction site with building materials, equipment and workers.",
|
70 |
"medical_facility": "A photo of a medical facility with healthcare equipment and professional staff.",
|
71 |
"educational_setting": "A photo of an educational setting with learning spaces and academic resources.",
|
72 |
-
"professional_kitchen": "A photo of a professional commercial kitchen with industrial cooking equipment and food preparation stations."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
}
|
74 |
|
75 |
# 文化特定場景提示
|
@@ -151,6 +193,30 @@ COMPARATIVE_PROMPTS = {
|
|
151 |
"A street-level view showing pedestrian perspective and immediate surroundings.",
|
152 |
"A bird's-eye view of city organization and movement patterns from high above.",
|
153 |
"An eye-level perspective showing direct human interaction with urban elements."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
]
|
155 |
}
|
156 |
|
@@ -170,7 +236,16 @@ LIGHTING_CONDITION_PROMPTS = {
|
|
170 |
"mixed_lighting": "A scene with combined natural and artificial light sources creating transition zones.",
|
171 |
"beach_daylight": "A photo taken at a beach with bright natural sunlight and reflections from water.",
|
172 |
"sports_arena_lighting": "A photo of a sports venue illuminated by powerful overhead lighting systems.",
|
173 |
-
"kitchen_task_lighting": "A photo of a kitchen with focused lighting concentrated on work surfaces."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
}
|
175 |
|
176 |
# 針對新場景類型的特殊提示
|
@@ -228,6 +303,29 @@ SPECIALIZED_SCENE_PROMPTS = {
|
|
228 |
"A high-angle view of an intersection showing traffic and pedestrian flow patterns.",
|
229 |
"A drone perspective of urban crossing design viewed from directly above.",
|
230 |
"A vertical view of a street intersection showing crossing infrastructure."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
]
|
232 |
}
|
233 |
|
@@ -239,7 +337,15 @@ VIEWPOINT_PROMPTS = {
|
|
239 |
"bird_eye": "A photo taken from very high above showing a complete overhead perspective.",
|
240 |
"street_level": "A photo taken from the perspective of someone standing on the street.",
|
241 |
"interior": "A photo taken from inside a building showing the internal environment.",
|
242 |
-
"vehicular": "A photo taken from inside or mounted on a moving vehicle."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
}
|
244 |
|
245 |
OBJECT_COMBINATION_PROMPTS = {
|
@@ -250,7 +356,15 @@ OBJECT_COMBINATION_PROMPTS = {
|
|
250 |
"retail_environment": "A scene with merchandise displays, shoppers, and store fixtures.",
|
251 |
"crosswalk_scene": "A scene with street markings, pedestrians crossing, and traffic signals.",
|
252 |
"cooking_area": "A scene with stoves, prep surfaces, cooking utensils, and food items.",
|
253 |
-
"recreational_space": "A scene with sports equipment, play areas, and activity participants."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
}
|
255 |
|
256 |
ACTIVITY_PROMPTS = {
|
@@ -261,5 +375,14 @@ ACTIVITY_PROMPTS = {
|
|
261 |
"exercising": "People engaged in physical activities, using sports equipment, and training.",
|
262 |
"cooking": "People preparing food, using kitchen equipment, and creating meals.",
|
263 |
"crossing_street": "People walking across designated crosswalks and navigating intersections.",
|
264 |
-
"recreational_activity": "People engaged in leisure activities, games, and social recreation."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
}
|
|
|
69 |
"construction_site": "A photo of a construction site with building materials, equipment and workers.",
|
70 |
"medical_facility": "A photo of a medical facility with healthcare equipment and professional staff.",
|
71 |
"educational_setting": "A photo of an educational setting with learning spaces and academic resources.",
|
72 |
+
"professional_kitchen": "A photo of a professional commercial kitchen with industrial cooking equipment and food preparation stations.",
|
73 |
+
|
74 |
+
# 工作空間的多樣化
|
75 |
+
"modern_open_office": "A photo of a modern open office with collaborative workspaces, standing desks and contemporary furniture design.",
|
76 |
+
"traditional_cubicle_office": "A photo of a traditional office with individual cubicles, separated workstations and formal business environment.",
|
77 |
+
"home_office_study": "A photo of a home office or study room with personal workspace setup and residential comfort elements.",
|
78 |
+
"creative_workspace": "A photo of a creative workspace with design materials, artistic tools and inspiring work environment.",
|
79 |
+
"shared_workspace_hub": "A photo of a shared coworking space with flexible seating, community areas and collaborative atmosphere.",
|
80 |
+
|
81 |
+
# 用餐空間的情境化
|
82 |
+
"casual_family_dining": "A photo of a casual family dining area with comfortable seating and everyday meal setup.",
|
83 |
+
"formal_dining_room": "A photo of a formal dining room with elegant table setting and sophisticated dining arrangement.",
|
84 |
+
"breakfast_nook_area": "A photo of a cozy breakfast nook with intimate seating and morning meal atmosphere.",
|
85 |
+
"outdoor_patio_dining": "A photo of an outdoor patio dining area with weather-resistant furniture and al fresco dining setup.",
|
86 |
+
"kitchen_island_dining": "A photo of a kitchen island used for casual dining with bar-style seating and integrated cooking space.",
|
87 |
+
|
88 |
+
# 生活空間的使用情境
|
89 |
+
"family_entertainment_room": "A photo of a family room focused on entertainment with large TV, comfortable seating and recreational atmosphere.",
|
90 |
+
"reading_lounge_area": "A photo of a quiet reading area with comfortable chairs, good lighting and book storage.",
|
91 |
+
"social_gathering_space": "A photo of a living area arranged for social gatherings with multiple seating options and conversation-friendly layout.",
|
92 |
+
"relaxation_living_space": "A photo of a living room designed for relaxation with soft furnishings and calm atmosphere.",
|
93 |
+
|
94 |
+
# 商業空間的服務導向
|
95 |
+
"quick_service_restaurant": "A photo of a quick service restaurant with efficient ordering system and fast-casual dining setup.",
|
96 |
+
"coffee_shop_workspace": "A photo of a coffee shop that doubles as workspace with WiFi-friendly seating and laptop users.",
|
97 |
+
"boutique_retail_space": "A photo of a boutique retail store with curated merchandise display and personalized shopping experience.",
|
98 |
+
"convenience_store_market": "A photo of a convenience store with everyday items, quick shopping layout and accessible product arrangement.",
|
99 |
+
|
100 |
+
# 學習環境的專業化
|
101 |
+
"collaborative_classroom": "A photo of a modern classroom designed for group work with flexible seating and interactive learning setup.",
|
102 |
+
"lecture_hall_setting": "A photo of a traditional lecture hall with tiered seating and formal educational presentation setup.",
|
103 |
+
"study_hall_library": "A photo of a quiet study area in a library with individual study spaces and academic atmosphere.",
|
104 |
+
"computer_lab_classroom": "A photo of a computer lab or digital classroom with technology workstations and learning equipment.",
|
105 |
+
|
106 |
+
# 用時間判斷
|
107 |
+
"morning_routine_kitchen": "A photo of a kitchen during morning routine with breakfast preparation and daily startup activities.",
|
108 |
+
"evening_relaxation_living": "A photo of a living room in evening mode with dim lighting and relaxation activities.",
|
109 |
+
"weekend_leisure_space": "A photo of a living area during weekend with casual activities and relaxed atmosphere.",
|
110 |
+
|
111 |
+
# 活動強度的描述
|
112 |
+
"busy_work_environment": "A photo of an active workplace with multiple people engaged in work tasks and productive atmosphere.",
|
113 |
+
"quiet_study_atmosphere": "A photo of a peaceful study or work environment with focused activity and minimal distractions.",
|
114 |
+
"social_interaction_space": "A photo of a space designed for social interaction with multiple people engaging in conversation."
|
115 |
}
|
116 |
|
117 |
# 文化特定場景提示
|
|
|
193 |
"A street-level view showing pedestrian perspective and immediate surroundings.",
|
194 |
"A bird's-eye view of city organization and movement patterns from high above.",
|
195 |
"An eye-level perspective showing direct human interaction with urban elements."
|
196 |
+
],
|
197 |
+
"modern_vs_traditional_kitchen": [
|
198 |
+
"A modern kitchen with sleek stainless steel appliances, minimalist design and contemporary fixtures.",
|
199 |
+
"A traditional kitchen with classic wooden cabinets, vintage appliances and conventional design elements."
|
200 |
+
],
|
201 |
+
|
202 |
+
"business_vs_leisure_dining": [
|
203 |
+
"A business dining environment with professional atmosphere, formal table settings and corporate meeting setup.",
|
204 |
+
"A leisure dining space with relaxed atmosphere, casual seating and recreational meal environment."
|
205 |
+
],
|
206 |
+
|
207 |
+
"dense_vs_spacious_retail": [
|
208 |
+
"A densely packed retail space with closely arranged merchandise and compact shopping aisles.",
|
209 |
+
"A spacious retail environment with wide aisles, generous display spacing and open shopping layout."
|
210 |
+
],
|
211 |
+
|
212 |
+
"private_vs_shared_workspace": [
|
213 |
+
"A private office space with individual workstation, personal storage and isolated work environment.",
|
214 |
+
"A shared workspace with communal tables, collaborative areas and open interaction zones."
|
215 |
+
],
|
216 |
+
|
217 |
+
"functional_vs_aesthetic_space": [
|
218 |
+
"A purely functional workspace focused on efficiency with practical furniture and utilitarian design.",
|
219 |
+
"An aesthetically designed space emphasizing visual appeal with decorative elements and stylistic choices."
|
220 |
]
|
221 |
}
|
222 |
|
|
|
236 |
"mixed_lighting": "A scene with combined natural and artificial light sources creating transition zones.",
|
237 |
"beach_daylight": "A photo taken at a beach with bright natural sunlight and reflections from water.",
|
238 |
"sports_arena_lighting": "A photo of a sports venue illuminated by powerful overhead lighting systems.",
|
239 |
+
"kitchen_task_lighting": "A photo of a kitchen with focused lighting concentrated on work surfaces.",
|
240 |
+
"photography_studio_lighting": "A photo taken in a photography studio with controlled professional lighting and even illumination.",
|
241 |
+
"retail_display_lighting": "A photo taken in retail environment with strategic product lighting and commercial illumination design.",
|
242 |
+
"conference_room_lighting": "A photo taken in a conference room with balanced meeting lighting and presentation-friendly illumination.",
|
243 |
+
"golden_hour_outdoor": "A photo taken during golden hour with warm, low-angle sunlight creating dramatic shadows and highlights.",
|
244 |
+
"overcast_diffused_light": "A photo taken under overcast sky with soft, even diffused lighting and minimal shadows.",
|
245 |
+
"harsh_midday_sun": "A photo taken under intense midday sunlight with strong contrasts and sharp shadows.",
|
246 |
+
"office_mixed_lighting": "A photo taken in office environment combining natural window light with artificial ceiling illumination.",
|
247 |
+
"restaurant_ambient_lighting": "A photo taken in restaurant with carefully designed ambient lighting combining multiple warm light sources.",
|
248 |
+
"retail_accent_lighting": "A photo taken in retail space with accent lighting highlighting products against general ambient illumination."
|
249 |
}
|
250 |
|
251 |
# 針對新場景類型的特殊提示
|
|
|
303 |
"A high-angle view of an intersection showing traffic and pedestrian flow patterns.",
|
304 |
"A drone perspective of urban crossing design viewed from directly above.",
|
305 |
"A vertical view of a street intersection showing crossing infrastructure."
|
306 |
+
],
|
307 |
+
"medical_waiting_room": [
|
308 |
+
"A medical facility waiting area with comfortable seating, health information displays and patient-focused design.",
|
309 |
+
"A healthcare waiting space with sanitized surfaces, medical equipment visibility and clinical atmosphere.",
|
310 |
+
"A medical office reception area with appointment scheduling setup and healthcare service information."
|
311 |
+
],
|
312 |
+
|
313 |
+
"science_laboratory": [
|
314 |
+
"A science laboratory with experimental equipment, safety features and research workstations.",
|
315 |
+
"A chemistry lab with fume hoods, lab benches and scientific instrument arrangements.",
|
316 |
+
"A biology laboratory with microscopes, specimen storage and life science research setup."
|
317 |
+
],
|
318 |
+
|
319 |
+
"design_studio_workspace": [
|
320 |
+
"A design studio with creative tools, inspiration boards and artistic project development areas.",
|
321 |
+
"An architecture office with drafting tables, model displays and design development workspaces.",
|
322 |
+
"A graphic design workspace with computer workstations, color calibration tools and creative project areas."
|
323 |
+
],
|
324 |
+
|
325 |
+
"maintenance_workshop": [
|
326 |
+
"A maintenance workshop with repair tools, work benches and technical service equipment.",
|
327 |
+
"A mechanical service area with diagnostic equipment, repair stations and automotive maintenance setup.",
|
328 |
+
"A technical workshop with specialized tools, parts storage and equipment maintenance facilities."
|
329 |
]
|
330 |
}
|
331 |
|
|
|
337 |
"bird_eye": "A photo taken from very high above showing a complete overhead perspective.",
|
338 |
"street_level": "A photo taken from the perspective of someone standing on the street.",
|
339 |
"interior": "A photo taken from inside a building showing the internal environment.",
|
340 |
+
"vehicular": "A photo taken from inside or mounted on a moving vehicle.",
|
341 |
+
|
342 |
+
# 較詳細的視角
|
343 |
+
"security_camera_angle": "A photo taken from fixed security camera position showing surveillance perspective of the area.",
|
344 |
+
"drone_inspection_view": "A photo taken from drone perspective for inspection purposes showing detailed overhead examination angle.",
|
345 |
+
"architectural_documentation_view": "A photo taken specifically for architectural documentation showing building features and structural details.",
|
346 |
+
"customer_entering_view": "A photo taken from the perspective of a customer or visitor entering the space for the first time.",
|
347 |
+
"worker_daily_perspective": "A photo taken from the viewpoint of someone who works in this environment on a daily basis.",
|
348 |
+
"maintenance_access_view": "A photo taken from the perspective needed for maintenance or service access to equipment and facilities."
|
349 |
}
|
350 |
|
351 |
OBJECT_COMBINATION_PROMPTS = {
|
|
|
356 |
"retail_environment": "A scene with merchandise displays, shoppers, and store fixtures.",
|
357 |
"crosswalk_scene": "A scene with street markings, pedestrians crossing, and traffic signals.",
|
358 |
"cooking_area": "A scene with stoves, prep surfaces, cooking utensils, and food items.",
|
359 |
+
"recreational_space": "A scene with sports equipment, play areas, and activity participants.",
|
360 |
+
"medical_examination_setup": "A scene with medical examination table, diagnostic equipment, and healthcare monitoring devices.",
|
361 |
+
"laboratory_research_arrangement": "A scene with scientific instruments, sample containers, and research documentation materials.",
|
362 |
+
"technical_repair_station": "A scene with diagnostic tools, replacement parts, and mechanical repair equipment.",
|
363 |
+
"art_creation_workspace": "A scene with artistic supplies, canvases, and creative project materials arranged for art making.",
|
364 |
+
"music_practice_setup": "A scene with musical instruments, sheet music, and sound equipment for music practice.",
|
365 |
+
"craft_workshop_arrangement": "A scene with crafting tools, materials, and project supplies organized for handmade creation.",
|
366 |
+
"language_learning_environment": "A scene with language learning materials, reference books, and communication practice tools.",
|
367 |
+
"science_experiment_setup": "A scene with scientific apparatus, measurement tools, and experimental materials for hands-on learning."
|
368 |
}
|
369 |
|
370 |
ACTIVITY_PROMPTS = {
|
|
|
375 |
"exercising": "People engaged in physical activities, using sports equipment, and training.",
|
376 |
"cooking": "People preparing food, using kitchen equipment, and creating meals.",
|
377 |
"crossing_street": "People walking across designated crosswalks and navigating intersections.",
|
378 |
+
"recreational_activity": "People engaged in leisure activities, games, and social recreation.",
|
379 |
+
"consulting": "People engaged in professional consultation with documents, presentations, and advisory discussions.",
|
380 |
+
"training": "People participating in skill development training with instructional materials and practice exercises.",
|
381 |
+
"maintenance": "People performing maintenance tasks with technical equipment and repair procedures.",
|
382 |
+
"brainstorming": "People engaged in creative brainstorming with idea development tools and collaborative thinking.",
|
383 |
+
"designing": "People working on design projects with creative tools, sketches, and visual development materials.",
|
384 |
+
"prototyping": "People building and testing prototypes with development materials and experimental approaches.",
|
385 |
+
"researching": "People conducting research with reference materials, databases, and investigative methods.",
|
386 |
+
"experimenting": "People performing experiments with scientific equipment and systematic testing procedures.",
|
387 |
+
"practicing": "People engaged in skill practice with repetitive exercises and performance improvement activities."
|
388 |
}
|
clip_zero_shot_classifier.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
|
2 |
import torch
|
3 |
-
import
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
import logging
|
@@ -21,18 +21,18 @@ class CLIPZeroShotClassifier:
|
|
21 |
這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
|
22 |
"""
|
23 |
|
24 |
-
def __init__(self, model_name: str = "ViT-B
|
25 |
"""
|
26 |
初始化CLIP零樣本分類器
|
27 |
|
28 |
Args:
|
29 |
-
model_name:
|
30 |
device: 運行設備,None則自動選擇
|
31 |
"""
|
32 |
self.logger = logging.getLogger(__name__)
|
33 |
|
34 |
# 初始化各個組件
|
35 |
-
self.clip_model_manager = CLIPModelManager(model_name, device)
|
36 |
self.landmark_data_manager = LandmarkDataManager()
|
37 |
self.image_analyzer = ImageAnalyzer()
|
38 |
self.confidence_manager = ConfidenceManager()
|
|
|
1 |
|
2 |
import torch
|
3 |
+
import open_clip
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
import logging
|
|
|
21 |
這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
|
22 |
"""
|
23 |
|
24 |
+
def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
|
25 |
"""
|
26 |
初始化CLIP零樣本分類器
|
27 |
|
28 |
Args:
|
29 |
+
model_name: OpenCLIP模型名稱,默認為"ViT-B-16"
|
30 |
device: 運行設備,None則自動選擇
|
31 |
"""
|
32 |
self.logger = logging.getLogger(__name__)
|
33 |
|
34 |
# 初始化各個組件
|
35 |
+
self.clip_model_manager = CLIPModelManager(model_name, device, pretrained)
|
36 |
self.landmark_data_manager = LandmarkDataManager()
|
37 |
self.image_analyzer = ImageAnalyzer()
|
38 |
self.confidence_manager = ConfidenceManager()
|
llm_enhancer.py
CHANGED
@@ -3,7 +3,7 @@ import traceback
|
|
3 |
import re
|
4 |
from typing import Dict, List, Any, Optional
|
5 |
|
6 |
-
from
|
7 |
from prompt_template_manager import PromptTemplateManager
|
8 |
from response_processor import ResponseProcessor
|
9 |
from text_quality_validator import TextQualityValidator
|
@@ -44,7 +44,7 @@ class LLMEnhancer:
|
|
44 |
|
45 |
try:
|
46 |
# 初始化四個核心組件
|
47 |
-
self.model_manager =
|
48 |
model_path=model_path,
|
49 |
tokenizer_path=tokenizer_path,
|
50 |
device=device,
|
|
|
3 |
import re
|
4 |
from typing import Dict, List, Any, Optional
|
5 |
|
6 |
+
from llm_model_manager import LLMModelManager
|
7 |
from prompt_template_manager import PromptTemplateManager
|
8 |
from response_processor import ResponseProcessor
|
9 |
from text_quality_validator import TextQualityValidator
|
|
|
44 |
|
45 |
try:
|
46 |
# 初始化四個核心組件
|
47 |
+
self.model_manager = LLMModelManager(
|
48 |
model_path=model_path,
|
49 |
tokenizer_path=tokenizer_path,
|
50 |
device=device,
|
llm_model_manager.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
from typing import Dict, Optional, Any
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
6 |
+
from huggingface_hub import login
|
7 |
+
|
8 |
+
class ModelLoadingError(Exception):
|
9 |
+
"""Custom exception for model loading failures"""
|
10 |
+
pass
|
11 |
+
|
12 |
+
|
13 |
+
class ModelGenerationError(Exception):
|
14 |
+
"""Custom exception for model generation failures"""
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
class LLMModelManager:
|
19 |
+
"""
|
20 |
+
負責LLM模型的載入、設備管理和文本生成。
|
21 |
+
管理模型、記憶體優化和設備配置。
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
model_path: Optional[str] = None,
|
26 |
+
tokenizer_path: Optional[str] = None,
|
27 |
+
device: Optional[str] = None,
|
28 |
+
max_length: int = 2048,
|
29 |
+
temperature: float = 0.3,
|
30 |
+
top_p: float = 0.85):
|
31 |
+
"""
|
32 |
+
初始化模型管理器
|
33 |
+
|
34 |
+
Args:
|
35 |
+
model_path: LLM模型的路徑或HuggingFace模型名稱,默認使用Llama 3.2
|
36 |
+
tokenizer_path: tokenizer的路徑,通常與model_path相同
|
37 |
+
device: 運行設備 ('cpu'或'cuda'),None時自動檢測
|
38 |
+
max_length: 輸入文本的最大長度
|
39 |
+
temperature: 生成文本的溫度參數
|
40 |
+
top_p: 生成文本時的核心採樣機率閾值
|
41 |
+
"""
|
42 |
+
# 設置專屬logger
|
43 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
44 |
+
if not self.logger.handlers:
|
45 |
+
handler = logging.StreamHandler()
|
46 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
47 |
+
handler.setFormatter(formatter)
|
48 |
+
self.logger.addHandler(handler)
|
49 |
+
self.logger.setLevel(logging.INFO)
|
50 |
+
|
51 |
+
# 模型配置
|
52 |
+
self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
|
53 |
+
self.tokenizer_path = tokenizer_path or self.model_path
|
54 |
+
|
55 |
+
# 設備管理
|
56 |
+
self.device = self._detect_device(device)
|
57 |
+
self.logger.info(f"Device selected: {self.device}")
|
58 |
+
|
59 |
+
# 生成參數
|
60 |
+
self.max_length = max_length
|
61 |
+
self.temperature = temperature
|
62 |
+
self.top_p = top_p
|
63 |
+
|
64 |
+
# 模型狀態
|
65 |
+
self.model = None
|
66 |
+
self.tokenizer = None
|
67 |
+
self._model_loaded = False
|
68 |
+
self.call_count = 0
|
69 |
+
|
70 |
+
# HuggingFace認證
|
71 |
+
self.hf_token = self._setup_huggingface_auth()
|
72 |
+
|
73 |
+
def _detect_device(self, device: Optional[str]) -> str:
|
74 |
+
"""
|
75 |
+
檢測並設置運行設備
|
76 |
+
|
77 |
+
Args:
|
78 |
+
device: 用戶指定的設備,None時自動檢測
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
str: ('cuda' or 'cpu')
|
82 |
+
"""
|
83 |
+
if device:
|
84 |
+
if device == 'cuda' and not torch.cuda.is_available():
|
85 |
+
self.logger.warning("CUDA requested but not available, falling back to CPU")
|
86 |
+
return 'cpu'
|
87 |
+
return device
|
88 |
+
|
89 |
+
detected_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
90 |
+
|
91 |
+
if detected_device == 'cuda':
|
92 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
93 |
+
self.logger.info(f"CUDA detected with {gpu_memory:.2f} GB GPU memory")
|
94 |
+
|
95 |
+
return detected_device
|
96 |
+
|
97 |
+
def _setup_huggingface_auth(self) -> Optional[str]:
|
98 |
+
"""
|
99 |
+
設置HuggingFace認證
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
Optional[str]: HuggingFace token,如果可用
|
103 |
+
"""
|
104 |
+
hf_token = os.environ.get("HF_TOKEN")
|
105 |
+
|
106 |
+
if hf_token:
|
107 |
+
try:
|
108 |
+
login(token=hf_token)
|
109 |
+
self.logger.info("Successfully authenticated with HuggingFace")
|
110 |
+
return hf_token
|
111 |
+
except Exception as e:
|
112 |
+
self.logger.error(f"HuggingFace authentication failed: {e}")
|
113 |
+
return None
|
114 |
+
else:
|
115 |
+
self.logger.warning("HF_TOKEN not found. Access to gated models may be limited")
|
116 |
+
return None
|
117 |
+
|
118 |
+
def _load_model(self):
|
119 |
+
"""
|
120 |
+
載入LLM模型和tokenizer,使用8位量化以節省記憶體
|
121 |
+
|
122 |
+
Raises:
|
123 |
+
ModelLoadingError: 當模型載入失敗時
|
124 |
+
"""
|
125 |
+
if self._model_loaded:
|
126 |
+
return
|
127 |
+
|
128 |
+
try:
|
129 |
+
self.logger.info(f"Loading model from {self.model_path} with 8-bit quantization")
|
130 |
+
|
131 |
+
# 清理GPU記憶體
|
132 |
+
self._clear_gpu_cache()
|
133 |
+
|
134 |
+
# 設置8位量化配置
|
135 |
+
quantization_config = BitsAndBytesConfig(
|
136 |
+
load_in_8bit=True,
|
137 |
+
llm_int8_enable_fp32_cpu_offload=True
|
138 |
+
)
|
139 |
+
|
140 |
+
# 載入tokenizer
|
141 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
142 |
+
self.tokenizer_path,
|
143 |
+
padding_side="left",
|
144 |
+
use_fast=False,
|
145 |
+
token=self.hf_token
|
146 |
+
)
|
147 |
+
|
148 |
+
# 設置特殊標記
|
149 |
+
if self.tokenizer.pad_token is None:
|
150 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
151 |
+
|
152 |
+
# 載入模型
|
153 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
154 |
+
self.model_path,
|
155 |
+
quantization_config=quantization_config,
|
156 |
+
device_map="auto",
|
157 |
+
low_cpu_mem_usage=True,
|
158 |
+
token=self.hf_token
|
159 |
+
)
|
160 |
+
|
161 |
+
self._model_loaded = True
|
162 |
+
self.logger.info("Model loaded successfully")
|
163 |
+
|
164 |
+
except Exception as e:
|
165 |
+
error_msg = f"Failed to load model: {str(e)}"
|
166 |
+
self.logger.error(error_msg)
|
167 |
+
raise ModelLoadingError(error_msg) from e
|
168 |
+
|
169 |
+
def _clear_gpu_cache(self):
|
170 |
+
"""清理GPU記憶體緩存"""
|
171 |
+
if torch.cuda.is_available():
|
172 |
+
torch.cuda.empty_cache()
|
173 |
+
self.logger.debug("GPU cache cleared")
|
174 |
+
|
175 |
+
def generate_response(self, prompt: str, **generation_kwargs) -> str:
|
176 |
+
"""
|
177 |
+
生成LLM回應
|
178 |
+
|
179 |
+
Args:
|
180 |
+
prompt: 輸入提示詞
|
181 |
+
**generation_kwargs: 額外的生成參數,可覆蓋預設值
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
str: 生成的回應文本
|
185 |
+
|
186 |
+
Raises:
|
187 |
+
ModelGenerationError: 當生成失敗時
|
188 |
+
"""
|
189 |
+
# 確保模型已載入
|
190 |
+
if not self._model_loaded:
|
191 |
+
self._load_model()
|
192 |
+
|
193 |
+
try:
|
194 |
+
self.call_count += 1
|
195 |
+
self.logger.info(f"Generating response (call #{self.call_count})")
|
196 |
+
|
197 |
+
# clean GPU
|
198 |
+
self._clear_gpu_cache()
|
199 |
+
|
200 |
+
# 設置固定種子以提高一致性
|
201 |
+
torch.manual_seed(42)
|
202 |
+
|
203 |
+
# prepare input
|
204 |
+
inputs = self.tokenizer(
|
205 |
+
prompt,
|
206 |
+
return_tensors="pt",
|
207 |
+
truncation=True,
|
208 |
+
max_length=self.max_length
|
209 |
+
).to(self.device)
|
210 |
+
|
211 |
+
# 準備生成參數
|
212 |
+
generation_params = self._prepare_generation_params(**generation_kwargs)
|
213 |
+
generation_params.update({
|
214 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
215 |
+
"attention_mask": inputs.attention_mask,
|
216 |
+
"use_cache": True,
|
217 |
+
})
|
218 |
+
|
219 |
+
# resposne
|
220 |
+
with torch.no_grad():
|
221 |
+
outputs = self.model.generate(inputs.input_ids, **generation_params)
|
222 |
+
|
223 |
+
# 解碼回應
|
224 |
+
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
225 |
+
response = self._extract_generated_response(full_response, prompt)
|
226 |
+
|
227 |
+
if not response or len(response.strip()) < 10:
|
228 |
+
raise ModelGenerationError("Generated response is too short or empty")
|
229 |
+
|
230 |
+
self.logger.info(f"Response generated successfully ({len(response)} characters)")
|
231 |
+
return response
|
232 |
+
|
233 |
+
except Exception as e:
|
234 |
+
error_msg = f"Text generation failed: {str(e)}"
|
235 |
+
self.logger.error(error_msg)
|
236 |
+
raise ModelGenerationError(error_msg) from e
|
237 |
+
|
238 |
+
def _prepare_generation_params(self, **kwargs) -> Dict[str, Any]:
|
239 |
+
"""
|
240 |
+
準備生成參數,支援模型特定的優化
|
241 |
+
|
242 |
+
Args:
|
243 |
+
**kwargs: 用戶提供的生成參數
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
Dict[str, Any]: 完整的生成參數配置
|
247 |
+
"""
|
248 |
+
# basic parameters
|
249 |
+
params = {
|
250 |
+
"max_new_tokens": 120,
|
251 |
+
"temperature": self.temperature,
|
252 |
+
"top_p": self.top_p,
|
253 |
+
"do_sample": True,
|
254 |
+
}
|
255 |
+
|
256 |
+
# 針對Llama模型的特殊優化
|
257 |
+
if "llama" in self.model_path.lower():
|
258 |
+
params.update({
|
259 |
+
"max_new_tokens": 600,
|
260 |
+
"temperature": 0.35, # not too big
|
261 |
+
"top_p": 0.75,
|
262 |
+
"repetition_penalty": 1.5,
|
263 |
+
"num_beams": 5,
|
264 |
+
"length_penalty": 1,
|
265 |
+
"no_repeat_ngram_size": 3
|
266 |
+
})
|
267 |
+
else:
|
268 |
+
params.update({
|
269 |
+
"max_new_tokens": 300,
|
270 |
+
"temperature": 0.6,
|
271 |
+
"top_p": 0.9,
|
272 |
+
"num_beams": 1,
|
273 |
+
"repetition_penalty": 1.05
|
274 |
+
})
|
275 |
+
|
276 |
+
# 用戶參數覆蓋預設值
|
277 |
+
params.update(kwargs)
|
278 |
+
|
279 |
+
return params
|
280 |
+
|
281 |
+
def _extract_generated_response(self, full_response: str, prompt: str) -> str:
|
282 |
+
"""
|
283 |
+
從完整回應中提取生成的部分
|
284 |
+
|
285 |
+
Args:
|
286 |
+
full_response: 模型的完整輸出
|
287 |
+
prompt: 原始提示詞
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
str: 提取的生成回應
|
291 |
+
"""
|
292 |
+
# 尋找assistant標記
|
293 |
+
assistant_tag = "<|assistant|>"
|
294 |
+
if assistant_tag in full_response:
|
295 |
+
response = full_response.split(assistant_tag)[-1].strip()
|
296 |
+
|
297 |
+
# 檢查是否有未閉合的user標記
|
298 |
+
user_tag = "<|user|>"
|
299 |
+
if user_tag in response:
|
300 |
+
response = response.split(user_tag)[0].strip()
|
301 |
+
|
302 |
+
return response
|
303 |
+
|
304 |
+
# 移除輸入提示詞
|
305 |
+
if full_response.startswith(prompt):
|
306 |
+
return full_response[len(prompt):].strip()
|
307 |
+
|
308 |
+
return full_response.strip()
|
309 |
+
|
310 |
+
def reset_context(self):
|
311 |
+
"""重置模型上下文,清理GPU緩存"""
|
312 |
+
if self._model_loaded:
|
313 |
+
self._clear_gpu_cache()
|
314 |
+
self.logger.info("Model context reset")
|
315 |
+
else:
|
316 |
+
self.logger.info("Model not loaded, no context to reset")
|
317 |
+
|
318 |
+
def get_current_device(self) -> str:
|
319 |
+
"""
|
320 |
+
獲取當前運行設備
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
str: 當前設備名稱
|
324 |
+
"""
|
325 |
+
return self.device
|
326 |
+
|
327 |
+
def is_model_loaded(self) -> bool:
|
328 |
+
"""
|
329 |
+
檢查模型是否已載入
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
bool: 模型載入狀態
|
333 |
+
"""
|
334 |
+
return self._model_loaded
|
335 |
+
|
336 |
+
def get_call_count(self) -> int:
|
337 |
+
"""
|
338 |
+
獲取模型調用次數
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
int: 調用次數
|
342 |
+
"""
|
343 |
+
return self.call_count
|
344 |
+
|
345 |
+
def get_model_info(self) -> Dict[str, Any]:
|
346 |
+
"""
|
347 |
+
獲取模型信息
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
Dict[str, Any]: 包含模型路徑、設備、載入狀態等信息
|
351 |
+
"""
|
352 |
+
return {
|
353 |
+
"model_path": self.model_path,
|
354 |
+
"device": self.device,
|
355 |
+
"is_loaded": self._model_loaded,
|
356 |
+
"call_count": self.call_count,
|
357 |
+
"has_hf_token": self.hf_token is not None
|
358 |
+
}
|
requirements.txt
CHANGED
@@ -6,7 +6,7 @@ pillow>=9.4.0
|
|
6 |
numpy>=1.23.5
|
7 |
matplotlib>=3.7.0
|
8 |
gradio>=3.32.0
|
9 |
-
|
10 |
yt-dlp>=2023.3.4
|
11 |
requests>=2.28.1
|
12 |
transformers
|
@@ -14,4 +14,4 @@ accelerate
|
|
14 |
bitsandbytes
|
15 |
sentencepiece
|
16 |
huggingface_hub>=0.19.0
|
17 |
-
urllib3>=1.26.0
|
|
|
6 |
numpy>=1.23.5
|
7 |
matplotlib>=3.7.0
|
8 |
gradio>=3.32.0
|
9 |
+
open-clip-torch>=2.20.0
|
10 |
yt-dlp>=2023.3.4
|
11 |
requests>=2.28.1
|
12 |
transformers
|
|
|
14 |
bitsandbytes
|
15 |
sentencepiece
|
16 |
huggingface_hub>=0.19.0
|
17 |
+
urllib3>=1.26.0
|
scene_scoring_engine.py
CHANGED
@@ -249,13 +249,13 @@ class SceneScoringEngine:
|
|
249 |
Returns:
|
250 |
(最佳場景類型, 置信度) 的元組
|
251 |
"""
|
|
|
252 |
if not scene_scores:
|
253 |
return "unknown", 0.0
|
254 |
|
255 |
-
# 檢查地標相關分數是否達到門檻,如果是,直接回傳 "tourist_landmark"
|
256 |
# 假設場景分數 dictionary 中,"tourist_landmark"、"historical_monument"、"natural_landmark" 三個 key
|
257 |
# 分別代表不同類型地標。將它們加總,若總分超過 0.3,就認定為地標場景。
|
258 |
-
# print(f"DEBUG: determine_scene_type input scores: {scene_scores}")
|
259 |
landmark_score = (
|
260 |
scene_scores.get("tourist_landmark", 0.0) +
|
261 |
scene_scores.get("historical_monument", 0.0) +
|
@@ -268,7 +268,7 @@ class SceneScoringEngine:
|
|
268 |
# 找分數最高的那個場景
|
269 |
best_scene = max(scene_scores, key=scene_scores.get)
|
270 |
best_score = scene_scores[best_scene]
|
271 |
-
|
272 |
return best_scene, float(best_score)
|
273 |
|
274 |
def fuse_scene_scores(self, yolo_scene_scores: Dict[str, float],
|
@@ -361,8 +361,9 @@ class SceneScoringEngine:
|
|
361 |
current_yolo_weight = default_yolo_weight
|
362 |
current_clip_weight = default_clip_weight
|
363 |
current_places365_weight = default_places365_weight
|
364 |
-
|
365 |
-
|
|
|
366 |
|
367 |
scene_definition = self.scene_types.get(scene_type, {})
|
368 |
|
@@ -394,8 +395,8 @@ class SceneScoringEngine:
|
|
394 |
"professional_kitchen", "cafe", "library", "gym", "retail_store",
|
395 |
"supermarket", "classroom", "conference_room", "medical_facility",
|
396 |
"educational_setting", "dining_area"]):
|
397 |
-
current_yolo_weight = 0.
|
398 |
-
current_clip_weight = 0.
|
399 |
current_places365_weight = 0.25
|
400 |
|
401 |
# 對於特定室外常見場景(非地標),物體仍然重要
|
@@ -491,7 +492,7 @@ class SceneScoringEngine:
|
|
491 |
fused_scores[scene_type] = min(1.0, max(0.0, fused_score))
|
492 |
|
493 |
return fused_scores
|
494 |
-
|
495 |
|
496 |
def update_enable_landmark_status(self, enable_landmark: bool):
|
497 |
"""
|
|
|
249 |
Returns:
|
250 |
(最佳場景類型, 置信度) 的元組
|
251 |
"""
|
252 |
+
print(f"DEBUG: determine_scene_type input scores: {scene_scores}")
|
253 |
if not scene_scores:
|
254 |
return "unknown", 0.0
|
255 |
|
256 |
+
# 檢查地標相關分數是否達到門檻,如果是,直接回傳 "tourist_landmark"
|
257 |
# 假設場景分數 dictionary 中,"tourist_landmark"、"historical_monument"、"natural_landmark" 三個 key
|
258 |
# 分別代表不同類型地標。將它們加總,若總分超過 0.3,就認定為地標場景。
|
|
|
259 |
landmark_score = (
|
260 |
scene_scores.get("tourist_landmark", 0.0) +
|
261 |
scene_scores.get("historical_monument", 0.0) +
|
|
|
268 |
# 找分數最高的那個場景
|
269 |
best_scene = max(scene_scores, key=scene_scores.get)
|
270 |
best_score = scene_scores[best_scene]
|
271 |
+
print(f"DEBUG: determine_scene_type result: scene={best_scene}, score={best_score}")
|
272 |
return best_scene, float(best_score)
|
273 |
|
274 |
def fuse_scene_scores(self, yolo_scene_scores: Dict[str, float],
|
|
|
361 |
current_yolo_weight = default_yolo_weight
|
362 |
current_clip_weight = default_clip_weight
|
363 |
current_places365_weight = default_places365_weight
|
364 |
+
print(f"DEBUG: Scene {scene_type} - yolo_score: {yolo_score}, clip_score: {clip_score}, places365_score: {places365_score}")
|
365 |
+
print(f"DEBUG: Scene {scene_type} - weights: yolo={current_yolo_weight:.3f}, clip={current_clip_weight:.3f}, places365={current_places365_weight:.3f}")
|
366 |
+
|
367 |
|
368 |
scene_definition = self.scene_types.get(scene_type, {})
|
369 |
|
|
|
395 |
"professional_kitchen", "cafe", "library", "gym", "retail_store",
|
396 |
"supermarket", "classroom", "conference_room", "medical_facility",
|
397 |
"educational_setting", "dining_area"]):
|
398 |
+
current_yolo_weight = 0.50
|
399 |
+
current_clip_weight = 0.25
|
400 |
current_places365_weight = 0.25
|
401 |
|
402 |
# 對於特定室外常見場景(非地標),物體仍然重要
|
|
|
492 |
fused_scores[scene_type] = min(1.0, max(0.0, fused_score))
|
493 |
|
494 |
return fused_scores
|
495 |
+
print(f"DEBUG: fuse_scene_scores final result: {fused_scores}")
|
496 |
|
497 |
def update_enable_landmark_status(self, enable_landmark: bool):
|
498 |
"""
|