Spaces:
Sleeping
Sleeping
import joblib | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
# 10 condition-based questions | |
condition_questions = [ | |
"Do you have a cough?", | |
"Do you feel shortness of breath?", | |
"Are you experiencing chest pain?", | |
"Do you smoke?", | |
"Do you have a fever?", | |
"Do you have fatigue?", | |
"Have you had recent respiratory infection?", | |
"Do you have a family history of lung issues?", | |
"Do you feel wheezing or noisy breathing?", | |
"Have you been exposed to pollution or chemicals recently?" | |
] | |
# Load the tabular ML model | |
chest_model = joblib.load("chest_model.joblib") | |
# OPTIONAL: Load CNN model for image (replace with your model class) | |
# cnn_model = torch.load("cnn_model.pth", map_location=torch.device('cpu')) | |
# cnn_model.eval() | |
# Encoding helpers | |
def encode_gender(gender): | |
return 0 if gender == "Male" else 1 | |
def encode_view_position(position): | |
return 0 if position == "PA" else 1 | |
# Main prediction function | |
def predict_chest(age, gender, view_position, conditions, uploaded_image=None): | |
""" | |
Predicts chest disease from tabular data or uploaded image. | |
Parameters: | |
- age: int | |
- gender: 'Male' or 'Female' | |
- view_position: 'PA' or 'AP' | |
- conditions: list of 10 binary values (0 or 1) | |
- uploaded_image: image file from Streamlit uploader (optional) | |
Returns: | |
- prediction string | |
""" | |
# ===== IMAGE-BASED PREDICTION (optional) ===== | |
if uploaded_image is not None: | |
image = Image.open(uploaded_image).convert("RGB") | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # match your CNN input size | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]) # update channels if needed | |
]) | |
input_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
# ===== Example Placeholder Logic ===== | |
# You should replace this with actual CNN model inference | |
# output = cnn_model(input_tensor) | |
# predicted = torch.argmax(output, dim=1).item() | |
# return "Chest Disease Detected" if predicted == 1 else "No Chest Disease Detected" | |
return "Chest Disease Detected (image-based)" # placeholder for now | |
# ===== TABULAR PREDICTION ===== | |
gender_encoded = encode_gender(gender) | |
position_encoded = encode_view_position(view_position) | |
if len(conditions) != 10: | |
raise ValueError("Expected 10 binary values for conditions.") | |
features = np.array([[age, gender_encoded, position_encoded] + conditions]) | |
prediction = chest_model.predict(features)[0] | |
return "Chest Disease Detected" if prediction == 1 else "No Chest Disease Detected" |