Heating_mantles / train_model.py
Sirivennela's picture
Update train_model.py
a42bd42 verified
# train_model.py
from sklearn.ensemble import RandomForestClassifier
from joblib import dump
import pandas as pd
# Load the enhanced dataset
try:
df = pd.read_csv("enhanced_mantle_training.csv") # Ensure the path to the CSV file is correct
print("Dataset loaded successfully!")
print(df.head()) # This will print the first few rows of the data for verification
except FileNotFoundError:
print("Error: 'enhanced_mantle_training.csv' not found. Please check the file path.")
exit()
# Ensure that the necessary columns are present in the dataset
required_columns = ["temperature", "duration", "risk_level"]
if not all(col in df.columns for col in required_columns):
print(f"Error: Missing one or more required columns. Ensure the dataset contains {required_columns}.")
exit()
# Check for any missing values in the data
if df.isnull().any().any():
print("Warning: Dataset contains missing values. Please clean the data.")
# Optionally, you can fill missing values or drop rows with missing data
df = df.dropna() # Drop rows with missing data, or use df.fillna() to fill missing values
print("Missing values have been dropped.")
# Prepare the features (temperature, duration) and target (risk_level)
X = df[["temperature", "duration"]] # Features: temperature and duration
# Convert risk_level to numeric for training the model (Low=0, Moderate=1, High=2)
y = df["risk_level"].map({"Low": 0, "Moderate": 1, "High": 2}) # Target: risk_level
# Check if data is being prepared correctly
print(f"Prepared {len(X)} rows for training.")
# Initialize the Random Forest model
model = RandomForestClassifier()
# Train the model
try:
model.fit(X, y)
print("Model trained successfully!")
except Exception as e:
print(f"Error during model training: {e}")
exit()
# Save the trained model to a file
try:
dump(model, "heating_model_with_risk_score.pkl")
print("Model training complete! Model saved as 'heating_model_with_risk_score.pkl'.")
except Exception as e:
print(f"Error saving the model: {e}")
exit()