Heating_mantles_SV / risk_model.py
neerajkalyank's picture
Update risk_model.py
4e4a014 verified
import numpy as np
from joblib import load
import os
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Define model path
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "heating_model_with_risk_score.pkl")
# Load the trained model
try:
model = load(MODEL_PATH)
logger.info(f"Model loaded successfully from '{MODEL_PATH}'.")
except FileNotFoundError:
logger.error(f"Trained model '{MODEL_PATH}' not found. Please run train_model.py first.")
raise FileNotFoundError(f"Trained model '{MODEL_PATH}' not found. Please run train_model.py first.")
def predict_risk(temp, duration):
"""
Predicts the risk level, risk score, and alert using the trained Random Forest model.
Args:
temp (float): The maximum temperature of the heating mantle.
duration (float): The duration the mantle is used.
Returns:
tuple: (risk_level, risk_score, alert)
"""
# Input validation
if not (50 <= temp <= 200):
logger.error("Temperature must be between 50 and 200°C.")
raise ValueError("Temperature must be between 50 and 200°C.")
if not (5 <= duration <= 120):
logger.error("Duration must be between 5 and 120 minutes.")
raise ValueError("Duration must be between 5 and 120 minutes.")
# Prepare input for model
input_data = np.array([[temp, duration]])
# Predict risk level
try:
risk_level = model.predict(input_data)[0]
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise
# Get probability scores
try:
probabilities = model.predict_proba(input_data)[0]
class_order = model.classes_
risk_score = round(probabilities[np.where(class_order == risk_level)][0] * 100, 2)
except Exception as e:
logger.error(f"Error calculating risk score: {e}")
raise
# Assign alert based on risk level
alert_map = {"Low": "Safe", "Moderate": "Risk", "High": "High Risk"}
alert = alert_map.get(risk_level, "Unknown")
logger.info(f"Prediction: Risk Level={risk_level}, Risk Score={risk_score}%, Alert={alert}")
return risk_level, risk_score, alert