# modules/ai_model.py import torch import base64 import requests from io import BytesIO import os from huggingface_hub import login from PIL import Image from transformers import AutoProcessor, Gemma3nForConditionalGeneration from utils.logger import log from typing import Union, Tuple class AIModel: def __init__(self, model_name: str = "google/gemma-3n-e2b-it"): self.model_name = model_name self.model = None self.processor = None # 设置缓存目录 self._setup_cache_dirs() self._initialize_model() def _setup_cache_dirs(self): """设置缓存目录""" cache_dir = "/app/.cache/huggingface" os.makedirs(cache_dir, exist_ok=True) # 设置环境变量 os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir os.environ["HF_DATASETS_CACHE"] = cache_dir log.info(f"设置缓存目录: {cache_dir}") def _authenticate_hf(self): assitant_token = os.getenv("Assitant_tocken") token_to_use = assitant_token cache_dir = "/app/.cache/huggingface" login(token=token_to_use, add_to_git_credential=False) log.info("✅ HuggingFace 认证成功") return token_to_use def _initialize_model(self): """初始化Gemma模型""" try: log.info(f"正在加载模型: {self.model_name}") token = self._authenticate_hf() if not token: log.error("❌ 无法获取有效token,模型加载失败") self.model = None self.processor = None return cache_dir = "/app/.cache/huggingface" self.model = Gemma3nForConditionalGeneration.from_pretrained( self.model_name, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, token=token, cache_dir=cache_dir ).eval() self.processor = AutoProcessor.from_pretrained( self.model_name, trust_remote_code=True, token=token, cache_dir=cache_dir ) log.info("✅ Gemma AI 模型初始化成功") except Exception as e: log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True) self.model = None self.processor = None def is_available(self) -> bool: return self.model is not None and self.processor is not None def detect_input_type(self, input_data: str) -> str: if not isinstance(input_data, str): return "text" image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"] if (input_data.startswith(("http://", "https://")) and any(input_data.lower().endswith(ext) for ext in image_extensions)): return "image" elif any(input_data.endswith(ext) for ext in image_extensions): return "image" elif input_data.startswith("data:image/"): return "image" audio_extensions = [".wav", ".mp3", ".m4a", ".ogg", ".flac"] if (input_data.startswith(("http://", "https://")) and any(input_data.lower().endswith(ext) for ext in audio_extensions)): return "audio" elif any(input_data.endswith(ext) for ext in audio_extensions): return "audio" return "text" def transcribe_audio(self, audio_path: str) -> str: """ 使用 Hugging Face Inference API 将音频文件转写为文本。 - 通过环境变量加载 HF_TOKEN 保证安全。 - 包含网络请求超时和状态码检查,增强健壮性。 """ # 1. 从环境变量安全地获取 Token hf_token = os.getenv("Assitant_tocken") API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large" # 建议使用更新的 v3 版本 headers = {"Authorization": f"Bearer {hf_token}"} # 2. 检查音频文件是否存在 if not os.path.exists(audio_path): log.error(f"❌ 音频文件不存在: {audio_path}") raise FileNotFoundError(f"指定的音频文件路径不存在: {audio_path}") try: with open(audio_path, "rb") as f: # 3. 发送请求,并设置较长的超时时间 (例如 60 秒) log.info(f"🎤 正在向 HF API 发送音频数据... (超时设置为60秒)") response = requests.post(API_URL, headers=headers, data=f, timeout=60) # 4. 检查 HTTP 响应状态码,主动抛出错误 response.raise_for_status() # 如果状态码不是 2xx,则会引发 HTTPError result = response.json() log.info("✅ HF API 响应成功。") # 5. 可靠地提取结果或处理错误信息 if "text" in result: return result["text"].strip() else: error_message = result.get("error", "未知的 API 错误结构。") log.error(f"❌ 转录失败,API 返回: {error_message}") # 如果模型正在加载,HuggingFace 会在 error 字段中提示 if isinstance(error_message, dict) and "estimated_time" in error_message: raise RuntimeError(f"模型正在加载中,请稍后重试。预计等待时间: {error_message['estimated_time']:.1f}秒") raise RuntimeError(f"转录失败: {error_message}") except requests.exceptions.Timeout: log.error("❌ 请求超时!API 未在60秒内响应。") raise RuntimeError("语音识别服务请求超时,请稍后再试。") except requests.exceptions.RequestException as e: log.error(f"❌ 网络请求失败: {e}") raise RuntimeError(f"无法连接到语音识别服务: {e}") except Exception as e: # 捕获其他所有可能的异常,例如文件读取错误、JSON解码错误等 log.error(f"❌ 处理音频时发生未知错误: {e}", exc_info=True) raise e def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]: if input_type == "image": try: if raw_input.startswith("data:image/"): header, encoded = raw_input.split(",", 1) image_data = base64.b64decode(encoded) image = Image.open(BytesIO(image_data)).convert("RGB") elif raw_input.startswith(("http://", "https://")): response = requests.get(raw_input, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(raw_input).convert("RGB") log.info("✅ 图片加载成功") return input_type, image, "请描述这张图片,并基于图片内容提供旅游建议。" except Exception as e: log.error(f"❌ 图片加载失败: {e}") return "text", None, f"图片加载失败,请检查路径或URL。" elif input_type == "audio": try: # --- 音频处理核心 --- # 假设: 您的类中有一个方法 `transcribe_audio` 用于语音转文字。 # 您需要自行实现这个方法, 例如通过调用 Whisper, FunASR 或其他 ASR 服务。 # 它接收音频文件路径 (raw_input) 并返回转写的文本字符串。 log.info(f"🎤 开始处理音频文件: {raw_input}") transcribed_text = self.transcribe_audio(raw_input) log.info(f"✅ 音频转写成功: '{transcribed_text[:50]}...'") # 注意:处理成功后,我们将 input_type 转为 "text", # 因为音频内容已变为文本,后续流程可以统一处理。 return "text", None, transcribed_text except Exception as e: log.error(f"❌ 音频处理失败: {e}", exc_info=True) return "text", None, f"音频处理失败,请检查文件或稍后再试。" else: # text return input_type, None, raw_input def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str: try: inputs = self.processor( text=prompt, return_tensors="pt" ).to(self.model.device, dtype=torch.bfloat16) with torch.inference_mode(): generation_args = { "max_new_tokens": 1024, "pad_token_id": self.processor.tokenizer.eos_token_id, "use_cache": True } # 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务) if temperature < 1e-6: log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。") generation_args["do_sample"] = False # 否则,使用采样解码 (用于创造性生成任务) else: log.info(f"▶️ 使用采样解码 (do_sample=True),temperature={temperature}。") generation_args["do_sample"] = True generation_args["temperature"] = temperature generation_args["top_p"] = 0.9 # top_p 只在采样时有意义 # 使用构建好的参数字典来调用 generate outputs = self.model.generate( **inputs, **generation_args ) input_length = inputs.input_ids.shape[-1] generated_tokens = outputs[0][input_length:] decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。" except RuntimeError as e: if "shape" in str(e): log.error(f"❌ Tensor形状错误: {e}") return "输入处理遇到问题,请尝试简化您的问题。" raise e except Exception as e: log.error(f"❌ 模型推理失败: {e}", exc_info=True) return "抱歉,处理您的请求时遇到技术问题。" def chat_completion(self, model: str, messages: list, **kwargs) -> str: if not self.is_available(): log.error("模型未就绪,无法执行 chat_completion") if kwargs.get("response_format", {}).get("type") == "json_object": return '{"error": "Model not available"}' return "抱歉,AI 模型当前不可用。" full_prompt = "\n".join([msg.get("content", "") for msg in messages]) temperature = kwargs.get("temperature", 0.6) if kwargs.get("response_format", {}).get("type") == "json_object": # 在 prompt 末尾添加指令,强制模型输出 JSON full_prompt += "\n\n请注意:你的回答必须是一个严格的、不含任何额外解释和代码块标记的 JSON 对象。" # 对于JSON生成任务,使用较低的 temperature 以获得更稳定、确定性的结构 temperature = 0.1 log.debug(f"▶️ 执行 chat_completion (适配器), temperature={temperature}, prompt='{full_prompt[:100]}...'") return self.run_inference( input_type="text", formatted_input=None, prompt=full_prompt, temperature=temperature # 将处理后的 temperature 传递下去 ) def _build_prompt(self, processed_text: str, context: str = "") -> str: if context: return ( f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n" f"--- 背景信息 ---\n{context}\n\n" f"--- 用户问题 ---\n{processed_text}\n\n" f"请提供专业、实用的旅游建议:" ) else: return ( f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n" f"用户问题:{processed_text}\n\n" f"请提供专业、实用的旅游建议:" ) def generate(self, user_input: str, context: str = "") -> str: """主要的生成方法 - 保持原有逻辑""" if not self.is_available(): return "抱歉,AI 模型当前不可用,请稍后再试。" try: # 1. 检测输入类型 input_type = self.detect_input_type(user_input) log.info(f"检测到输入类型: {input_type}") # 2. 格式化输入 input_type, formatted_data, processed_text = self.format_input(input_type, user_input) # 3. 构建prompt - 使用你的原有结构 prompt = self._build_prompt(processed_text, context) # 4. 执行推理 return self.run_inference("text", formatted_data, prompt) except Exception as e: log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True) return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"