Heating_mantles_SV / generate_data.py
neerajkalyank's picture
Update generate_data.py
2e267e9 verified
import pandas as pd
import numpy as np
from datetime import datetime
import os
def generate_enhanced_data_v3(num_samples=10000, output_path="enhanced_mantle_training.csv"):
data = []
# Ensure balanced classes: approximately 33% Low, 33% Moderate, 33% High
samples_per_class = num_samples // 3
for _ in range(samples_per_class):
# Low Risk: temp <= 160°C, duration <= 45 min
temp = np.random.randint(50, 161)
duration = np.random.randint(5, 46)
risk_level = "Low"
risk_score = np.random.uniform(0, 40)
alert = "Safe"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
data.append([temp, duration, risk_level, risk_score, alert, timestamp])
for _ in range(samples_per_class):
# Moderate Risk: 161°C <= temp <= 190°C, 46 min <= duration <= 90 min
temp = np.random.randint(161, 191)
duration = np.random.randint(46, 91)
risk_level = "Moderate"
risk_score = np.random.uniform(40, 70)
alert = "Risk"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
data.append([temp, duration, risk_level, risk_score, alert, timestamp])
for _ in range(num_samples - 2 * samples_per_class):
# High Risk: temp > 190°C or duration > 90 min
temp = np.random.randint(191, 201)
duration = np.random.randint(91, 121)
risk_level = "High"
risk_score = np.random.uniform(70, 100)
alert = "High Risk"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
data.append([temp, duration, risk_level, risk_score, alert, timestamp])
# Shuffle the data
np.random.shuffle(data)
# Create DataFrame
df = pd.DataFrame(data, columns=["temperature", "duration", "risk_level", "risk_score", "alert", "timestamp"])
# Save to file
os.makedirs(os.path.dirname(output_path), exist_ok=True)
df.to_csv(output_path, index=False)
print(f"Data generation complete! Dataset saved as '{output_path}'.")
return df
if __name__ == "__main__":
# Use relative path for Hugging Face
output_path = os.path.join(os.path.dirname(__file__), "data", "enhanced_mantle_training.csv")
generate_enhanced_data_v3(10000, output_path)