File size: 2,724 Bytes
2ed073c
246f1ab
922edb8
 
 
2ed073c
2a22fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
922edb8
246f1ab
2ed073c
922edb8
 
 
 
2a22fb3
af3ac6a
 
246f1ab
 
 
 
2a22fb3
922edb8
246f1ab
922edb8
246f1ab
 
 
 
 
922edb8
246f1ab
2a22fb3
246f1ab
922edb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246f1ab
 
922edb8
246f1ab
 
 
 
 
922edb8
af3ac6a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import joblib
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

# 10 condition-based questions
condition_questions = [
    "Do you have a cough?",
    "Do you feel shortness of breath?",
    "Are you experiencing chest pain?",
    "Do you smoke?",
    "Do you have a fever?",
    "Do you have fatigue?",
    "Have you had recent respiratory infection?",
    "Do you have a family history of lung issues?",
    "Do you feel wheezing or noisy breathing?",
    "Have you been exposed to pollution or chemicals recently?"
]

# Load the tabular ML model
chest_model = joblib.load("chest_model.joblib")

# OPTIONAL: Load CNN model for image (replace with your model class)
# cnn_model = torch.load("cnn_model.pth", map_location=torch.device('cpu'))
# cnn_model.eval()

# Encoding helpers
def encode_gender(gender):
    return 0 if gender == "Male" else 1

def encode_view_position(position):
    return 0 if position == "PA" else 1

# Main prediction function
def predict_chest(age, gender, view_position, conditions, uploaded_image=None):
    """
    Predicts chest disease from tabular data or uploaded image.
    Parameters:
    - age: int
    - gender: 'Male' or 'Female'
    - view_position: 'PA' or 'AP'
    - conditions: list of 10 binary values (0 or 1)
    - uploaded_image: image file from Streamlit uploader (optional)
    Returns:
    - prediction string
    """

    # ===== IMAGE-BASED PREDICTION (optional) =====
    if uploaded_image is not None:
        image = Image.open(uploaded_image).convert("RGB")

        transform = transforms.Compose([
            transforms.Resize((224, 224)),     # match your CNN input size
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # update channels if needed
        ])
        input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

        # ===== Example Placeholder Logic =====
        # You should replace this with actual CNN model inference
        # output = cnn_model(input_tensor)
        # predicted = torch.argmax(output, dim=1).item()
        # return "Chest Disease Detected" if predicted == 1 else "No Chest Disease Detected"
        return "Chest Disease Detected (image-based)"  # placeholder for now

    # ===== TABULAR PREDICTION =====
    gender_encoded = encode_gender(gender)
    position_encoded = encode_view_position(view_position)

    if len(conditions) != 10:
        raise ValueError("Expected 10 binary values for conditions.")

    features = np.array([[age, gender_encoded, position_encoded] + conditions])
    prediction = chest_model.predict(features)[0]

    return "Chest Disease Detected" if prediction == 1 else "No Chest Disease Detected"