Spaces:
Runtime error
Runtime error
import numpy as np | |
from joblib import load | |
import os | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# Define model path | |
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "heating_model_with_risk_score.pkl") | |
# Load the trained model | |
try: | |
model = load(MODEL_PATH) | |
logger.info(f"Model loaded successfully from '{MODEL_PATH}'.") | |
except FileNotFoundError: | |
logger.error(f"Trained model '{MODEL_PATH}' not found. Please run train_model.py first.") | |
raise FileNotFoundError(f"Trained model '{MODEL_PATH}' not found. Please run train_model.py first.") | |
def predict_risk(temp, duration): | |
""" | |
Predicts the risk level, risk score, and alert using the trained Random Forest model. | |
Args: | |
temp (float): The maximum temperature of the heating mantle. | |
duration (float): The duration the mantle is used. | |
Returns: | |
tuple: (risk_level, risk_score, alert) | |
""" | |
# Input validation | |
if not (50 <= temp <= 200): | |
logger.error("Temperature must be between 50 and 200°C.") | |
raise ValueError("Temperature must be between 50 and 200°C.") | |
if not (5 <= duration <= 120): | |
logger.error("Duration must be between 5 and 120 minutes.") | |
raise ValueError("Duration must be between 5 and 120 minutes.") | |
# Prepare input for model | |
input_data = np.array([[temp, duration]]) | |
# Predict risk level | |
try: | |
risk_level = model.predict(input_data)[0] | |
except Exception as e: | |
logger.error(f"Error during prediction: {e}") | |
raise | |
# Get probability scores | |
try: | |
probabilities = model.predict_proba(input_data)[0] | |
class_order = model.classes_ | |
risk_score = round(probabilities[np.where(class_order == risk_level)][0] * 100, 2) | |
except Exception as e: | |
logger.error(f"Error calculating risk score: {e}") | |
raise | |
# Assign alert based on risk level | |
alert_map = {"Low": "Safe", "Moderate": "Risk", "High": "High Risk"} | |
alert = alert_map.get(risk_level, "Unknown") | |
logger.info(f"Prediction: Risk Level={risk_level}, Risk Score={risk_score}%, Alert={alert}") | |
return risk_level, risk_score, alert |