#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版) 整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。 此版本專注於效能、可維護性、健壯性與使用者體驗。 """ # ---------- 環境與快取設定 (應置於最前) ---------- import os import pathlib os.environ.setdefault("HF_HOME", "/tmp/hf") os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers") os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache") for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")): pathlib.Path(d).mkdir(parents=True, exist_ok=True) # ---------- Python 標準函式庫 ---------- import re import hmac import base64 import hashlib import pickle import logging import json import textwrap import time import tenacity from typing import List, Dict, Any, Optional, Tuple, Union from functools import lru_cache from dataclasses import dataclass, field from contextlib import asynccontextmanager import unicodedata from collections import defaultdict # ---------- 第三方函式庫 ---------- import numpy as np import pandas as pd from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks import uvicorn import jieba from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer import faiss import torch from openai import OpenAI from tenacity import retry, stop_after_attempt, wait_fixed import requests from transformers import pipeline # [MODIFIED] 限制 PyTorch 執行緒數量,避免在 CPU 環境下過度佔用資源 torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1"))) # ==== CONFIG (從環境變數載入,或使用預設值) ==== # [MODIFIED] 新增環境變數健檢函式 def _require_env(var: str) -> str: v = os.getenv(var) if not v: raise RuntimeError(f"FATAL: Missing required environment variable: {var}") return v # [MODIFIED] 檢查 LLM 相關環境變數 def _require_llm_config(): for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"): _require_env(k) # MedGemma 模型直接硬編碼,如果需要替換,可以在這裡加入檢查 # _require_env("MEDGEMMA_MODEL_NAME") # 如果 MEDGEMMA_MODEL_NAME 是環境變數 CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv") FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index") SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl") BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl") TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20)) PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30)) MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30)) EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh") LLM_API_CONFIG = { "base_url": os.getenv("LITELLM_BASE_URL"), "api_key": os.getenv("LITELLM_API_KEY"), "model": os.getenv("LM_MODEL") } MEDGEMMA_MODEL_NAME = "google/medgemma-4b-it" # 硬編碼 MedGemma 模型名稱 LLM_MODEL_CONFIG = { "max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)), "max_tokens": int(os.getenv("MAX_TOKENS", 1024)), "temperature": float(os.getenv("TEMPERATURE", 0.0)), "seed": int(os.getenv("LLM_SEED", 42)), # 新增 seed 以確保重現性 } INTENT_CATEGORIES = [ "操作 (Administration)", "保存/攜帶 (Storage & Handling)", "副作用/異常 (Side Effects / Issues)", "劑型相關 (Dosage Form Concerns)", "時間/併用 (Timing & Interaction)", "劑量調整 (Dosage Adjustment)", "禁忌症/適應症 (Contraindications/Indications)" ] # [新增] 意圖分類 → CSV section 對照表 INTENT_TO_SECTION = { "操作 (Administration)": ["用法用量", "病人使用須知"], "保存/攜帶 (Storage & Handling)": ["包裝及儲存"], "副作用/異常 (Side Effects / Issues)": ["不良反應", "警語與注意事項"], "劑型相關 (Dosage Form Concerns)": ["劑型", "藥品外觀"], "時間/併用 (Timing & Interaction)": ["用法用量"], "劑量調整 (Dosage Adjustment)": ["用法用量"], "禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"] } # 新增 REFERENCE_MAPPING REFERENCE_MAPPING = { "如何用藥?": "病人使用須知、用法用量", "如何保存與攜帶?": "包裝及儲存", "可能的副作用?": "警語與注意事項、不良反應", "每次劑量多少?": "用法用量、藥袋上的醫囑", "用藥時間?": "用法用量、藥袋上的醫囑", } # 新增反向映射,從 sections 找到 intents SECTION_TO_INTENT = defaultdict(list) for intent, sections in INTENT_TO_SECTION.items(): for section in sections: SECTION_TO_INTENT[section].append(intent) DRUG_NAME_MAPPING = { "fentanyl patch": "fentanyl", "spiriva respimat": "spiriva", "augmentin for syrup": "augmentin syrup", "nitrostat": "nitroglycerin", "ozempic": "ozempic", "niflec": "niflec", "fosamax": "fosamax", "humira": "humira", "premarin": "premarin", "smecta": "smecta", } DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務務必諮詢您的醫師或藥師。" PROMPT_TEMPLATES = { "analyze_query": """ 請分析以下使用者問題,並完成以下兩個任務: 1. 將問題分解為1-3個核心的子問題。 2. 從清單中選擇所有相關的意圖分類。 請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intents' (字串陣列) 兩個鍵。 範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"]}} 意圖分類清單: {options} 使用者問題:{query} """, "expand_query": """ 請根據以下意圖:{intents},擴展這個查詢,加入相關同義詞或術語。 原始查詢:{query} 請僅輸出擴展後的查詢,不需任何額外的解釋或格式。 """, "final_answer_concise": """ 您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆: 一、 回覆準則 嚴格依據資料: 所有回覆內容都必須完全來自提供的參考資料,禁止任何形式的捏造或引用外部資訊。 資料不足: 若參考資料無法回答使用者的問題,請直接回覆:「根據提供的資料,無法回答您的問題。」 專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。 精簡扼要: 內容需極其簡潔,資訊完整,不要使用*符號,字數請嚴格控制在60字以內。 二、 排版規範 條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。 結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」 範例: 使用者問題: 請問普拿疼可以怎麼吃? 參考資料: 普拿疼成人建議劑量為一次1至2錠,每4至6小時服用一次,每日不超過8錠。 AI回覆範例: 普拿疼成人劑量: 1-2錠/次 每4-6小時 每日≤8錠 如有不適請立即就醫。 {additional_instruction} --- 參考資料: {context} --- 使用者問題:{query} 請直接輸出最終的答案: """, "final_answer_detailed": """ 您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆: 一、 回覆準則 嚴格依據資料: 所有回覆內容都必須完全來自提供的參考資料,禁止任何形式的捏造或引用外部資訊。 資料不足: 若參考資料無法回答使用者的問題,請直接回覆:「根據提供的資料,無法回答您的問題。」 專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。 精簡扼要: 內容需簡潔但資訊完整,不要使用*符號,字數請控制在200字以內,提供更多細節解釋。 二、 排版規範 條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。 結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」 範例: 使用者問題: 請問普拿疼可以怎麼吃? 參考資料: 普拿疼成人建議劑量為一次1至2錠,每4至6小時服用一次,每日不超過8錠。 AI回覆範例: 普拿疼成人建議劑量為: - 一次服用1至2錠,視疼痛程度調整。 - 每4至6小時服用一次,避免過頻。 - 每日總量不超過8錠,以防副作用。 如有不適請立即就醫。 {additional_instruction} --- 參考資料: {context} --- 使用者問題:{query} 請直接輸出最終的答案: """, "direct_answer_concise": """ 您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,直接基於您的知識給予回覆: 一、 回覆準則 專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。 精簡扼要: 內容需極其簡潔,資訊完整,不要使用*符號,字數請嚴格控制在60字以內。 二、 排版規範 條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。 結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」 範例: 使用者問題: 請問普拿疼可以怎麼吃? AI回覆範例: 普拿疼成人劑量: 1-2錠/次 每4-6小時 每日≤8錠 如有不適請立即就醫。 {additional_instruction} 使用者問題:{query} 請直接輸出最終的答案: """, "direct_answer_detailed": """ 您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,直接基於您的知識給予回覆: 一、 回覆準則 專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。 精簡扼要: 內容需簡潔但資訊完整,不要使用*符號,字數請控制在200字以內,提供更多細節解釋。 二、 排版規範 條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。 結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」 範例: 使用者問題: 請問普拿疼可以怎麼吃? AI回覆範例: 普拿疼成人建議劑量為: - 一次服用1至2錠,視疼痛程度調整。 - 每4至6小時服用一次,避免過頻。 - 每日總量不超過8錠,以防副作用。 如有不適請立即就醫。 {additional_instruction} 使用者問題:{query} 請直接輸出最終的答案: """ } # ---------- 日誌設定 ---------- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) # [新增] 統一字串正規化函式 def _norm(s: str) -> str: """統一化字串:NFKC 正規化、轉小寫、移除標點符號與空白。""" s = unicodedata.normalize("NFKC", s) return re.sub(r"[^\w\s]", "", s.lower()).strip() @dataclass class FusedCandidate: idx: int fused_score: float sem_score: float bm_score: float @dataclass class RerankResult: idx: int rerank_score: float text: str meta: Dict[str, Any] = field(default_factory=dict) # ---------- 核心 RAG 邏輯 ---------- class RagPipeline: def __init__(self): if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]: raise ValueError("LLM API Key or Base URL is not configured.") # OpenAI client for LITELLM self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"]) self.litellm_model_name = LLM_API_CONFIG["model"] # MedGemma pipeline device = "cuda" if torch.cuda.is_available() else "cpu" log.info(f"載入 MedGemma 模型: {MEDGEMMA_MODEL_NAME} 至 {device}...") try: self.medgemma_pipe = pipeline( "text-generation", model=MEDGEMMA_MODEL_NAME, torch_dtype=torch.bfloat16, device=device, ) log.info("MedGemma 模型載入成功。") except Exception as e: log.error(f"載入 MedGemma 模型失敗: {e}") raise RuntimeError(f"MedGemma 模型載入失敗: {MEDGEMMA_MODEL_NAME}") self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding") self.drug_name_to_ids: Dict[str, List[str]] = {} self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()} self.state = type('state', (), {})() # [新增] 澄清問題的 Prompt Template self.CLARIFICATION_PROMPT = """ 請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精確地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境等。 範例: 使用者問題:這個藥會怎麼樣? 澄清提問:您好,請問您指的是哪一種藥物呢? 使用者問題:請問這要吃多久? 澄清提問:請問您是想了解該藥品的建議療程長度嗎? 使用者問題:{query} 澄清提問:""" def _load_model(self, model_class, model_name: str, model_type: str): device = "cuda" if torch.cuda.is_available() else "cpu" log.info(f"載入 {model_type} 模型:{model_name} 至 {device}...") try: return model_class(model_name, device=device) except Exception as e: log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。") try: return model_class(model_name, device="cpu") except Exception as e_cpu: log.error(f"切換至 CPU 仍無法載入模型: {model_name}。請確認模型路徑或網路連線。錯誤訊息: {e_cpu}") raise RuntimeError(f"模型載入失敗: {model_name}") def load_data(self): log.info("開始載入資料與模型...") for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]: if not pathlib.Path(path).exists(): raise FileNotFoundError(f"必要的資料檔案不存在: {path}") try: self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('') for col in ("drug_name_norm", "drug_id"): if col not in self.df_csv.columns: raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}") self.drug_name_to_ids = self._build_drug_name_to_ids() self._load_drug_name_vocabulary() log.info("載入 FAISS 索引與句子資料...") self.state.index = faiss.read_index(FAISS_INDEX) self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2) if hasattr(self.state.index, "nprobe"): self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16")) if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: log.info("FAISS 索引使用內積 (IP) 指標,檢索時將自動進行 L2 正規化以實現餘弦相似度。") with open(SENTENCES_PKL, "rb") as f: data = pickle.load(f) self.state.sentences = data["sentences"] self.state.meta = data["meta"] log.info("載入 BM25 索引...") with open(BM25_PKL, "rb") as f: bm25_data = pickle.load(f) self.state.bm25 = bm25_data["bm25"] if not isinstance(self.state.bm25, BM25Okapi): raise ValueError("Loaded BM25 is not a BM25Okapi instance.") except (FileNotFoundError, KeyError) as e: log.exception(f"資料或索引檔案載入失敗: {e}") raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}") log.info("所有模型與資料載入完成。") # [新增] 將 drug_id 轉換為使用者友善的 drug_name_norm @lru_cache(maxsize=128) def _get_drug_name_by_id(self, drug_id: str) -> Optional[str]: """從 drug_id 查找對應的 drug_name_norm。""" row = self.df_csv[self.df_csv['drug_id'] == drug_id] if not row.empty: return row.iloc[0]['drug_name_norm'] return None def _find_drug_ids_from_name(self, query: str) -> List[str]: q_norm_parts = set(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(query))) drug_ids = set() for part in q_norm_parts: if part in self.drug_name_to_ids: drug_ids.update(self.drug_name_to_ids[part]) # [新增] 新增 drug_name_norm 至 drug_id 的反向查找,以支持上下文處理 for drug_name, ids in self.drug_name_to_ids.items(): if drug_name in _norm(query): drug_ids.update(ids) return sorted(list(drug_ids)) def _build_drug_name_to_ids(self) -> Dict[str, List[str]]: mapping = {} for _, row in self.df_csv.iterrows(): drug_id = row['drug_id'] zh_parts = list(jieba.cut(row['drug_name_zh'])) en_parts = re.findall(r'[a-zA-Z0-9]+', row['drug_name_en'].lower() if row['drug_name_en'] else '') norm_parts = re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(row['drug_name_norm'])) all_parts = set(zh_parts + en_parts + norm_parts) for part in all_parts: part = part.strip() if part and len(part) > 1: mapping.setdefault(part, []).append(drug_id) for alias, canonical_name in DRUG_NAME_MAPPING.items(): if _norm(canonical_name) in _norm(row['drug_name_norm']): mapping.setdefault(_norm(alias), []).append(drug_id) for key in mapping: mapping[key] = sorted(list(set(mapping[key]))) return mapping def _load_drug_name_vocabulary(self): log.info("建立藥名詞庫...") for _, row in self.df_csv.iterrows(): norm_name = row['drug_name_norm'] words = list(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', norm_name)) for word in words: if re.search(r'[\u4e00-\u9fff]', word): self.drug_vocab["zh"].add(word) else: self.drug_vocab["en"].add(word) for alias in DRUG_NAME_MAPPING: if re.search(r'[\u4e00-\u9fff]', alias): self.drug_vocab["zh"].add(alias) else: self.drug_vocab["en"].add(alias) for word in self.drug_vocab["zh"]: try: if word not in jieba.dt.FREQ: jieba.add_word(word, freq=2_000_000) except Exception: pass # [MODIFIED] 兩個獨立的 LLM 調用函式,用於輸出比較 @tenacity.retry( wait=tenacity.wait_fixed(2), stop=tenacity.stop_after_attempt(3), retry=tenacity.retry_if_exception_type(ValueError), before_sleep=tenacity.before_sleep_log(log, logging.WARNING), after=tenacity.after_log(log, logging.INFO) ) def _litellm_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None, seed: Optional[int] = None) -> str: """安全地呼叫 LITELLM API,並處理可能的回應內容為空錯誤。""" log.info(f"LITELLM 呼叫開始. 模型: {self.litellm_model_name}, max_tokens: {max_tokens}, temperature: {temperature}, seed: {seed}") start_time = time.time() try: response = self.llm_client.chat.completions.create( model=self.litellm_model_name, messages=messages, max_tokens=max_tokens, temperature=temperature, seed=seed, ) end_time = time.time() log.info(f"LITELLM 收到完整回應: {response.model_dump_json(indent=2)}") if not response.choices or not response.choices[0].message.content: log.warning("LITELLM 呼叫成功 (200 OK),但回傳內容為空。將回傳空字串。") return "" content = response.choices[0].message.content log.info(f"LITELLM 呼叫完成,耗時: {end_time - start_time:.2f} 秒。內容長度: {len(content)} 字。") return content except Exception as e: log.error(f"LITELLM API 呼叫失敗: {e}") raise @tenacity.retry( wait=tenacity.wait_fixed(2), stop=tenacity.stop_after_attempt(3), retry=tenacity.retry_if_exception_type(ValueError), before_sleep=tenacity.before_sleep_log(log, logging.WARNING), after=tenacity.after_log(log, logging.INFO) ) def _medgemma_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None, seed: Optional[int] = None) -> str: """安全地呼叫 MedGemma 模型,並處理可能的回應內容為空錯誤。""" log.info(f"MedGemma 呼叫開始. max_tokens: {max_tokens}, temperature: {temperature}, seed: {seed}") start_time = time.time() try: # MedGemma pipeline requires a specific format for messages converted_messages = [] for msg in messages: role = msg["role"] content = [{"type": "text", "text": msg["content"]}] converted_messages.append({"role": role, "content": content}) output = self.medgemma_pipe( converted_messages, max_new_tokens=max_tokens or LLM_MODEL_CONFIG["max_tokens"], temperature=temperature if temperature is not None else LLM_MODEL_CONFIG["temperature"], # MedGemma pipeline 可能不直接支持seed,如果不支持,可移除或處理 ) end_time = time.time() if not output or not output[0]["generated_text"] or not output[0]["generated_text"][-1]["content"]: log.warning("MedGemma 呼叫成功,但回傳內容為空。將回傳空字串。") return "" content = output[0]["generated_text"][-1]["content"] log.info(f"MedGemma 呼叫完成,耗時: {end_time - start_time:.2f} 秒。內容長度: {len(content)} 字。") return content except Exception as e: log.error(f"MedGemma 呼叫失敗: {e}") raise # [MODIFIED] 修改 answer_question 函式以回傳四種答案,每種包括簡潔版和詳細版 def answer_question(self, q_orig: str) -> Tuple[Dict[str, Dict[str, str]], List[str]]: start_time = time.time() log.info(f"===== 處理新查詢: '{q_orig}' =====") try: drug_ids = self._find_drug_ids_from_name(q_orig) if not drug_ids: log.info("未從查詢中找到相關藥名,透過兩種 LLM 生成澄清性問題。") clarification_litellm = self._generate_clarification_query_litellm(q_orig) clarification_medgemma = self._generate_clarification_query_medgemma(q_orig) log.info(f"澄清問題比較: LITELLM: '{clarification_litellm}', MedGemma: '{clarification_medgemma}'") return {"clarification_litellm": clarification_litellm + f"\n{DISCLAIMER}", "clarification_medgemma": clarification_medgemma + f"\n{DISCLAIMER}"}, [] log.info(f"步驟 1/4: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒") step_start = time.time() # 分開處理兩種 LLM 的分析 analysis_litellm = self._analyze_query_litellm(q_orig) analysis_medgemma = self._analyze_query_medgemma(q_orig) sub_queries_litellm, intents_litellm = analysis_litellm.get("sub_queries", [q_orig]), analysis_litellm.get("intents", []) sub_queries_medgemma, intents_medgemma = analysis_medgemma.get("sub_queries", [q_orig]), analysis_medgemma.get("intents", []) log.info(f"意圖分析比較: LITELLM: {analysis_litellm}, MedGemma: {analysis_medgemma}") if not intents_litellm and not intents_medgemma: log.info("意圖分析失敗,透過兩種 LLM 生成澄清性問題。") clarification_litellm = self._generate_clarification_query_litellm(q_orig) clarification_medgemma = self._generate_clarification_query_medgemma(q_orig) log.info(f"澄清問題比較: LITELLM: '{clarification_litellm}', MedGemma: '{clarification_medgemma}'") return {"clarification_litellm": clarification_litellm + f"\n{DISCLAIMER}", "clarification_medgemma": clarification_medgemma + f"\n{DISCLAIMER}"}, drug_ids log.info(f"步驟 2/4: 意圖分析完成。LITELLM 子問題: {sub_queries_litellm}, 意圖: {intents_litellm}。MedGemma 子問題: {sub_queries_medgemma}, 意圖: {intents_medgemma}。耗時: {time.time() - step_start:.2f} 秒") step_start = time.time() # 分開處理兩種 LLM 的檢索流程 all_candidates_litellm = self._retrieve_candidates_for_all_queries_litellm(drug_ids, sub_queries_litellm, intents_litellm) all_candidates_medgemma = self._retrieve_candidates_for_all_queries_medgemma(drug_ids, sub_queries_medgemma, intents_medgemma) log.info(f"步驟 3/4: 檢索完成。LITELLM 找到 {len(all_candidates_litellm)} 個, MedGemma 找到 {len(all_candidates_medgemma)} 個。耗時: {time.time() - step_start:.2f} 秒") step_start = time.time() # LITELLM RAG final_candidates_litellm = all_candidates_litellm[:TOP_K_SENTENCES] reranked_results_litellm = [ RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx]) for c in final_candidates_litellm ] prioritized_results_litellm = self._prioritize_context(reranked_results_litellm, intents_litellm) context_litellm = self._build_context(prioritized_results_litellm) # MedGemma RAG final_candidates_medgemma = all_candidates_medgemma[:TOP_K_SENTENCES] reranked_results_medgemma = [ RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx]) for c in final_candidates_medgemma ] prioritized_results_medgemma = self._prioritize_context(reranked_results_medgemma, intents_medgemma) context_medgemma = self._build_context(prioritized_results_medgemma) log.info(f"步驟 4/4: 最終選出 LITELLM {len(reranked_results_litellm)} 個, MedGemma {len(reranked_results_medgemma)} 個。耗時: {time.time() - step_start:.2f} 秒") step_start = time.time() if not context_litellm and not context_medgemma: log.info("沒有足夠的上下文來回答問題。") return {"error": f"根據提供的資料,無法回答您的問題。\n{DISCLAIMER}"}, drug_ids # 通用參數:temperature=0.0, seed=42 以確保重現性 temp = LLM_MODEL_CONFIG["temperature"] seed = LLM_MODEL_CONFIG["seed"] max_tokens_concise = 256 # 簡潔版限制較小 max_tokens_detailed = LLM_MODEL_CONFIG["max_tokens"] # 詳細版使用預設 # 生成答案 - With RAG answers_with_rag = self._generate_answers_with_rag(q_orig, context_litellm, intents_litellm, context_medgemma, intents_medgemma, temp, seed, max_tokens_concise, max_tokens_detailed) # 生成答案 - Without RAG (直接用 query) answers_without_rag = self._generate_answers_without_rag(q_orig, intents_litellm, intents_medgemma, temp, seed, max_tokens_concise, max_tokens_detailed) log.info(f"答案生成完成。耗時: {time.time() - step_start:.2f} 秒") log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====") # 合併返回 return { "LITELLM_with_RAG": answers_with_rag["LITELLM"], "MedGemma_with_RAG": answers_with_rag["MedGemma"], "LITELLM_without_RAG": answers_without_rag["LITELLM"], "MedGemma_without_RAG": answers_without_rag["MedGemma"] }, drug_ids except Exception as e: log.error(f"處理查詢 '{q_orig}' 時發生嚴重錯誤: {e}", exc_info=True) return {"error": f"處理您的問題時發生內部錯誤,請稍後再試。\n{DISCLAIMER}"}, [] def _analyze_query_litellm(self, query: str) -> Dict[str, Any]: return self._analyze_query_generic(query, self._litellm_call) def _analyze_query_medgemma(self, query: str) -> Dict[str, Any]: return self._analyze_query_generic(query, self._medgemma_call) def _analyze_query_generic(self, query: str, llm_call) -> Dict[str, Any]: # 先檢查 REFERENCE_MAPPING norm_query = _norm(query) for ref_key, ref_value in REFERENCE_MAPPING.items(): if _norm(ref_key) in norm_query: sections = [s.strip() for s in ref_value.split(',')] intents = [] for section in sections: intents.extend(SECTION_TO_INTENT.get(section, [])) intents = list(set(intents)) # 去重 if intents: log.info(f"匹配 REFERENCE_MAPPING: '{ref_key}' -> intents: {intents}") return {"sub_queries": [query], "intents": intents} # 如果不匹配,才進行 LLM 意圖偵測 prompt = PROMPT_TEMPLATES["analyze_query"].format( options="\n".join(f"- {c}" for c in INTENT_CATEGORIES), query=query ) response_str = llm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]) return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []}) def _generate_clarification_query_litellm(self, query: str) -> str: prompt = self.CLARIFICATION_PROMPT.format(query=query) return self._litellm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]).strip() def _generate_clarification_query_medgemma(self, query: str) -> str: prompt = self.CLARIFICATION_PROMPT.format(query=query) return self._medgemma_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]).strip() def _retrieve_candidates_for_all_queries_litellm(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]: return self._retrieve_candidates_for_all_queries_generic(drug_ids, sub_queries, intents, self._expand_query_with_llm_litellm) def _retrieve_candidates_for_all_queries_medgemma(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]: return self._retrieve_candidates_for_all_queries_generic(drug_ids, sub_queries, intents, self._expand_query_with_llm_medgemma) def _retrieve_candidates_for_all_queries_generic(self, drug_ids: List[str], sub_queries: List[str], intents: List[str], expand_func) -> List[FusedCandidate]: drug_ids_set = set(map(str, drug_ids)) if drug_ids_set: relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id", "")) in drug_ids_set} else: relevant_indices = set(range(len(self.state.meta))) if not relevant_indices: return [] all_fused_candidates: Dict[int, FusedCandidate] = {} for sub_q in sub_queries: expanded_q = expand_func(sub_q, tuple(intents)) q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32") if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: faiss.normalize_L2(q_emb) distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K) tokenized_query = list(jieba.cut(expanded_q)) bm25_scores = self.state.bm25.get_scores(tokenized_query) rel_idx = np.fromiter(relevant_indices, dtype=int) rel_scores = bm25_scores[rel_idx] top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]] doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_rel} candidate_scores: Dict[int, Dict[str, float]] = {} def to_similarity(d: float) -> float: if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: return float(d) else: return 1.0 / (1.0 + float(d)) for i, dist in zip(sim_indices[0], distances[0]): if i in relevant_indices: similarity = to_similarity(dist) candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0} for i, score in doc_to_bm25_score.items(): if i in relevant_indices: candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score if not candidate_scores: continue keys = list(candidate_scores.keys()) sem_scores = np.array([candidate_scores[k]['sem'] for k in keys]) bm_scores = np.array([candidate_scores[k]['bm'] for k in keys]) def norm(x): rng = x.max() - x.min() return (x - x.min()) / (rng + 1e-8) if rng > 0 else np.zeros_like(x) sem_n, bm_n = norm(sem_scores), norm(bm_scores) for idx, k in enumerate(keys): fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4 if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score: all_fused_candidates[k] = FusedCandidate( idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx] ) return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True) def _expand_query_with_llm_litellm(self, query: str, intents: tuple) -> str: if not intents: return query prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query) try: expanded = self._litellm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]) log.info(f"LITELLM 查詢擴展成功。原始: '{query}', 擴展後: '{expanded}'") return expanded.strip() if expanded.strip() else query except Exception as e: log.error(f"LITELLM 查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。") return query def _expand_query_with_llm_medgemma(self, query: str, intents: tuple) -> str: if not intents: return query prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query) try: expanded = self._medgemma_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]) log.info(f"MedGemma 查詢擴展成功。原始: '{query}', 擴展後: '{expanded}'") return expanded.strip() if expanded.strip() else query except Exception as e: log.error(f"MedGemma 查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。") return query def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]: if not intents: return results prioritized_sections = set() for intent in intents: prioritized_sections.update(INTENT_TO_SECTION.get(intent, [])) if not prioritized_sections: return results log.info(f"根據意圖 '{intents}' 優先處理章節: {prioritized_sections}") prioritized_results = [] other_results = [] for res in results: section = res.meta.get("section", "") if section in prioritized_sections: prioritized_results.append(res) else: other_results.append(res) final_prioritized_list = prioritized_results + other_results return final_prioritized_list def _build_context(self, reranked_results: List[RerankResult]) -> str: context = "" for res in reranked_results: if len(context) + len(res.text) > LLM_MODEL_CONFIG["max_context_chars"]: break context += res.text + "\n\n" return context.strip() def _safe_json_parse(self, s: str, default: Any = None) -> Any: try: return json.loads(s) except json.JSONDecodeError: log.warning(f"無法解析完整 JSON。嘗試從字串中提取: {s[:200]}...") m = re.search(r'\{.*?\}', s, re.DOTALL) if m: try: return json.loads(m.group(0)) except json.JSONDecodeError: log.warning(f"提取的 JSON 仍無法解析: {m.group(0)[:100]}...") return default # ---------- FastAPI 事件與路由 ---------- class AppConfig: CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN") CHANNEL_SECRET = _require_env("CHANNEL_SECRET") rag_pipeline: Optional[RagPipeline] = None # [MODIFIED] 全域使用者狀態快取, 儲存更詳細的資訊 USER_STATE_CACHE = defaultdict(lambda: {"last_query": None, "last_drug_ids": []}) @asynccontextmanager async def lifespan(app: FastAPI): _require_llm_config() global rag_pipeline rag_pipeline = RagPipeline() rag_pipeline.load_data() log.info("啟動完成,服務準備就緒。") yield log.info("服務關閉中。") app = FastAPI(lifespan=lifespan) @app.post("/webhook") async def handle_webhook(request: Request, background_tasks: BackgroundTasks): signature = request.headers.get("X-Line-Signature") if not signature: raise HTTPException(status_code=400, detail="Missing X-Line-Signature") if not AppConfig.CHANNEL_SECRET: log.error("CHANNEL_SECRET is not configured.") raise HTTPException(status_code=500, detail="Server configuration error") body = await request.body() try: hash = hmac.new(AppConfig.CHANNEL_SECRET.encode('utf-8'), body, hashlib.sha256) expected_signature = base64.b64encode(hash.digest()).decode('utf-8') except Exception as e: log.error(f"Failed to generate signature: {e}") raise HTTPException(status_code=500, detail="Signature generation error") if not hmac.compare_digest(expected_signature, signature): raise HTTPException(status_code=403, detail="Invalid signature") try: data = json.loads(body.decode('utf-8')) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid JSON body") for event in data.get("events", []): if event.get("type") == "message" and event.get("message", {}).get("type") == "text": user_text = event.get("message", {}).get("text", "").strip() source = event.get("source", {}) stype = source.get("type") target_id = source.get("userId") or source.get("groupId") or source.get("roomId") if user_text and target_id: background_tasks.add_task(rag_pipeline.process_user_query, stype, target_id, user_text) return Response(status_code=status.HTTP_200_OK) @retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) def line_api_call(endpoint: str, data: Dict): headers = { "Content-Type": "application/json", "Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}" } try: response = requests.post(f"https://api.line.me/v2/bot/message/{endpoint}", headers=headers, json=data, timeout=10) response.raise_for_status() except requests.exceptions.RequestException as e: log.error(f"LINE API ({endpoint}) 呼叫失敗: {e} | Response: {e.response.text if e.response else 'N/A'}") raise def line_push_generic(source_type: str, target_id: str, text: str): messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]] endpoint = "push" data = {"to": target_id, "messages": messages} line_api_call(endpoint, data) # ---------- 執行 ---------- if __name__ == "__main__": port = int(os.getenv("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)