Spaces:
Sleeping
Sleeping
File size: 15,250 Bytes
8854fc0 cb0f77f 8854fc0 35495ed cb0f77f 5af0a86 cb0f77f 5af0a86 35495ed 8854fc0 cb0f77f 8854fc0 cb0f77f 8854fc0 cb0f77f 8854fc0 cb0f77f 8854fc0 cb0f77f 24cbb20 8854fc0 35730ed 8854fc0 cb0f77f 35730ed cb0f77f 35730ed cb0f77f 35730ed 68e4176 cb0f77f 0398b1b 68e4176 cb0f77f 0398b1b cb0f77f 35730ed 54f3d72 9a28be7 cb0f77f 35730ed cb0f77f 9a28be7 35730ed 9a28be7 35730ed 9a28be7 fce139e 9a28be7 fce139e 9a28be7 35730ed 9a28be7 35730ed 8df30bc 710638b 35730ed 3a2f487 35730ed 3a2f487 35730ed 2325d37 c9db41c 35730ed f811c0c c9db41c cb0f77f c9db41c cb0f77f 35730ed c9db41c cb0f77f c9db41c 9a28be7 fce139e 9a28be7 fce139e 9a28be7 fce139e 9a28be7 c9db41c c5507d4 ae55fb3 7840638 9a28be7 cb0f77f 9a28be7 c5507d4 9a28be7 35730ed 0398b1b 9a28be7 cb0f77f 9a28be7 cb0f77f c9db41c 9a28be7 35730ed 9a28be7 3a2f487 9a28be7 fce139e 9a28be7 3a2f487 9a28be7 c9db41c cb0f77f 9a28be7 8854fc0 cb0f77f 5af0a86 cb0f77f 8854fc0 c5507d4 8854fc0 cb0f77f 8854fc0 cb0f77f 5af0a86 cb0f77f 8854fc0 cb0f77f 8854fc0 cb0f77f 8854fc0 54f3d72 8854fc0 35730ed c9db41c 0398b1b 3a2f487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
import os
import joblib
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import google.generativeai as genai
from dotenv import load_dotenv
import firebase_admin
from firebase_admin import credentials, firestore
import datetime
import json
from huggingface_hub import hf_hub_download
# --- App and Environment Initialization ---
app = Flask(__name__)
@app.route("/", methods=["GET"])
def health_check():
return jsonify({
"status": "ok",
"message": "Health AI Backend is running π"
}), 200
# --- CORS CONFIGURATION ---
CORS(app, resources={r"/*": {"origins": [
"https://health-app-lilac.vercel.app",
"http://localhost:3000"
]}})
load_dotenv()
# --- Firebase Initialization ---
try:
service_account_str = os.getenv("FIREBASE_SERVICE_ACCOUNT")
if service_account_str:
service_account_info = json.loads(service_account_str)
cred = credentials.Certificate(service_account_info)
else:
cred = credentials.Certificate("firebase_key.json")
if not firebase_admin._apps:
firebase_admin.initialize_app(cred)
db = firestore.client()
print("β
Firebase initialized successfully.")
except Exception as e:
print(f"β Firebase initialization failed: {e}")
db = None
# --- Gemini API Configuration ---
try:
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
print("β¨ Gemini API configured.")
except Exception as e:
print(f"β Gemini configuration failed: {e}")
# --- Global variables to hold the loaded models ---
local_model, tokenizer, label_encoder = None, None, None
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_FOLDER_PATH = "./models/finetuned_model"
LABEL_ENCODER_PATH = "./models/label_encoder.joblib"
def load_models():
"""Loads the model, tokenizer, and label encoder."""
global local_model, tokenizer, label_encoder
if local_model is not None:
return None
print("π First request received. Starting model loading process...")
try:
if not os.path.exists(MODEL_FOLDER_PATH) or not os.path.exists(LABEL_ENCODER_PATH):
error_msg = "β Model files not found locally."
print(error_msg)
return error_msg
local_model = AutoModelForSequenceClassification.from_pretrained(MODEL_FOLDER_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_FOLDER_PATH)
label_encoder = joblib.load(LABEL_ENCODER_PATH)
local_model.to(device)
local_model.eval()
print(f"β
Models loaded successfully and moved to {device}.")
return None
except Exception as e:
error_msg = f"β CRITICAL ERROR during model loading: {e}"
print(error_msg)
return error_msg
# --- Prompts and AI Logic ---
def get_doctor_persona_prompt(stage, user_details=None, local_predictions=None):
"""
Generates a system prompt tailored to the specific stage of the conversation.
"""
# Use real-time data instead of hardcoded values
try:
from zoneinfo import ZoneInfo
tz = ZoneInfo("Asia/Kolkata")
now = datetime.datetime.now(tz)
except Exception:
now = datetime.datetime.now() # Fallback to server's local time
time_str = now.strftime("%A, %B %d, %Y at %I:%M %p %Z")
location = user_details.get('location', 'the user\'s area') if user_details else 'the user\'s area'
# --- STAGE 1: Initial Greeting ---
if stage == 'awaiting_name':
return """
**SYSTEM INSTRUCTION:** You are Dr. Aether, a friendly and professional AI medical assistant.
Your ONLY task is to greet the user warmly and ask for their first name to begin the consultation.
Example: "Hello, I'm Dr. Aether. To get started, could you please tell me your full name?"
"""
# --- STAGE 2 & 3: Unified Triage, Inquiry, and Conversation ---
if stage in ['awaiting_age', 'awaiting_sex', 'awaiting_symptoms', 'chatting']:
return f"""
**SYSTEM INSTRUCTION: ACT AS DR. AETHER - MEDICAL TRIAGE & INQUIRY**
**Persona:** You are Dr. Aether, an experienced and empathetic AI physician.
**User Context:** The user is in {location}. The current time is {time_str}.
**Your Task:** Conduct a medical consultation. First, gather essential patient information, then ask clarifying medical questions.
**Information Checklist (Review the chat history to see what is missing):**
- [ ] Patient's Name
- [ ] Patient's Age
- [ ] Patient's Biological Sex
- [ ] Primary Symptoms
**CRITICAL Directives & Flow:**
1. **Check the List:** Review the entire chat history to determine the FIRST piece of information from the checklist that is still missing.
2. **Ask for ONE Missing Item:** Your response must ask for ONLY that single missing item.
- If Age is missing, ask for it using `[CHIPS: ["Under 18", "18-64", "65+"]]`.
- If Sex is missing, ask for it using `[CHIPS: ["Male", "Female"]]`.
- If Symptoms are missing, ask for them with an open-ended question.
3. **Transition to Medical Inquiry:** ONLY when all four items on the checklist have been gathered, you must transition to asking focused medical questions to clarify the symptoms (e.g., "How long have you had these symptoms?", "Is the cough dry or productive?").
4. **One Question at a Time:** ALWAYS ask only ONE question per turn.
5. **Conclude Inquiry:** Once you have gathered sufficient details about the symptoms (e.g., duration, type, severity, associated symptoms), your very next response MUST BE ONLY the token `[SUMMARY_READY]`. Do not ask "is there anything else?". Just output the token.
"""
# --- STAGE 4: Final Summary Generation ---
if stage == 'process_symptoms':
predictions_text = "No preliminary analysis was available."
if local_predictions:
formatted = [f"- {p.get('disease', 'N/A')} ({p.get('confidence', 0):.0%})" for p in local_predictions]
predictions_text = "My preliminary analysis suggests the following possibilities:\n" + "\n".join(formatted)
return f"""
**SYSTEM INSTRUCTION: ACT AS DR. AETHER - FINAL SUMMARY GENERATOR**
**CONTEXT:**
- **User Location:** {location}
- **Local Model Analysis:** {predictions_text}
- **Conversation History:** (The full history will be provided after this prompt)
**YOUR ONLY TASK:**
Review the conversation history and the local model analysis, then generate a final summary in the required JSON format. Your entire output must be ONLY the JSON block, starting with `[SUMMARY:` and ending with `}}]`.
**CRITICAL JSON RULES:**
1. All text inside JSON string values (e.g., in "recap") MUST be a single line. DO NOT use unescaped newline characters.
2. Fill every field with medically sound, empathetic, and relevant information based on the conversation.
**EXAMPLE OUTPUT STRUCTURE:**
[SUMMARY: {{"recap": "Satvik, a user in the 18-64 age range, reports a 3-day history of high fever, mild dry cough, and nightly chills, with an associated sore throat.", "possibilities": "Based on the symptoms and preliminary analysis, this could suggest a viral upper respiratory infection like the common cold or influenza. The local model noted a possibility of pneumonia (48%), which should be considered, especially given the fever and cough.", "homeCare": ["Rest as much as possible to help your body fight the infection.", "Stay hydrated by drinking plenty of water, broth, or tea.", "Use over-the-counter medications like acetaminophen or ibuprofen for fever and aches, following package directions."], "recommendation": "It is recommended you consult a doctor in {location} within the next 24-48 hours for a proper diagnosis, especially to rule out a bacterial infection like pneumonia.", "conclusion": "I hope this information has been helpful. Please follow up with a healthcare professional as recommended. Is there anything else I can assist you with?"}}]
**Now, generate the JSON for the current user's conversation history.**
"""
return "You are a helpful assistant." # Fallback
def run_local_prediction(symptoms_text):
"""Runs the local model to get disease predictions."""
if not all([local_model, tokenizer, label_encoder]):
print("β Cannot run local prediction: model components not loaded.")
return None
try:
inputs = tokenizer(symptoms_text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = local_model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
top_probs, top_indices = torch.topk(probabilities, 3)
predictions = [
{"disease": label_encoder.classes_[idx], "confidence": float(prob)}
for idx, prob in zip(top_indices.cpu().numpy()[0], top_probs.cpu().numpy()[0])
]
print(f"β
Local model predictions: {predictions}")
return predictions
except Exception as e:
print(f"β Local prediction error: {e}")
return None
def _get_symptoms_from_history(history, model):
"""Uses Gemini to extract a clean symptom string from the chat history."""
print("π Asking Gemini to extract symptom string for local model...")
symptom_extraction_prompt = (
"Review the user messages in this chat history. Extract all medical symptoms, durations, and key descriptors. "
"Consolidate them into a clean, comma-separated string for a machine learning model. "
"Example: 'fever, dry cough, sore throat, duration 1-3 days, chills at night, 102F temperature'. "
"Output ONLY the symptom string."
)
# This is a more robust way to pass history for this task
symptom_history = [{'role': 'user', 'parts': [symptom_extraction_prompt]}] + [msg for msg in history if msg['role'] == 'user']
try:
symptom_response = model.generate_content(symptom_history)
symptoms_summary = symptom_response.text.strip().replace('\n', ', ')
print(f" β
Gemini-summarized symptoms: '{symptoms_summary}'")
return symptoms_summary
except Exception as e:
print(f" β Failed to get symptoms from history: {e}")
return ""
@app.route('/chat', methods=['POST'])
def chat():
load_error = load_models()
if load_error:
return jsonify({"error": f"Model loading failed: {load_error}"}), 503
data = request.get_json() or {}
history = data.get('history', [])
location = data.get('location', 'your area') # Use location from frontend
# Determine stage based on history length; more robust than trusting client
stage = 'awaiting_name'
if len(history) > 1:
stage = data.get('stage', 'chatting')
if not history:
return jsonify({"error": "Chat history not provided."}), 400
user_details = {"location": location}
try:
model = genai.GenerativeModel('gemini-1.5-flash')
# --- Primary Conversational Call ---
system_prompt = get_doctor_persona_prompt(stage, user_details)
conversation_history = [
{'role': 'user', 'parts': [system_prompt]},
{'role': 'model', 'parts': ["Understood. I will act as Dr. Aether and follow all instructions for my current stage."]}
] + history
response = model.generate_content(conversation_history)
response_text = response.text.strip()
# --- Check for Summary Trigger ---
if response_text == "[SUMMARY_READY]":
print("β
AI signaled summary is ready. Starting final analysis.")
# 1. Summarize symptoms for local model
symptoms_summary = _get_symptoms_from_history(history, model)
# 2. Run local model
predictions_result = run_local_prediction(symptoms_summary) if symptoms_summary else []
# 3. Generate the final summary response
final_summary_prompt = get_doctor_persona_prompt('process_symptoms', user_details, predictions_result)
final_summary_history = [
{'role': 'user', 'parts': [final_summary_prompt]},
{'role': 'model', 'parts': ["Understood. I will now generate the final summary in the required JSON format based on the provided history."]}
] + history
final_response = model.generate_content(final_summary_history)
# ADDED DEBUGGING
print(f"--- DEBUG: Raw Final Summary Response from Gemini ---")
print(final_response.text)
print(f"----------------------------------------------------")
response_data = {"reply": final_response.text}
if predictions_result:
response_data["predictions"] = predictions_result
return jsonify(response_data)
else: # It's a normal conversational turn
return jsonify({"reply": response_text})
except Exception as e:
print(f"β API error: {e}")
return jsonify({"error": f"An error occurred with the AI service: {e}"}), 500
@app.route('/get_chats', methods=['POST'])
def get_chats():
if not db:
return jsonify({"error": "Firestore not initialized."}), 500
data = request.get_json() or {}
user_id = data.get('user_id')
if not user_id:
return jsonify({"error": "User ID not provided."}), 400
try:
chats_ref = db.collection('users').document(user_id).collection('chats')
query = chats_ref.order_by("timestamp", direction=firestore.Query.DESCENDING)
chats = [doc.to_dict() for doc in query.stream()]
return jsonify(chats)
except Exception as e:
print(f"β Firestore get_chats error: {e}")
return jsonify({"error": f"Failed to retrieve chats: {e}"}), 500
@app.route('/save_chat', methods=['POST'])
def save_chat():
if not db:
return jsonify({"error": "Firestore is not initialized."}), 500
data = request.get_json() or {}
user_id = data.get('userId')
chat_data = data.get('chatData')
if not user_id or not chat_data:
return jsonify({"error": "User ID or chat data is missing."}), 400
try:
chat_data['timestamp'] = datetime.datetime.fromisoformat(
chat_data['timestamp'].replace('Z', '+00:00')
)
chat_ref = db.collection('users').document(user_id).collection('chats').document(chat_data['id'])
chat_ref.set(chat_data, merge=True)
return jsonify({"success": True, "chatId": chat_data['id']})
except Exception as e:
print(f"β Firestore save_chat error: {e}")
return jsonify({"error": f"Failed to save chat: {e}"}), 500
# --- Server Startup Block ---
if __name__ == "__main__":
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)
|