import os import joblib import torch from flask import Flask, request, jsonify from flask_cors import CORS from transformers import AutoTokenizer, AutoModelForSequenceClassification import google.generativeai as genai from dotenv import load_dotenv import firebase_admin from firebase_admin import credentials, firestore import datetime import json from huggingface_hub import hf_hub_download # --- App and Environment Initialization --- app = Flask(__name__) @app.route("/", methods=["GET"]) def health_check(): return jsonify({ "status": "ok", "message": "Health AI Backend is running 🚀" }), 200 # --- CORS CONFIGURATION --- CORS(app, resources={r"/*": {"origins": [ "https://health-app-lilac.vercel.app", "http://localhost:3000" ]}}) load_dotenv() # --- Firebase Initialization --- try: service_account_str = os.getenv("FIREBASE_SERVICE_ACCOUNT") if service_account_str: service_account_info = json.loads(service_account_str) cred = credentials.Certificate(service_account_info) else: cred = credentials.Certificate("firebase_key.json") if not firebase_admin._apps: firebase_admin.initialize_app(cred) db = firestore.client() print("✅ Firebase initialized successfully.") except Exception as e: print(f"❌ Firebase initialization failed: {e}") db = None # --- Gemini API Configuration --- try: genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) print("✨ Gemini API configured.") except Exception as e: print(f"❌ Gemini configuration failed: {e}") # --- Global variables to hold the loaded models --- local_model, tokenizer, label_encoder = None, None, None device = "cuda" if torch.cuda.is_available() else "cpu" MODEL_FOLDER_PATH = "./models/finetuned_model" LABEL_ENCODER_PATH = "./models/label_encoder.joblib" def load_models(): """Loads the model, tokenizer, and label encoder.""" global local_model, tokenizer, label_encoder if local_model is not None: return None print("📂 First request received. Starting model loading process...") try: if not os.path.exists(MODEL_FOLDER_PATH) or not os.path.exists(LABEL_ENCODER_PATH): error_msg = "❌ Model files not found locally." print(error_msg) return error_msg local_model = AutoModelForSequenceClassification.from_pretrained(MODEL_FOLDER_PATH) tokenizer = AutoTokenizer.from_pretrained(MODEL_FOLDER_PATH) label_encoder = joblib.load(LABEL_ENCODER_PATH) local_model.to(device) local_model.eval() print(f"✅ Models loaded successfully and moved to {device}.") return None except Exception as e: error_msg = f"❌ CRITICAL ERROR during model loading: {e}" print(error_msg) return error_msg # --- Prompts and AI Logic --- def get_doctor_persona_prompt(stage, user_details=None, local_predictions=None): """ Generates a system prompt tailored to the specific stage of the conversation. """ # Use real-time data instead of hardcoded values try: from zoneinfo import ZoneInfo tz = ZoneInfo("Asia/Kolkata") now = datetime.datetime.now(tz) except Exception: now = datetime.datetime.now() # Fallback to server's local time time_str = now.strftime("%A, %B %d, %Y at %I:%M %p %Z") location = user_details.get('location', 'the user\'s area') if user_details else 'the user\'s area' # --- STAGE 1: Initial Greeting --- if stage == 'awaiting_name': return """ **SYSTEM INSTRUCTION:** You are Dr. Aether, a friendly and professional AI medical assistant. Your ONLY task is to greet the user warmly and ask for their first name to begin the consultation. Example: "Hello, I'm Dr. Aether. To get started, could you please tell me your full name?" """ # --- STAGE 2 & 3: Unified Triage, Inquiry, and Conversation --- if stage in ['awaiting_age', 'awaiting_sex', 'awaiting_symptoms', 'chatting']: return f""" **SYSTEM INSTRUCTION: ACT AS DR. AETHER - MEDICAL TRIAGE & INQUIRY** **Persona:** You are Dr. Aether, an experienced and empathetic AI physician. **User Context:** The user is in {location}. The current time is {time_str}. **Your Task:** Conduct a medical consultation. First, gather essential patient information, then ask clarifying medical questions. **Information Checklist (Review the chat history to see what is missing):** - [ ] Patient's Name - [ ] Patient's Age - [ ] Patient's Biological Sex - [ ] Primary Symptoms **CRITICAL Directives & Flow:** 1. **Check the List:** Review the entire chat history to determine the FIRST piece of information from the checklist that is still missing. 2. **Ask for ONE Missing Item:** Your response must ask for ONLY that single missing item. - If Age is missing, ask for it using `[CHIPS: ["Under 18", "18-64", "65+"]]`. - If Sex is missing, ask for it using `[CHIPS: ["Male", "Female"]]`. - If Symptoms are missing, ask for them with an open-ended question. 3. **Transition to Medical Inquiry:** ONLY when all four items on the checklist have been gathered, you must transition to asking focused medical questions to clarify the symptoms (e.g., "How long have you had these symptoms?", "Is the cough dry or productive?"). 4. **One Question at a Time:** ALWAYS ask only ONE question per turn. 5. **Conclude Inquiry:** Once you have gathered sufficient details about the symptoms (e.g., duration, type, severity, associated symptoms), your very next response MUST BE ONLY the token `[SUMMARY_READY]`. Do not ask "is there anything else?". Just output the token. """ # --- STAGE 4: Final Summary Generation --- if stage == 'process_symptoms': predictions_text = "No preliminary analysis was available." if local_predictions: formatted = [f"- {p.get('disease', 'N/A')} ({p.get('confidence', 0):.0%})" for p in local_predictions] predictions_text = "My preliminary analysis suggests the following possibilities:\n" + "\n".join(formatted) return f""" **SYSTEM INSTRUCTION: ACT AS DR. AETHER - FINAL SUMMARY GENERATOR** **CONTEXT:** - **User Location:** {location} - **Local Model Analysis:** {predictions_text} - **Conversation History:** (The full history will be provided after this prompt) **YOUR ONLY TASK:** Review the conversation history and the local model analysis, then generate a final summary in the required JSON format. Your entire output must be ONLY the JSON block, starting with `[SUMMARY:` and ending with `}}]`. **CRITICAL JSON RULES:** 1. All text inside JSON string values (e.g., in "recap") MUST be a single line. DO NOT use unescaped newline characters. 2. Fill every field with medically sound, empathetic, and relevant information based on the conversation. **EXAMPLE OUTPUT STRUCTURE:** [SUMMARY: {{"recap": "Satvik, a user in the 18-64 age range, reports a 3-day history of high fever, mild dry cough, and nightly chills, with an associated sore throat.", "possibilities": "Based on the symptoms and preliminary analysis, this could suggest a viral upper respiratory infection like the common cold or influenza. The local model noted a possibility of pneumonia (48%), which should be considered, especially given the fever and cough.", "homeCare": ["Rest as much as possible to help your body fight the infection.", "Stay hydrated by drinking plenty of water, broth, or tea.", "Use over-the-counter medications like acetaminophen or ibuprofen for fever and aches, following package directions."], "recommendation": "It is recommended you consult a doctor in {location} within the next 24-48 hours for a proper diagnosis, especially to rule out a bacterial infection like pneumonia.", "conclusion": "I hope this information has been helpful. Please follow up with a healthcare professional as recommended. Is there anything else I can assist you with?"}}] **Now, generate the JSON for the current user's conversation history.** """ return "You are a helpful assistant." # Fallback def run_local_prediction(symptoms_text): """Runs the local model to get disease predictions.""" if not all([local_model, tokenizer, label_encoder]): print("❌ Cannot run local prediction: model components not loaded.") return None try: inputs = tokenizer(symptoms_text, return_tensors="pt", truncation=True, padding=True).to(device) with torch.no_grad(): outputs = local_model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) top_probs, top_indices = torch.topk(probabilities, 3) predictions = [ {"disease": label_encoder.classes_[idx], "confidence": float(prob)} for idx, prob in zip(top_indices.cpu().numpy()[0], top_probs.cpu().numpy()[0]) ] print(f"✅ Local model predictions: {predictions}") return predictions except Exception as e: print(f"❌ Local prediction error: {e}") return None def _get_symptoms_from_history(history, model): """Uses Gemini to extract a clean symptom string from the chat history.""" print("👉 Asking Gemini to extract symptom string for local model...") symptom_extraction_prompt = ( "Review the user messages in this chat history. Extract all medical symptoms, durations, and key descriptors. " "Consolidate them into a clean, comma-separated string for a machine learning model. " "Example: 'fever, dry cough, sore throat, duration 1-3 days, chills at night, 102F temperature'. " "Output ONLY the symptom string." ) # This is a more robust way to pass history for this task symptom_history = [{'role': 'user', 'parts': [symptom_extraction_prompt]}] + [msg for msg in history if msg['role'] == 'user'] try: symptom_response = model.generate_content(symptom_history) symptoms_summary = symptom_response.text.strip().replace('\n', ', ') print(f" ✅ Gemini-summarized symptoms: '{symptoms_summary}'") return symptoms_summary except Exception as e: print(f" ❌ Failed to get symptoms from history: {e}") return "" @app.route('/chat', methods=['POST']) def chat(): load_error = load_models() if load_error: return jsonify({"error": f"Model loading failed: {load_error}"}), 503 data = request.get_json() or {} history = data.get('history', []) location = data.get('location', 'your area') # Use location from frontend # Determine stage based on history length; more robust than trusting client stage = 'awaiting_name' if len(history) > 1: stage = data.get('stage', 'chatting') if not history: return jsonify({"error": "Chat history not provided."}), 400 user_details = {"location": location} try: model = genai.GenerativeModel('gemini-1.5-flash') # --- Primary Conversational Call --- system_prompt = get_doctor_persona_prompt(stage, user_details) conversation_history = [ {'role': 'user', 'parts': [system_prompt]}, {'role': 'model', 'parts': ["Understood. I will act as Dr. Aether and follow all instructions for my current stage."]} ] + history response = model.generate_content(conversation_history) response_text = response.text.strip() # --- Check for Summary Trigger --- if response_text == "[SUMMARY_READY]": print("✅ AI signaled summary is ready. Starting final analysis.") # 1. Summarize symptoms for local model symptoms_summary = _get_symptoms_from_history(history, model) # 2. Run local model predictions_result = run_local_prediction(symptoms_summary) if symptoms_summary else [] # 3. Generate the final summary response final_summary_prompt = get_doctor_persona_prompt('process_symptoms', user_details, predictions_result) final_summary_history = [ {'role': 'user', 'parts': [final_summary_prompt]}, {'role': 'model', 'parts': ["Understood. I will now generate the final summary in the required JSON format based on the provided history."]} ] + history final_response = model.generate_content(final_summary_history) # ADDED DEBUGGING print(f"--- DEBUG: Raw Final Summary Response from Gemini ---") print(final_response.text) print(f"----------------------------------------------------") response_data = {"reply": final_response.text} if predictions_result: response_data["predictions"] = predictions_result return jsonify(response_data) else: # It's a normal conversational turn return jsonify({"reply": response_text}) except Exception as e: print(f"❌ API error: {e}") return jsonify({"error": f"An error occurred with the AI service: {e}"}), 500 @app.route('/get_chats', methods=['POST']) def get_chats(): if not db: return jsonify({"error": "Firestore not initialized."}), 500 data = request.get_json() or {} user_id = data.get('user_id') if not user_id: return jsonify({"error": "User ID not provided."}), 400 try: chats_ref = db.collection('users').document(user_id).collection('chats') query = chats_ref.order_by("timestamp", direction=firestore.Query.DESCENDING) chats = [doc.to_dict() for doc in query.stream()] return jsonify(chats) except Exception as e: print(f"❌ Firestore get_chats error: {e}") return jsonify({"error": f"Failed to retrieve chats: {e}"}), 500 @app.route('/save_chat', methods=['POST']) def save_chat(): if not db: return jsonify({"error": "Firestore is not initialized."}), 500 data = request.get_json() or {} user_id = data.get('userId') chat_data = data.get('chatData') if not user_id or not chat_data: return jsonify({"error": "User ID or chat data is missing."}), 400 try: chat_data['timestamp'] = datetime.datetime.fromisoformat( chat_data['timestamp'].replace('Z', '+00:00') ) chat_ref = db.collection('users').document(user_id).collection('chats').document(chat_data['id']) chat_ref.set(chat_data, merge=True) return jsonify({"success": True, "chatId": chat_data['id']}) except Exception as e: print(f"❌ Firestore save_chat error: {e}") return jsonify({"error": f"Failed to save chat: {e}"}), 500 # --- Server Startup Block --- if __name__ == "__main__": port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port)