# app.py import os import faiss import numpy as np import time import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from pymongo import MongoClient from google import genai from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim from memory import MemoryManager from translation import translate_query from vlm import process_medical_image # ✅ Enable Logging for Debugging import logging # —————— Silence Noisy Loggers —————— for name in [ "uvicorn.error", "uvicorn.access", "fastapi", "starlette", "pymongo", "gridfs", "sentence_transformers", "faiss", "google", "google.auth", ]: logging.getLogger(name).setLevel(logging.WARNING) logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader logger = logging.getLogger("medical-chatbot") logger.setLevel(logging.DEBUG) # Debug Start logger.info("🚀 Starting Medical Chatbot API...") # ✅ Environment Variables mongo_uri = os.getenv("MONGO_URI") index_uri = os.getenv("INDEX_URI") gemini_flash_api_key = os.getenv("FlashAPI") # Validate environment endpoint if not all([gemini_flash_api_key, mongo_uri, index_uri]): raise ValueError("❌ Missing API keys! Set them in Hugging Face Secrets.") # logger.info(f"🔎 MongoDB URI: {mongo_uri}") # logger.info(f"🔎 FAISS Index URI: {index_uri}") # ✅ Monitor Resources Before Startup import psutil def check_system_resources(): memory = psutil.virtual_memory() cpu = psutil.cpu_percent(interval=1) disk = psutil.disk_usage("/") # Defines log info messages logger.info(f"[System] 🔍 System Resources - RAM: {memory.percent}%, CPU: {cpu}%, Disk: {disk.percent}%") if memory.percent > 85: logger.warning("⚠️ High RAM usage detected!") if cpu > 90: logger.warning("⚠️ High CPU usage detected!") if disk.percent > 90: logger.warning("⚠️ High Disk usage detected!") check_system_resources() # ✅ Reduce Memory usage with optimizers os.environ["OMP_NUM_THREADS"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" # ✅ Initialize FastAPI app app = FastAPI(title="Medical Chatbot API") memory = MemoryManager() from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin # Define the origins origins = [ "http://localhost:5173", # Vite dev server "http://localhost:3000", # Another vercel local dev "https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL ] # Add the CORS middleware: app.add_middleware( CORSMiddleware, allow_origins=origins, # or ["*"] to allow all allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ✅ Use Lazy Loading for FAISS Index index = None # Delay FAISS Index loading until first query # ✅ Load SentenceTransformer Model (Quantized/Halved) logger.info("[Embedder] 📥 Loading SentenceTransformer Model...") MODEL_CACHE_DIR = "/app/model_cache" try: embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device="cpu") embedding_model = embedding_model.half() # Reduce memory logger.info("✅ Model Loaded Successfully.") except Exception as e: logger.error(f"❌ Model Loading Failed: {e}") exit(1) # Cache in-memory vectors (optional — useful for <10k rows) SYMPTOM_VECTORS = None SYMPTOM_DOCS = None # ✅ Setup MongoDB Connection # QA data client = MongoClient(mongo_uri) db = client["MedicalChatbotDB"] qa_collection = db["qa_data"] # FAISS Index data iclient = MongoClient(index_uri) idb = iclient["MedicalChatbotDB"] index_collection = idb["faiss_index_files"] # Symptom Diagnosis data symptom_client = MongoClient(mongo_uri) symptom_col = symptom_client["MedicalChatbotDB"]["symptom_diagnosis"] # ✅ Load FAISS Index (Lazy Load) import gridfs fs = gridfs.GridFS(idb, collection="faiss_index_files") def load_faiss_index(): global index if index is None: logger.info("[KB] ⏳ Loading FAISS index from GridFS...") existing_file = fs.find_one({"filename": "faiss_index.bin"}) if existing_file: stored_index_bytes = existing_file.read() index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8') index = faiss.deserialize_index(index_bytes_np) logger.info("[KB] ✅ FAISS Index Loaded") else: logger.error("[KB] ❌ FAISS index not found in GridFS.") return index # ✅ Retrieve Medical Info (256,916 scenario) def retrieve_medical_info(query, k=5, min_sim=0.9): # Min similarity between query and kb is to be 80% global index index = load_faiss_index() if index is None: return [""] # Embed query query_vec = embedding_model.encode([query], convert_to_numpy=True) D, I = index.search(query_vec, k=k) # Filter by cosine threshold results = [] kept = [] kept_vecs = [] # Smart dedup on cosine threshold between similar candidates for score, idx in zip(D[0], I[0]): if score < min_sim: continue # List sim docs doc = qa_collection.find_one({"i": int(idx)}) if not doc: continue # Only compare answers answer = doc.get("Doctor", "").strip() if not answer: continue # Check semantic redundancy among previously kept results new_vec = embedding_model.encode([answer], convert_to_numpy=True)[0] is_similar = False for i, vec in enumerate(kept_vecs): sim = np.dot(vec, new_vec) / (np.linalg.norm(vec) * np.linalg.norm(new_vec) + 1e-9) if sim >= 0.9: # High semantic similarity is_similar = True # Keep only better match to original query cur_sim_to_query = np.dot(vec, query_vec[0]) / (np.linalg.norm(vec) * np.linalg.norm(query_vec[0]) + 1e-9) new_sim_to_query = np.dot(new_vec, query_vec[0]) / (np.linalg.norm(new_vec) * np.linalg.norm(query_vec[0]) + 1e-9) if new_sim_to_query > cur_sim_to_query: kept[i] = answer kept_vecs[i] = new_vec break # Non-similar candidates if not is_similar: kept.append(answer) kept_vecs.append(new_vec) # Final return kept if kept else [""] # ✅ Retrieve Sym-Dia Info (4,962 scenario) def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.5): global SYMPTOM_VECTORS, SYMPTOM_DOCS # Lazy load if SYMPTOM_VECTORS is None: all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1})) SYMPTOM_DOCS = all_docs SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32) # Embed input qvec = embedding_model.encode(symptom_text, convert_to_numpy=True) qvec = qvec / (np.linalg.norm(qvec) + 1e-9) # Similarity compute sims = SYMPTOM_VECTORS @ qvec # cosine sorted_idx = np.argsort(sims)[-top_k:][::-1] seen_diag = set() final = [] # Dedup for i in sorted_idx: sim = sims[i] if sim < min_sim: continue label = SYMPTOM_DOCS[i]["prognosis"] if label not in seen_diag: final.append(SYMPTOM_DOCS[i]["answer"]) seen_diag.add(label) return final # ✅ Gemini Flash API Call def gemini_flash_completion(prompt, model, temperature=0.7): client_genai = genai.Client(api_key=gemini_flash_api_key) try: response = client_genai.models.generate_content(model=model, contents=prompt) return response.text except Exception as e: logger.error(f"[LLM] ❌ Error calling Gemini API: {e}") return "Error generating response from Gemini." # ✅ Chatbot Class class RAGMedicalChatbot: def __init__(self, model_name, retrieve_function): self.model_name = model_name self.retrieve = retrieve_function def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "") -> str: # 0. Translate query if not EN, this help our RAG system if lang.upper() in {"VI", "ZH"}: user_query = translate_query(user_query, lang.lower()) # 1. Fetch knowledge ## a. KB for generic QA retrieval retrieved_info = self.retrieve(user_query) knowledge_base = "\n".join(retrieved_info) ## b. Diagnosis RAG from symptom query diagnosis_guides = retrieve_diagnosis_from_symptoms(user_query) # smart matcher # 2. Hybrid Context Retrieval: RAG + Recent History + Intelligent Selection contextual_chunks = memory.get_contextual_chunks(user_id, user_query, lang) # 3. Build prompt parts parts = ["You are a medical chatbot, designed to answer medical questions."] parts.append("Please format your answer using MarkDown.") parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.") # 4. Append image diagnosis from VLM if image_diagnosis: parts.append( "A user medical image is diagnosed by our VLM agent:\n" f"{image_diagnosis}\n\n" "Please incorporate the above findings in your response if medically relevant.\n\n" ) # Append contextual chunks from hybrid approach if contextual_chunks: parts.append("Relevant context from conversation history:\n" + contextual_chunks) # Load up guideline (RAG over medical knowledge base) if knowledge_base: parts.append(f"Example Q&A medical scenario knowledge-base: {knowledge_base}") # Symptom-Diagnosis prediction RAG if diagnosis_guides: parts.append("Symptom-based diagnosis guidance (if applicable):\n" + "\n".join(diagnosis_guides)) parts.append(f"User's question: {user_query}") parts.append(f"Language to generate answer: {lang}") prompt = "\n\n".join(parts) logger.info(f"[LLM] Question query in `prompt`: {prompt}") # Debug out checking RAG on kb and history response = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7) # Store exchange + chunking if user_id: memory.add_exchange(user_id, user_query, response, lang=lang) logger.info(f"[LLM] Response on `prompt`: {response.strip()}") # Debug out base response return response.strip() # ✅ Initialize Chatbot chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash", retrieve_function=retrieve_medical_info) # ✅ Chat Endpoint @app.post("/chat") async def chat_endpoint(req: Request): body = await req.json() user_id = body.get("user_id", "anonymous") query_raw = body.get("query") query = query_raw.strip() if isinstance(query_raw, str) else "" lang = body.get("lang", "EN") image_base64 = body.get("image_base64", None) img_desc = body.get("img_desc", "Describe and investigate any clinical findings from this medical image.") start = time.time() image_diagnosis = "" # LLM Only if not image_base64: logger.info("[BOT] LLM scenario.") # LLM+VLM else: # If image is present → diagnose first safe_load = len(image_base64.encode("utf-8")) if safe_load > 5_000_000: # Img size safe processor return JSONResponse({"response": "⚠️ Image too large. Please upload smaller images (<5MB)."}) logger.info("[BOT] VLM+LLM scenario.") logger.info(f"[VLM] Process medical image size: {safe_load}, desc: {img_desc}, {lang}.") image_diagnosis = process_medical_image(image_base64, img_desc, lang) answer = chatbot.chat(user_id, query, lang, image_diagnosis) elapsed = time.time() - start # Final return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"}) # ✅ Run Uvicorn if __name__ == "__main__": logger.info("[System] ✅ Starting FastAPI Server...") try: uvicorn.run(app, host="0.0.0.0", port=7860, log_level="debug") except Exception as e: logger.error(f"❌ Server Startup Failed: {e}") exit(1)