Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版) | |
整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。 | |
此版本專注於效能、可維護性、健壯性與使用者體驗。 | |
""" | |
# ---------- 環境與快取設定 (應置於最前) ---------- | |
import os | |
import pathlib | |
os.environ.setdefault("HF_HOME", "/tmp/hf") | |
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers") | |
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache") | |
for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")): | |
pathlib.Path(d).mkdir(parents=True, exist_ok=True) | |
# ---------- Python 標準函式庫 ---------- | |
import re | |
import hmac | |
import base64 | |
import hashlib | |
import pickle | |
import logging | |
import json | |
import textwrap | |
import time | |
import tenacity | |
from typing import List, Dict, Any, Optional, Tuple, Union | |
from functools import lru_cache | |
from dataclasses import dataclass, field | |
from contextlib import asynccontextmanager | |
import unicodedata | |
# ---------- 第三方函式庫 ---------- | |
import numpy as np | |
import pandas as pd | |
from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks | |
import uvicorn | |
import jieba | |
from rank_bm25 import BM25Okapi | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
import faiss | |
import torch | |
from openai import OpenAI | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
import requests | |
# [MODIFIED] 限制 PyTorch 執行緒數量,避免在 CPU 環境下過度佔用資源 | |
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1"))) | |
# ==== CONFIG (從環境變數載入,或使用預設值) ==== | |
# [MODIFIED] 新增環境變數健檢函式 | |
def _require_env(var: str) -> str: | |
v = os.getenv(var) | |
if not v: | |
raise RuntimeError(f"FATAL: Missing required environment variable: {var}") | |
return v | |
# [MODIFIED] 檢查 LLM 相關環境變數 | |
def _require_llm_config(): | |
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"): | |
_require_env(k) | |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv") | |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index") | |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl") | |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl") | |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 15)) | |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30)) | |
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30)) | |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh") | |
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3") | |
LLM_API_CONFIG = { | |
"base_url": os.getenv("LITELLM_BASE_URL"), | |
"api_key": os.getenv("LITELLM_API_KEY"), | |
"model": os.getenv("LM_MODEL") | |
} | |
LLM_MODEL_CONFIG = { | |
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)), | |
"max_tokens": int(os.getenv("MAX_TOKENS", 1024)), | |
"temperature": float(os.getenv("TEMPERATURE", 0.0)), | |
} | |
INTENT_CATEGORIES = [ | |
"操作 (Administration)", "保存/攜帶 (Storage & Handling)", "副作用/異常 (Side Effects / Issues)", | |
"劑型相關 (Dosage Form Concerns)", "時間/併用 (Timing & Interaction)", "劑量調整 (Dosage Adjustment)", | |
"禁忌症/適應症 (Contraindications/Indications)" | |
] | |
DRUG_NAME_MAPPING = { | |
"fentanyl patch": "fentanyl", "spiriva respimat": "spiriva", "augmentin for syrup": "augmentin syrup", | |
"nitrostat": "nitroglycerin", "ozempic": "ozempic", "niflec": "niflec", | |
"fosamax": "fosamax", "humira": "humira", "premarin": "premarin", "smecta": "smecta", | |
} | |
DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務務必諮詢您的醫師或藥師。" | |
PROMPT_TEMPLATES = { | |
"analyze_query": """ | |
請分析以下使用者問題,並完成以下兩個任務: | |
1. 將問題分解為1-3個核心的子問題。 | |
2. 從清單中選擇所有相關的意圖分類。 | |
請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intents' (字串陣列) 兩個鍵。 | |
範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"]}} | |
意圖分類清單: | |
{options} | |
使用者問題:{query} | |
""", | |
"expand_query": """ | |
請根據以下意圖:{intents},擴展這個查詢,加入相關同義詞或術語。 | |
原始查詢:{query} | |
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。 | |
""", | |
"final_answer": """ | |
你是一位專業且謹慎的台灣藥師。請嚴格根據「參考資料」回答使用者問題,使用繁體中文。 | |
規則: | |
所有回答內容必須嚴格依據提供的參考資料,禁止任何形式的捏造或引用外部資訊。 | |
若資料不足以回答,請回覆:「根據提供的資料,無法回答您的問題。」 | |
針對原始查詢,以專業、友善的口吻,提供簡潔但資訊完整的中文繁體回答。 | |
回答字數限制在120字以內。 | |
排版格式: | |
使用條列式分行呈現,排版需適合LINE對話框顯示。 | |
回覆結尾必須加上指定提醒語句:「如有不適請立即就醫。」 | |
{additional_instruction} | |
--- | |
參考資料: | |
{context} | |
--- | |
使用者問題:{query} | |
請直接輸出最終的答案: | |
""" | |
} | |
# ---------- 日誌設定 ---------- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
log = logging.getLogger(__name__) | |
# [新增] 統一字串正規化函式 | |
def _norm(s: str) -> str: | |
"""統一化字串:NFKC 正規化、轉小寫、移除標點符號與空白。""" | |
s = unicodedata.normalize("NFKC", s) | |
return re.sub(r"[^\w\s]", "", s.lower()).strip() | |
class FusedCandidate: | |
idx: int | |
fused_score: float | |
sem_score: float | |
bm_score: float | |
class RerankResult: | |
idx: int | |
rerank_score: float | |
text: str | |
meta: Dict[str, Any] = field(default_factory=dict) | |
# ---------- 核心 RAG 邏輯 ---------- | |
class RagPipeline: | |
def __init__(self): | |
# [MODIFIED] 不再傳入 AppConfig,直接引用 | |
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]: | |
raise ValueError("LLM API Key or Base URL is not configured.") | |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"]) | |
# [FIXED] 新增 model_name 屬性 | |
self.model_name = LLM_API_CONFIG["model"] | |
self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding") | |
self.reranker = self._load_model(CrossEncoder, RERANKER_MODEL, "reranker") | |
self.drug_name_to_ids: Dict[str, List[str]] = {} | |
self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()} | |
self.state = type('state', (), {})() | |
def _load_model(self, model_class, model_name: str, model_type: str): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
log.info(f"載入 {model_type} 模型:{model_name} 至 {device}...") | |
try: | |
return model_class(model_name, device=device) | |
except Exception as e: | |
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。") | |
try: | |
return model_class(model_name, device="cpu") | |
except Exception as e_cpu: | |
log.error(f"切換至 CPU 仍無法載入模型: {model_name}。請確認模型路徑或網路連線。錯誤訊息: {e_cpu}") | |
raise RuntimeError(f"模型載入失敗: {model_name}") | |
def load_data(self): | |
log.info("開始載入資料與模型...") | |
# [MODIFIED] 增加檔案存在性檢查 | |
for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]: | |
if not pathlib.Path(path).exists(): | |
raise FileNotFoundError(f"必要的資料檔案不存在: {path}") | |
try: | |
self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('') | |
# [MODIFIED] 增加必要欄位檢查 | |
for col in ("drug_name_norm", "drug_id"): | |
if col not in self.df_csv.columns: | |
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}") | |
# [MODIFIED] 新增更強大的藥名詞典建立邏輯 | |
self.drug_name_to_ids = self._build_drug_name_to_ids() | |
self._load_drug_name_vocabulary() | |
log.info("載入 FAISS 索引與句子資料...") | |
self.state.index = faiss.read_index(FAISS_INDEX) | |
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2) | |
if hasattr(self.state.index, "nprobe"): | |
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16")) | |
# [新增] 檢查 FAISS 指標類型,若為 IP 則提示 | |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: | |
log.info("FAISS 索引使用內積 (IP) 指標,檢索時將自動進行 L2 正規化以實現餘弦相似度。") | |
with open(SENTENCES_PKL, "rb") as f: | |
data = pickle.load(f) | |
self.state.sentences = data["sentences"] | |
self.state.meta = data["meta"] | |
log.info("載入 BM25 索引...") | |
with open(BM25_PKL, "rb") as f: | |
# 載入整個字典,然後取 'bm25' 這個鍵 | |
bm25_data = pickle.load(f) | |
self.state.bm25 = bm25_data["bm25"] | |
if not isinstance(self.state.bm25, BM25Okapi): | |
raise ValueError("Loaded BM25 is not a BM25Okapi instance.") | |
except (FileNotFoundError, KeyError) as e: | |
log.exception(f"資料或索引檔案載入失敗: {e}") | |
raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}") | |
log.info("所有模型與資料載入完成。") | |
def _find_drug_ids_from_name(self, query: str) -> List[str]: | |
# [MODIFIED] 新增更強大的藥名詞典建立邏輯 | |
q_norm_parts = set(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(query))) | |
drug_ids = set() | |
for part in q_norm_parts: | |
if part in self.drug_name_to_ids: | |
drug_ids.update(self.drug_name_to_ids[part]) | |
return sorted(list(drug_ids)) | |
def _build_drug_name_to_ids(self) -> Dict[str, List[str]]: | |
mapping = {} | |
for _, row in self.df_csv.iterrows(): | |
drug_id = row['drug_id'] | |
# 使用 jieba 將中文藥名切分,並將英文名拆分 | |
zh_parts = list(jieba.cut(row['drug_name_zh'])) | |
en_parts = re.findall(r'[a-zA-Z0-9]+', row['drug_name_en'].lower() if row['drug_name_en'] else '') | |
# 統一使用 _norm 函數處理,以確保與查詢的處理方式一致 | |
norm_parts = re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(row['drug_name_norm'])) | |
all_parts = set(zh_parts + en_parts + norm_parts) | |
for part in all_parts: | |
part = part.strip() | |
if part and len(part) > 1: | |
mapping.setdefault(part, []).append(drug_id) | |
# 將 DRUG_NAME_MAPPING 中的別名也加入 | |
for alias, canonical_name in DRUG_NAME_MAPPING.items(): | |
if _norm(canonical_name) in _norm(row['drug_name_norm']): | |
mapping.setdefault(_norm(alias), []).append(drug_id) | |
for key in mapping: | |
mapping[key] = sorted(list(set(mapping[key]))) | |
return mapping | |
def _load_drug_name_vocabulary(self): | |
log.info("建立藥名詞庫...") | |
for _, row in self.df_csv.iterrows(): | |
norm_name = row['drug_name_norm'] | |
words = list(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', norm_name)) | |
for word in words: | |
if re.search(r'[\u4e00-\u9fff]', word): | |
self.drug_vocab["zh"].add(word) | |
else: | |
self.drug_vocab["en"].add(word) | |
for alias in DRUG_NAME_MAPPING: | |
if re.search(r'[\u4e00-\u9fff]', alias): | |
self.drug_vocab["zh"].add(alias) | |
else: | |
self.drug_vocab["en"].add(alias) | |
for word in self.drug_vocab["zh"]: | |
try: | |
if word not in jieba.dt.FREQ: | |
jieba.add_word(word, freq=2_000_000) | |
except Exception: | |
pass | |
def _llm_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None) -> str: | |
"""安全地呼叫 LLM API,並處理可能的回應內容為空錯誤。""" | |
log.info(f"LLM 呼叫開始. 模型: {self.model_name}, max_tokens: {max_tokens}, temperature: {temperature}") | |
log.info(f"送出的 LLM 提示 (messages): {json.dumps(messages, ensure_ascii=False, indent=2)}") | |
start_time = time.time() | |
try: | |
response = self.llm_client.chat.completions.create( | |
model=self.model_name, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
) | |
end_time = time.time() | |
log.info(f"LLM 收到完整回應: {response.model_dump_json(indent=2)}") | |
# --- 修正處:當回傳內容為空時,直接回傳空字串,而非拋出 ValueError --- | |
if not response.choices or not response.choices[0].message.content: | |
log.warning("LLM 呼叫成功 (200 OK),但回傳內容為空。將回傳空字串。") | |
return "" | |
# --- 修正結束 --- | |
content = response.choices[0].message.content | |
log.info(f"LLM 呼叫完成,耗時: {end_time - start_time:.2f} 秒。內容長度: {len(content)} 字。") | |
return content | |
except Exception as e: | |
log.error(f"LLM API 呼叫失敗: {e}") | |
raise | |
def answer_question(self, q_orig: str) -> str: | |
start_time = time.time() | |
log.info(f"===== 處理新查詢: '{q_orig}' =====") | |
try: | |
drug_ids = self._find_drug_ids_from_name(q_orig) | |
if not drug_ids: | |
log.info("未從查詢中找到相關藥名,直接返回預設訊息。") | |
return f"未從查詢中找到相關藥名,無法回答您的問題。\n{DISCLAIMER}" | |
log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒") | |
step_start = time.time() | |
analysis = self._analyze_query(q_orig) | |
sub_queries, intents = analysis.get("sub_queries", [q_orig]), analysis.get("intents", []) | |
is_simple_query = self._is_simple_query(sub_queries, intents) | |
log.info(f"步驟 2/5: 意圖分析完成。子問題: {sub_queries}, 意圖: {intents}。判定為簡單查詢: {is_simple_query}。耗時: {time.time() - step_start:.2f} 秒") | |
step_start = time.time() | |
all_candidates = self._retrieve_candidates_for_all_queries(drug_ids, sub_queries, intents) | |
log.info(f"步驟 3/5: 檢索完成。所有子查詢共找到 {len(all_candidates)} 個不重複候選 chunks。耗時: {time.time() - step_start:.2f} 秒") | |
step_start = time.time() | |
if is_simple_query: | |
log.info("偵測到簡單查詢,跳過 Reranker 步驟。") | |
final_candidates = all_candidates[:TOP_K_SENTENCES] | |
reranked_results = [ | |
RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx]) | |
for c in final_candidates | |
] | |
else: | |
log.info("偵測到複雜查詢,執行 Reranker。") | |
reranked_results = self._rerank_with_crossencoder(q_orig, all_candidates) | |
log.info(f"步驟 4/5: 最終選出 {len(reranked_results)} 個高品質候選。耗時: {time.time() - step_start:.2f} 秒") | |
step_start = time.time() | |
# [新增] 根據意圖,將內容進行排序優化 | |
prioritized_results = self._prioritize_context(reranked_results, intents) | |
context = self._build_context(prioritized_results) | |
if not context: | |
log.info("沒有足夠的上下文來回答問題。") | |
return f"根據提供的資料,無法回答您的問題。{DISCLAIMER}" | |
prompt = self._make_final_prompt(q_orig, context, intents) | |
answer = self._llm_call([{"role": "user", "content": prompt}]) | |
# --- 新增處理:如果 LLM 回傳空字串,則回傳預設訊息 --- | |
if not answer: | |
log.warning("LLM 回傳的答案為空,將使用預設回覆。") | |
return f"根據提供的資料,無法回答您的問題。{DISCLAIMER}" | |
# --- 處理結束 --- | |
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}" | |
log.info(f"步驟 5/5: 答案生成完成。答案長度: {len(answer.strip())} 字。耗時: {time.time() - step_start:.2f} 秒") | |
log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====") | |
return final_answer | |
except Exception as e: | |
log.error(f"處理查詢 '{q_orig}' 時發生嚴重錯誤: {e}", exc_info=True) | |
return f"處理您的問題時發生內部錯誤,請稍後再試。{DISCLAIMER}" | |
def _is_simple_query(self, sub_queries: List[str], intents: List[str]) -> bool: | |
# 如果意圖分析回傳的子查詢數量 <= 1,且意圖分類數量也 <= 1,則判定為簡單問題 | |
return len(sub_queries) <= 1 and len(intents) <= 1 | |
def _analyze_query(self, query: str) -> Dict[str, Any]: | |
prompt = PROMPT_TEMPLATES["analyze_query"].format( | |
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES), | |
query=query | |
) | |
response_str = self._llm_call([{"role": "user", "content": prompt}], temperature=0) | |
return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []}) | |
def _retrieve_candidates_for_all_queries(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]: | |
drug_ids_set = set(map(str, drug_ids)) | |
if drug_ids_set: | |
relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id", "")) in drug_ids_set} | |
else: | |
relevant_indices = set(range(len(self.state.meta))) | |
if not relevant_indices: return [] | |
all_fused_candidates: Dict[int, FusedCandidate] = {} | |
for sub_q in sub_queries: | |
expanded_q = self._expand_query_with_llm(sub_q, tuple(intents)) | |
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32") | |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: | |
faiss.normalize_L2(q_emb) | |
distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K) | |
tokenized_query = list(jieba.cut(expanded_q)) | |
bm25_scores = self.state.bm25.get_scores(tokenized_query) | |
rel_idx = np.fromiter(relevant_indices, dtype=int) | |
rel_scores = bm25_scores[rel_idx] | |
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]] | |
doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_rel} | |
candidate_scores: Dict[int, Dict[str, float]] = {} | |
def to_similarity(d: float) -> float: | |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT: | |
return float(d) | |
else: | |
return 1.0 / (1.0 + float(d)) | |
for i, dist in zip(sim_indices[0], distances[0]): | |
if i in relevant_indices: | |
similarity = to_similarity(dist) | |
candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0} | |
for i, score in doc_to_bm25_score.items(): | |
if i in relevant_indices: | |
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score | |
if not candidate_scores: continue | |
keys = list(candidate_scores.keys()) | |
sem_scores = np.array([candidate_scores[k]['sem'] for k in keys]) | |
bm_scores = np.array([candidate_scores[k]['bm'] for k in keys]) | |
def norm(x): | |
rng = x.max() - x.min() | |
return (x - x.min()) / (rng + 1e-8) if rng > 0 else np.zeros_like(x) | |
sem_n, bm_n = norm(sem_scores), norm(bm_scores) | |
for idx, k in enumerate(keys): | |
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4 | |
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score: | |
all_fused_candidates[k] = FusedCandidate( | |
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx] | |
) | |
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True) | |
def _expand_query_with_llm(self, query: str, intents: tuple) -> str: | |
if not intents: | |
return query | |
prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query) | |
try: | |
expanded_query = self._llm_call([{"role": "user", "content": prompt}]) | |
if expanded_query and expanded_query.strip(): | |
log.info(f"查詢擴展成功。原始: '{query}', 擴展後: '{expanded_query}'") | |
return expanded_query | |
else: | |
log.warning(f"查詢擴展回傳空內容。原始查詢: '{query}'。將使用原始查詢。") | |
return query | |
except Exception as e: | |
log.error(f"查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。") | |
return query | |
def _rerank_with_crossencoder(self, query: str, candidates: List[FusedCandidate]) -> List[RerankResult]: | |
if not candidates: return [] | |
top_candidates = candidates[:MAX_RERANK_CANDIDATES] | |
pairs = [(query, self.state.sentences[c.idx]) for c in top_candidates] | |
scores = self.reranker.predict(pairs, show_progress_bar=False) | |
results = [ | |
RerankResult(idx=c.idx, rerank_score=float(score), text=self.state.sentences[c.idx], meta=self.state.meta[c.idx]) | |
for c, score in zip(top_candidates, scores) | |
] | |
return sorted(results, key=lambda x: x.rerank_score, reverse=True)[:TOP_K_SENTENCES] | |
def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]: | |
if "副作用/異常 (Side Effects / Issues)" not in intents: | |
return results | |
warnings_and_notes = [res for res in results if res.meta.get("section", "").startswith("警語與注意事項")] | |
adverse_reactions = [res for res in results if res.meta.get("section", "").startswith("不良反應")] | |
other_results = [res for res in results if res not in warnings_and_notes and res not in adverse_reactions] | |
prioritized = warnings_and_notes + other_results + adverse_reactions | |
return prioritized | |
def _build_context(self, reranked_results: List[RerankResult]) -> str: | |
context = "" | |
for res in reranked_results: | |
if len(context) + len(res.text) > LLM_MODEL_CONFIG["max_context_chars"]: break | |
context += res.text + "\n\n" | |
return context.strip() | |
def _make_final_prompt(self, query: str, context: str, intents: List[str]) -> str: | |
add_instr = "" | |
if any(i in intents for i in ["劑量調整 (Dosage Adjustment)", "時間/併用 (Timing & Interaction)"]): | |
add_instr = "在回答用藥劑量和時間時,務必提醒使用者,醫師開立的藥袋醫囑優先於仿單的一般建議。" | |
if "保存/攜帶 (Storage & Handling)" in intents: | |
add_instr += "在回答保存與攜帶問題時,除了仿單內容,請根據常識加入實際情境的提醒,例如提醒需冷藏藥品要用保冷袋攜帶。" | |
return PROMPT_TEMPLATES["final_answer"].format( | |
additional_instruction=add_instr, context=context, query=query | |
) | |
def _safe_json_parse(self, s: str, default: Any = None) -> Any: | |
try: | |
return json.loads(s) | |
except json.JSONDecodeError: | |
log.warning(f"無法解析完整 JSON。嘗試從字串中提取: {s[:200]}...") | |
m = re.search(r'\{.*?\}', s, re.DOTALL) | |
if m: | |
try: | |
return json.loads(m.group(0)) | |
except json.JSONDecodeError: | |
log.warning(f"提取的 JSON 仍無法解析: {m.group(0)[:100]}...") | |
return default | |
# ---------- FastAPI 事件與路由 ---------- | |
class AppConfig: | |
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN") | |
CHANNEL_SECRET = _require_env("CHANNEL_SECRET") | |
rag_pipeline: Optional[RagPipeline] = None | |
async def lifespan(app: FastAPI): | |
_require_llm_config() | |
global rag_pipeline | |
rag_pipeline = RagPipeline() | |
rag_pipeline.load_data() | |
log.info("啟動完成,服務準備就緒。") | |
yield | |
log.info("服務關閉中。") | |
app = FastAPI(lifespan=lifespan) | |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks): | |
signature = request.headers.get("X-Line-Signature") | |
if not signature: | |
raise HTTPException(status_code=400, detail="Missing X-Line-Signature") | |
if not AppConfig.CHANNEL_SECRET: | |
log.error("CHANNEL_SECRET is not configured.") | |
raise HTTPException(status_code=500, detail="Server configuration error") | |
body = await request.body() | |
try: | |
hash = hmac.new(AppConfig.CHANNEL_SECRET.encode('utf-8'), body, hashlib.sha256) | |
expected_signature = base64.b64encode(hash.digest()).decode('utf-8') | |
except Exception as e: | |
log.error(f"Failed to generate signature: {e}") | |
raise HTTPException(status_code=500, detail="Signature generation error") | |
if not hmac.compare_digest(expected_signature, signature): | |
raise HTTPException(status_code=403, detail="Invalid signature") | |
try: | |
data = json.loads(body.decode('utf-8')) | |
except json.JSONDecodeError: | |
raise HTTPException(status_code=400, detail="Invalid JSON body") | |
for event in data.get("events", []): | |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text": | |
reply_token = event.get("replyToken") | |
user_text = event.get("message", {}).get("text", "").strip() | |
source = event.get("source", {}) | |
stype = source.get("type") | |
target_id = source.get("userId") or source.get("groupId") or source.get("roomId") | |
if reply_token and user_text and target_id: | |
line_reply(reply_token, "收到您的問題,正在查詢資料庫,請稍候...") | |
background_tasks.add_task(process_user_query, stype, target_id, user_text) | |
return Response(status_code=status.HTTP_200_OK) | |
def process_user_query(source_type: str, target_id: str, user_text: str): | |
try: | |
if rag_pipeline: | |
answer = rag_pipeline.answer_question(user_text) | |
else: | |
answer = "系統正在啟動中,請稍後再試。" | |
line_push_generic(source_type, target_id, answer) | |
except Exception as e: | |
log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True) | |
line_push_generic(source_type, target_id, f"抱歉,處理時發生未預期的錯誤。{DISCLAIMER}") | |
def line_api_call(endpoint: str, data: Dict): | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}" | |
} | |
try: | |
response = requests.post(f"https://api.line.me/v2/bot/message/{endpoint}", headers=headers, json=data, timeout=10) | |
response.raise_for_status() | |
except requests.exceptions.RequestException as e: | |
log.error(f"LINE API ({endpoint}) 呼叫失敗: {e} | Response: {e.response.text if e.response else 'N/A'}") | |
raise | |
def line_reply(reply_token: str, text: str): | |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]] | |
line_api_call("reply", {"replyToken": reply_token, "messages": messages}) | |
def line_push_generic(source_type: str, target_id: str, text: str): | |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]] | |
endpoint = "push" | |
data = {"to": target_id, "messages": messages} | |
line_api_call(endpoint, data) | |
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list: | |
candidates = set() | |
q_norm = _norm(query) | |
for word in re.findall(r"[a-z0-9]+", q_norm): | |
if word in drug_vocab["en"]: | |
candidates.add(word) | |
for token in jieba.cut(q_norm): | |
if token in drug_vocab["zh"]: | |
candidates.add(token) | |
return list(candidates) | |
# ---------- 執行 ---------- | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |