Spaces:
Running
Running
""" | |
SuperKart Sales Prediction Frontend | |
A Streamlit web application for predicting product sales using the SuperKart ML model. | |
This frontend provides an intuitive interface for users to input product and store features | |
and get sales predictions from the backend API. | |
""" | |
import warnings | |
import streamlit as st | |
import requests | |
import pandas as pd | |
import argparse | |
import os | |
import sys | |
from typing import Dict | |
# Suppress SyntaxWarnings from Streamlit library | |
warnings.filterwarnings("ignore", category=SyntaxWarning) | |
# Page configuration | |
st.set_page_config( | |
page_title="SuperKart Sales Predictor", | |
page_icon="๐", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# Custom CSS for better styling | |
st.markdown( | |
""" | |
<style> | |
.main-header { | |
font-size: 3rem; | |
color: #1f77b4; | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
.prediction-box { | |
background-color: #f0f8ff; | |
padding: 20px; | |
border-radius: 10px; | |
border-left: 5px solid #1f77b4; | |
margin: 20px 0; | |
} | |
.success-box { | |
background-color: #d4edda; | |
padding: 15px; | |
border-radius: 5px; | |
border-left: 5px solid #28a745; | |
margin: 10px 0; | |
} | |
.error-box { | |
background-color: #f8d7da; | |
padding: 15px; | |
border-radius: 5px; | |
border-left: 5px solid #dc3545; | |
margin: 10px 0; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
def get_backend_url(): | |
"""Get backend URL from command line arguments, environment variables, or default.""" | |
# Check if running with Streamlit (sys.argv will contain streamlit run ...) | |
if len(sys.argv) > 1 and "streamlit" in sys.argv[0]: | |
# Parse additional arguments after the script name | |
parser = argparse.ArgumentParser(description="SuperKart Frontend App") | |
parser.add_argument( | |
"--backend-url", | |
type=str, | |
default=os.getenv("BACKEND_URL", "http://localhost:7860"), | |
help="Backend API URL (default: http://localhost:7860)", | |
) | |
# Only parse known args to avoid conflicts with Streamlit args | |
try: | |
known_args, _ = parser.parse_known_args() | |
return known_args.backend_url | |
except (SystemExit, argparse.ArgumentError): | |
pass | |
# Fallback to environment variable or default | |
return os.getenv("BACKEND_URL", "http://localhost:7860") | |
# Configuration | |
BACKEND_URL = get_backend_url() | |
def make_api_request(endpoint: str, data: Dict = None, method: str = "GET") -> Dict: | |
"""Make API request to backend service.""" | |
try: | |
url = f"{BACKEND_URL}{endpoint}" | |
if method == "GET": | |
response = requests.get(url, timeout=30) | |
elif method == "POST": | |
response = requests.post(url, json=data, timeout=30) | |
response.raise_for_status() | |
return {"success": True, "data": response.json()} | |
except requests.exceptions.ConnectionError: | |
return { | |
"success": False, | |
"error": "Cannot connect to backend API. Please ensure the backend service is running.", | |
} | |
except requests.exceptions.Timeout: | |
return { | |
"success": False, | |
"error": "Request timeout. The backend service is taking too long to respond.", | |
} | |
except requests.exceptions.RequestException as e: | |
return {"success": False, "error": f"API request failed: {str(e)}"} | |
def get_feature_info(): | |
"""Get feature information from backend API.""" | |
result = make_api_request("/features") | |
if result["success"]: | |
return result["data"] | |
else: | |
st.error(f"Failed to get feature information: {result['error']}") | |
return None | |
def create_input_form(): | |
"""Create the input form for prediction.""" | |
st.header("๐ฎ Product Sales Prediction") | |
# Get feature information | |
feature_info = get_feature_info() | |
if not feature_info: | |
return None | |
# Create form | |
with st.form("prediction_form"): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("๐ฆ Product Features") | |
product_weight = st.number_input( | |
"Product Weight (kg)", | |
min_value=0.1, | |
max_value=100.0, | |
value=12.66, | |
step=0.1, | |
help="Weight of the product in kilograms", | |
) | |
product_sugar_content = st.selectbox( | |
"Sugar Content", | |
options=["Low Sugar", "Regular", "No Sugar"], | |
index=0, | |
help="Sugar content level of the product", | |
) | |
product_allocated_area = st.number_input( | |
"Allocated Display Area (Ratio)", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.027, | |
step=0.001, | |
format="%.3f", | |
help="Ratio of allocated display area (0.0 to 1.0)", | |
) | |
product_type = st.selectbox( | |
"Product Type", | |
options=[ | |
"Dairy", | |
"Soft Drinks", | |
"Meat", | |
"Fruits and Vegetables", | |
"Household", | |
"Baking Goods", | |
"Snack Foods", | |
"Frozen Foods", | |
"Breakfast", | |
"Health and Hygiene", | |
"Hard Drinks", | |
"Canned", | |
"Bread", | |
"Starchy Foods", | |
"Others", | |
"Seafood", | |
], | |
index=7, # Frozen Foods | |
help="Category of the product", | |
) | |
product_mrp = st.number_input( | |
"Maximum Retail Price ($)", | |
min_value=1.0, | |
max_value=1000.0, | |
value=117.08, | |
step=0.01, | |
format="%.2f", | |
help="Maximum retail price in USD", | |
) | |
with col2: | |
st.subheader("๐ช Store Features") | |
store_establishment_year = st.selectbox( | |
"Store Establishment Year", | |
options=[1987, 1998, 1999, 2009], | |
index=3, # 2009 | |
help="Year when the store was established", | |
) | |
store_size = st.selectbox( | |
"Store Size", | |
options=["Small", "Medium", "High"], | |
index=1, # Medium | |
help="Size category of the store", | |
) | |
store_location_city_type = st.selectbox( | |
"City Type", | |
options=["Tier 1", "Tier 2", "Tier 3"], | |
index=1, # Tier 2 | |
help="Type of city where the store is located", | |
) | |
store_type = st.selectbox( | |
"Store Type", | |
options=[ | |
"Supermarket Type1", | |
"Supermarket Type2", | |
"Supermarket Type3", | |
"Departmental Store", | |
"Food Mart", | |
], | |
index=1, # Supermarket Type2 | |
help="Type/format of the store", | |
) | |
# Submit button | |
submitted = st.form_submit_button("๐ฏ Predict Sales", type="primary") | |
if submitted: | |
# Prepare input data | |
input_data = { | |
"Product_Weight": product_weight, | |
"Product_Sugar_Content": product_sugar_content, | |
"Product_Allocated_Area": product_allocated_area, | |
"Product_Type": product_type, | |
"Product_MRP": product_mrp, | |
"Store_Establishment_Year": store_establishment_year, | |
"Store_Size": store_size, | |
"Store_Location_City_Type": store_location_city_type, | |
"Store_Type": store_type, | |
} | |
return input_data | |
return None | |
def display_prediction_result(prediction_data: Dict): | |
"""Display the prediction result with EDA-based insights.""" | |
predicted_sales = prediction_data["predicted_sales"] | |
# Main prediction display | |
st.markdown('<div class="prediction-box">', unsafe_allow_html=True) | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
with col2: | |
st.markdown( | |
f""" | |
<div style="text-align: center;"> | |
<h2>๐ฐ Predicted Sales Revenue</h2> | |
<h1 style="color: #28a745; font-size: 4rem;">${predicted_sales:,.2f}</h1> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.markdown("</div>", unsafe_allow_html=True) | |
# EDA-based insights and business metrics | |
st.subheader("๐ Sales Analysis & Business Insights") | |
# Based on EDA: Sales range $33-$8,000, Mean: $3,464, Median: $3,452, Std: $1,066 | |
sales_mean = 3464 | |
sales_median = 3452 | |
sales_std = 1066 | |
sales_q1 = 2762 | |
sales_q3 = 4145 | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
# Performance vs Mean | |
vs_mean = ((predicted_sales - sales_mean) / sales_mean) * 100 | |
delta_color = "normal" if abs(vs_mean) < 10 else "inverse" | |
st.metric( | |
label="๐ vs Dataset Mean", | |
value=f"${predicted_sales:,.2f}", | |
delta=f"{vs_mean:+.1f}%", | |
delta_color=delta_color, | |
) | |
with col2: | |
# Performance vs Median | |
vs_median = ((predicted_sales - sales_median) / sales_median) * 100 | |
delta_color = "normal" if abs(vs_median) < 10 else "inverse" | |
st.metric( | |
label="๐ vs Dataset Median", | |
value=f"${sales_median:,.2f}", | |
delta=f"{vs_median:+.1f}%", | |
delta_color=delta_color, | |
) | |
with col3: | |
# Percentile ranking based on EDA quartiles | |
if predicted_sales <= sales_q1: | |
percentile = "Bottom 25%" | |
percentile_color = "๐ด" | |
elif predicted_sales <= sales_median: | |
percentile = "25th-50th" | |
percentile_color = "๐ก" | |
elif predicted_sales <= sales_q3: | |
percentile = "50th-75th" | |
percentile_color = "๐ " | |
else: | |
percentile = "Top 25%" | |
percentile_color = "๐ข" | |
st.metric( | |
label="๐ฏ Performance Percentile", | |
value=f"{percentile_color} {percentile}", | |
delta=None, | |
) | |
with col4: | |
# Standard deviation analysis | |
z_score = (predicted_sales - sales_mean) / sales_std | |
if abs(z_score) <= 1: | |
volatility = "Normal" | |
vol_color = "๐ข" | |
elif abs(z_score) <= 2: | |
volatility = "Moderate" | |
vol_color = "๐ก" | |
else: | |
volatility = "High" | |
vol_color = "๐ด" | |
st.metric( | |
label="๐ Sales Volatility", | |
value=f"{vol_color} {volatility}", | |
delta=f"ฯ: {z_score:+.1f}", | |
) | |
# Business insights section | |
st.subheader("๐ผ Business Recommendations & Next Steps") | |
# Performance Summary Box | |
if predicted_sales >= sales_q3: # Top 25% | |
performance_level = "โญ Excellent" | |
performance_color = "#28a745" | |
summary_message = ( | |
"This product is predicted to perform in the top 25% of SuperKart sales!" | |
) | |
elif predicted_sales >= sales_median: # Above median | |
performance_level = "โ Good" | |
performance_color = "#17a2b8" | |
summary_message = ( | |
"This product is predicted to perform above the historical average." | |
) | |
elif predicted_sales >= sales_q1: # Above bottom quartile | |
performance_level = "โ ๏ธ Below Average" | |
performance_color = "#ffc107" | |
summary_message = ( | |
"This product may underperform compared to typical SuperKart sales." | |
) | |
else: # Bottom 25% | |
performance_level = "๐ด Needs Attention" | |
performance_color = "#dc3545" | |
summary_message = ( | |
"This product is predicted to be in the bottom 25% of sales performance." | |
) | |
# Performance summary box | |
st.markdown( | |
f""" | |
<div style="background-color: {performance_color}20; padding: 20px; border-radius: 10px; | |
border-left: 5px solid {performance_color}; margin: 15px 0;"> | |
<h4 style="color: {performance_color}; margin: 0 0 10px 0;"> | |
{performance_level} Performance Expected | |
</h4> | |
<p style="margin: 0; font-size: 16px;">{summary_message}</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Three-column layout for insights | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown("#### ๐ฐ Financial Impact") | |
# Revenue tier classification (moved to top for consistency) | |
if predicted_sales >= 5000: | |
tier = "๐ Premium Tier" | |
elif predicted_sales >= 3000: | |
tier = "๐ฅ Standard Tier" | |
else: | |
tier = "๐ฅ Value Tier" | |
st.info(f"**Revenue Classification:** {tier}") | |
# Financial metrics with clear labels | |
profit_margin = 0.2 # 20% profit margin | |
estimated_profit = predicted_sales * profit_margin | |
st.metric("Predicted Revenue", f"${predicted_sales:,.0f}") | |
st.metric("Estimated Profit (20%)", f"${estimated_profit:,.0f}") | |
with col2: | |
st.markdown("#### ๐ Market Position") | |
# Clear market positioning | |
vs_mean_pct = ((predicted_sales - sales_mean) / sales_mean) * 100 | |
if vs_mean_pct > 10: | |
position = "๐ Above Market Average" | |
elif vs_mean_pct > -10: | |
position = "๐ Market Average" | |
else: | |
position = "๐ Below Market Average" | |
st.success(position) | |
st.write(f"**vs Historical Mean:** {vs_mean_pct:+.1f}%") | |
st.write("**Market Range:** \\$33 - \\$8,000") | |
st.write(f"**Your Prediction:** ${predicted_sales:,.0f}") | |
with col3: | |
st.markdown("#### ๐ฏ Action Items") | |
# Clear, actionable recommendations | |
if predicted_sales < sales_q1: | |
st.warning("**Low Performance Risk**") | |
st.write("**Immediate Actions:**") | |
st.write("โข Launch promotional campaign") | |
st.write("โข Review pricing strategy") | |
st.write("โข Optimize product placement") | |
st.write("โข Analyze competitor offerings") | |
elif predicted_sales > sales_q3: | |
st.success("**High Performance Opportunity**") | |
st.write("**Recommended Actions:**") | |
st.write("โข Ensure adequate stock levels") | |
st.write("โข Consider premium pricing") | |
st.write("โข Expand to similar products") | |
st.write("โข Allocate prime shelf space") | |
else: | |
st.info("**Standard Performance Expected**") | |
st.write("**Monitor & Optimize:**") | |
st.write("โข Track actual vs predicted") | |
st.write("โข A/B test marketing approaches") | |
st.write("โข Monitor competitor activity") | |
st.write("โข Adjust inventory as needed") | |
def create_input_summary(input_data: Dict): | |
"""Create a summary of input features.""" | |
st.subheader("๐ Input Summary") | |
# Create two columns for better layout | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown("**Product Information:**") | |
st.write(f"โข Weight: {input_data['Product_Weight']} kg") | |
st.write(f"โข Sugar Content: {input_data['Product_Sugar_Content']}") | |
st.write(f"โข Display Area: {input_data['Product_Allocated_Area']:.3f}") | |
st.write(f"โข Type: {input_data['Product_Type']}") | |
st.write(f"โข MRP: ${input_data['Product_MRP']:.2f}") | |
with col2: | |
st.markdown("**Store Information:**") | |
st.write(f"โข Establishment Year: {input_data['Store_Establishment_Year']}") | |
st.write(f"โข Size: {input_data['Store_Size']}") | |
st.write(f"โข City Type: {input_data['Store_Location_City_Type']}") | |
st.write(f"โข Store Type: {input_data['Store_Type']}") | |
def create_batch_prediction(): | |
"""Create batch prediction interface.""" | |
st.header("๐ Batch Prediction") | |
st.markdown(""" | |
Upload a CSV file with multiple products to get batch predictions. | |
The CSV should contain all required columns with the same names as in the single prediction form. | |
""") | |
# File uploader | |
uploaded_file = st.file_uploader( | |
"Choose a CSV file", | |
type="csv", | |
help="Upload a CSV file with product and store features", | |
) | |
if uploaded_file is not None: | |
try: | |
# Read the CSV file | |
df = pd.read_csv(uploaded_file) | |
# Display the uploaded data | |
st.subheader("๐ Uploaded Data") | |
st.dataframe(df.head(10)) | |
if st.button("๐ Run Batch Prediction", type="primary"): | |
# Convert DataFrame to list of dictionaries | |
predictions_data = df.to_dict("records") | |
# Make batch prediction request | |
result = make_api_request( | |
"/predict/batch", {"predictions": predictions_data}, "POST" | |
) | |
if result["success"]: | |
batch_results = result["data"] | |
# Display results | |
st.subheader("๐ Batch Prediction Results") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric( | |
"โ Successful", batch_results["successful_predictions"] | |
) | |
with col2: | |
st.metric("โ Failed", batch_results["failed_predictions"]) | |
with col3: | |
st.metric("๐ Total", len(predictions_data)) | |
# Show successful predictions | |
if batch_results["results"]: | |
st.subheader("๐ฏ Successful Predictions") | |
# Create a user-friendly results DataFrame | |
display_results = [] | |
for result in batch_results["results"]: | |
# Extract readable product info | |
input_features = result["input_features"] | |
# Determine performance category | |
sales = result["predicted_sales"] | |
if sales >= 4145: # Top 25% (Q3) | |
category = "๐ข High" | |
elif sales >= 3452: # Above median | |
category = "๐ก Good" | |
elif sales >= 2762: # Above Q1 | |
category = "๐ Average" | |
else: | |
category = "๐ด Low" | |
display_row = { | |
"Row": result["index"] + 1, | |
"Product Type": input_features["Product_Type"], | |
"Weight (kg)": input_features["Product_Weight"], | |
"MRP ($)": f"${input_features['Product_MRP']:.2f}", | |
"Store Size": input_features["Store_Size"], | |
"Store Type": input_features["Store_Type"], | |
"Predicted Sales": f"${sales:,.2f}", | |
"Performance": category, | |
} | |
display_results.append(display_row) | |
display_df = pd.DataFrame(display_results) | |
# Show the clean results table | |
st.dataframe( | |
display_df, use_container_width=True, hide_index=True | |
) | |
# Summary statistics | |
sales_values = [ | |
result["predicted_sales"] | |
for result in batch_results["results"] | |
] | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("๐ฐ Total Revenue", f"${sum(sales_values):,.0f}") | |
with col2: | |
st.metric( | |
"๐ Average Sale", | |
f"${sum(sales_values) / len(sales_values):,.0f}", | |
) | |
with col3: | |
high_performers = len( | |
[s for s in sales_values if s >= 4145] | |
) | |
st.metric("๐ข High Performers", f"{high_performers}") | |
with col4: | |
low_performers = len([s for s in sales_values if s < 2762]) | |
st.metric("๐ด Needs Attention", f"{low_performers}") | |
# Download options | |
col1, col2 = st.columns(2) | |
with col1: | |
# Download user-friendly results | |
csv_display = display_df.to_csv(index=False) | |
st.download_button( | |
label="๐ฅ Download Summary Results", | |
data=csv_display, | |
file_name="batch_predictions_summary.csv", | |
mime="text/csv", | |
) | |
with col2: | |
# Download detailed results for technical users | |
detailed_results = [] | |
for result in batch_results["results"]: | |
detailed_row = { | |
"row_index": result["index"], | |
"predicted_sales": result["predicted_sales"], | |
**result["input_features"], | |
} | |
detailed_results.append(detailed_row) | |
detailed_df = pd.DataFrame(detailed_results) | |
csv_detailed = detailed_df.to_csv(index=False) | |
st.download_button( | |
label="๐ง Download Detailed Results", | |
data=csv_detailed, | |
file_name="batch_predictions_detailed.csv", | |
mime="text/csv", | |
) | |
# Show errors if any | |
if batch_results["errors"]: | |
st.subheader("โ ๏ธ Prediction Errors") | |
errors_df = pd.DataFrame(batch_results["errors"]) | |
st.dataframe(errors_df) | |
else: | |
st.error(f"Batch prediction failed: {result['error']}") | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
def main(): | |
"""Main application function.""" | |
# Title and description | |
st.markdown( | |
'<h1 class="main-header">๐ SuperKart Sales Predictor</h1>', | |
unsafe_allow_html=True, | |
) | |
st.markdown( | |
""" | |
<div style="text-align: center; margin-bottom: 2rem;"> | |
<p style="font-size: 1.2rem; color: #666;"> | |
Predict product sales revenue using machine learning based on product and store characteristics | |
</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Check backend health | |
health_result = make_api_request("/") | |
if not health_result["success"]: | |
st.error( | |
f"โ ๏ธ Backend API is not available at `{BACKEND_URL}`. Please ensure the backend service is running." | |
) | |
st.info( | |
""" | |
**How to specify a different backend URL:** | |
1. **Command line argument:** | |
``` | |
streamlit run app.py -- --backend-url http://your-backend:5050 | |
``` | |
2. **Environment variable:** | |
``` | |
export BACKEND_URL=http://your-backend:5050 | |
streamlit run app.py | |
``` | |
""" | |
) | |
st.stop() | |
# Sidebar navigation | |
st.sidebar.title("๐งญ Navigation") | |
# Display current backend URL and connection status | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("**๐ Backend Configuration**") | |
st.sidebar.code(BACKEND_URL, language=None) | |
# Show connection status | |
if health_result["success"]: | |
st.sidebar.success("๐ข Connected") | |
if "data" in health_result and "model_loaded" in health_result["data"]: | |
model_status = ( | |
"๐ค Model Loaded" | |
if health_result["data"]["model_loaded"] | |
else "โ ๏ธ Model Not Loaded" | |
) | |
st.sidebar.info(model_status) | |
else: | |
st.sidebar.error("๐ด Disconnected") | |
st.sidebar.markdown("---") | |
app_mode = st.sidebar.selectbox( | |
"Choose App Mode", | |
["Single Prediction", "Batch Prediction", "API Documentation"], | |
) | |
if app_mode == "Single Prediction": | |
# Single prediction interface | |
input_data = create_input_form() | |
if input_data: | |
# Make prediction | |
result = make_api_request("/predict", input_data, "POST") | |
if result["success"]: | |
prediction_data = result["data"] | |
# Display results | |
display_prediction_result(prediction_data) | |
# Show input summary | |
with st.expander("๐ View Input Details", expanded=False): | |
create_input_summary(input_data) | |
# Success message | |
st.markdown( | |
'<div class="success-box">โ Prediction completed successfully!</div>', | |
unsafe_allow_html=True, | |
) | |
else: | |
st.markdown( | |
f'<div class="error-box">โ Prediction failed: {result["error"]}</div>', | |
unsafe_allow_html=True, | |
) | |
elif app_mode == "Batch Prediction": | |
create_batch_prediction() | |
elif app_mode == "API Documentation": | |
st.header("๐ API Documentation") | |
# Get feature information | |
feature_info = get_feature_info() | |
if feature_info: | |
st.subheader("๐ง Required Features") | |
features_df = pd.DataFrame( | |
[ | |
{"Feature": k, "Description": v} | |
for k, v in feature_info["feature_descriptions"].items() | |
] | |
) | |
st.table(features_df) | |
st.subheader("๐ Example Input") | |
st.json(feature_info["example_input"]) | |
st.subheader("๐ API Endpoints") | |
st.markdown(""" | |
- **GET /**: Health check | |
- **POST /predict**: Single prediction | |
- **POST /predict/batch**: Batch prediction | |
- **GET /features**: Get feature information | |
""") | |
# Footer | |
st.markdown("---") | |
st.markdown( | |
"<div style='text-align: center; color: #666;'>" | |
"SuperKart Sales Prediction System | Krishnaswamy Subramanian" | |
"</div>", | |
unsafe_allow_html=True, | |
) | |
if __name__ == "__main__": | |
main() | |