Spaces:
Sleeping
Sleeping
# train_model.py | |
from sklearn.ensemble import RandomForestClassifier | |
from joblib import dump | |
import pandas as pd | |
# Load the enhanced dataset | |
try: | |
df = pd.read_csv("enhanced_mantle_training.csv") # Ensure the path to the CSV file is correct | |
print("Dataset loaded successfully!") | |
print(df.head()) # This will print the first few rows of the data for verification | |
except FileNotFoundError: | |
print("Error: 'enhanced_mantle_training.csv' not found. Please check the file path.") | |
exit() | |
# Ensure that the necessary columns are present in the dataset | |
required_columns = ["temperature", "duration", "risk_level"] | |
if not all(col in df.columns for col in required_columns): | |
print(f"Error: Missing one or more required columns. Ensure the dataset contains {required_columns}.") | |
exit() | |
# Check for any missing values in the data | |
if df.isnull().any().any(): | |
print("Warning: Dataset contains missing values. Please clean the data.") | |
# Optionally, you can fill missing values or drop rows with missing data | |
df = df.dropna() # Drop rows with missing data, or use df.fillna() to fill missing values | |
print("Missing values have been dropped.") | |
# Prepare the features (temperature, duration) and target (risk_level) | |
X = df[["temperature", "duration"]] # Features: temperature and duration | |
# Convert risk_level to numeric for training the model (Low=0, Moderate=1, High=2) | |
y = df["risk_level"].map({"Low": 0, "Moderate": 1, "High": 2}) # Target: risk_level | |
# Check if data is being prepared correctly | |
print(f"Prepared {len(X)} rows for training.") | |
# Initialize the Random Forest model | |
model = RandomForestClassifier() | |
# Train the model | |
try: | |
model.fit(X, y) | |
print("Model trained successfully!") | |
except Exception as e: | |
print(f"Error during model training: {e}") | |
exit() | |
# Save the trained model to a file | |
try: | |
dump(model, "heating_model_with_risk_score.pkl") | |
print("Model training complete! Model saved as 'heating_model_with_risk_score.pkl'.") | |
except Exception as e: | |
print(f"Error saving the model: {e}") | |
exit() |