File size: 15,250 Bytes
8854fc0
 
 
 
 
 
 
 
 
 
 
 
cb0f77f
8854fc0
 
 
35495ed
cb0f77f
 
 
 
 
 
 
 
5af0a86
 
cb0f77f
 
5af0a86
35495ed
8854fc0
 
 
 
 
cb0f77f
 
 
 
 
 
 
 
 
8854fc0
 
 
 
cb0f77f
8854fc0
 
 
cb0f77f
8854fc0
 
 
 
cb0f77f
8854fc0
 
cb0f77f
24cbb20
 
 
8854fc0
35730ed
8854fc0
cb0f77f
35730ed
cb0f77f
35730ed
cb0f77f
35730ed
 
 
 
 
68e4176
 
 
 
 
cb0f77f
0398b1b
68e4176
cb0f77f
 
0398b1b
cb0f77f
35730ed
54f3d72
9a28be7
cb0f77f
35730ed
cb0f77f
9a28be7
 
 
 
 
 
 
 
 
35730ed
 
9a28be7
 
 
 
 
 
 
 
 
 
35730ed
9a28be7
 
 
 
fce139e
9a28be7
 
 
 
 
fce139e
9a28be7
 
 
 
 
 
 
 
 
35730ed
9a28be7
 
 
35730ed
 
8df30bc
710638b
35730ed
 
3a2f487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35730ed
3a2f487
35730ed
 
 
2325d37
c9db41c
35730ed
f811c0c
c9db41c
 
cb0f77f
c9db41c
cb0f77f
 
 
 
 
 
 
 
35730ed
c9db41c
cb0f77f
 
c9db41c
 
9a28be7
 
 
 
 
 
 
 
 
 
 
 
fce139e
9a28be7
 
 
 
fce139e
9a28be7
 
fce139e
9a28be7
 
c9db41c
 
 
 
c5507d4
ae55fb3
7840638
9a28be7
 
 
 
 
 
cb0f77f
9a28be7
c5507d4
9a28be7
 
35730ed
 
0398b1b
9a28be7
 
cb0f77f
9a28be7
 
cb0f77f
c9db41c
9a28be7
 
35730ed
9a28be7
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2f487
9a28be7
fce139e
9a28be7
 
3a2f487
 
 
 
 
9a28be7
 
 
 
 
 
 
c9db41c
cb0f77f
9a28be7
 
 
8854fc0
 
cb0f77f
 
5af0a86
 
cb0f77f
 
 
8854fc0
 
 
c5507d4
8854fc0
cb0f77f
 
 
8854fc0
 
 
cb0f77f
 
5af0a86
 
 
cb0f77f
 
 
8854fc0
cb0f77f
 
 
8854fc0
 
 
cb0f77f
 
 
8854fc0
54f3d72
8854fc0
35730ed
c9db41c
0398b1b
3a2f487
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import os
import joblib
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import google.generativeai as genai
from dotenv import load_dotenv
import firebase_admin
from firebase_admin import credentials, firestore
import datetime
import json
from huggingface_hub import hf_hub_download

# --- App and Environment Initialization ---
app = Flask(__name__)

@app.route("/", methods=["GET"])
def health_check():
    return jsonify({
        "status": "ok",
        "message": "Health AI Backend is running πŸš€"
    }), 200


# --- CORS CONFIGURATION ---
CORS(app, resources={r"/*": {"origins": [
    "https://health-app-lilac.vercel.app",
    "http://localhost:3000"
]}})

load_dotenv()

# --- Firebase Initialization ---
try:
    service_account_str = os.getenv("FIREBASE_SERVICE_ACCOUNT")
    if service_account_str:
        service_account_info = json.loads(service_account_str)
        cred = credentials.Certificate(service_account_info)
    else:
        cred = credentials.Certificate("firebase_key.json")
    
    if not firebase_admin._apps:
        firebase_admin.initialize_app(cred)
    
    db = firestore.client()
    print("βœ… Firebase initialized successfully.")
except Exception as e:
    print(f"❌ Firebase initialization failed: {e}")
    db = None

# --- Gemini API Configuration ---
try:
    genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
    print("✨ Gemini API configured.")
except Exception as e:
    print(f"❌ Gemini configuration failed: {e}")

# --- Global variables to hold the loaded models ---
local_model, tokenizer, label_encoder = None, None, None
device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_FOLDER_PATH = "./models/finetuned_model" 
LABEL_ENCODER_PATH = "./models/label_encoder.joblib"

def load_models():
    """Loads the model, tokenizer, and label encoder."""
    global local_model, tokenizer, label_encoder
    if local_model is not None:
        return None
        
    print("πŸ“‚ First request received. Starting model loading process...")
    try:
        if not os.path.exists(MODEL_FOLDER_PATH) or not os.path.exists(LABEL_ENCODER_PATH):
             error_msg = "❌ Model files not found locally."
             print(error_msg)
             return error_msg
        
        local_model = AutoModelForSequenceClassification.from_pretrained(MODEL_FOLDER_PATH)
        tokenizer = AutoTokenizer.from_pretrained(MODEL_FOLDER_PATH)
        label_encoder = joblib.load(LABEL_ENCODER_PATH)
        local_model.to(device)
        local_model.eval()
        print(f"βœ… Models loaded successfully and moved to {device}.")
        return None
    except Exception as e:
        error_msg = f"❌ CRITICAL ERROR during model loading: {e}"
        print(error_msg)
        return error_msg
    
# --- Prompts and AI Logic ---

def get_doctor_persona_prompt(stage, user_details=None, local_predictions=None):
    """
    Generates a system prompt tailored to the specific stage of the conversation.
    """
    # Use real-time data instead of hardcoded values
    try:
        from zoneinfo import ZoneInfo
        tz = ZoneInfo("Asia/Kolkata")
        now = datetime.datetime.now(tz)
    except Exception:
        now = datetime.datetime.now() # Fallback to server's local time

    time_str = now.strftime("%A, %B %d, %Y at %I:%M %p %Z")
    location = user_details.get('location', 'the user\'s area') if user_details else 'the user\'s area'

    # --- STAGE 1: Initial Greeting ---
    if stage == 'awaiting_name':
        return """
        **SYSTEM INSTRUCTION:** You are Dr. Aether, a friendly and professional AI medical assistant.
        Your ONLY task is to greet the user warmly and ask for their first name to begin the consultation.
        Example: "Hello, I'm Dr. Aether. To get started, could you please tell me your full name?"
        """

    # --- STAGE 2 & 3: Unified Triage, Inquiry, and Conversation ---
    if stage in ['awaiting_age', 'awaiting_sex', 'awaiting_symptoms', 'chatting']:
        return f"""
        **SYSTEM INSTRUCTION: ACT AS DR. AETHER - MEDICAL TRIAGE & INQUIRY**
        **Persona:** You are Dr. Aether, an experienced and empathetic AI physician.
        **User Context:** The user is in {location}. The current time is {time_str}.
        **Your Task:** Conduct a medical consultation. First, gather essential patient information, then ask clarifying medical questions.

        **Information Checklist (Review the chat history to see what is missing):**
        - [ ] Patient's Name
        - [ ] Patient's Age
        - [ ] Patient's Biological Sex
        - [ ] Primary Symptoms

        **CRITICAL Directives & Flow:**
        1.  **Check the List:** Review the entire chat history to determine the FIRST piece of information from the checklist that is still missing.
        2.  **Ask for ONE Missing Item:** Your response must ask for ONLY that single missing item.
            - If Age is missing, ask for it using `[CHIPS: ["Under 18", "18-64", "65+"]]`.
            - If Sex is missing, ask for it using `[CHIPS: ["Male", "Female"]]`.
            - If Symptoms are missing, ask for them with an open-ended question.
        3.  **Transition to Medical Inquiry:** ONLY when all four items on the checklist have been gathered, you must transition to asking focused medical questions to clarify the symptoms (e.g., "How long have you had these symptoms?", "Is the cough dry or productive?").
        4.  **One Question at a Time:** ALWAYS ask only ONE question per turn.
        5.  **Conclude Inquiry:** Once you have gathered sufficient details about the symptoms (e.g., duration, type, severity, associated symptoms), your very next response MUST BE ONLY the token `[SUMMARY_READY]`. Do not ask "is there anything else?". Just output the token.
        """
    
    # --- STAGE 4: Final Summary Generation ---
    if stage == 'process_symptoms':
        predictions_text = "No preliminary analysis was available."
        if local_predictions:
            formatted = [f"- {p.get('disease', 'N/A')} ({p.get('confidence', 0):.0%})" for p in local_predictions]
            predictions_text = "My preliminary analysis suggests the following possibilities:\n" + "\n".join(formatted)

        return f"""
        **SYSTEM INSTRUCTION: ACT AS DR. AETHER - FINAL SUMMARY GENERATOR**
        **CONTEXT:**
        - **User Location:** {location}
        - **Local Model Analysis:** {predictions_text}
        - **Conversation History:** (The full history will be provided after this prompt)

        **YOUR ONLY TASK:**
        Review the conversation history and the local model analysis, then generate a final summary in the required JSON format. Your entire output must be ONLY the JSON block, starting with `[SUMMARY:` and ending with `}}]`.

        **CRITICAL JSON RULES:**
        1. All text inside JSON string values (e.g., in "recap") MUST be a single line. DO NOT use unescaped newline characters.
        2. Fill every field with medically sound, empathetic, and relevant information based on the conversation.

        **EXAMPLE OUTPUT STRUCTURE:**
        [SUMMARY: {{"recap": "Satvik, a user in the 18-64 age range, reports a 3-day history of high fever, mild dry cough, and nightly chills, with an associated sore throat.", "possibilities": "Based on the symptoms and preliminary analysis, this could suggest a viral upper respiratory infection like the common cold or influenza. The local model noted a possibility of pneumonia (48%), which should be considered, especially given the fever and cough.", "homeCare": ["Rest as much as possible to help your body fight the infection.", "Stay hydrated by drinking plenty of water, broth, or tea.", "Use over-the-counter medications like acetaminophen or ibuprofen for fever and aches, following package directions."], "recommendation": "It is recommended you consult a doctor in {location} within the next 24-48 hours for a proper diagnosis, especially to rule out a bacterial infection like pneumonia.", "conclusion": "I hope this information has been helpful. Please follow up with a healthcare professional as recommended. Is there anything else I can assist you with?"}}]
        
        **Now, generate the JSON for the current user's conversation history.**
        """
    
    return "You are a helpful assistant." # Fallback

def run_local_prediction(symptoms_text):
    """Runs the local model to get disease predictions."""
    if not all([local_model, tokenizer, label_encoder]):
        print("❌ Cannot run local prediction: model components not loaded.")
        return None
    try:
        inputs = tokenizer(symptoms_text, return_tensors="pt", truncation=True, padding=True).to(device)
        with torch.no_grad():
            outputs = local_model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
        top_probs, top_indices = torch.topk(probabilities, 3)
        predictions = [
            {"disease": label_encoder.classes_[idx], "confidence": float(prob)}
            for idx, prob in zip(top_indices.cpu().numpy()[0], top_probs.cpu().numpy()[0])
        ]
        print(f"βœ… Local model predictions: {predictions}")
        return predictions
    except Exception as e:
        print(f"❌ Local prediction error: {e}")
        return None

def _get_symptoms_from_history(history, model):
    """Uses Gemini to extract a clean symptom string from the chat history."""
    print("πŸ‘‰ Asking Gemini to extract symptom string for local model...")
    symptom_extraction_prompt = (
        "Review the user messages in this chat history. Extract all medical symptoms, durations, and key descriptors. "
        "Consolidate them into a clean, comma-separated string for a machine learning model. "
        "Example: 'fever, dry cough, sore throat, duration 1-3 days, chills at night, 102F temperature'. "
        "Output ONLY the symptom string."
    )
    # This is a more robust way to pass history for this task
    symptom_history = [{'role': 'user', 'parts': [symptom_extraction_prompt]}] + [msg for msg in history if msg['role'] == 'user']
    
    try:
        symptom_response = model.generate_content(symptom_history)
        symptoms_summary = symptom_response.text.strip().replace('\n', ', ')
        print(f"   βœ… Gemini-summarized symptoms: '{symptoms_summary}'")
        return symptoms_summary
    except Exception as e:
        print(f"   ❌ Failed to get symptoms from history: {e}")
        return ""

@app.route('/chat', methods=['POST'])
def chat():
    load_error = load_models()
    if load_error:
        return jsonify({"error": f"Model loading failed: {load_error}"}), 503

    data = request.get_json() or {}
    history = data.get('history', [])
    location = data.get('location', 'your area') # Use location from frontend
    
    # Determine stage based on history length; more robust than trusting client
    stage = 'awaiting_name'
    if len(history) > 1:
        stage = data.get('stage', 'chatting')
    
    if not history:
        return jsonify({"error": "Chat history not provided."}), 400

    user_details = {"location": location}
    
    try:
        model = genai.GenerativeModel('gemini-1.5-flash')
        
        # --- Primary Conversational Call ---
        system_prompt = get_doctor_persona_prompt(stage, user_details)
        conversation_history = [
            {'role': 'user', 'parts': [system_prompt]},
            {'role': 'model', 'parts': ["Understood. I will act as Dr. Aether and follow all instructions for my current stage."]}
        ] + history
        
        response = model.generate_content(conversation_history)
        response_text = response.text.strip()
        
        # --- Check for Summary Trigger ---
        if response_text == "[SUMMARY_READY]":
            print("βœ… AI signaled summary is ready. Starting final analysis.")
            
            # 1. Summarize symptoms for local model
            symptoms_summary = _get_symptoms_from_history(history, model)
            
            # 2. Run local model
            predictions_result = run_local_prediction(symptoms_summary) if symptoms_summary else []
            
            # 3. Generate the final summary response
            final_summary_prompt = get_doctor_persona_prompt('process_symptoms', user_details, predictions_result)
            final_summary_history = [
                {'role': 'user', 'parts': [final_summary_prompt]},
                {'role': 'model', 'parts': ["Understood. I will now generate the final summary in the required JSON format based on the provided history."]}
            ] + history
            
            final_response = model.generate_content(final_summary_history)
            
            # ADDED DEBUGGING
            print(f"--- DEBUG: Raw Final Summary Response from Gemini ---")
            print(final_response.text)
            print(f"----------------------------------------------------")
            
            response_data = {"reply": final_response.text}
            if predictions_result:
                response_data["predictions"] = predictions_result
            return jsonify(response_data)
        
        else: # It's a normal conversational turn
            return jsonify({"reply": response_text})

    except Exception as e:
        print(f"❌ API error: {e}")
        return jsonify({"error": f"An error occurred with the AI service: {e}"}), 500
    
@app.route('/get_chats', methods=['POST'])
def get_chats():
    if not db:
        return jsonify({"error": "Firestore not initialized."}), 500
    data = request.get_json() or {}
    user_id = data.get('user_id')
    if not user_id:
        return jsonify({"error": "User ID not provided."}), 400

    try:
        chats_ref = db.collection('users').document(user_id).collection('chats')
        query = chats_ref.order_by("timestamp", direction=firestore.Query.DESCENDING)
        chats = [doc.to_dict() for doc in query.stream()]
        return jsonify(chats)
    except Exception as e:
        print(f"❌ Firestore get_chats error: {e}")
        return jsonify({"error": f"Failed to retrieve chats: {e}"}), 500

@app.route('/save_chat', methods=['POST'])
def save_chat():
    if not db:
        return jsonify({"error": "Firestore is not initialized."}), 500
    data = request.get_json() or {}
    user_id = data.get('userId')
    chat_data = data.get('chatData')
    if not user_id or not chat_data:
        return jsonify({"error": "User ID or chat data is missing."}), 400

    try:
        chat_data['timestamp'] = datetime.datetime.fromisoformat(
            chat_data['timestamp'].replace('Z', '+00:00')
        )
        chat_ref = db.collection('users').document(user_id).collection('chats').document(chat_data['id'])
        chat_ref.set(chat_data, merge=True)
        return jsonify({"success": True, "chatId": chat_data['id']})
    except Exception as e:
        print(f"❌ Firestore save_chat error: {e}")
        return jsonify({"error": f"Failed to save chat: {e}"}), 500

# --- Server Startup Block ---
if __name__ == "__main__":
    port = int(os.environ.get('PORT', 7860))
    app.run(host='0.0.0.0', port=port)