File size: 13,915 Bytes
01d337c
91e463e
01d337c
 
d217fb0
01d337c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d217fb0
01d337c
d217fb0
 
 
 
 
 
 
 
 
 
 
 
 
 
01d337c
 
 
 
 
 
 
 
 
d217fb0
01d337c
 
 
 
 
 
 
 
 
d217fb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d337c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d217fb0
01d337c
 
 
 
d217fb0
 
 
 
 
01d337c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d217fb0
01d337c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6be6bee
 
 
 
01d337c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6be6bee
01d337c
 
 
 
 
6be6bee
 
 
 
01d337c
 
6be6bee
 
 
01d337c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6be6bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d337c
6be6bee
 
 
 
 
 
 
01d337c
 
6be6bee
 
 
 
 
 
 
 
 
01d337c
6be6bee
01d337c
 
 
 
 
d217fb0
01d337c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d217fb0
 
01d337c
d217fb0
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
import os
import re
import torch
import logging
import threading
from typing import Dict, Optional, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login

class ModelLoadingError(Exception):
    """Custom exception for model loading failures"""
    pass


class ModelGenerationError(Exception):
    """Custom exception for model generation failures"""
    pass


class LLMModelManager:
    """
    負責LLM模型的載入、設備管理和文本生成。
    管理模型、記憶體優化和設備配置。
    實現單例模式確保全應用程式只有一個模型載入方式。
    """
    
    _instance = None
    _initialized = False
    _lock = threading.Lock()

    def __new__(cls, *args, **kwargs):
        """
        單例模式實現:確保整個應用程式只創建一個 LLMModelManager 
        """
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super(LLMModelManager, cls).__new__(cls)
        return cls._instance

    def __init__(self,
                 model_path: Optional[str] = None,
                 tokenizer_path: Optional[str] = None,
                 device: Optional[str] = None,
                 max_length: int = 2048,
                 temperature: float = 0.3,
                 top_p: float = 0.85):
        """
        初始化模型管理器(只在第一次創建實例時執行)

        Args:
            model_path: LLM模型的路徑或HuggingFace模型名稱,默認使用Llama 3.2
            tokenizer_path: tokenizer的路徑,通常與model_path相同
            device: 運行設備 ('cpu'或'cuda'),None時自動檢測
            max_length: 輸入文本的最大長度
            temperature: 生成文本的溫度參數
            top_p: 生成文本時的核心採樣機率閾值
        """
        # 避免重複初始化
        if self._initialized:
            return
            
        with self._lock:
            if self._initialized:
                return
                
            # set logger
            self.logger = logging.getLogger(self.__class__.__name__)
            if not self.logger.handlers:
                handler = logging.StreamHandler()
                formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
                handler.setFormatter(formatter)
                self.logger.addHandler(handler)
                self.logger.setLevel(logging.INFO)

            # model config
            self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
            self.tokenizer_path = tokenizer_path or self.model_path

            # device management
            self.device = self._detect_device(device)
            self.logger.info(f"Device selected: {self.device}")

            # 生成參數
            self.max_length = max_length
            self.temperature = temperature
            self.top_p = top_p

            # 模型狀態
            self.model = None
            self.tokenizer = None
            self._model_loaded = False
            self.call_count = 0

            # HuggingFace認證
            self.hf_token = self._setup_huggingface_auth()
            
            # 標記為已初始化
            self._initialized = True
            self.logger.info("LLMModelManager singleton initialized")

    def _detect_device(self, device: Optional[str]) -> str:
        """
        檢測並設置運行設備

        Args:
            device: 用戶指定的設備,None時自動檢測

        Returns:
            str: ('cuda' or 'cpu')
        """
        if device:
            if device == 'cuda' and not torch.cuda.is_available():
                self.logger.warning("CUDA requested but not available, falling back to CPU")
                return 'cpu'
            return device

        detected_device = 'cuda' if torch.cuda.is_available() else 'cpu'

        if detected_device == 'cuda':
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
            self.logger.info(f"CUDA detected with {gpu_memory:.2f} GB GPU memory")

        return detected_device

    def _setup_huggingface_auth(self) -> Optional[str]:
        """
        設置HuggingFace認證

        Returns:
            Optional[str]: HuggingFace token,如果可用
        """
        hf_token = os.environ.get("HF_TOKEN")

        if hf_token:
            try:
                login(token=hf_token)
                self.logger.info("Successfully authenticated with HuggingFace")
                return hf_token
            except Exception as e:
                self.logger.error(f"HuggingFace authentication failed: {e}")
                return None
        else:
            self.logger.warning("HF_TOKEN not found. Access to gated models may be limited")
            return None

    def _load_model(self):
        """
        載入LLM模型和tokenizer,使用8位量化以節省記憶體
        增強的狀態檢查確保模型只載入一次

        Raises:
            ModelLoadingError: 當模型載入失敗時
        """
        # 完整的模型狀態檢查
        if (self._model_loaded and 
            hasattr(self, 'model') and self.model is not None and
            hasattr(self, 'tokenizer') and self.tokenizer is not None):
            self.logger.info("Model already loaded, skipping reload")
            return

        try:
            self.logger.info(f"Loading model from {self.model_path} with 8-bit quantization")

            # 清理GPU記憶體
            self._clear_gpu_cache()

            # 設置8位量化配置
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_enable_fp32_cpu_offload=True
            )

            # 載入tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.tokenizer_path,
                padding_side="left",
                use_fast=False,
                token=self.hf_token
            )

            # 設置特殊標記
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # 載入模型
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                quantization_config=quantization_config,
                device_map="auto",
                low_cpu_mem_usage=True,
                token=self.hf_token
            )

            self._model_loaded = True
            self.logger.info("Model loaded successfully (singleton instance)")

        except Exception as e:
            error_msg = f"Failed to load model: {str(e)}"
            self.logger.error(error_msg)
            raise ModelLoadingError(error_msg) from e

    def _clear_gpu_cache(self):
        """清理GPU記憶體緩存"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            self.logger.debug("GPU cache cleared")

    def generate_response(self, prompt: str, **generation_kwargs) -> str:
        # 確保模型已載入
        if not self._model_loaded:
            self._load_model()

        try:
            self.call_count += 1
            self.logger.info(f"Generating response (call #{self.call_count})")

            # # record input prompt
            # self.logger.info(f"DEBUG: Input prompt length: {len(prompt)}")
            # self.logger.info(f"DEBUG: Input prompt preview: {prompt[:200]}...")

            # clean GPU
            self._clear_gpu_cache()

            # 設置固定種子以提高一致性
            torch.manual_seed(42)

            # prepare input
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=self.max_length
            ).to(self.device)

            # 準備生成參數
            generation_params = self._prepare_generation_params(**generation_kwargs)
            generation_params.update({
                "pad_token_id": self.tokenizer.eos_token_id,
                "attention_mask": inputs.attention_mask,
                "use_cache": True,
            })

            # response
            with torch.no_grad():
                outputs = self.model.generate(inputs.input_ids, **generation_params)

            # 解碼回應
            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # # record whole response
            # self.logger.info(f"DEBUG: Full LLM response: {full_response}")

            response = self._extract_generated_response(full_response, prompt)

            # # 記錄提取後的回應
            # self.logger.info(f"DEBUG: Extracted response: {response}")

            if not response or len(response.strip()) < 10:
                raise ModelGenerationError("Generated response is too short or empty")

            self.logger.info(f"Response generated successfully ({len(response)} characters)")
            return response

        except Exception as e:
            error_msg = f"Text generation failed: {str(e)}"
            self.logger.error(error_msg)
            raise ModelGenerationError(error_msg) from e

    def _prepare_generation_params(self, **kwargs) -> Dict[str, Any]:
        """
        準備生成參數,支援模型特定的優化

        Args:
            **kwargs: 用戶提供的生成參數

        Returns:
            Dict[str, Any]: 完整的生成參數配置
        """
        # basic parameters
        params = {
            "max_new_tokens": 120,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "do_sample": True,
        }

        # 針對Llama模型的特殊優化
        if "llama" in self.model_path.lower():
            params.update({
                "max_new_tokens": 600,
                "temperature": 0.35, # not too big
                "top_p": 0.75,
                "repetition_penalty": 1.5,
                "num_beams": 5,
                "length_penalty": 1,
                "no_repeat_ngram_size": 3
            })
        else:
            params.update({
                "max_new_tokens": 300,
                "temperature": 0.6,
                "top_p": 0.9,
                "num_beams": 1,
                "repetition_penalty": 1.05
            })

        # 用戶參數覆蓋預設值
        params.update(kwargs)

        return params

    def _extract_generated_response(self, full_response: str, prompt: str) -> str:
        """
        從完整回應中提取生成的部分
        """
        # 尋找assistant標記
        assistant_tag = "<|assistant|>"
        if assistant_tag in full_response:
            response = full_response.split(assistant_tag)[-1].strip()

            # 檢查是否有未閉合的user標記
            user_tag = "<|user|>"
            if user_tag in response:
                response = response.split(user_tag)[0].strip()
        else:
            # 移除輸入提示詞
            if full_response.startswith(prompt):
                response = full_response[len(prompt):].strip()
            else:
                response = full_response.strip()

        # 移除不自然的場景類型前綴
        response = self._remove_scene_type_prefixes(response)

        return response

    def _remove_scene_type_prefixes(self, response: str) -> str:
        """
        移除LLM生成回應中的場景類型前綴

        Args:
            response: 原始LLM回應

        Returns:
            str: 移除前綴後的回應
        """
        if not response:
            return response

        prefix_patterns = [r'^[A-Za-z]+\,\s*']

        # 應用清理模式
        for pattern in prefix_patterns:
            response = re.sub(pattern, '', response, flags=re.IGNORECASE)

        # 確保首字母大寫
        if response and response[0].islower():
            response = response[0].upper() + response[1:]

        return response.strip()

    def reset_context(self):
        """重置模型上下文,清理GPU緩存"""
        if self._model_loaded:
            self._clear_gpu_cache()
            self.logger.info("Model context reset (singleton instance)")
        else:
            self.logger.info("Model not loaded, no context to reset")

    def get_current_device(self) -> str:
        """
        獲取當前運行設備

        Returns:
            str: 當前設備名稱
        """
        return self.device

    def is_model_loaded(self) -> bool:
        """
        檢查模型是否已載入

        Returns:
            bool: 模型載入狀態
        """
        return self._model_loaded

    def get_call_count(self) -> int:
        """
        獲取模型調用次數

        Returns:
            int: 調用次數
        """
        return self.call_count

    def get_model_info(self) -> Dict[str, Any]:
        """
        獲取模型信息

        Returns:
            Dict[str, Any]: 包含模型路徑、設備、載入狀態等信息
        """
        return {
            "model_path": self.model_path,
            "device": self.device,
            "is_loaded": self._model_loaded,
            "call_count": self.call_count,
            "has_hf_token": self.hf_token is not None,
            "is_singleton": True
        }

    @classmethod
    def reset_singleton(cls):
        """
        重置單例實例(僅用於測試或應用程式重啟)
        注意:這會導致模型需要重新載入
        """
        with cls._lock:
            if cls._instance is not None:
                instance = cls._instance
                if hasattr(instance, 'logger'):
                    instance.logger.info("Resetting singleton instance")
                cls._instance = None
                cls._initialized = False