VisionScout / llm_model_manager.py
DawnC's picture
Upload llm_model_manager.py
d217fb0 verified
import os
import re
import torch
import logging
import threading
from typing import Dict, Optional, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login
class ModelLoadingError(Exception):
"""Custom exception for model loading failures"""
pass
class ModelGenerationError(Exception):
"""Custom exception for model generation failures"""
pass
class LLMModelManager:
"""
負責LLM模型的載入、設備管理和文本生成。
管理模型、記憶體優化和設備配置。
實現單例模式確保全應用程式只有一個模型載入方式。
"""
_instance = None
_initialized = False
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
"""
單例模式實現:確保整個應用程式只創建一個 LLMModelManager
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(LLMModelManager, cls).__new__(cls)
return cls._instance
def __init__(self,
model_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
device: Optional[str] = None,
max_length: int = 2048,
temperature: float = 0.3,
top_p: float = 0.85):
"""
初始化模型管理器(只在第一次創建實例時執行)
Args:
model_path: LLM模型的路徑或HuggingFace模型名稱,默認使用Llama 3.2
tokenizer_path: tokenizer的路徑,通常與model_path相同
device: 運行設備 ('cpu'或'cuda'),None時自動檢測
max_length: 輸入文本的最大長度
temperature: 生成文本的溫度參數
top_p: 生成文本時的核心採樣機率閾值
"""
# 避免重複初始化
if self._initialized:
return
with self._lock:
if self._initialized:
return
# set logger
self.logger = logging.getLogger(self.__class__.__name__)
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
# model config
self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
self.tokenizer_path = tokenizer_path or self.model_path
# device management
self.device = self._detect_device(device)
self.logger.info(f"Device selected: {self.device}")
# 生成參數
self.max_length = max_length
self.temperature = temperature
self.top_p = top_p
# 模型狀態
self.model = None
self.tokenizer = None
self._model_loaded = False
self.call_count = 0
# HuggingFace認證
self.hf_token = self._setup_huggingface_auth()
# 標記為已初始化
self._initialized = True
self.logger.info("LLMModelManager singleton initialized")
def _detect_device(self, device: Optional[str]) -> str:
"""
檢測並設置運行設備
Args:
device: 用戶指定的設備,None時自動檢測
Returns:
str: ('cuda' or 'cpu')
"""
if device:
if device == 'cuda' and not torch.cuda.is_available():
self.logger.warning("CUDA requested but not available, falling back to CPU")
return 'cpu'
return device
detected_device = 'cuda' if torch.cuda.is_available() else 'cpu'
if detected_device == 'cuda':
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
self.logger.info(f"CUDA detected with {gpu_memory:.2f} GB GPU memory")
return detected_device
def _setup_huggingface_auth(self) -> Optional[str]:
"""
設置HuggingFace認證
Returns:
Optional[str]: HuggingFace token,如果可用
"""
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
try:
login(token=hf_token)
self.logger.info("Successfully authenticated with HuggingFace")
return hf_token
except Exception as e:
self.logger.error(f"HuggingFace authentication failed: {e}")
return None
else:
self.logger.warning("HF_TOKEN not found. Access to gated models may be limited")
return None
def _load_model(self):
"""
載入LLM模型和tokenizer,使用8位量化以節省記憶體
增強的狀態檢查確保模型只載入一次
Raises:
ModelLoadingError: 當模型載入失敗時
"""
# 完整的模型狀態檢查
if (self._model_loaded and
hasattr(self, 'model') and self.model is not None and
hasattr(self, 'tokenizer') and self.tokenizer is not None):
self.logger.info("Model already loaded, skipping reload")
return
try:
self.logger.info(f"Loading model from {self.model_path} with 8-bit quantization")
# 清理GPU記憶體
self._clear_gpu_cache()
# 設置8位量化配置
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True
)
# 載入tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_path,
padding_side="left",
use_fast=False,
token=self.hf_token
)
# 設置特殊標記
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# 載入模型
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
token=self.hf_token
)
self._model_loaded = True
self.logger.info("Model loaded successfully (singleton instance)")
except Exception as e:
error_msg = f"Failed to load model: {str(e)}"
self.logger.error(error_msg)
raise ModelLoadingError(error_msg) from e
def _clear_gpu_cache(self):
"""清理GPU記憶體緩存"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.logger.debug("GPU cache cleared")
def generate_response(self, prompt: str, **generation_kwargs) -> str:
# 確保模型已載入
if not self._model_loaded:
self._load_model()
try:
self.call_count += 1
self.logger.info(f"Generating response (call #{self.call_count})")
# # record input prompt
# self.logger.info(f"DEBUG: Input prompt length: {len(prompt)}")
# self.logger.info(f"DEBUG: Input prompt preview: {prompt[:200]}...")
# clean GPU
self._clear_gpu_cache()
# 設置固定種子以提高一致性
torch.manual_seed(42)
# prepare input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.max_length
).to(self.device)
# 準備生成參數
generation_params = self._prepare_generation_params(**generation_kwargs)
generation_params.update({
"pad_token_id": self.tokenizer.eos_token_id,
"attention_mask": inputs.attention_mask,
"use_cache": True,
})
# response
with torch.no_grad():
outputs = self.model.generate(inputs.input_ids, **generation_params)
# 解碼回應
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# # record whole response
# self.logger.info(f"DEBUG: Full LLM response: {full_response}")
response = self._extract_generated_response(full_response, prompt)
# # 記錄提取後的回應
# self.logger.info(f"DEBUG: Extracted response: {response}")
if not response or len(response.strip()) < 10:
raise ModelGenerationError("Generated response is too short or empty")
self.logger.info(f"Response generated successfully ({len(response)} characters)")
return response
except Exception as e:
error_msg = f"Text generation failed: {str(e)}"
self.logger.error(error_msg)
raise ModelGenerationError(error_msg) from e
def _prepare_generation_params(self, **kwargs) -> Dict[str, Any]:
"""
準備生成參數,支援模型特定的優化
Args:
**kwargs: 用戶提供的生成參數
Returns:
Dict[str, Any]: 完整的生成參數配置
"""
# basic parameters
params = {
"max_new_tokens": 120,
"temperature": self.temperature,
"top_p": self.top_p,
"do_sample": True,
}
# 針對Llama模型的特殊優化
if "llama" in self.model_path.lower():
params.update({
"max_new_tokens": 600,
"temperature": 0.35, # not too big
"top_p": 0.75,
"repetition_penalty": 1.5,
"num_beams": 5,
"length_penalty": 1,
"no_repeat_ngram_size": 3
})
else:
params.update({
"max_new_tokens": 300,
"temperature": 0.6,
"top_p": 0.9,
"num_beams": 1,
"repetition_penalty": 1.05
})
# 用戶參數覆蓋預設值
params.update(kwargs)
return params
def _extract_generated_response(self, full_response: str, prompt: str) -> str:
"""
從完整回應中提取生成的部分
"""
# 尋找assistant標記
assistant_tag = "<|assistant|>"
if assistant_tag in full_response:
response = full_response.split(assistant_tag)[-1].strip()
# 檢查是否有未閉合的user標記
user_tag = "<|user|>"
if user_tag in response:
response = response.split(user_tag)[0].strip()
else:
# 移除輸入提示詞
if full_response.startswith(prompt):
response = full_response[len(prompt):].strip()
else:
response = full_response.strip()
# 移除不自然的場景類型前綴
response = self._remove_scene_type_prefixes(response)
return response
def _remove_scene_type_prefixes(self, response: str) -> str:
"""
移除LLM生成回應中的場景類型前綴
Args:
response: 原始LLM回應
Returns:
str: 移除前綴後的回應
"""
if not response:
return response
prefix_patterns = [r'^[A-Za-z]+\,\s*']
# 應用清理模式
for pattern in prefix_patterns:
response = re.sub(pattern, '', response, flags=re.IGNORECASE)
# 確保首字母大寫
if response and response[0].islower():
response = response[0].upper() + response[1:]
return response.strip()
def reset_context(self):
"""重置模型上下文,清理GPU緩存"""
if self._model_loaded:
self._clear_gpu_cache()
self.logger.info("Model context reset (singleton instance)")
else:
self.logger.info("Model not loaded, no context to reset")
def get_current_device(self) -> str:
"""
獲取當前運行設備
Returns:
str: 當前設備名稱
"""
return self.device
def is_model_loaded(self) -> bool:
"""
檢查模型是否已載入
Returns:
bool: 模型載入狀態
"""
return self._model_loaded
def get_call_count(self) -> int:
"""
獲取模型調用次數
Returns:
int: 調用次數
"""
return self.call_count
def get_model_info(self) -> Dict[str, Any]:
"""
獲取模型信息
Returns:
Dict[str, Any]: 包含模型路徑、設備、載入狀態等信息
"""
return {
"model_path": self.model_path,
"device": self.device,
"is_loaded": self._model_loaded,
"call_count": self.call_count,
"has_hf_token": self.hf_token is not None,
"is_singleton": True
}
@classmethod
def reset_singleton(cls):
"""
重置單例實例(僅用於測試或應用程式重啟)
注意:這會導致模型需要重新載入
"""
with cls._lock:
if cls._instance is not None:
instance = cls._instance
if hasattr(instance, 'logger'):
instance.logger.info("Resetting singleton instance")
cls._instance = None
cls._initialized = False