Spaces:
Runtime error
Runtime error
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.model_selection import train_test_split, GridSearchCV | |
from sklearn.metrics import accuracy_score, classification_report | |
from joblib import dump | |
import pandas as pd | |
import numpy as np | |
import os | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# Define paths | |
DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "enhanced_mantle_training.csv") | |
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "heating_model_with_risk_score.pkl") | |
# Load the enhanced dataset | |
try: | |
df = pd.read_csv(DATA_PATH) | |
logger.info("Dataset loaded successfully!") | |
logger.info(f"Dataset head:\n{df.head().to_string()}") | |
except FileNotFoundError: | |
logger.error(f"Error: '{DATA_PATH}' not found. Please generate the dataset using generate_data.py.") | |
exit(1) | |
# Ensure required columns | |
required_columns = ["temperature", "duration", "risk_level"] | |
if not all(col in df.columns for col in required_columns): | |
logger.error(f"Error: Missing one or more required columns. Ensure the dataset contains {required_columns}.") | |
exit(1) | |
# Check for missing values | |
if df.isnull().any().any(): | |
logger.warning("Dataset contains missing values. Dropping rows with missing data.") | |
df = df.dropna() | |
# Prepare features and target | |
X = df[["temperature", "duration"]] | |
y = df["risk_level"] | |
# Split into training and testing sets | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) | |
# Initialize Random Forest model | |
model = RandomForestClassifier(random_state=42) | |
# Define hyperparameter grid for tuning | |
param_grid = { | |
'n_estimators': [100, 200], | |
'max_depth': [None, 10, 20], | |
'min_samples_split': [2, 5], | |
'min_samples_leaf': [1, 2] | |
} | |
# Perform GridSearchCV | |
try: | |
grid_search = GridSearchCV(model, param_grid, cv=5, scoring='accuracy', n_jobs=-1) | |
grid_search.fit(X_train, y_train) | |
logger.info("Grid search completed successfully!") | |
logger.info(f"Best parameters: {grid_search.best_params_}") | |
logger.info(f"Best cross-validation accuracy: {grid_search.best_score_:.4f}") | |
except Exception as e: | |
logger.error(f"Error during model training: {e}") | |
exit(1) | |
# Use the best model | |
best_model = grid_search.best_estimator_ | |
# Evaluate on test set | |
y_pred = best_model.predict(X_test) | |
accuracy = accuracy_score(y_test, y_pred) | |
logger.info(f"Test set accuracy: {accuracy:.4f}") | |
logger.info("\nClassification Report:") | |
logger.info(classification_report(y_test, y_pred)) | |
# Ensure accuracy > 95% | |
if accuracy < 0.95: | |
logger.warning("Model accuracy is below 95%. Consider generating more data or adjusting model parameters.") | |
# Save the best model | |
try: | |
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) | |
dump(best_model, MODEL_PATH) | |
logger.info(f"Model training complete! Model saved as '{MODEL_PATH}'.") | |
except Exception as e: | |
logger.error(f"Error saving the model: {e}") | |
exit(1) |