Spaces:
Running
Running
File size: 41,706 Bytes
a49fa30 47bec05 a49fa30 47bec05 c66550a aa4568e c40d5cc 9e85da1 c66550a 4cc218a 4380328 7b2e5cd 8eeadb9 a49fa30 8eeadb9 a49fa30 47bec05 7b2e5cd a7a09dc 47bec05 a49fa30 a7a09dc 1741c30 bcf5657 47bec05 a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 7b2e5cd a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 7b2e5cd c66550a a49fa30 a7a09dc a49fa30 4fcb155 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 c66550a a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 bcf5657 a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 92ee3c2 a49fa30 a7a09dc a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 7b2e5cd 92ee3c2 a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 92ee3c2 a7a09dc a49fa30 92ee3c2 a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 7b2e5cd a49fa30 92ee3c2 a49fa30 92ee3c2 a49fa30 92ee3c2 a7a09dc a49fa30 92ee3c2 a49fa30 92ee3c2 a49fa30 92ee3c2 a49fa30 a7a09dc a49fa30 92ee3c2 a49fa30 7b2e5cd a49fa30 92ee3c2 a49fa30 a7a09dc a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 92ee3c2 a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 92ee3c2 a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 92ee3c2 a49fa30 a7a09dc a49fa30 92ee3c2 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc 92ee3c2 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc 7b2e5cd a7a09dc a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 7b2e5cd a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 7b2e5cd a7a09dc a49fa30 7b2e5cd a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc 4fcb155 a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 a7a09dc a49fa30 92ee3c2 7b2e5cd a7a09dc 92ee3c2 8eeadb9 47bec05 a49fa30 9e85da1 a49fa30 7b2e5cd 92ee3c2 a49fa30 92ee3c2 a49fa30 7b2e5cd a7a09dc 8eeadb9 a49fa30 fe9e7d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版)
整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。
此版本專注於效能、可維護性、健壯性與使用者體驗。
"""
# ---------- 環境與快取設定 (應置於最前) ----------
import os
import pathlib
os.environ.setdefault("HF_HOME", "/tmp/hf")
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers")
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache")
for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")):
pathlib.Path(d).mkdir(parents=True, exist_ok=True)
# ---------- Python 標準函式庫 ----------
import re
import hmac
import base64
import hashlib
import pickle
import logging
import json
import textwrap
import time
import tenacity
from typing import List, Dict, Any, Optional, Tuple, Union
from functools import lru_cache
from dataclasses import dataclass, field
from contextlib import asynccontextmanager
import unicodedata
from collections import defaultdict
# ---------- 第三方函式庫 ----------
import numpy as np
import pandas as pd
from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
import uvicorn
import jieba
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import faiss
import torch
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_fixed
import requests
from transformers import pipeline
# [MODIFIED] 限制 PyTorch 執行緒數量,避免在 CPU 環境下過度佔用資源
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1")))
# ==== CONFIG (從環境變數載入,或使用預設值) ====
# [MODIFIED] 新增環境變數健檢函式
def _require_env(var: str) -> str:
v = os.getenv(var)
if not v:
raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
return v
# [MODIFIED] 檢查 LLM 相關環境變數
def _require_llm_config():
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
_require_env(k)
# MedGemma 模型直接硬編碼,如果需要替換,可以在這裡加入檢查
# _require_env("MEDGEMMA_MODEL_NAME") # 如果 MEDGEMMA_MODEL_NAME 是環境變數
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20))
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
LLM_API_CONFIG = {
"base_url": os.getenv("LITELLM_BASE_URL"),
"api_key": os.getenv("LITELLM_API_KEY"),
"model": os.getenv("LM_MODEL")
}
MEDGEMMA_MODEL_NAME = "google/medgemma-4b-it" # 硬編碼 MedGemma 模型名稱
LLM_MODEL_CONFIG = {
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)),
"max_tokens": int(os.getenv("MAX_TOKENS", 1024)),
"temperature": float(os.getenv("TEMPERATURE", 0.0)),
"seed": int(os.getenv("LLM_SEED", 42)), # 新增 seed 以確保重現性
}
INTENT_CATEGORIES = [
"操作 (Administration)", "保存/攜帶 (Storage & Handling)", "副作用/異常 (Side Effects / Issues)",
"劑型相關 (Dosage Form Concerns)", "時間/併用 (Timing & Interaction)", "劑量調整 (Dosage Adjustment)",
"禁忌症/適應症 (Contraindications/Indications)"
]
# [新增] 意圖分類 → CSV section 對照表
INTENT_TO_SECTION = {
"操作 (Administration)": ["用法用量", "病人使用須知"],
"保存/攜帶 (Storage & Handling)": ["包裝及儲存"],
"副作用/異常 (Side Effects / Issues)": ["不良反應", "警語與注意事項"],
"劑型相關 (Dosage Form Concerns)": ["劑型", "藥品外觀"],
"時間/併用 (Timing & Interaction)": ["用法用量"],
"劑量調整 (Dosage Adjustment)": ["用法用量"],
"禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"]
}
# 新增 REFERENCE_MAPPING
REFERENCE_MAPPING = {
"如何用藥?": "病人使用須知、用法用量",
"如何保存與攜帶?": "包裝及儲存",
"可能的副作用?": "警語與注意事項、不良反應",
"每次劑量多少?": "用法用量、藥袋上的醫囑",
"用藥時間?": "用法用量、藥袋上的醫囑",
}
# 新增反向映射,從 sections 找到 intents
SECTION_TO_INTENT = defaultdict(list)
for intent, sections in INTENT_TO_SECTION.items():
for section in sections:
SECTION_TO_INTENT[section].append(intent)
DRUG_NAME_MAPPING = {
"fentanyl patch": "fentanyl", "spiriva respimat": "spiriva", "augmentin for syrup": "augmentin syrup",
"nitrostat": "nitroglycerin", "ozempic": "ozempic", "niflec": "niflec",
"fosamax": "fosamax", "humira": "humira", "premarin": "premarin", "smecta": "smecta",
}
DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務務必諮詢您的醫師或藥師。"
PROMPT_TEMPLATES = {
"analyze_query": """
請分析以下使用者問題,並完成以下兩個任務:
1. 將問題分解為1-3個核心的子問題。
2. 從清單中選擇所有相關的意圖分類。
請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intents' (字串陣列) 兩個鍵。
範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"]}}
意圖分類清單:
{options}
使用者問題:{query}
""",
"expand_query": """
請根據以下意圖:{intents},擴展這個查詢,加入相關同義詞或術語。
原始查詢:{query}
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。
""",
"final_answer_concise": """
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆:
一、 回覆準則
嚴格依據資料: 所有回覆內容都必須完全來自提供的參考資料,禁止任何形式的捏造或引用外部資訊。
資料不足: 若參考資料無法回答使用者的問題,請直接回覆:「根據提供的資料,無法回答您的問題。」
專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。
精簡扼要: 內容需極其簡潔,資訊完整,不要使用*符號,字數請嚴格控制在60字以內。
二、 排版規範
條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。
結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」
範例:
使用者問題: 請問普拿疼可以怎麼吃?
參考資料: 普拿疼成人建議劑量為一次1至2錠,每4至6小時服用一次,每日不超過8錠。
AI回覆範例:
普拿疼成人劑量:
1-2錠/次
每4-6小時
每日≤8錠
如有不適請立即就醫。
{additional_instruction}
---
參考資料:
{context}
---
使用者問題:{query}
請直接輸出最終的答案:
""",
"final_answer_detailed": """
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆:
一、 回覆準則
嚴格依據資料: 所有回覆內容都必須完全來自提供的參考資料,禁止任何形式的捏造或引用外部資訊。
資料不足: 若參考資料無法回答使用者的問題,請直接回覆:「根據提供的資料,無法回答您的問題。」
專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。
精簡扼要: 內容需簡潔但資訊完整,不要使用*符號,字數請控制在200字以內,提供更多細節解釋。
二、 排版規範
條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。
結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」
範例:
使用者問題: 請問普拿疼可以怎麼吃?
參考資料: 普拿疼成人建議劑量為一次1至2錠,每4至6小時服用一次,每日不超過8錠。
AI回覆範例:
普拿疼成人建議劑量為:
- 一次服用1至2錠,視疼痛程度調整。
- 每4至6小時服用一次,避免過頻。
- 每日總量不超過8錠,以防副作用。
如有不適請立即就醫。
{additional_instruction}
---
參考資料:
{context}
---
使用者問題:{query}
請直接輸出最終的答案:
""",
"direct_answer_concise": """
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,直接基於您的知識給予回覆:
一、 回覆準則
專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。
精簡扼要: 內容需極其簡潔,資訊完整,不要使用*符號,字數請嚴格控制在60字以內。
二、 排版規範
條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。
結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」
範例:
使用者問題: 請問普拿疼可以怎麼吃?
AI回覆範例:
普拿疼成人劑量:
1-2錠/次
每4-6小時
每日≤8錠
如有不適請立即就醫。
{additional_instruction}
使用者問題:{query}
請直接輸出最終的答案:
""",
"direct_answer_detailed": """
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,直接基於您的知識給予回覆:
一、 回覆準則
專業且友善: 回覆需使用繁體中文,語氣專業、友善且親切,如同面對面衛教。
精簡扼要: 內容需簡潔但資訊完整,不要使用*符號,字數請控制在200字以內,提供更多細節解釋。
二、 排版規範
條列式呈現: 段落句點後要換行,確保排版在LINE對話框中清晰易讀。
結尾提醒: 所有回覆的最後,都必須加上這句指定的提醒語句:「如有不適請立即就醫。」
範例:
使用者問題: 請問普拿疼可以怎麼吃?
AI回覆範例:
普拿疼成人建議劑量為:
- 一次服用1至2錠,視疼痛程度調整。
- 每4至6小時服用一次,避免過頻。
- 每日總量不超過8錠,以防副作用。
如有不適請立即就醫。
{additional_instruction}
使用者問題:{query}
請直接輸出最終的答案:
"""
}
# ---------- 日誌設定 ----------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
# [新增] 統一字串正規化函式
def _norm(s: str) -> str:
"""統一化字串:NFKC 正規化、轉小寫、移除標點符號與空白。"""
s = unicodedata.normalize("NFKC", s)
return re.sub(r"[^\w\s]", "", s.lower()).strip()
@dataclass
class FusedCandidate:
idx: int
fused_score: float
sem_score: float
bm_score: float
@dataclass
class RerankResult:
idx: int
rerank_score: float
text: str
meta: Dict[str, Any] = field(default_factory=dict)
# ---------- 核心 RAG 邏輯 ----------
class RagPipeline:
def __init__(self):
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
raise ValueError("LLM API Key or Base URL is not configured.")
# OpenAI client for LITELLM
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
self.litellm_model_name = LLM_API_CONFIG["model"]
# MedGemma pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info(f"載入 MedGemma 模型: {MEDGEMMA_MODEL_NAME} 至 {device}...")
try:
self.medgemma_pipe = pipeline(
"text-generation",
model=MEDGEMMA_MODEL_NAME,
torch_dtype=torch.bfloat16,
device=device,
)
log.info("MedGemma 模型載入成功。")
except Exception as e:
log.error(f"載入 MedGemma 模型失敗: {e}")
raise RuntimeError(f"MedGemma 模型載入失敗: {MEDGEMMA_MODEL_NAME}")
self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding")
self.drug_name_to_ids: Dict[str, List[str]] = {}
self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()}
self.state = type('state', (), {})()
# [新增] 澄清問題的 Prompt Template
self.CLARIFICATION_PROMPT = """
請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精確地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境等。
範例:
使用者問題:這個藥會怎麼樣?
澄清提問:您好,請問您指的是哪一種藥物呢?
使用者問題:請問這要吃多久?
澄清提問:請問您是想了解該藥品的建議療程長度嗎?
使用者問題:{query}
澄清提問:"""
def _load_model(self, model_class, model_name: str, model_type: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info(f"載入 {model_type} 模型:{model_name} 至 {device}...")
try:
return model_class(model_name, device=device)
except Exception as e:
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
try:
return model_class(model_name, device="cpu")
except Exception as e_cpu:
log.error(f"切換至 CPU 仍無法載入模型: {model_name}。請確認模型路徑或網路連線。錯誤訊息: {e_cpu}")
raise RuntimeError(f"模型載入失敗: {model_name}")
def load_data(self):
log.info("開始載入資料與模型...")
for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]:
if not pathlib.Path(path).exists():
raise FileNotFoundError(f"必要的資料檔案不存在: {path}")
try:
self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('')
for col in ("drug_name_norm", "drug_id"):
if col not in self.df_csv.columns:
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
self.drug_name_to_ids = self._build_drug_name_to_ids()
self._load_drug_name_vocabulary()
log.info("載入 FAISS 索引與句子資料...")
self.state.index = faiss.read_index(FAISS_INDEX)
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
if hasattr(self.state.index, "nprobe"):
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
log.info("FAISS 索引使用內積 (IP) 指標,檢索時將自動進行 L2 正規化以實現餘弦相似度。")
with open(SENTENCES_PKL, "rb") as f:
data = pickle.load(f)
self.state.sentences = data["sentences"]
self.state.meta = data["meta"]
log.info("載入 BM25 索引...")
with open(BM25_PKL, "rb") as f:
bm25_data = pickle.load(f)
self.state.bm25 = bm25_data["bm25"]
if not isinstance(self.state.bm25, BM25Okapi):
raise ValueError("Loaded BM25 is not a BM25Okapi instance.")
except (FileNotFoundError, KeyError) as e:
log.exception(f"資料或索引檔案載入失敗: {e}")
raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}")
log.info("所有模型與資料載入完成。")
# [新增] 將 drug_id 轉換為使用者友善的 drug_name_norm
@lru_cache(maxsize=128)
def _get_drug_name_by_id(self, drug_id: str) -> Optional[str]:
"""從 drug_id 查找對應的 drug_name_norm。"""
row = self.df_csv[self.df_csv['drug_id'] == drug_id]
if not row.empty:
return row.iloc[0]['drug_name_norm']
return None
def _find_drug_ids_from_name(self, query: str) -> List[str]:
q_norm_parts = set(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(query)))
drug_ids = set()
for part in q_norm_parts:
if part in self.drug_name_to_ids:
drug_ids.update(self.drug_name_to_ids[part])
# [新增] 新增 drug_name_norm 至 drug_id 的反向查找,以支持上下文處理
for drug_name, ids in self.drug_name_to_ids.items():
if drug_name in _norm(query):
drug_ids.update(ids)
return sorted(list(drug_ids))
def _build_drug_name_to_ids(self) -> Dict[str, List[str]]:
mapping = {}
for _, row in self.df_csv.iterrows():
drug_id = row['drug_id']
zh_parts = list(jieba.cut(row['drug_name_zh']))
en_parts = re.findall(r'[a-zA-Z0-9]+', row['drug_name_en'].lower() if row['drug_name_en'] else '')
norm_parts = re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', _norm(row['drug_name_norm']))
all_parts = set(zh_parts + en_parts + norm_parts)
for part in all_parts:
part = part.strip()
if part and len(part) > 1:
mapping.setdefault(part, []).append(drug_id)
for alias, canonical_name in DRUG_NAME_MAPPING.items():
if _norm(canonical_name) in _norm(row['drug_name_norm']):
mapping.setdefault(_norm(alias), []).append(drug_id)
for key in mapping:
mapping[key] = sorted(list(set(mapping[key])))
return mapping
def _load_drug_name_vocabulary(self):
log.info("建立藥名詞庫...")
for _, row in self.df_csv.iterrows():
norm_name = row['drug_name_norm']
words = list(re.findall(r'[a-z0-9]+|[\u4e00-\u9fff]+', norm_name))
for word in words:
if re.search(r'[\u4e00-\u9fff]', word):
self.drug_vocab["zh"].add(word)
else:
self.drug_vocab["en"].add(word)
for alias in DRUG_NAME_MAPPING:
if re.search(r'[\u4e00-\u9fff]', alias):
self.drug_vocab["zh"].add(alias)
else:
self.drug_vocab["en"].add(alias)
for word in self.drug_vocab["zh"]:
try:
if word not in jieba.dt.FREQ:
jieba.add_word(word, freq=2_000_000)
except Exception:
pass
# [MODIFIED] 兩個獨立的 LLM 調用函式,用於輸出比較
@tenacity.retry(
wait=tenacity.wait_fixed(2),
stop=tenacity.stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(ValueError),
before_sleep=tenacity.before_sleep_log(log, logging.WARNING),
after=tenacity.after_log(log, logging.INFO)
)
def _litellm_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None, seed: Optional[int] = None) -> str:
"""安全地呼叫 LITELLM API,並處理可能的回應內容為空錯誤。"""
log.info(f"LITELLM 呼叫開始. 模型: {self.litellm_model_name}, max_tokens: {max_tokens}, temperature: {temperature}, seed: {seed}")
start_time = time.time()
try:
response = self.llm_client.chat.completions.create(
model=self.litellm_model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
seed=seed,
)
end_time = time.time()
log.info(f"LITELLM 收到完整回應: {response.model_dump_json(indent=2)}")
if not response.choices or not response.choices[0].message.content:
log.warning("LITELLM 呼叫成功 (200 OK),但回傳內容為空。將回傳空字串。")
return ""
content = response.choices[0].message.content
log.info(f"LITELLM 呼叫完成,耗時: {end_time - start_time:.2f} 秒。內容長度: {len(content)} 字。")
return content
except Exception as e:
log.error(f"LITELLM API 呼叫失敗: {e}")
raise
@tenacity.retry(
wait=tenacity.wait_fixed(2),
stop=tenacity.stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(ValueError),
before_sleep=tenacity.before_sleep_log(log, logging.WARNING),
after=tenacity.after_log(log, logging.INFO)
)
def _medgemma_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None, seed: Optional[int] = None) -> str:
"""安全地呼叫 MedGemma 模型,並處理可能的回應內容為空錯誤。"""
log.info(f"MedGemma 呼叫開始. max_tokens: {max_tokens}, temperature: {temperature}, seed: {seed}")
start_time = time.time()
try:
# MedGemma pipeline requires a specific format for messages
converted_messages = []
for msg in messages:
role = msg["role"]
content = [{"type": "text", "text": msg["content"]}]
converted_messages.append({"role": role, "content": content})
output = self.medgemma_pipe(
converted_messages,
max_new_tokens=max_tokens or LLM_MODEL_CONFIG["max_tokens"],
temperature=temperature if temperature is not None else LLM_MODEL_CONFIG["temperature"],
# MedGemma pipeline 可能不直接支持seed,如果不支持,可移除或處理
)
end_time = time.time()
if not output or not output[0]["generated_text"] or not output[0]["generated_text"][-1]["content"]:
log.warning("MedGemma 呼叫成功,但回傳內容為空。將回傳空字串。")
return ""
content = output[0]["generated_text"][-1]["content"]
log.info(f"MedGemma 呼叫完成,耗時: {end_time - start_time:.2f} 秒。內容長度: {len(content)} 字。")
return content
except Exception as e:
log.error(f"MedGemma 呼叫失敗: {e}")
raise
# [MODIFIED] 修改 answer_question 函式以回傳四種答案,每種包括簡潔版和詳細版
def answer_question(self, q_orig: str) -> Tuple[Dict[str, Dict[str, str]], List[str]]:
start_time = time.time()
log.info(f"===== 處理新查詢: '{q_orig}' =====")
try:
drug_ids = self._find_drug_ids_from_name(q_orig)
if not drug_ids:
log.info("未從查詢中找到相關藥名,透過兩種 LLM 生成澄清性問題。")
clarification_litellm = self._generate_clarification_query_litellm(q_orig)
clarification_medgemma = self._generate_clarification_query_medgemma(q_orig)
log.info(f"澄清問題比較: LITELLM: '{clarification_litellm}', MedGemma: '{clarification_medgemma}'")
return {"clarification_litellm": clarification_litellm + f"\n{DISCLAIMER}", "clarification_medgemma": clarification_medgemma + f"\n{DISCLAIMER}"}, []
log.info(f"步驟 1/4: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒")
step_start = time.time()
# 分開處理兩種 LLM 的分析
analysis_litellm = self._analyze_query_litellm(q_orig)
analysis_medgemma = self._analyze_query_medgemma(q_orig)
sub_queries_litellm, intents_litellm = analysis_litellm.get("sub_queries", [q_orig]), analysis_litellm.get("intents", [])
sub_queries_medgemma, intents_medgemma = analysis_medgemma.get("sub_queries", [q_orig]), analysis_medgemma.get("intents", [])
log.info(f"意圖分析比較: LITELLM: {analysis_litellm}, MedGemma: {analysis_medgemma}")
if not intents_litellm and not intents_medgemma:
log.info("意圖分析失敗,透過兩種 LLM 生成澄清性問題。")
clarification_litellm = self._generate_clarification_query_litellm(q_orig)
clarification_medgemma = self._generate_clarification_query_medgemma(q_orig)
log.info(f"澄清問題比較: LITELLM: '{clarification_litellm}', MedGemma: '{clarification_medgemma}'")
return {"clarification_litellm": clarification_litellm + f"\n{DISCLAIMER}", "clarification_medgemma": clarification_medgemma + f"\n{DISCLAIMER}"}, drug_ids
log.info(f"步驟 2/4: 意圖分析完成。LITELLM 子問題: {sub_queries_litellm}, 意圖: {intents_litellm}。MedGemma 子問題: {sub_queries_medgemma}, 意圖: {intents_medgemma}。耗時: {time.time() - step_start:.2f} 秒")
step_start = time.time()
# 分開處理兩種 LLM 的檢索流程
all_candidates_litellm = self._retrieve_candidates_for_all_queries_litellm(drug_ids, sub_queries_litellm, intents_litellm)
all_candidates_medgemma = self._retrieve_candidates_for_all_queries_medgemma(drug_ids, sub_queries_medgemma, intents_medgemma)
log.info(f"步驟 3/4: 檢索完成。LITELLM 找到 {len(all_candidates_litellm)} 個, MedGemma 找到 {len(all_candidates_medgemma)} 個。耗時: {time.time() - step_start:.2f} 秒")
step_start = time.time()
# LITELLM RAG
final_candidates_litellm = all_candidates_litellm[:TOP_K_SENTENCES]
reranked_results_litellm = [
RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx])
for c in final_candidates_litellm
]
prioritized_results_litellm = self._prioritize_context(reranked_results_litellm, intents_litellm)
context_litellm = self._build_context(prioritized_results_litellm)
# MedGemma RAG
final_candidates_medgemma = all_candidates_medgemma[:TOP_K_SENTENCES]
reranked_results_medgemma = [
RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx])
for c in final_candidates_medgemma
]
prioritized_results_medgemma = self._prioritize_context(reranked_results_medgemma, intents_medgemma)
context_medgemma = self._build_context(prioritized_results_medgemma)
log.info(f"步驟 4/4: 最終選出 LITELLM {len(reranked_results_litellm)} 個, MedGemma {len(reranked_results_medgemma)} 個。耗時: {time.time() - step_start:.2f} 秒")
step_start = time.time()
if not context_litellm and not context_medgemma:
log.info("沒有足夠的上下文來回答問題。")
return {"error": f"根據提供的資料,無法回答您的問題。\n{DISCLAIMER}"}, drug_ids
# 通用參數:temperature=0.0, seed=42 以確保重現性
temp = LLM_MODEL_CONFIG["temperature"]
seed = LLM_MODEL_CONFIG["seed"]
max_tokens_concise = 256 # 簡潔版限制較小
max_tokens_detailed = LLM_MODEL_CONFIG["max_tokens"] # 詳細版使用預設
# 生成答案 - With RAG
answers_with_rag = self._generate_answers_with_rag(q_orig, context_litellm, intents_litellm, context_medgemma, intents_medgemma, temp, seed, max_tokens_concise, max_tokens_detailed)
# 生成答案 - Without RAG (直接用 query)
answers_without_rag = self._generate_answers_without_rag(q_orig, intents_litellm, intents_medgemma, temp, seed, max_tokens_concise, max_tokens_detailed)
log.info(f"答案生成完成。耗時: {time.time() - step_start:.2f} 秒")
log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====")
# 合併返回
return {
"LITELLM_with_RAG": answers_with_rag["LITELLM"],
"MedGemma_with_RAG": answers_with_rag["MedGemma"],
"LITELLM_without_RAG": answers_without_rag["LITELLM"],
"MedGemma_without_RAG": answers_without_rag["MedGemma"]
}, drug_ids
except Exception as e:
log.error(f"處理查詢 '{q_orig}' 時發生嚴重錯誤: {e}", exc_info=True)
return {"error": f"處理您的問題時發生內部錯誤,請稍後再試。\n{DISCLAIMER}"}, []
def _analyze_query_litellm(self, query: str) -> Dict[str, Any]:
return self._analyze_query_generic(query, self._litellm_call)
def _analyze_query_medgemma(self, query: str) -> Dict[str, Any]:
return self._analyze_query_generic(query, self._medgemma_call)
def _analyze_query_generic(self, query: str, llm_call) -> Dict[str, Any]:
# 先檢查 REFERENCE_MAPPING
norm_query = _norm(query)
for ref_key, ref_value in REFERENCE_MAPPING.items():
if _norm(ref_key) in norm_query:
sections = [s.strip() for s in ref_value.split(',')]
intents = []
for section in sections:
intents.extend(SECTION_TO_INTENT.get(section, []))
intents = list(set(intents)) # 去重
if intents:
log.info(f"匹配 REFERENCE_MAPPING: '{ref_key}' -> intents: {intents}")
return {"sub_queries": [query], "intents": intents}
# 如果不匹配,才進行 LLM 意圖偵測
prompt = PROMPT_TEMPLATES["analyze_query"].format(
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
query=query
)
response_str = llm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"])
return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []})
def _generate_clarification_query_litellm(self, query: str) -> str:
prompt = self.CLARIFICATION_PROMPT.format(query=query)
return self._litellm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]).strip()
def _generate_clarification_query_medgemma(self, query: str) -> str:
prompt = self.CLARIFICATION_PROMPT.format(query=query)
return self._medgemma_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"]).strip()
def _retrieve_candidates_for_all_queries_litellm(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]:
return self._retrieve_candidates_for_all_queries_generic(drug_ids, sub_queries, intents, self._expand_query_with_llm_litellm)
def _retrieve_candidates_for_all_queries_medgemma(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]:
return self._retrieve_candidates_for_all_queries_generic(drug_ids, sub_queries, intents, self._expand_query_with_llm_medgemma)
def _retrieve_candidates_for_all_queries_generic(self, drug_ids: List[str], sub_queries: List[str], intents: List[str], expand_func) -> List[FusedCandidate]:
drug_ids_set = set(map(str, drug_ids))
if drug_ids_set:
relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id", "")) in drug_ids_set}
else:
relevant_indices = set(range(len(self.state.meta)))
if not relevant_indices: return []
all_fused_candidates: Dict[int, FusedCandidate] = {}
for sub_q in sub_queries:
expanded_q = expand_func(sub_q, tuple(intents))
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
faiss.normalize_L2(q_emb)
distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K)
tokenized_query = list(jieba.cut(expanded_q))
bm25_scores = self.state.bm25.get_scores(tokenized_query)
rel_idx = np.fromiter(relevant_indices, dtype=int)
rel_scores = bm25_scores[rel_idx]
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_rel}
candidate_scores: Dict[int, Dict[str, float]] = {}
def to_similarity(d: float) -> float:
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
return float(d)
else:
return 1.0 / (1.0 + float(d))
for i, dist in zip(sim_indices[0], distances[0]):
if i in relevant_indices:
similarity = to_similarity(dist)
candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0}
for i, score in doc_to_bm25_score.items():
if i in relevant_indices:
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
if not candidate_scores: continue
keys = list(candidate_scores.keys())
sem_scores = np.array([candidate_scores[k]['sem'] for k in keys])
bm_scores = np.array([candidate_scores[k]['bm'] for k in keys])
def norm(x):
rng = x.max() - x.min()
return (x - x.min()) / (rng + 1e-8) if rng > 0 else np.zeros_like(x)
sem_n, bm_n = norm(sem_scores), norm(bm_scores)
for idx, k in enumerate(keys):
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
all_fused_candidates[k] = FusedCandidate(
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
)
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
def _expand_query_with_llm_litellm(self, query: str, intents: tuple) -> str:
if not intents:
return query
prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query)
try:
expanded = self._litellm_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"])
log.info(f"LITELLM 查詢擴展成功。原始: '{query}', 擴展後: '{expanded}'")
return expanded.strip() if expanded.strip() else query
except Exception as e:
log.error(f"LITELLM 查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。")
return query
def _expand_query_with_llm_medgemma(self, query: str, intents: tuple) -> str:
if not intents:
return query
prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query)
try:
expanded = self._medgemma_call([{"role": "user", "content": prompt}], temperature=LLM_MODEL_CONFIG["temperature"], seed=LLM_MODEL_CONFIG["seed"])
log.info(f"MedGemma 查詢擴展成功。原始: '{query}', 擴展後: '{expanded}'")
return expanded.strip() if expanded.strip() else query
except Exception as e:
log.error(f"MedGemma 查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。")
return query
def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]:
if not intents:
return results
prioritized_sections = set()
for intent in intents:
prioritized_sections.update(INTENT_TO_SECTION.get(intent, []))
if not prioritized_sections:
return results
log.info(f"根據意圖 '{intents}' 優先處理章節: {prioritized_sections}")
prioritized_results = []
other_results = []
for res in results:
section = res.meta.get("section", "")
if section in prioritized_sections:
prioritized_results.append(res)
else:
other_results.append(res)
final_prioritized_list = prioritized_results + other_results
return final_prioritized_list
def _build_context(self, reranked_results: List[RerankResult]) -> str:
context = ""
for res in reranked_results:
if len(context) + len(res.text) > LLM_MODEL_CONFIG["max_context_chars"]: break
context += res.text + "\n\n"
return context.strip()
def _safe_json_parse(self, s: str, default: Any = None) -> Any:
try:
return json.loads(s)
except json.JSONDecodeError:
log.warning(f"無法解析完整 JSON。嘗試從字串中提取: {s[:200]}...")
m = re.search(r'\{.*?\}', s, re.DOTALL)
if m:
try:
return json.loads(m.group(0))
except json.JSONDecodeError:
log.warning(f"提取的 JSON 仍無法解析: {m.group(0)[:100]}...")
return default
# ---------- FastAPI 事件與路由 ----------
class AppConfig:
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
rag_pipeline: Optional[RagPipeline] = None
# [MODIFIED] 全域使用者狀態快取, 儲存更詳細的資訊
USER_STATE_CACHE = defaultdict(lambda: {"last_query": None, "last_drug_ids": []})
@asynccontextmanager
async def lifespan(app: FastAPI):
_require_llm_config()
global rag_pipeline
rag_pipeline = RagPipeline()
rag_pipeline.load_data()
log.info("啟動完成,服務準備就緒。")
yield
log.info("服務關閉中。")
app = FastAPI(lifespan=lifespan)
@app.post("/webhook")
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
signature = request.headers.get("X-Line-Signature")
if not signature:
raise HTTPException(status_code=400, detail="Missing X-Line-Signature")
if not AppConfig.CHANNEL_SECRET:
log.error("CHANNEL_SECRET is not configured.")
raise HTTPException(status_code=500, detail="Server configuration error")
body = await request.body()
try:
hash = hmac.new(AppConfig.CHANNEL_SECRET.encode('utf-8'), body, hashlib.sha256)
expected_signature = base64.b64encode(hash.digest()).decode('utf-8')
except Exception as e:
log.error(f"Failed to generate signature: {e}")
raise HTTPException(status_code=500, detail="Signature generation error")
if not hmac.compare_digest(expected_signature, signature):
raise HTTPException(status_code=403, detail="Invalid signature")
try:
data = json.loads(body.decode('utf-8'))
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON body")
for event in data.get("events", []):
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
user_text = event.get("message", {}).get("text", "").strip()
source = event.get("source", {})
stype = source.get("type")
target_id = source.get("userId") or source.get("groupId") or source.get("roomId")
if user_text and target_id:
background_tasks.add_task(rag_pipeline.process_user_query, stype, target_id, user_text)
return Response(status_code=status.HTTP_200_OK)
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def line_api_call(endpoint: str, data: Dict):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
}
try:
response = requests.post(f"https://api.line.me/v2/bot/message/{endpoint}", headers=headers, json=data, timeout=10)
response.raise_for_status()
except requests.exceptions.RequestException as e:
log.error(f"LINE API ({endpoint}) 呼叫失敗: {e} | Response: {e.response.text if e.response else 'N/A'}")
raise
def line_push_generic(source_type: str, target_id: str, text: str):
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
endpoint = "push"
data = {"to": target_id, "messages": messages}
line_api_call(endpoint, data)
# ---------- 執行 ----------
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port) |