Spaces:
Sleeping
Sleeping
File size: 12,356 Bytes
57e04c3 5e69775 4455263 57e04c3 5e69775 4455263 3dcd314 57e04c3 6b4f62a 6db39d6 5e69775 65d7792 5c59423 15ed85c 57e04c3 65d7792 15ed85c 57e04c3 65d7792 4455263 5e69775 4455263 65d7792 4455263 a8e02fb 65d7792 986cdbd 65d7792 5e69775 4455263 57e04c3 65d7792 5e69775 4455263 5e69775 65d7792 986cdbd 4455263 65d7792 9382e01 5e69775 4455263 65d7792 4455263 65d7792 4455263 9382e01 4455263 5e69775 4455263 8828f20 5e69775 986cdbd 4455263 986cdbd 4455263 986cdbd 4455263 6b4f62a 25a6ec3 4455263 8bc48fc 5e69775 2415f43 8bc48fc 2415f43 8bc48fc 2415f43 8bc48fc 2415f43 8bc48fc 6b4f62a 25a6ec3 9382e01 3dcd314 9382e01 6b4f62a 5e69775 4455263 5e69775 986cdbd 5e69775 4455263 5e69775 6db39d6 6b4f62a 57e04c3 9382e01 5e69775 9382e01 5e69775 d999c28 57e04c3 d999c28 6db39d6 25a6ec3 6db39d6 115b95d 6db39d6 d999c28 2afe3f5 d999c28 8bc48fc 25a6ec3 9382e01 25a6ec3 57e04c3 986cdbd 57e04c3 1ca4ee7 57e04c3 1ca4ee7 2a70cd4 57e04c3 5e69775 4455263 a0c9251 5e69775 4455263 5e69775 57e04c3 08a2f2d 57e04c3 6db39d6 c14055f b27a873 6db39d6 c14055f 6db39d6 b27a873 6db39d6 c14055f 6db39d6 57e04c3 5e69775 65d7792 8828f20 986cdbd 65d7792 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 |
# app.py
import os
import faiss
import numpy as np
import time
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pymongo import MongoClient
from google import genai
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from memory import MemoryManager
from translation import translate_query
from vlm import process_medical_image
# ✅ Enable Logging for Debugging
import logging
# —————— Silence Noisy Loggers ——————
for name in [
"uvicorn.error", "uvicorn.access",
"fastapi", "starlette",
"pymongo", "gridfs",
"sentence_transformers", "faiss",
"google", "google.auth",
]:
logging.getLogger(name).setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
logger = logging.getLogger("medical-chatbot")
logger.setLevel(logging.DEBUG)
# Debug Start
logger.info("🚀 Starting Medical Chatbot API...")
# ✅ Environment Variables
mongo_uri = os.getenv("MONGO_URI")
index_uri = os.getenv("INDEX_URI")
gemini_flash_api_key = os.getenv("FlashAPI")
# Validate environment endpoint
if not all([gemini_flash_api_key, mongo_uri, index_uri]):
raise ValueError("❌ Missing API keys! Set them in Hugging Face Secrets.")
# logger.info(f"🔎 MongoDB URI: {mongo_uri}")
# logger.info(f"🔎 FAISS Index URI: {index_uri}")
# ✅ Monitor Resources Before Startup
import psutil
def check_system_resources():
memory = psutil.virtual_memory()
cpu = psutil.cpu_percent(interval=1)
disk = psutil.disk_usage("/")
# Defines log info messages
logger.info(f"[System] 🔍 System Resources - RAM: {memory.percent}%, CPU: {cpu}%, Disk: {disk.percent}%")
if memory.percent > 85:
logger.warning("⚠️ High RAM usage detected!")
if cpu > 90:
logger.warning("⚠️ High CPU usage detected!")
if disk.percent > 90:
logger.warning("⚠️ High Disk usage detected!")
check_system_resources()
# ✅ Reduce Memory usage with optimizers
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# ✅ Initialize FastAPI app
app = FastAPI(title="Medical Chatbot API")
memory = MemoryManager()
from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
# Define the origins
origins = [
"http://localhost:5173", # Vite dev server
"http://localhost:3000", # Another vercel local dev
"https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL
]
# Add the CORS middleware:
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # or ["*"] to allow all
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ✅ Use Lazy Loading for FAISS Index
index = None # Delay FAISS Index loading until first query
# ✅ Load SentenceTransformer Model (Quantized/Halved)
logger.info("[Embedder] 📥 Loading SentenceTransformer Model...")
MODEL_CACHE_DIR = "/app/model_cache"
try:
embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device="cpu")
embedding_model = embedding_model.half() # Reduce memory
logger.info("✅ Model Loaded Successfully.")
except Exception as e:
logger.error(f"❌ Model Loading Failed: {e}")
exit(1)
# Cache in-memory vectors (optional — useful for <10k rows)
SYMPTOM_VECTORS = None
SYMPTOM_DOCS = None
# ✅ Setup MongoDB Connection
# QA data
client = MongoClient(mongo_uri)
db = client["MedicalChatbotDB"]
qa_collection = db["qa_data"]
# FAISS Index data
iclient = MongoClient(index_uri)
idb = iclient["MedicalChatbotDB"]
index_collection = idb["faiss_index_files"]
# Symptom Diagnosis data
symptom_client = MongoClient(mongo_uri)
symptom_col = symptom_client["MedicalChatbotDB"]["symptom_diagnosis"]
# ✅ Load FAISS Index (Lazy Load)
import gridfs
fs = gridfs.GridFS(idb, collection="faiss_index_files")
def load_faiss_index():
global index
if index is None:
logger.info("[KB] ⏳ Loading FAISS index from GridFS...")
existing_file = fs.find_one({"filename": "faiss_index.bin"})
if existing_file:
stored_index_bytes = existing_file.read()
index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
index = faiss.deserialize_index(index_bytes_np)
logger.info("[KB] ✅ FAISS Index Loaded")
else:
logger.error("[KB] ❌ FAISS index not found in GridFS.")
return index
# ✅ Retrieve Medical Info (256,916 scenario)
def retrieve_medical_info(query, k=5, min_sim=0.9): # Min similarity between query and kb is to be 80%
global index
index = load_faiss_index()
if index is None:
return [""]
# Embed query
query_vec = embedding_model.encode([query], convert_to_numpy=True)
D, I = index.search(query_vec, k=k)
# Filter by cosine threshold
results = []
kept = []
kept_vecs = []
# Smart dedup on cosine threshold between similar candidates
for score, idx in zip(D[0], I[0]):
if score < min_sim:
continue
# List sim docs
doc = qa_collection.find_one({"i": int(idx)})
if not doc:
continue
# Only compare answers
answer = doc.get("Doctor", "").strip()
if not answer:
continue
# Check semantic redundancy among previously kept results
new_vec = embedding_model.encode([answer], convert_to_numpy=True)[0]
is_similar = False
for i, vec in enumerate(kept_vecs):
sim = np.dot(vec, new_vec) / (np.linalg.norm(vec) * np.linalg.norm(new_vec) + 1e-9)
if sim >= 0.9: # High semantic similarity
is_similar = True
# Keep only better match to original query
cur_sim_to_query = np.dot(vec, query_vec[0]) / (np.linalg.norm(vec) * np.linalg.norm(query_vec[0]) + 1e-9)
new_sim_to_query = np.dot(new_vec, query_vec[0]) / (np.linalg.norm(new_vec) * np.linalg.norm(query_vec[0]) + 1e-9)
if new_sim_to_query > cur_sim_to_query:
kept[i] = answer
kept_vecs[i] = new_vec
break
# Non-similar candidates
if not is_similar:
kept.append(answer)
kept_vecs.append(new_vec)
# Final
return kept if kept else [""]
# ✅ Retrieve Sym-Dia Info (4,962 scenario)
def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.5):
global SYMPTOM_VECTORS, SYMPTOM_DOCS
# Lazy load
if SYMPTOM_VECTORS is None:
all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1}))
SYMPTOM_DOCS = all_docs
SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
# Embed input
qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
# Similarity compute
sims = SYMPTOM_VECTORS @ qvec # cosine
sorted_idx = np.argsort(sims)[-top_k:][::-1]
seen_diag = set()
final = [] # Dedup
for i in sorted_idx:
sim = sims[i]
if sim < min_sim:
continue
label = SYMPTOM_DOCS[i]["prognosis"]
if label not in seen_diag:
final.append(SYMPTOM_DOCS[i]["answer"])
seen_diag.add(label)
return final
# ✅ Gemini Flash API Call
def gemini_flash_completion(prompt, model, temperature=0.7):
client_genai = genai.Client(api_key=gemini_flash_api_key)
try:
response = client_genai.models.generate_content(model=model, contents=prompt)
return response.text
except Exception as e:
logger.error(f"[LLM] ❌ Error calling Gemini API: {e}")
return "Error generating response from Gemini."
# ✅ Chatbot Class
class RAGMedicalChatbot:
def __init__(self, model_name, retrieve_function):
self.model_name = model_name
self.retrieve = retrieve_function
def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "") -> str:
# 0. Translate query if not EN, this help our RAG system
if lang.upper() in {"VI", "ZH"}:
user_query = translate_query(user_query, lang.lower())
# 1. Fetch knowledge
## a. KB for generic QA retrieval
retrieved_info = self.retrieve(user_query)
knowledge_base = "\n".join(retrieved_info)
## b. Diagnosis RAG from symptom query
diagnosis_guides = retrieve_diagnosis_from_symptoms(user_query) # smart matcher
# 2. Hybrid Context Retrieval: RAG + Recent History + Intelligent Selection
contextual_chunks = memory.get_contextual_chunks(user_id, user_query, lang)
# 3. Build prompt parts
parts = ["You are a medical chatbot, designed to answer medical questions."]
parts.append("Please format your answer using MarkDown.")
parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
# 4. Append image diagnosis from VLM
if image_diagnosis:
parts.append(
"A user medical image is diagnosed by our VLM agent:\n"
f"{image_diagnosis}\n\n"
"Please incorporate the above findings in your response if medically relevant.\n\n"
)
# Append contextual chunks from hybrid approach
if contextual_chunks:
parts.append("Relevant context from conversation history:\n" + contextual_chunks)
# Load up guideline (RAG over medical knowledge base)
if knowledge_base:
parts.append(f"Example Q&A medical scenario knowledge-base: {knowledge_base}")
# Symptom-Diagnosis prediction RAG
if diagnosis_guides:
parts.append("Symptom-based diagnosis guidance (if applicable):\n" + "\n".join(diagnosis_guides))
parts.append(f"User's question: {user_query}")
parts.append(f"Language to generate answer: {lang}")
prompt = "\n\n".join(parts)
logger.info(f"[LLM] Question query in `prompt`: {prompt}") # Debug out checking RAG on kb and history
response = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
# Store exchange + chunking
if user_id:
memory.add_exchange(user_id, user_query, response, lang=lang)
logger.info(f"[LLM] Response on `prompt`: {response.strip()}") # Debug out base response
return response.strip()
# ✅ Initialize Chatbot
chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash", retrieve_function=retrieve_medical_info)
# ✅ Chat Endpoint
@app.post("/chat")
async def chat_endpoint(req: Request):
body = await req.json()
user_id = body.get("user_id", "anonymous")
query_raw = body.get("query")
query = query_raw.strip() if isinstance(query_raw, str) else ""
lang = body.get("lang", "EN")
image_base64 = body.get("image_base64", None)
img_desc = body.get("img_desc", "Describe and investigate any clinical findings from this medical image.")
start = time.time()
image_diagnosis = ""
# LLM Only
if not image_base64:
logger.info("[BOT] LLM scenario.")
# LLM+VLM
else:
# If image is present → diagnose first
safe_load = len(image_base64.encode("utf-8"))
if safe_load > 5_000_000: # Img size safe processor
return JSONResponse({"response": "⚠️ Image too large. Please upload smaller images (<5MB)."})
logger.info("[BOT] VLM+LLM scenario.")
logger.info(f"[VLM] Process medical image size: {safe_load}, desc: {img_desc}, {lang}.")
image_diagnosis = process_medical_image(image_base64, img_desc, lang)
answer = chatbot.chat(user_id, query, lang, image_diagnosis)
elapsed = time.time() - start
# Final
return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
# ✅ Run Uvicorn
if __name__ == "__main__":
logger.info("[System] ✅ Starting FastAPI Server...")
try:
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="debug")
except Exception as e:
logger.error(f"❌ Server Startup Failed: {e}")
exit(1)
|