DawnC commited on
Commit
3f1e204
·
verified ·
1 Parent(s): 01d337c

Delete model_manager.py

Browse files
Files changed (1) hide show
  1. model_manager.py +0 -358
model_manager.py DELETED
@@ -1,358 +0,0 @@
1
- import os
2
- import torch
3
- import logging
4
- from typing import Dict, Optional, Any
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
- from huggingface_hub import login
7
-
8
- class ModelLoadingError(Exception):
9
- """Custom exception for model loading failures"""
10
- pass
11
-
12
-
13
- class ModelGenerationError(Exception):
14
- """Custom exception for model generation failures"""
15
- pass
16
-
17
-
18
- class ModelManager:
19
- """
20
- 負責LLM模型的載入、設備管理和文本生成。
21
- 管理模型、記憶體優化和設備配置。
22
- """
23
-
24
- def __init__(self,
25
- model_path: Optional[str] = None,
26
- tokenizer_path: Optional[str] = None,
27
- device: Optional[str] = None,
28
- max_length: int = 2048,
29
- temperature: float = 0.3,
30
- top_p: float = 0.85):
31
- """
32
- 初始化模型管理器
33
-
34
- Args:
35
- model_path: LLM模型的路徑或HuggingFace模型名稱,默認使用Llama 3.2
36
- tokenizer_path: tokenizer的路徑,通常與model_path相同
37
- device: 運行設備 ('cpu'或'cuda'),None時自動檢測
38
- max_length: 輸入文本的最大長度
39
- temperature: 生成文本的溫度參數
40
- top_p: 生成文本時的核心採樣機率閾值
41
- """
42
- # 設置專屬logger
43
- self.logger = logging.getLogger(self.__class__.__name__)
44
- if not self.logger.handlers:
45
- handler = logging.StreamHandler()
46
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
47
- handler.setFormatter(formatter)
48
- self.logger.addHandler(handler)
49
- self.logger.setLevel(logging.INFO)
50
-
51
- # 模型配置
52
- self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
53
- self.tokenizer_path = tokenizer_path or self.model_path
54
-
55
- # 設備管理
56
- self.device = self._detect_device(device)
57
- self.logger.info(f"Device selected: {self.device}")
58
-
59
- # 生成參數
60
- self.max_length = max_length
61
- self.temperature = temperature
62
- self.top_p = top_p
63
-
64
- # 模型狀態
65
- self.model = None
66
- self.tokenizer = None
67
- self._model_loaded = False
68
- self.call_count = 0
69
-
70
- # HuggingFace認證
71
- self.hf_token = self._setup_huggingface_auth()
72
-
73
- def _detect_device(self, device: Optional[str]) -> str:
74
- """
75
- 檢測並設置運行設備
76
-
77
- Args:
78
- device: 用戶指定的設備,None時自動檢測
79
-
80
- Returns:
81
- str: ('cuda' or 'cpu')
82
- """
83
- if device:
84
- if device == 'cuda' and not torch.cuda.is_available():
85
- self.logger.warning("CUDA requested but not available, falling back to CPU")
86
- return 'cpu'
87
- return device
88
-
89
- detected_device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
-
91
- if detected_device == 'cuda':
92
- gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
93
- self.logger.info(f"CUDA detected with {gpu_memory:.2f} GB GPU memory")
94
-
95
- return detected_device
96
-
97
- def _setup_huggingface_auth(self) -> Optional[str]:
98
- """
99
- 設置HuggingFace認證
100
-
101
- Returns:
102
- Optional[str]: HuggingFace token,如果可用
103
- """
104
- hf_token = os.environ.get("HF_TOKEN")
105
-
106
- if hf_token:
107
- try:
108
- login(token=hf_token)
109
- self.logger.info("Successfully authenticated with HuggingFace")
110
- return hf_token
111
- except Exception as e:
112
- self.logger.error(f"HuggingFace authentication failed: {e}")
113
- return None
114
- else:
115
- self.logger.warning("HF_TOKEN not found. Access to gated models may be limited")
116
- return None
117
-
118
- def _load_model(self):
119
- """
120
- 載入LLM模型和tokenizer,使用8位量化以節省記憶體
121
-
122
- Raises:
123
- ModelLoadingError: 當模型載入失敗時
124
- """
125
- if self._model_loaded:
126
- return
127
-
128
- try:
129
- self.logger.info(f"Loading model from {self.model_path} with 8-bit quantization")
130
-
131
- # 清理GPU記憶體
132
- self._clear_gpu_cache()
133
-
134
- # 設置8位量化配置
135
- quantization_config = BitsAndBytesConfig(
136
- load_in_8bit=True,
137
- llm_int8_enable_fp32_cpu_offload=True
138
- )
139
-
140
- # 載入tokenizer
141
- self.tokenizer = AutoTokenizer.from_pretrained(
142
- self.tokenizer_path,
143
- padding_side="left",
144
- use_fast=False,
145
- token=self.hf_token
146
- )
147
-
148
- # 設置特殊標記
149
- if self.tokenizer.pad_token is None:
150
- self.tokenizer.pad_token = self.tokenizer.eos_token
151
-
152
- # 載入模型
153
- self.model = AutoModelForCausalLM.from_pretrained(
154
- self.model_path,
155
- quantization_config=quantization_config,
156
- device_map="auto",
157
- low_cpu_mem_usage=True,
158
- token=self.hf_token
159
- )
160
-
161
- self._model_loaded = True
162
- self.logger.info("Model loaded successfully")
163
-
164
- except Exception as e:
165
- error_msg = f"Failed to load model: {str(e)}"
166
- self.logger.error(error_msg)
167
- raise ModelLoadingError(error_msg) from e
168
-
169
- def _clear_gpu_cache(self):
170
- """清理GPU記憶體緩存"""
171
- if torch.cuda.is_available():
172
- torch.cuda.empty_cache()
173
- self.logger.debug("GPU cache cleared")
174
-
175
- def generate_response(self, prompt: str, **generation_kwargs) -> str:
176
- """
177
- 生成LLM回應
178
-
179
- Args:
180
- prompt: 輸入提示詞
181
- **generation_kwargs: 額外的生成參數,可覆蓋預設值
182
-
183
- Returns:
184
- str: 生成的回應文本
185
-
186
- Raises:
187
- ModelGenerationError: 當生成失敗時
188
- """
189
- # 確保模型已載入
190
- if not self._model_loaded:
191
- self._load_model()
192
-
193
- try:
194
- self.call_count += 1
195
- self.logger.info(f"Generating response (call #{self.call_count})")
196
-
197
- # clean GPU
198
- self._clear_gpu_cache()
199
-
200
- # 設置固定種子以提高一致性
201
- torch.manual_seed(42)
202
-
203
- # prepare input
204
- inputs = self.tokenizer(
205
- prompt,
206
- return_tensors="pt",
207
- truncation=True,
208
- max_length=self.max_length
209
- ).to(self.device)
210
-
211
- # 準備生成參數
212
- generation_params = self._prepare_generation_params(**generation_kwargs)
213
- generation_params.update({
214
- "pad_token_id": self.tokenizer.eos_token_id,
215
- "attention_mask": inputs.attention_mask,
216
- "use_cache": True,
217
- })
218
-
219
- # resposne
220
- with torch.no_grad():
221
- outputs = self.model.generate(inputs.input_ids, **generation_params)
222
-
223
- # 解碼回應
224
- full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
225
- response = self._extract_generated_response(full_response, prompt)
226
-
227
- if not response or len(response.strip()) < 10:
228
- raise ModelGenerationError("Generated response is too short or empty")
229
-
230
- self.logger.info(f"Response generated successfully ({len(response)} characters)")
231
- return response
232
-
233
- except Exception as e:
234
- error_msg = f"Text generation failed: {str(e)}"
235
- self.logger.error(error_msg)
236
- raise ModelGenerationError(error_msg) from e
237
-
238
- def _prepare_generation_params(self, **kwargs) -> Dict[str, Any]:
239
- """
240
- 準備生成參數,支援模型特定的優化
241
-
242
- Args:
243
- **kwargs: 用戶提供的生成參數
244
-
245
- Returns:
246
- Dict[str, Any]: 完整的生成參數配置
247
- """
248
- # basic parameters
249
- params = {
250
- "max_new_tokens": 120,
251
- "temperature": self.temperature,
252
- "top_p": self.top_p,
253
- "do_sample": True,
254
- }
255
-
256
- # 針對Llama模型的特殊優化
257
- if "llama" in self.model_path.lower():
258
- params.update({
259
- "max_new_tokens": 600,
260
- "temperature": 0.35, # not too big
261
- "top_p": 0.75,
262
- "repetition_penalty": 1.5,
263
- "num_beams": 5,
264
- "length_penalty": 1,
265
- "no_repeat_ngram_size": 3
266
- })
267
- else:
268
- params.update({
269
- "max_new_tokens": 300,
270
- "temperature": 0.6,
271
- "top_p": 0.9,
272
- "num_beams": 1,
273
- "repetition_penalty": 1.05
274
- })
275
-
276
- # 用戶參數覆蓋預設值
277
- params.update(kwargs)
278
-
279
- return params
280
-
281
- def _extract_generated_response(self, full_response: str, prompt: str) -> str:
282
- """
283
- 從完整回應中提取生成的部分
284
-
285
- Args:
286
- full_response: 模型的完整輸出
287
- prompt: 原始提示詞
288
-
289
- Returns:
290
- str: 提取的生成回應
291
- """
292
- # 尋找assistant標記
293
- assistant_tag = "<|assistant|>"
294
- if assistant_tag in full_response:
295
- response = full_response.split(assistant_tag)[-1].strip()
296
-
297
- # 檢查是否有未閉合的user標記
298
- user_tag = "<|user|>"
299
- if user_tag in response:
300
- response = response.split(user_tag)[0].strip()
301
-
302
- return response
303
-
304
- # 移除輸入提示詞
305
- if full_response.startswith(prompt):
306
- return full_response[len(prompt):].strip()
307
-
308
- return full_response.strip()
309
-
310
- def reset_context(self):
311
- """重置模型上下文,清理GPU緩存"""
312
- if self._model_loaded:
313
- self._clear_gpu_cache()
314
- self.logger.info("Model context reset")
315
- else:
316
- self.logger.info("Model not loaded, no context to reset")
317
-
318
- def get_current_device(self) -> str:
319
- """
320
- 獲取當前運行設備
321
-
322
- Returns:
323
- str: 當前設備名稱
324
- """
325
- return self.device
326
-
327
- def is_model_loaded(self) -> bool:
328
- """
329
- 檢查模型是否已載入
330
-
331
- Returns:
332
- bool: 模型載入狀態
333
- """
334
- return self._model_loaded
335
-
336
- def get_call_count(self) -> int:
337
- """
338
- 獲取模型調用次數
339
-
340
- Returns:
341
- int: 調用次數
342
- """
343
- return self.call_count
344
-
345
- def get_model_info(self) -> Dict[str, Any]:
346
- """
347
- 獲取模型信息
348
-
349
- Returns:
350
- Dict[str, Any]: 包含模型路徑、設備、載入狀態等信息
351
- """
352
- return {
353
- "model_path": self.model_path,
354
- "device": self.device,
355
- "is_loaded": self._model_loaded,
356
- "call_count": self.call_count,
357
- "has_hf_token": self.hf_token is not None
358
- }