Prediction / chest_utils.py
rangerrRed's picture
Update chest_utils.py
af3ac6a verified
import joblib
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
# 10 condition-based questions
condition_questions = [
"Do you have a cough?",
"Do you feel shortness of breath?",
"Are you experiencing chest pain?",
"Do you smoke?",
"Do you have a fever?",
"Do you have fatigue?",
"Have you had recent respiratory infection?",
"Do you have a family history of lung issues?",
"Do you feel wheezing or noisy breathing?",
"Have you been exposed to pollution or chemicals recently?"
]
# Load the tabular ML model
chest_model = joblib.load("chest_model.joblib")
# OPTIONAL: Load CNN model for image (replace with your model class)
# cnn_model = torch.load("cnn_model.pth", map_location=torch.device('cpu'))
# cnn_model.eval()
# Encoding helpers
def encode_gender(gender):
return 0 if gender == "Male" else 1
def encode_view_position(position):
return 0 if position == "PA" else 1
# Main prediction function
def predict_chest(age, gender, view_position, conditions, uploaded_image=None):
"""
Predicts chest disease from tabular data or uploaded image.
Parameters:
- age: int
- gender: 'Male' or 'Female'
- view_position: 'PA' or 'AP'
- conditions: list of 10 binary values (0 or 1)
- uploaded_image: image file from Streamlit uploader (optional)
Returns:
- prediction string
"""
# ===== IMAGE-BASED PREDICTION (optional) =====
if uploaded_image is not None:
image = Image.open(uploaded_image).convert("RGB")
transform = transforms.Compose([
transforms.Resize((224, 224)), # match your CNN input size
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # update channels if needed
])
input_tensor = transform(image).unsqueeze(0) # Add batch dimension
# ===== Example Placeholder Logic =====
# You should replace this with actual CNN model inference
# output = cnn_model(input_tensor)
# predicted = torch.argmax(output, dim=1).item()
# return "Chest Disease Detected" if predicted == 1 else "No Chest Disease Detected"
return "Chest Disease Detected (image-based)" # placeholder for now
# ===== TABULAR PREDICTION =====
gender_encoded = encode_gender(gender)
position_encoded = encode_view_position(view_position)
if len(conditions) != 10:
raise ValueError("Expected 10 binary values for conditions.")
features = np.array([[age, gender_encoded, position_encoded] + conditions])
prediction = chest_model.predict(features)[0]
return "Chest Disease Detected" if prediction == 1 else "No Chest Disease Detected"