Spaces:
Running
Running
""" | |
SuperKart Sales Prediction Flask API | |
This Flask application provides a REST API for predicting product sales using a pre-trained | |
Random Forest model. The API accepts product and store features and returns predicted sales revenue. | |
""" | |
import os | |
import joblib | |
import pandas as pd | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
import logging | |
from typing import Any, Dict | |
from pydantic import BaseModel, ValidationError, field_validator | |
from datetime import datetime | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize Flask app | |
app = Flask(__name__) | |
CORS(app) # Enable CORS for frontend integration | |
# Global variables for model and preprocessing pipeline | |
model = None | |
feature_columns = None | |
# Define user input features (what user provides) | |
USER_INPUT_FEATURES = [ | |
"Product_Weight", | |
"Product_Sugar_Content", | |
"Product_Allocated_Area", | |
"Product_Type", | |
"Product_MRP", | |
"Store_Establishment_Year", | |
"Store_Size", | |
"Store_Location_City_Type", | |
"Store_Type", | |
] | |
# Define model features (what model expects after preprocessing) | |
MODEL_FEATURES = [ | |
"Product_Weight", | |
"Product_Sugar_Content", | |
"Product_Allocated_Area", | |
"Product_Type", | |
"Product_MRP", | |
"Store_Size", | |
"Store_Location_City_Type", | |
"Store_Type", | |
"Store_Age", | |
] | |
# Pydantic model for input validation | |
class PredictionInput(BaseModel): | |
Product_Weight: float | |
Product_Sugar_Content: str | |
Product_Allocated_Area: float | |
Product_Type: str | |
Product_MRP: float | |
Store_Establishment_Year: int | |
Store_Size: str | |
Store_Location_City_Type: str | |
Store_Type: str | |
def validate_product_weight(cls, v: float) -> float: | |
if v <= 0: | |
raise ValueError("Product_Weight must be greater than 0") | |
if v < 4.0 or v > 22.0: | |
raise ValueError("Product_Weight must be between 4.0 and 22.0") | |
return v | |
def validate_allocated_area(cls, v: float) -> float: | |
if v < 0 or v > 1: | |
raise ValueError("Product_Allocated_Area must be between 0 and 1") | |
return v | |
def validate_mrp(cls, v: float) -> float: | |
if v <= 0: | |
raise ValueError("Product_MRP must be greater than 0") | |
if v < 31.0 or v > 266.0: | |
raise ValueError("Product_MRP must be between 31.0 and 266.0") | |
return v | |
def validate_establishment_year(cls, v: int) -> int: | |
valid_years = [1987, 1998, 1999, 2009] | |
if v not in valid_years: | |
raise ValueError(f"Store_Establishment_Year must be one of: {valid_years}") | |
return v | |
def validate_sugar_content(cls, v: str) -> str: | |
valid = ["Low Sugar", "Regular", "No Sugar"] | |
if v not in valid: | |
raise ValueError(f"Product_Sugar_Content must be one of: {valid}") | |
return v | |
def validate_product_type(cls, v: str) -> str: | |
valid = [ | |
"Dairy", | |
"Soft Drinks", | |
"Meat", | |
"Fruits and Vegetables", | |
"Household", | |
"Baking Goods", | |
"Snack Foods", | |
"Frozen Foods", | |
"Breakfast", | |
"Health and Hygiene", | |
"Hard Drinks", | |
"Canned", | |
"Bread", | |
"Starchy Foods", | |
"Others", | |
"Seafood", | |
] | |
if v not in valid: | |
raise ValueError(f"Product_Type must be one of: {valid}") | |
return v | |
def validate_store_size(cls, v: str) -> str: | |
valid = ["Small", "Medium", "High"] | |
if v not in valid: | |
raise ValueError(f"Store_Size must be one of: {valid}") | |
return v | |
def validate_city_type(cls, v: str) -> str: | |
valid = ["Tier 1", "Tier 2", "Tier 3"] | |
if v not in valid: | |
raise ValueError(f"Store_Location_City_Type must be one of: {valid}") | |
return v | |
def validate_store_type(cls, v: str) -> str: | |
valid = [ | |
"Supermarket Type1", | |
"Supermarket Type2", | |
"Supermarket Type3", | |
"Departmental Store", | |
"Food Mart", | |
] | |
if v not in valid: | |
raise ValueError(f"Store_Type must be one of: {valid}") | |
return v | |
def load_model(model_path: str): | |
""" | |
Load the trained model from the specified path. | |
Args: | |
model_path (str): Path to the model file. | |
Returns: | |
bool: True if model loaded successfully, False otherwise. | |
""" | |
global model, feature_columns | |
try: | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model file not found at: {model_path}") | |
# Load the trained model (which includes preprocessing pipeline) | |
model = joblib.load(model_path) | |
logger.info(f"β Model loaded successfully from: {model_path}") | |
# Set feature columns | |
feature_columns = MODEL_FEATURES | |
logger.info(f"π Model features: {MODEL_FEATURES}") | |
logger.info(f"π User input features: {USER_INPUT_FEATURES}") | |
return True | |
except Exception as e: | |
logger.error(f"β Error loading model: {str(e)}") | |
return False | |
def convert_establishment_year_to_age(data: Dict[str, Any]) -> Dict[str, Any]: | |
"""Convert Store_Establishment_Year to Store_Age.""" | |
# Create a copy to avoid modifying the original | |
converted_data = data.copy() | |
# Get current year | |
current_year = datetime.now().year | |
# Convert establishment year to age | |
if "Store_Establishment_Year" in converted_data: | |
establishment_year = converted_data.pop("Store_Establishment_Year") | |
converted_data["Store_Age"] = current_year - establishment_year | |
return converted_data | |
def preprocess_input(data: Dict[str, Any]) -> pd.DataFrame: | |
"""Convert input data to DataFrame format expected by the model.""" | |
# First convert establishment year to age | |
converted_data = convert_establishment_year_to_age(data) | |
# Create DataFrame with model features | |
df = pd.DataFrame([converted_data]) | |
df = df[MODEL_FEATURES] | |
return df | |
def health_check(): | |
"""Health check endpoint.""" | |
return jsonify( | |
{ | |
"status": "healthy", | |
"message": "SuperKart Sales Prediction API is running", | |
"model_loaded": model is not None, | |
} | |
) | |
def predict(): | |
"""Predict sales for given product and store features.""" | |
if model is None: | |
return jsonify({"error": "Model not loaded. Please check server logs."}), 500 | |
try: | |
# Get JSON data from request | |
data = request.get_json() | |
if not data: | |
return jsonify( | |
{ | |
"error": "No data provided. Please send JSON data in the request body." | |
} | |
), 400 | |
# Validate input using Pydantic | |
try: | |
validated = PredictionInput(**data) | |
except ValidationError as ve: | |
return jsonify( | |
{"error": "Input validation failed", "details": ve.errors()} | |
), 400 | |
# Preprocess input data | |
input_df = preprocess_input(validated.model_dump()) | |
# Make prediction | |
prediction = model.predict(input_df) | |
predicted_sales = float(prediction[0]) | |
# Prepare response | |
response = { | |
"predicted_sales": round(predicted_sales, 2), | |
"currency": "USD", | |
"input_features": validated.model_dump(), | |
"status": "success", | |
} | |
logger.info(f"β Prediction successful: ${predicted_sales:.2f}") | |
return jsonify(response) | |
except Exception as e: | |
logger.error(f"β Prediction error: {str(e)}") | |
return jsonify({"error": f"Prediction failed: {str(e)}"}), 500 | |
def get_features(): | |
"""Get information about expected input features.""" | |
feature_info = { | |
"required_features": USER_INPUT_FEATURES, | |
"feature_descriptions": { | |
"Product_Weight": "Weight of the product (4.0-22.0 kg)", | |
"Product_Sugar_Content": "Sugar content (Low Sugar, Regular, No Sugar)", | |
"Product_Allocated_Area": "Allocated display area ratio (0.0-1.0)", | |
"Product_Type": "Product category (16 types: Dairy, Soft Drinks, Meat, etc.)", | |
"Product_MRP": "Maximum retail price (31.0-266.0 USD)", | |
"Store_Establishment_Year": "Year store was established (1987, 1998, 1999, 2009)", | |
"Store_Size": "Store size (Small, Medium, High)", | |
"Store_Location_City_Type": "City type (Tier 1, Tier 2, Tier 3)", | |
"Store_Type": "Store type (Supermarket Type1/2/3, Departmental Store, Food Mart)", | |
}, | |
"example_input": { | |
"Product_Weight": 12.66, | |
"Product_Sugar_Content": "Low Sugar", | |
"Product_Allocated_Area": 0.027, | |
"Product_Type": "Frozen Foods", | |
"Product_MRP": 117.08, | |
"Store_Establishment_Year": 2009, | |
"Store_Size": "Medium", | |
"Store_Location_City_Type": "Tier 2", | |
"Store_Type": "Supermarket Type2", | |
}, | |
} | |
return jsonify(feature_info) | |
def predict_batch(): | |
"""Predict sales for multiple products at once.""" | |
if model is None: | |
return jsonify({"error": "Model not loaded. Please check server logs."}), 500 | |
try: | |
# Get JSON data from request | |
data = request.get_json() | |
if not data or "predictions" not in data: | |
return jsonify( | |
{ | |
"error": 'No data provided. Please send JSON with "predictions" array.' | |
} | |
), 400 | |
predictions_data = data["predictions"] | |
if not isinstance(predictions_data, list): | |
return jsonify({"error": "Predictions must be an array of objects."}), 400 | |
results = [] | |
errors = [] | |
for i, item in enumerate(predictions_data): | |
try: | |
# Validate input using Pydantic | |
try: | |
validated = PredictionInput(**item) | |
except ValidationError as ve: | |
errors.append({"index": i, "error": ve.errors(), "input": item}) | |
continue | |
# Preprocess and predict | |
input_df = preprocess_input(validated.model_dump()) | |
prediction = model.predict(input_df) | |
predicted_sales = float(prediction[0]) | |
results.append( | |
{ | |
"index": i, | |
"predicted_sales": round(predicted_sales, 2), | |
"input_features": validated.model_dump(), | |
} | |
) | |
except Exception as e: | |
errors.append({"index": i, "error": str(e), "input": item}) | |
response = { | |
"successful_predictions": len(results), | |
"failed_predictions": len(errors), | |
"results": results, | |
"errors": errors, | |
"status": "completed", | |
} | |
logger.info( | |
f"β Batch prediction completed: {len(results)} successful, {len(errors)} failed" | |
) | |
return jsonify(response) | |
except Exception as e: | |
logger.error(f"β Batch prediction error: {str(e)}") | |
return jsonify({"error": f"Batch prediction failed: {str(e)}"}), 500 | |
# Load model on module import (for Gunicorn compatibility) | |
if not load_model("./superkart_model.joblib"): | |
logger.error("β Failed to load model. Application may not work properly.") | |
if __name__ == "__main__": | |
# This runs only when script is executed directly (not imported by Gunicorn) | |
logger.info("π Starting SuperKart Sales Prediction API...") | |
app.run(host="0.0.0.0", port=7860, debug=True) | |