File size: 18,106 Bytes
c829fa9 4897a44 c829fa9 4897a44 c829fa9 9794e0d 4897a44 c829fa9 4897a44 c829fa9 a3796f2 c829fa9 45d4821 a3796f2 4897a44 45d4821 a3796f2 4897a44 45d4821 a3796f2 45d4821 a3796f2 45d4821 4897a44 a3796f2 4897a44 45d4821 4897a44 45d4821 4897a44 a3796f2 4897a44 45d4821 4897a44 a3796f2 4897a44 a3796f2 4897a44 45d4821 4897a44 a3796f2 c829fa9 45d4821 c829fa9 45d4821 c829fa9 a3796f2 c829fa9 a3796f2 c829fa9 a3796f2 c829fa9 a3796f2 45d4821 c829fa9 4897a44 a3796f2 4897a44 a3796f2 4897a44 a3796f2 4897a44 a3796f2 4897a44 45d4821 4897a44 45d4821 4897a44 a3796f2 4897a44 a3796f2 4897a44 a3796f2 45d4821 4897a44 a3796f2 4897a44 a3796f2 4897a44 45d4821 4897a44 45d4821 4897a44 a3796f2 4897a44 45d4821 4897a44 a3796f2 45d4821 a3796f2 4897a44 45d4821 c829fa9 4897a44 a3796f2 4897a44 45d4821 a3796f2 4897a44 45d4821 a3796f2 4897a44 45d4821 4897a44 45d4821 4897a44 a3796f2 4897a44 a3796f2 4897a44 a3796f2 4897a44 a3796f2 4897a44 a3796f2 4897a44 a3796f2 4897a44 a3796f2 45d4821 c829fa9 45d4821 4897a44 45d4821 4897a44 45d4821 c829fa9 a3796f2 c829fa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 |
"""
Embedded Model Training for HF Spaces
Fixed version with dynamic column mapping for SAP SALT dataset
"""
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import joblib
import json
import streamlit as st
from pathlib import Path
from datetime import datetime
class EmbeddedChurnTrainer:
"""Embedded trainer with dynamic column mapping for real SAP SALT data"""
def __init__(self):
self.model_path = Path('models/churn_model_v1.pkl')
self.metadata_path = Path('models/model_metadata.json')
self.model = None
self.label_encoders = {}
self.feature_columns = []
self.column_mapping = {}
def model_exists(self):
"""Check if trained model exists"""
return self.model_path.exists() and self.metadata_path.exists()
@st.cache_data
def load_sap_data(_self):
"""Load real SAP SALT dataset and inspect its structure"""
try:
from datasets import load_dataset
st.info("π Loading SAP SALT dataset from Hugging Face...")
# Load the dataset
dataset = load_dataset("SAP/SALT", split="train")
data_df = dataset.to_pandas()
# Debug: Show actual columns
st.info(f"π Dataset columns: {list(data_df.columns)}")
st.info(f"π Dataset shape: {data_df.shape}")
# Create column mapping based on available columns
_self.column_mapping = _self._create_column_mapping(data_df.columns)
st.info(f"π Column mapping: {_self.column_mapping}")
# Add aggregated fields
data_df = _self._add_aggregated_fields(data_df)
st.success(f"β
Loaded {len(data_df)} records from SAP SALT dataset")
return data_df
except ImportError:
st.error("β Hugging Face datasets library not available")
raise RuntimeError("datasets library required")
except Exception as e:
if "gated" in str(e).lower() or "authentication" in str(e).lower():
st.error("π **SAP SALT Dataset Access Required**")
st.info("""
**To access SAP SALT dataset:**
1. Visit: https://huggingface.co/datasets/SAP/SALT
2. Click "Agree and access repository"
3. Add HF token to Space secrets: `HF_TOKEN`
4. Restart the Space
""")
else:
st.error(f"β Failed to load SAP SALT dataset: {str(e)}")
raise
def _create_column_mapping(self, available_columns):
"""Create mapping from expected columns to available columns"""
cols = [col.upper() for col in available_columns] # Convert to uppercase for matching
available_upper = {col.upper(): col for col in available_columns}
mapping = {}
# Map customer identifier
customer_candidates = ['CUSTOMER', 'SOLDTOPARTY', 'CUSTOMERID', 'CUSTOMER_ID']
for candidate in customer_candidates:
if candidate in cols:
mapping['Customer'] = available_upper[candidate]
break
else:
mapping['Customer'] = available_columns[0] if available_columns else 'Customer' # Fallback
# Map customer name
name_candidates = ['CUSTOMERNAME', 'CUSTOMER_NAME', 'NAME', 'COMPANYNAME']
for candidate in name_candidates:
if candidate in cols:
mapping['CustomerName'] = available_upper[candidate]
break
else:
mapping['CustomerName'] = None
# Map country
country_candidates = ['COUNTRY', 'COUNTRYKEY', 'COUNTRY_CODE', 'LAND1']
for candidate in country_candidates:
if candidate in cols:
mapping['Country'] = available_upper[candidate]
break
else:
mapping['Country'] = None
# Map customer group
group_candidates = ['CUSTOMERGROUP', 'CUSTOMER_GROUP', 'CUSTOMERCLASSIFICATION', 'KTOKD']
for candidate in group_candidates:
if candidate in cols:
mapping['CustomerGroup'] = available_upper[candidate]
break
else:
mapping['CustomerGroup'] = None
# Map sales document
doc_candidates = ['SALESDOCUMENT', 'SALES_DOCUMENT', 'VBELN', 'DOCUMENTNUMBER']
for candidate in doc_candidates:
if candidate in cols:
mapping['SalesDocument'] = available_upper[candidate]
break
else:
mapping['SalesDocument'] = None
# Map creation date
date_candidates = ['CREATIONDATE', 'CREATION_DATE', 'ERDAT', 'REQUESTEDDELIVERYDATE', 'DATE']
for candidate in date_candidates:
if candidate in cols:
mapping['CreationDate'] = available_upper[candidate]
break
else:
mapping['CreationDate'] = None
return mapping
def _add_aggregated_fields(self, data):
"""Add customer-level aggregations using dynamic column mapping"""
# Get actual column names
customer_col = self.column_mapping.get('Customer')
date_col = self.column_mapping.get('CreationDate')
sales_doc_col = self.column_mapping.get('SalesDocument')
if not customer_col:
st.error("β No customer identifier column found")
raise ValueError("Cannot identify customer column")
# Customer-level aggregations
agg_dict = {}
if sales_doc_col:
agg_dict[sales_doc_col] = 'count'
if date_col:
agg_dict[date_col] = ['min', 'max']
if not agg_dict:
# If no aggregation columns available, create dummy data
data['total_orders'] = 1
data['first_order_date'] = '2024-01-01'
data['last_order_date'] = '2024-01-01'
else:
customer_aggs = data.groupby(customer_col).agg(agg_dict).reset_index()
# Flatten column names
new_cols = [customer_col]
if sales_doc_col:
new_cols.append('total_orders')
if date_col:
new_cols.extend(['first_order_date', 'last_order_date'])
customer_aggs.columns = new_cols
# Merge back to original data
data = data.merge(customer_aggs, on=customer_col, how='left')
# Standardize column names for downstream processing
rename_dict = {}
for standard_name, actual_name in self.column_mapping.items():
if actual_name and actual_name in data.columns:
rename_dict[actual_name] = standard_name
if rename_dict:
data = data.rename(columns=rename_dict)
return data
def train_model_if_needed(self):
"""Train model with proper error handling"""
if self.model_exists():
return self.load_existing_metadata()
progress_bar = st.progress(0)
status_text = st.empty()
try:
# Step 1: Load SAP SALT data
status_text.text("π₯ Loading SAP SALT dataset...")
progress_bar.progress(20)
data = self.load_sap_data()
# Step 2: Feature engineering
status_text.text("π§ Engineering features...")
progress_bar.progress(40)
features_data = self.engineer_features(data)
# Step 3: Train model
status_text.text("ποΈ Training ML model...")
progress_bar.progress(60)
metrics = self.train_model(features_data)
# Step 4: Save model
status_text.text("πΎ Saving model...")
progress_bar.progress(80)
self.save_model_artifacts(metrics)
# Complete
progress_bar.progress(100)
status_text.text("β
Model training complete!")
return metrics
except Exception as e:
st.error(f"β Training failed: {str(e)}")
raise
def engineer_features(self, data):
"""Feature engineering with dynamic column handling"""
try:
# Identify available columns for customer aggregation
agg_cols = ['Customer'] # Always need customer ID
optional_cols = ['CustomerName', 'Country', 'CustomerGroup']
for col in optional_cols:
if col in data.columns and data[col].notna().any():
agg_cols.append(col)
# Customer-level aggregation with only available columns
agg_dict = {}
for col in agg_cols:
if col != 'Customer':
agg_dict[col] = 'first'
# Add order-related aggregations
if 'total_orders' in data.columns:
agg_dict['total_orders'] = 'first'
if 'first_order_date' in data.columns:
agg_dict['first_order_date'] = 'first'
if 'last_order_date' in data.columns:
agg_dict['last_order_date'] = 'first'
customer_features = data.groupby('Customer').agg(agg_dict).reset_index()
# Handle dates safely
reference_date = pd.to_datetime('2024-12-31')
if 'last_order_date' in customer_features.columns:
customer_features['last_order_date'] = pd.to_datetime(customer_features['last_order_date'], errors='coerce')
customer_features['Recency'] = (reference_date - customer_features['last_order_date']).dt.days
else:
customer_features['Recency'] = 100 # Default recency
if 'first_order_date' in customer_features.columns:
customer_features['first_order_date'] = pd.to_datetime(customer_features['first_order_date'], errors='coerce')
customer_features['Tenure'] = (reference_date - customer_features['first_order_date']).dt.days
else:
customer_features['Tenure'] = 365 # Default tenure
# RFM Features with safe handling
customer_features['Recency'] = customer_features['Recency'].fillna(365).clip(0, 3650)
if 'total_orders' in customer_features.columns:
customer_features['Frequency'] = customer_features['total_orders'].fillna(1).clip(1, 1000)
else:
customer_features['Frequency'] = 1 # Default frequency
customer_features['Monetary'] = (customer_features['Frequency'] * 500).clip(100, 1000000)
customer_features['Tenure'] = customer_features['Tenure'].fillna(365).clip(1, 3650)
# Safe OrderVelocity calculation
tenure_months = customer_features['Tenure'] / 30 + 1
customer_features['OrderVelocity'] = (customer_features['Frequency'] / tenure_months).clip(0, 50)
# Categorical encoding only for available columns
self.label_encoders = {}
categorical_features = []
for col in ['Country', 'CustomerGroup']:
if col in customer_features.columns and customer_features[col].notna().any():
try:
self.label_encoders[col] = LabelEncoder()
customer_features[f'{col}_encoded'] = self.label_encoders[col].fit_transform(
customer_features[col].fillna('Unknown')
)
categorical_features.append(f'{col}_encoded')
except Exception as e:
st.warning(f"β οΈ Could not encode {col}: {str(e)}")
# Target variable (churn definition)
customer_features['IsChurned'] = (
(customer_features['Recency'] > 90) &
(customer_features['Frequency'] > 1)
).astype(int)
# Define feature columns
self.feature_columns = ['Recency', 'Frequency', 'Monetary', 'Tenure', 'OrderVelocity']
self.feature_columns.extend(categorical_features)
# Prepare final dataset
required_cols = self.feature_columns + ['IsChurned', 'Customer']
# Add CustomerName if available
if 'CustomerName' in customer_features.columns:
required_cols.append('CustomerName')
# Filter to only existing columns
available_cols = [col for col in required_cols if col in customer_features.columns]
final_data = customer_features[available_cols].copy()
# **CRITICAL: Clean all data**
for col in self.feature_columns:
if col in final_data.columns:
final_data[col] = final_data[col].replace([np.inf, -np.inf], np.nan).fillna(0)
final_data[col] = final_data[col].clip(-1e9, 1e9)
st.info(f"β
Features engineered: {self.feature_columns}")
st.info(f"π Final dataset shape: {final_data.shape}")
return final_data
except Exception as e:
st.error(f"Feature engineering failed: {str(e)}")
st.info(f"Available columns: {list(data.columns)}")
raise
def train_model(self, data):
"""Train model with additional validation"""
try:
# Ensure all feature columns exist
missing_features = [col for col in self.feature_columns if col not in data.columns]
if missing_features:
st.warning(f"β οΈ Missing features: {missing_features}")
# Use only available features
self.feature_columns = [col for col in self.feature_columns if col in data.columns]
if not self.feature_columns:
raise ValueError("No valid features available for training")
X = data[self.feature_columns].copy()
y = data['IsChurned'].copy()
# Final data cleaning
if not np.isfinite(X).all().all():
X = X.replace([np.inf, -np.inf], np.nan).fillna(0)
# Check data quality
if len(X) < 50:
raise ValueError(f"Insufficient training data: {len(X)} samples")
if y.nunique() < 2:
st.warning("β οΈ Creating artificial target variation for training...")
# Create some variation for model training
variation_size = len(y) // 4
y.iloc[:variation_size] = 1 - y.iloc[:variation_size]
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42,
stratify=y if y.nunique() > 1 else None
)
# Train model
self.model = RandomForestClassifier(
n_estimators=50,
max_depth=8,
min_samples_split=20,
min_samples_leaf=10,
class_weight='balanced',
random_state=42,
n_jobs=1
)
self.model.fit(X_train, y_train)
# Evaluate
train_score = self.model.score(X_train, y_train)
test_score = self.model.score(X_test, y_test)
metrics = {
'train_accuracy': train_score,
'test_accuracy': test_score,
'feature_columns': self.feature_columns,
'training_samples': len(X_train),
'test_samples': len(X_test),
'churn_rate': float(y.mean()),
'feature_importance': dict(zip(self.feature_columns, self.model.feature_importances_)),
'column_mapping': self.column_mapping
}
st.success(f"β
Model trained successfully! Accuracy: {test_score:.3f}")
return metrics
except Exception as e:
st.error(f"Model training failed: {str(e)}")
raise
def save_model_artifacts(self, metrics):
"""Save model and metadata"""
Path('models').mkdir(exist_ok=True)
model_data = {
'model': self.model,
'label_encoders': self.label_encoders,
'feature_columns': self.feature_columns,
'column_mapping': self.column_mapping,
'version': 'v1',
'training_date': datetime.now().isoformat()
}
joblib.dump(model_data, self.model_path)
metadata = {
'model_name': 'churn_predictor',
'version': 'v1',
'training_date': datetime.now().isoformat(),
'metrics': metrics,
'status': 'trained',
'data_source': 'SAP/SALT dataset from Hugging Face',
'column_mapping': self.column_mapping
}
with open(self.metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
def load_existing_metadata(self):
"""Load existing model metadata"""
try:
with open(self.metadata_path, 'r') as f:
return json.load(f)
except Exception:
return None
|