Heating_mantles_SV / train_model.py
neerajkalyank's picture
Update train_model.py
57d763d verified
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, classification_report
from joblib import dump
import pandas as pd
import numpy as np
import os
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Define paths
DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "enhanced_mantle_training.csv")
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "heating_model_with_risk_score.pkl")
# Load the enhanced dataset
try:
df = pd.read_csv(DATA_PATH)
logger.info("Dataset loaded successfully!")
logger.info(f"Dataset head:\n{df.head().to_string()}")
except FileNotFoundError:
logger.error(f"Error: '{DATA_PATH}' not found. Please generate the dataset using generate_data.py.")
exit(1)
# Ensure required columns
required_columns = ["temperature", "duration", "risk_level"]
if not all(col in df.columns for col in required_columns):
logger.error(f"Error: Missing one or more required columns. Ensure the dataset contains {required_columns}.")
exit(1)
# Check for missing values
if df.isnull().any().any():
logger.warning("Dataset contains missing values. Dropping rows with missing data.")
df = df.dropna()
# Prepare features and target
X = df[["temperature", "duration"]]
y = df["risk_level"]
# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# Initialize Random Forest model
model = RandomForestClassifier(random_state=42)
# Define hyperparameter grid for tuning
param_grid = {
'n_estimators': [100, 200],
'max_depth': [None, 10, 20],
'min_samples_split': [2, 5],
'min_samples_leaf': [1, 2]
}
# Perform GridSearchCV
try:
grid_search = GridSearchCV(model, param_grid, cv=5, scoring='accuracy', n_jobs=-1)
grid_search.fit(X_train, y_train)
logger.info("Grid search completed successfully!")
logger.info(f"Best parameters: {grid_search.best_params_}")
logger.info(f"Best cross-validation accuracy: {grid_search.best_score_:.4f}")
except Exception as e:
logger.error(f"Error during model training: {e}")
exit(1)
# Use the best model
best_model = grid_search.best_estimator_
# Evaluate on test set
y_pred = best_model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
logger.info(f"Test set accuracy: {accuracy:.4f}")
logger.info("\nClassification Report:")
logger.info(classification_report(y_test, y_pred))
# Ensure accuracy > 95%
if accuracy < 0.95:
logger.warning("Model accuracy is below 95%. Consider generating more data or adjusting model parameters.")
# Save the best model
try:
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
dump(best_model, MODEL_PATH)
logger.info(f"Model training complete! Model saved as '{MODEL_PATH}'.")
except Exception as e:
logger.error(f"Error saving the model: {e}")
exit(1)