PD03's picture
Update utils/model_trainer.py
4897a44 verified
"""
Embedded Model Training for HF Spaces
Fixed version with dynamic column mapping for SAP SALT dataset
"""
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import joblib
import json
import streamlit as st
from pathlib import Path
from datetime import datetime
class EmbeddedChurnTrainer:
"""Embedded trainer with dynamic column mapping for real SAP SALT data"""
def __init__(self):
self.model_path = Path('models/churn_model_v1.pkl')
self.metadata_path = Path('models/model_metadata.json')
self.model = None
self.label_encoders = {}
self.feature_columns = []
self.column_mapping = {}
def model_exists(self):
"""Check if trained model exists"""
return self.model_path.exists() and self.metadata_path.exists()
@st.cache_data
def load_sap_data(_self):
"""Load real SAP SALT dataset and inspect its structure"""
try:
from datasets import load_dataset
st.info("πŸ”„ Loading SAP SALT dataset from Hugging Face...")
# Load the dataset
dataset = load_dataset("SAP/SALT", split="train")
data_df = dataset.to_pandas()
# Debug: Show actual columns
st.info(f"πŸ“‹ Dataset columns: {list(data_df.columns)}")
st.info(f"πŸ“Š Dataset shape: {data_df.shape}")
# Create column mapping based on available columns
_self.column_mapping = _self._create_column_mapping(data_df.columns)
st.info(f"πŸ”— Column mapping: {_self.column_mapping}")
# Add aggregated fields
data_df = _self._add_aggregated_fields(data_df)
st.success(f"βœ… Loaded {len(data_df)} records from SAP SALT dataset")
return data_df
except ImportError:
st.error("❌ Hugging Face datasets library not available")
raise RuntimeError("datasets library required")
except Exception as e:
if "gated" in str(e).lower() or "authentication" in str(e).lower():
st.error("πŸ” **SAP SALT Dataset Access Required**")
st.info("""
**To access SAP SALT dataset:**
1. Visit: https://huggingface.co/datasets/SAP/SALT
2. Click "Agree and access repository"
3. Add HF token to Space secrets: `HF_TOKEN`
4. Restart the Space
""")
else:
st.error(f"❌ Failed to load SAP SALT dataset: {str(e)}")
raise
def _create_column_mapping(self, available_columns):
"""Create mapping from expected columns to available columns"""
cols = [col.upper() for col in available_columns] # Convert to uppercase for matching
available_upper = {col.upper(): col for col in available_columns}
mapping = {}
# Map customer identifier
customer_candidates = ['CUSTOMER', 'SOLDTOPARTY', 'CUSTOMERID', 'CUSTOMER_ID']
for candidate in customer_candidates:
if candidate in cols:
mapping['Customer'] = available_upper[candidate]
break
else:
mapping['Customer'] = available_columns[0] if available_columns else 'Customer' # Fallback
# Map customer name
name_candidates = ['CUSTOMERNAME', 'CUSTOMER_NAME', 'NAME', 'COMPANYNAME']
for candidate in name_candidates:
if candidate in cols:
mapping['CustomerName'] = available_upper[candidate]
break
else:
mapping['CustomerName'] = None
# Map country
country_candidates = ['COUNTRY', 'COUNTRYKEY', 'COUNTRY_CODE', 'LAND1']
for candidate in country_candidates:
if candidate in cols:
mapping['Country'] = available_upper[candidate]
break
else:
mapping['Country'] = None
# Map customer group
group_candidates = ['CUSTOMERGROUP', 'CUSTOMER_GROUP', 'CUSTOMERCLASSIFICATION', 'KTOKD']
for candidate in group_candidates:
if candidate in cols:
mapping['CustomerGroup'] = available_upper[candidate]
break
else:
mapping['CustomerGroup'] = None
# Map sales document
doc_candidates = ['SALESDOCUMENT', 'SALES_DOCUMENT', 'VBELN', 'DOCUMENTNUMBER']
for candidate in doc_candidates:
if candidate in cols:
mapping['SalesDocument'] = available_upper[candidate]
break
else:
mapping['SalesDocument'] = None
# Map creation date
date_candidates = ['CREATIONDATE', 'CREATION_DATE', 'ERDAT', 'REQUESTEDDELIVERYDATE', 'DATE']
for candidate in date_candidates:
if candidate in cols:
mapping['CreationDate'] = available_upper[candidate]
break
else:
mapping['CreationDate'] = None
return mapping
def _add_aggregated_fields(self, data):
"""Add customer-level aggregations using dynamic column mapping"""
# Get actual column names
customer_col = self.column_mapping.get('Customer')
date_col = self.column_mapping.get('CreationDate')
sales_doc_col = self.column_mapping.get('SalesDocument')
if not customer_col:
st.error("❌ No customer identifier column found")
raise ValueError("Cannot identify customer column")
# Customer-level aggregations
agg_dict = {}
if sales_doc_col:
agg_dict[sales_doc_col] = 'count'
if date_col:
agg_dict[date_col] = ['min', 'max']
if not agg_dict:
# If no aggregation columns available, create dummy data
data['total_orders'] = 1
data['first_order_date'] = '2024-01-01'
data['last_order_date'] = '2024-01-01'
else:
customer_aggs = data.groupby(customer_col).agg(agg_dict).reset_index()
# Flatten column names
new_cols = [customer_col]
if sales_doc_col:
new_cols.append('total_orders')
if date_col:
new_cols.extend(['first_order_date', 'last_order_date'])
customer_aggs.columns = new_cols
# Merge back to original data
data = data.merge(customer_aggs, on=customer_col, how='left')
# Standardize column names for downstream processing
rename_dict = {}
for standard_name, actual_name in self.column_mapping.items():
if actual_name and actual_name in data.columns:
rename_dict[actual_name] = standard_name
if rename_dict:
data = data.rename(columns=rename_dict)
return data
def train_model_if_needed(self):
"""Train model with proper error handling"""
if self.model_exists():
return self.load_existing_metadata()
progress_bar = st.progress(0)
status_text = st.empty()
try:
# Step 1: Load SAP SALT data
status_text.text("πŸ“₯ Loading SAP SALT dataset...")
progress_bar.progress(20)
data = self.load_sap_data()
# Step 2: Feature engineering
status_text.text("πŸ”§ Engineering features...")
progress_bar.progress(40)
features_data = self.engineer_features(data)
# Step 3: Train model
status_text.text("πŸ‹οΈ Training ML model...")
progress_bar.progress(60)
metrics = self.train_model(features_data)
# Step 4: Save model
status_text.text("πŸ’Ύ Saving model...")
progress_bar.progress(80)
self.save_model_artifacts(metrics)
# Complete
progress_bar.progress(100)
status_text.text("βœ… Model training complete!")
return metrics
except Exception as e:
st.error(f"❌ Training failed: {str(e)}")
raise
def engineer_features(self, data):
"""Feature engineering with dynamic column handling"""
try:
# Identify available columns for customer aggregation
agg_cols = ['Customer'] # Always need customer ID
optional_cols = ['CustomerName', 'Country', 'CustomerGroup']
for col in optional_cols:
if col in data.columns and data[col].notna().any():
agg_cols.append(col)
# Customer-level aggregation with only available columns
agg_dict = {}
for col in agg_cols:
if col != 'Customer':
agg_dict[col] = 'first'
# Add order-related aggregations
if 'total_orders' in data.columns:
agg_dict['total_orders'] = 'first'
if 'first_order_date' in data.columns:
agg_dict['first_order_date'] = 'first'
if 'last_order_date' in data.columns:
agg_dict['last_order_date'] = 'first'
customer_features = data.groupby('Customer').agg(agg_dict).reset_index()
# Handle dates safely
reference_date = pd.to_datetime('2024-12-31')
if 'last_order_date' in customer_features.columns:
customer_features['last_order_date'] = pd.to_datetime(customer_features['last_order_date'], errors='coerce')
customer_features['Recency'] = (reference_date - customer_features['last_order_date']).dt.days
else:
customer_features['Recency'] = 100 # Default recency
if 'first_order_date' in customer_features.columns:
customer_features['first_order_date'] = pd.to_datetime(customer_features['first_order_date'], errors='coerce')
customer_features['Tenure'] = (reference_date - customer_features['first_order_date']).dt.days
else:
customer_features['Tenure'] = 365 # Default tenure
# RFM Features with safe handling
customer_features['Recency'] = customer_features['Recency'].fillna(365).clip(0, 3650)
if 'total_orders' in customer_features.columns:
customer_features['Frequency'] = customer_features['total_orders'].fillna(1).clip(1, 1000)
else:
customer_features['Frequency'] = 1 # Default frequency
customer_features['Monetary'] = (customer_features['Frequency'] * 500).clip(100, 1000000)
customer_features['Tenure'] = customer_features['Tenure'].fillna(365).clip(1, 3650)
# Safe OrderVelocity calculation
tenure_months = customer_features['Tenure'] / 30 + 1
customer_features['OrderVelocity'] = (customer_features['Frequency'] / tenure_months).clip(0, 50)
# Categorical encoding only for available columns
self.label_encoders = {}
categorical_features = []
for col in ['Country', 'CustomerGroup']:
if col in customer_features.columns and customer_features[col].notna().any():
try:
self.label_encoders[col] = LabelEncoder()
customer_features[f'{col}_encoded'] = self.label_encoders[col].fit_transform(
customer_features[col].fillna('Unknown')
)
categorical_features.append(f'{col}_encoded')
except Exception as e:
st.warning(f"⚠️ Could not encode {col}: {str(e)}")
# Target variable (churn definition)
customer_features['IsChurned'] = (
(customer_features['Recency'] > 90) &
(customer_features['Frequency'] > 1)
).astype(int)
# Define feature columns
self.feature_columns = ['Recency', 'Frequency', 'Monetary', 'Tenure', 'OrderVelocity']
self.feature_columns.extend(categorical_features)
# Prepare final dataset
required_cols = self.feature_columns + ['IsChurned', 'Customer']
# Add CustomerName if available
if 'CustomerName' in customer_features.columns:
required_cols.append('CustomerName')
# Filter to only existing columns
available_cols = [col for col in required_cols if col in customer_features.columns]
final_data = customer_features[available_cols].copy()
# **CRITICAL: Clean all data**
for col in self.feature_columns:
if col in final_data.columns:
final_data[col] = final_data[col].replace([np.inf, -np.inf], np.nan).fillna(0)
final_data[col] = final_data[col].clip(-1e9, 1e9)
st.info(f"βœ… Features engineered: {self.feature_columns}")
st.info(f"πŸ“Š Final dataset shape: {final_data.shape}")
return final_data
except Exception as e:
st.error(f"Feature engineering failed: {str(e)}")
st.info(f"Available columns: {list(data.columns)}")
raise
def train_model(self, data):
"""Train model with additional validation"""
try:
# Ensure all feature columns exist
missing_features = [col for col in self.feature_columns if col not in data.columns]
if missing_features:
st.warning(f"⚠️ Missing features: {missing_features}")
# Use only available features
self.feature_columns = [col for col in self.feature_columns if col in data.columns]
if not self.feature_columns:
raise ValueError("No valid features available for training")
X = data[self.feature_columns].copy()
y = data['IsChurned'].copy()
# Final data cleaning
if not np.isfinite(X).all().all():
X = X.replace([np.inf, -np.inf], np.nan).fillna(0)
# Check data quality
if len(X) < 50:
raise ValueError(f"Insufficient training data: {len(X)} samples")
if y.nunique() < 2:
st.warning("⚠️ Creating artificial target variation for training...")
# Create some variation for model training
variation_size = len(y) // 4
y.iloc[:variation_size] = 1 - y.iloc[:variation_size]
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42,
stratify=y if y.nunique() > 1 else None
)
# Train model
self.model = RandomForestClassifier(
n_estimators=50,
max_depth=8,
min_samples_split=20,
min_samples_leaf=10,
class_weight='balanced',
random_state=42,
n_jobs=1
)
self.model.fit(X_train, y_train)
# Evaluate
train_score = self.model.score(X_train, y_train)
test_score = self.model.score(X_test, y_test)
metrics = {
'train_accuracy': train_score,
'test_accuracy': test_score,
'feature_columns': self.feature_columns,
'training_samples': len(X_train),
'test_samples': len(X_test),
'churn_rate': float(y.mean()),
'feature_importance': dict(zip(self.feature_columns, self.model.feature_importances_)),
'column_mapping': self.column_mapping
}
st.success(f"βœ… Model trained successfully! Accuracy: {test_score:.3f}")
return metrics
except Exception as e:
st.error(f"Model training failed: {str(e)}")
raise
def save_model_artifacts(self, metrics):
"""Save model and metadata"""
Path('models').mkdir(exist_ok=True)
model_data = {
'model': self.model,
'label_encoders': self.label_encoders,
'feature_columns': self.feature_columns,
'column_mapping': self.column_mapping,
'version': 'v1',
'training_date': datetime.now().isoformat()
}
joblib.dump(model_data, self.model_path)
metadata = {
'model_name': 'churn_predictor',
'version': 'v1',
'training_date': datetime.now().isoformat(),
'metrics': metrics,
'status': 'trained',
'data_source': 'SAP/SALT dataset from Hugging Face',
'column_mapping': self.column_mapping
}
with open(self.metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
def load_existing_metadata(self):
"""Load existing model metadata"""
try:
with open(self.metadata_path, 'r') as f:
return json.load(f)
except Exception:
return None