File size: 13,909 Bytes
af60cba
 
 
 
 
a42280a
 
af60cba
 
 
 
 
 
 
 
 
 
a4da6d3
 
 
af60cba
 
8bacbbf
 
 
 
f30e96e
8bacbbf
 
 
 
f30e96e
8bacbbf
 
 
82e8be7
 
 
 
 
 
 
3154fce
82e8be7
 
a42280a
af60cba
a4da6d3
af60cba
 
a4da6d3
 
 
 
 
 
 
 
 
8bacbbf
 
af60cba
 
 
 
a4da6d3
 
2bef76a
af60cba
8bacbbf
af60cba
 
a42280a
a4da6d3
2bef76a
af60cba
8bacbbf
af60cba
8bacbbf
af60cba
 
 
 
 
 
6c0d50f
af60cba
 
 
6c0d50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af60cba
794c23a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af60cba
 
6c0d50f
af60cba
 
 
 
 
 
 
 
 
 
 
 
6c0d50f
 
af60cba
6c0d50f
 
af60cba
 
6c0d50f
 
af60cba
794c23a
 
 
 
 
 
 
 
 
 
 
 
af60cba
794c23a
 
 
af60cba
6c0d50f
 
af60cba
7324283
6c0d50f
af60cba
794c23a
 
 
 
ce08446
af60cba
96512ae
794c23a
96512ae
 
 
 
 
8d69a10
96512ae
 
 
 
 
 
 
 
 
 
af60cba
6c0d50f
96512ae
af60cba
ce08446
 
8d69a10
7324283
6c0d50f
8d69a10
6c0d50f
 
 
 
 
af60cba
 
6c0d50f
3589840
 
 
 
 
 
 
 
 
 
 
794c23a
3589840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c0d50f
794c23a
 
6c0d50f
 
 
 
 
 
 
 
 
 
 
 
 
af60cba
 
6c0d50f
af60cba
 
6c0d50f
af60cba
 
 
 
6c0d50f
af60cba
 
6c0d50f
 
794c23a
6c0d50f
af60cba
794c23a
6c0d50f
af60cba
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# modules/ai_model.py
import torch
import base64
import requests
from io import BytesIO
import os
from huggingface_hub import login
from PIL import Image
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from utils.logger import log
from typing import Union, Tuple

class AIModel:
    def __init__(self, model_name: str = "google/gemma-3n-e2b-it"):
        self.model_name = model_name
        self.model = None
        self.processor = None
        
        # 设置缓存目录
        self._setup_cache_dirs()
        self._initialize_model()

    def _setup_cache_dirs(self):
        """设置缓存目录"""
        cache_dir = "/app/.cache/huggingface"
        os.makedirs(cache_dir, exist_ok=True)
        
        # 设置环境变量
        os.environ["HF_HOME"] = cache_dir
        os.environ["TRANSFORMERS_CACHE"] = cache_dir
        os.environ["HF_DATASETS_CACHE"] = cache_dir
        
        log.info(f"设置缓存目录: {cache_dir}")

    def _authenticate_hf(self):

        assitant_token = os.getenv("Assitant_tocken")
        token_to_use = assitant_token

        cache_dir = "/app/.cache/huggingface"
        login(token=token_to_use, add_to_git_credential=False)
        log.info("✅ HuggingFace 认证成功")
        return token_to_use



    def _initialize_model(self):
        """初始化Gemma模型"""
        try:
            log.info(f"正在加载模型: {self.model_name}")
            
            token = self._authenticate_hf()
            
            if not token:
                log.error("❌ 无法获取有效token,模型加载失败")
                self.model = None
                self.processor = None
                return
            
            cache_dir = "/app/.cache/huggingface"
            
            self.model = Gemma3nForConditionalGeneration.from_pretrained(
                self.model_name,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
                token=token,
                cache_dir=cache_dir
            ).eval()
            
            self.processor = AutoProcessor.from_pretrained(
                self.model_name,
                trust_remote_code=True,
                token=token,
                cache_dir=cache_dir
            )
            
            log.info("✅ Gemma AI 模型初始化成功")
            
        except Exception as e:
            log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True)
            self.model = None
            self.processor = None

    def is_available(self) -> bool:
        
        return self.model is not None and self.processor is not None

    def detect_input_type(self, input_data: str) -> str:

        if not isinstance(input_data, str):
            return "text"
    
        image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"]
        if (input_data.startswith(("http://", "https://")) and 
            any(input_data.lower().endswith(ext) for ext in image_extensions)):
            return "image"
        elif any(input_data.endswith(ext) for ext in image_extensions):
            return "image"
        elif input_data.startswith("data:image/"):
            return "image"

        audio_extensions = [".wav", ".mp3", ".m4a", ".ogg", ".flac"]
        if (input_data.startswith(("http://", "https://")) and 
            any(input_data.lower().endswith(ext) for ext in audio_extensions)):
            return "audio"
        elif any(input_data.endswith(ext) for ext in audio_extensions):
            return "audio"
    
        return "text"
    
    def transcribe_audio(self, audio_path: str) -> str:
        """
        使用 Hugging Face Inference API 将音频文件转写为文本。
        - 通过环境变量加载 HF_TOKEN 保证安全。
        - 包含网络请求超时和状态码检查,增强健壮性。
        """
        # 1. 从环境变量安全地获取 Token
        hf_token = os.getenv("Assitant_tocken")

        API_URL = "https://api-inference.huggingface.co/models/openai/whisper-large" # 建议使用更新的 v3 版本
        headers = {"Authorization": f"Bearer {hf_token}"}
    
        # 2. 检查音频文件是否存在
        if not os.path.exists(audio_path):
            log.error(f"❌ 音频文件不存在: {audio_path}")
            raise FileNotFoundError(f"指定的音频文件路径不存在: {audio_path}")

        try:
            with open(audio_path, "rb") as f:
                # 3. 发送请求,并设置较长的超时时间 (例如 60 秒)
                log.info(f"🎤 正在向 HF API 发送音频数据... (超时设置为60秒)")
                response = requests.post(API_URL, headers=headers, data=f, timeout=60)

            # 4. 检查 HTTP 响应状态码,主动抛出错误
            response.raise_for_status()  # 如果状态码不是 2xx,则会引发 HTTPError

            result = response.json()
            log.info("✅ HF API 响应成功。")

            # 5. 可靠地提取结果或处理错误信息
            if "text" in result:
                return result["text"].strip()
            else:
                error_message = result.get("error", "未知的 API 错误结构。")
                log.error(f"❌ 转录失败,API 返回: {error_message}")
                # 如果模型正在加载,HuggingFace 会在 error 字段中提示
                if isinstance(error_message, dict) and "estimated_time" in error_message:
                    raise RuntimeError(f"模型正在加载中,请稍后重试。预计等待时间: {error_message['estimated_time']:.1f}秒")
                raise RuntimeError(f"转录失败: {error_message}")

        except requests.exceptions.Timeout:
            log.error("❌ 请求超时!API 未在60秒内响应。")
            raise RuntimeError("语音识别服务请求超时,请稍后再试。")
        except requests.exceptions.RequestException as e:
            log.error(f"❌ 网络请求失败: {e}")
            raise RuntimeError(f"无法连接到语音识别服务: {e}")
        except Exception as e:
            # 捕获其他所有可能的异常,例如文件读取错误、JSON解码错误等
            log.error(f"❌ 处理音频时发生未知错误: {e}", exc_info=True)
            raise e


    def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
    
        if input_type == "image":
            try:
                if raw_input.startswith("data:image/"):
                    header, encoded = raw_input.split(",", 1)
                    image_data = base64.b64decode(encoded)
                    image = Image.open(BytesIO(image_data)).convert("RGB")
                elif raw_input.startswith(("http://", "https://")):
                    response = requests.get(raw_input, timeout=10)
                    response.raise_for_status()
                    image = Image.open(BytesIO(response.content)).convert("RGB")
                else:
                
                    image = Image.open(raw_input).convert("RGB")
            
                log.info("✅ 图片加载成功")
                return input_type, image, "请描述这张图片,并基于图片内容提供旅游建议。"
            
            except Exception as e:
                log.error(f"❌ 图片加载失败: {e}")
                return "text", None, f"图片加载失败,请检查路径或URL。"
            
        elif input_type == "audio":
            try:
                # --- 音频处理核心 ---
                # 假设: 您的类中有一个方法 `transcribe_audio` 用于语音转文字。
                # 您需要自行实现这个方法, 例如通过调用 Whisper, FunASR 或其他 ASR 服务。
                # 它接收音频文件路径 (raw_input) 并返回转写的文本字符串。
                log.info(f"🎤 开始处理音频文件: {raw_input}")
                transcribed_text = self.transcribe_audio(raw_input) 
                log.info(f"✅ 音频转写成功: '{transcribed_text[:50]}...'")
            
                # 注意:处理成功后,我们将 input_type 转为 "text",
                # 因为音频内容已变为文本,后续流程可以统一处理。
                return "text", None, transcribed_text
            
            except Exception as e:
                log.error(f"❌ 音频处理失败: {e}", exc_info=True)
                return "text", None, f"音频处理失败,请检查文件或稍后再试。"
        
        else:  # text
            return input_type, None, raw_input

    def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str:
   
        try:
            inputs = self.processor(
                text=prompt,
                return_tensors="pt"
            ).to(self.model.device, dtype=torch.bfloat16)
            
            with torch.inference_mode():
                generation_args = {
                    "max_new_tokens": 1024,
                    "pad_token_id": self.processor.tokenizer.eos_token_id,
                    "use_cache": True
                }

                # 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
                if temperature < 1e-6: 
                    log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
                    generation_args["do_sample"] = False
                # 否则,使用采样解码 (用于创造性生成任务)
                else:
                    log.info(f"▶️ 使用采样解码 (do_sample=True),temperature={temperature}。")
                    generation_args["do_sample"] = True
                    generation_args["temperature"] = temperature
                    generation_args["top_p"] = 0.9 # top_p 只在采样时有意义

                # 使用构建好的参数字典来调用 generate
                outputs = self.model.generate(
                    **inputs,
                    **generation_args
                )
            input_length = inputs.input_ids.shape[-1]
            generated_tokens = outputs[0][input_length:]
            decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

            return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"

        except RuntimeError as e:
            if "shape" in str(e):
                log.error(f"❌ Tensor形状错误: {e}")
                return "输入处理遇到问题,请尝试简化您的问题。"
            raise e
        except Exception as e:
            log.error(f"❌ 模型推理失败: {e}", exc_info=True)
            return "抱歉,处理您的请求时遇到技术问题。"
        
    def chat_completion(self, model: str, messages: list, **kwargs) -> str:
            
        if not self.is_available():
            log.error("模型未就绪,无法执行 chat_completion")
            if kwargs.get("response_format", {}).get("type") == "json_object":
                return '{"error": "Model not available"}'
            return "抱歉,AI 模型当前不可用。"

        full_prompt = "\n".join([msg.get("content", "") for msg in messages])

        temperature = kwargs.get("temperature", 0.6)

        if kwargs.get("response_format", {}).get("type") == "json_object":
            # 在 prompt 末尾添加指令,强制模型输出 JSON
            full_prompt += "\n\n请注意:你的回答必须是一个严格的、不含任何额外解释和代码块标记的 JSON 对象。"
            # 对于JSON生成任务,使用较低的 temperature 以获得更稳定、确定性的结构
            temperature = 0.1 

        log.debug(f"▶️ 执行 chat_completion (适配器), temperature={temperature}, prompt='{full_prompt[:100]}...'")

        return self.run_inference(
            input_type="text",
            formatted_input=None,
            prompt=full_prompt,
            temperature=temperature # 将处理后的 temperature 传递下去
        )


    def _build_prompt(self, processed_text: str, context: str = "") -> str:

        if context:
            return (
                f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
                f"--- 背景信息 ---\n{context}\n\n"
                f"--- 用户问题 ---\n{processed_text}\n\n"
                f"请提供专业、实用的旅游建议:"
            )
        else:
            return (
                f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n"
                f"用户问题:{processed_text}\n\n"
                f"请提供专业、实用的旅游建议:"
            )

    def generate(self, user_input: str, context: str = "") -> str:
        """主要的生成方法 - 保持原有逻辑"""
        if not self.is_available():
            return "抱歉,AI 模型当前不可用,请稍后再试。"
    
        try:
            # 1. 检测输入类型
            input_type = self.detect_input_type(user_input)
            log.info(f"检测到输入类型: {input_type}")
        
            # 2. 格式化输入
            input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
        
            # 3. 构建prompt - 使用你的原有结构
            prompt = self._build_prompt(processed_text, context)
        
            # 4. 执行推理
            return self.run_inference("text", formatted_data, prompt)
            
        except Exception as e:
            log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
            return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"