statsforecast / app.py
fmegahed's picture
Update app.py
01074ec verified
import pandas as pd
import matplotlib.pyplot as plt
import gradio as gr
import tempfile
import os
from datetime import datetime
from statsforecast import StatsForecast
from statsforecast.models import (
HistoricAverage,
Naive,
SeasonalNaive,
WindowAverage,
SeasonalWindowAverage,
AutoETS,
AutoARIMA
)
from utilsforecast.evaluation import evaluate
from utilsforecast.losses import *
# Import for TimeGPT
from nixtla import NixtlaClient
# Function to load and process uploaded CSV
def load_data(file):
if file is None:
return None, "Please upload a CSV file"
try:
df = pd.read_csv(file)
required_cols = ['unique_id', 'ds', 'y']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
return None, f"Missing required columns: {', '.join(missing_cols)}"
df['ds'] = pd.to_datetime(df['ds'])
df = df[required_cols]
df = df.sort_values(['unique_id', 'ds']).reset_index(drop=True)
# Check for NaN values
if df['y'].isna().any():
return None, "Data contains missing values in the 'y' column"
return df, "Data loaded successfully!"
except Exception as e:
return None, f"Error loading data: {str(e)}"
# Helper function to calculate date offset based on frequency and horizon
def calculate_date_offset(freq, horizon):
"""Calculate a timedelta based on frequency code and horizon"""
if freq == 'H':
return pd.Timedelta(hours=horizon)
elif freq == 'D':
return pd.Timedelta(days=horizon)
elif freq == 'B':
# For business days, use approximately 1.4x multiplier to account for weekends
return pd.Timedelta(days=int(horizon * 1.4))
elif freq == 'WS':
return pd.Timedelta(weeks=horizon)
elif freq == 'MS':
return pd.Timedelta(days=horizon * 30) # Approximate
elif freq == 'QS':
return pd.Timedelta(days=horizon * 90) # Approximate
elif freq == 'YS':
return pd.Timedelta(days=horizon * 365) # Approximate
else:
# Default fallback
return pd.Timedelta(days=horizon)
# Function to generate and return a plot for validation results
def create_forecast_plot(forecast_df, original_df, title="Forecasting Results", horizon=None, freq='D'):
plt.figure(figsize=(12, 7))
unique_ids = forecast_df['unique_id'].unique()
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds', 'cutoff', 'y']]
colors = plt.cm.tab10.colors
# Track min and max dates for x-axis limits
min_cutoff = None
for i, unique_id in enumerate(unique_ids):
original_data = original_df[original_df['unique_id'] == unique_id]
plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Actual)')
forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
# Find the earliest cutoff date if available
if 'cutoff' in forecast_data.columns:
cutoffs = pd.to_datetime(forecast_data['cutoff'].unique())
if len(cutoffs) > 0:
earliest_cutoff = cutoffs.min()
if min_cutoff is None or earliest_cutoff < min_cutoff:
min_cutoff = earliest_cutoff
# Add vertical line at each cutoff
for cutoff in cutoffs:
plt.axvline(x=cutoff, color='gray', linestyle='--', alpha=0.4)
# Plot main prediction lines
for j, col in enumerate(forecast_cols):
if col in forecast_data.columns:
# Clean up model name for legend
model_name = col.replace('_', ' ').title()
if model_name == 'Timegpt':
model_name = 'TimeGPT'
plt.plot(forecast_data['ds'], forecast_data[col],
color=colors[j % len(colors)],
linestyle='--',
linewidth=1.5,
label=f'{model_name}')
plt.title(title, fontsize=16)
plt.xlabel('Date', fontsize=12)
plt.ylabel('Value', fontsize=12)
plt.grid(True, alpha=0.3)
# Better legend with smaller font and outside placement
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust layout to make room for legend
# Set x-axis limits based on cutoff and horizon
if min_cutoff is not None and horizon is not None:
# Calculate date offset based on frequency and horizon
date_offset = calculate_date_offset(freq, horizon)
# Calculate start date as 'horizon' units before the first cutoff
start_date = min_cutoff - date_offset
# Find max date from forecast
max_date = forecast_df['ds'].max()
plt.xlim(start_date, max_date)
# Add an annotation for the training/test split
plt.annotate('Training | Test',
xy=(min_cutoff, plt.ylim()[0]),
xytext=(0, -40),
textcoords='offset points',
horizontalalignment='center',
fontsize=10)
# Format date labels better
fig = plt.gcf()
ax = plt.gca()
fig.autofmt_xdate()
return fig
# Function to create a plot for future forecasts
def create_future_forecast_plot(forecast_df, original_df, horizon=None, freq='D'):
plt.figure(figsize=(12, 7))
unique_ids = forecast_df['unique_id'].unique()
forecast_cols = [col for col in forecast_df.columns if col not in ['unique_id', 'ds']]
colors = plt.cm.tab10.colors
# Track the forecast start date (min of forecast data)
forecast_start = None
if not forecast_df.empty:
forecast_start = pd.to_datetime(forecast_df['ds'].min())
for i, unique_id in enumerate(unique_ids):
# Plot historical data
original_data = original_df[original_df['unique_id'] == unique_id]
plt.plot(original_data['ds'], original_data['y'], 'k-', linewidth=2, label=f'{unique_id} (Historical)')
# Plot forecast data with shaded vertical line separator
forecast_data = forecast_df[forecast_df['unique_id'] == unique_id]
# Add vertical line at the forecast start
if not forecast_data.empty and not original_data.empty:
forecast_start = forecast_data['ds'].min()
plt.axvline(x=forecast_start, color='gray', linestyle='--', alpha=0.5)
# Add a shaded area for the forecast period
plt.axvspan(forecast_start, forecast_data['ds'].max(), alpha=0.1, color='blue')
# Annotate the split point
plt.annotate('Historical | Forecast',
xy=(forecast_start, plt.ylim()[0]),
xytext=(0, -40),
textcoords='offset points',
horizontalalignment='center',
fontsize=10)
# Plot main prediction lines
for j, col in enumerate(forecast_cols):
if col in forecast_data.columns:
# Clean up model name for legend
model_name = col.replace('_', ' ').title()
if model_name == 'Timegpt':
model_name = 'TimeGPT'
plt.plot(forecast_data['ds'], forecast_data[col],
color=colors[j % len(colors)],
linestyle='--',
linewidth=1.5,
label=f'{model_name}')
plt.title('Future Forecast', fontsize=16)
plt.xlabel('Date', fontsize=12)
plt.ylabel('Value', fontsize=12)
plt.grid(True, alpha=0.3)
# Better legend with smaller font and outside placement
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, fontsize=10)
plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust layout to make room for legend
# Set x-axis limits based on forecast start and horizon
if forecast_start is not None and horizon is not None:
# Calculate date offset based on frequency and horizon
date_offset = calculate_date_offset(freq, horizon)
# Calculate start date as 'horizon' units before the forecast start
start_date = forecast_start - date_offset
# Get the last date from historical data that's before or at the start_date
historical_dates = pd.to_datetime(original_df['ds'])
historical_dates_before_start = historical_dates[historical_dates <= start_date]
if not historical_dates_before_start.empty:
# Use the last available date in the historical data that's before our calculated start_date
adjusted_start_date = historical_dates_before_start.max()
else:
# Fallback to using the original start_date
adjusted_start_date = start_date
# Set the x-axis limits
plt.xlim(adjusted_start_date, forecast_df['ds'].max())
# Format date labels better
fig = plt.gcf()
ax = plt.gca()
fig.autofmt_xdate()
return fig
# Function to export results to CSV
def export_results(eval_df, cv_results, future_forecasts):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create temp directory if it doesn't exist
temp_dir = tempfile.mkdtemp()
result_files = []
if eval_df is not None:
eval_path = os.path.join(temp_dir, f"evaluation_metrics_{timestamp}.csv")
eval_df.to_csv(eval_path, index=False)
result_files.append(eval_path)
if cv_results is not None:
cv_path = os.path.join(temp_dir, f"cross_validation_results_{timestamp}.csv")
cv_results.to_csv(cv_path, index=False)
result_files.append(cv_path)
if future_forecasts is not None:
forecast_path = os.path.join(temp_dir, f"forecasts_{timestamp}.csv")
future_forecasts.to_csv(forecast_path, index=False)
result_files.append(forecast_path)
return result_files
# Main forecasting logic
def run_forecast(
file,
frequency,
eval_strategy,
horizon,
step_size,
num_windows,
use_historical_avg,
use_naive,
use_seasonal_naive,
seasonality,
use_window_avg,
window_size,
use_seasonal_window_avg,
seasonal_window_size,
use_autoets,
use_autoarima,
use_timegpt,
finetune_loss,
confidence_level,
future_horizon
):
df, message = load_data(file)
if df is None:
return None, None, None, None, None, None, message
# Initialize results
eval_df = None
cv_results = None
future_forecasts = None
# Set up traditional statistical models
models = []
model_aliases = []
if use_historical_avg:
models.append(HistoricAverage(alias='historical_average'))
model_aliases.append('historical_average')
if use_naive:
models.append(Naive(alias='naive'))
model_aliases.append('naive')
if use_seasonal_naive:
models.append(SeasonalNaive(season_length=seasonality, alias='seasonal_naive'))
model_aliases.append('seasonal_naive')
if use_window_avg:
models.append(WindowAverage(window_size=window_size, alias='window_average'))
model_aliases.append('window_average')
if use_seasonal_window_avg:
models.append(SeasonalWindowAverage(season_length=seasonality, window_size=seasonal_window_size, alias='seasonal_window_average'))
model_aliases.append('seasonal_window_average')
if use_autoets:
models.append(AutoETS(alias='autoets', season_length=seasonality))
model_aliases.append('autoets')
if use_autoarima:
models.append(AutoARIMA(alias='autoarima', season_length=seasonality))
model_aliases.append('autoarima')
if not models and not use_timegpt:
return None, None, None, None, None, None, "Please select at least one forecasting model"
try:
# Initialize results with empty DataFrames
combined_eval_df = pd.DataFrame()
combined_cv_results = pd.DataFrame()
combined_future_forecasts = pd.DataFrame()
# Run traditional statistical models if any are selected
if models:
sf = StatsForecast(models=models, freq=frequency, n_jobs=-1)
# Run cross-validation for traditional models
if eval_strategy == "Cross Validation":
cv_results = sf.cross_validation(df=df, h=horizon, step_size=step_size, n_windows=num_windows)
evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
eval_df = pd.DataFrame(evaluation).reset_index()
eval_df = eval_df.round({col: 2 for col in eval_df.columns[3:]})
else: # Fixed window
cv_results = sf.cross_validation(df=df, h=horizon, step_size=10, n_windows=1) # any step size for 1 window
evaluation = evaluate(df=cv_results, metrics=[bias, mae, rmse, mape], models=model_aliases)
eval_df = pd.DataFrame(evaluation).reset_index()
# Generate future forecasts
future_forecasts = sf.forecast(df=df, h=future_horizon)
# Store results
combined_eval_df = eval_df.copy() if eval_df is not None else pd.DataFrame()
combined_cv_results = cv_results.copy() if cv_results is not None else pd.DataFrame()
combined_cv_results = combined_cv_results
combined_future_forecasts = future_forecasts.copy() if future_forecasts is not None else pd.DataFrame()
# Run TimeGPT if selected
if use_timegpt:
try:
# Get API key from environment variables
nixtla_api_key = os.getenv("NIXTLA_API_KEY")
if not nixtla_api_key:
return None, None, None, None, None, None, "TimeGPT API key not found. Please set the NIXTLA_API_KEY environment variable."
# Initialize Nixtla client
nixtla_client = NixtlaClient(api_key=nixtla_api_key)
# Convert confidence_level to list format
level = [float(confidence_level)]
# Run cross-validation for TimeGPT
if eval_strategy == "Cross Validation":
timegpt_cv_df = nixtla_client.cross_validation(
df=df,
h=horizon,
freq=frequency,
n_windows=num_windows,
step_size=step_size
)
timegpt_cv_eval = evaluate(
df=timegpt_cv_df,
metrics=[mape, mae, rmse, bias],
models=['TimeGPT'],
)
timegpt_eval_df = pd.DataFrame(timegpt_cv_eval).reset_index()
timegpt_eval_df = timegpt_eval_df.round({col: 2 for col in timegpt_eval_df.columns[3:]})
else: # Fixed window
timegpt_cv_df = nixtla_client.cross_validation(
df=df,
h=horizon,
freq=frequency,
n_windows=1,
step_size=10
)
timegpt_cv_eval = evaluate(
df=timegpt_cv_df,
metrics=[mape, mae, rmse, bias],
models=['TimeGPT']
)
timegpt_eval_df = pd.DataFrame(timegpt_cv_eval).reset_index()
# Generate future forecasts with TimeGPT
forecast_timegpt = nixtla_client.forecast(
df=df,
h=future_horizon,
freq=frequency,
finetune_loss=finetune_loss
)
# Combine results - using merge instead of concat to avoid duplicate rows
if not combined_eval_df.empty and not timegpt_eval_df.empty:
# Get common columns for the join
join_columns = ['unique_id', 'metric']
# Merge the dataframes on unique_id and metric
combined_eval_df = pd.merge(
combined_eval_df,
timegpt_eval_df,
on=join_columns,
how='outer',
suffixes=('', '_timegpt')
)
# Clean up any duplicated columns from the merge
for col in combined_eval_df.columns:
if col.endswith('_timegpt'):
base_col = col.replace('_timegpt', '')
# Fill NaN values in the original column with values from the _timegpt column
if base_col in combined_eval_df.columns:
combined_eval_df[base_col] = combined_eval_df[base_col].fillna(combined_eval_df[col])
# Remove the _timegpt column
combined_eval_df = combined_eval_df.drop(columns=[col])
else:
combined_eval_df = timegpt_eval_df if not timegpt_eval_df.empty else combined_eval_df
if not combined_cv_results.empty and not timegpt_cv_df.empty:
# Make sure we're not duplicating the 'y' column
if 'y' in combined_cv_results.columns and 'y' in timegpt_cv_df.columns:
timegpt_cv_df_no_y = timegpt_cv_df.drop(columns=['y'])
combined_cv_results = pd.merge(
combined_cv_results,
timegpt_cv_df_no_y,
on=['unique_id', 'ds', 'cutoff'],
how='outer'
)
else:
combined_cv_results = pd.concat([combined_cv_results, timegpt_cv_df], ignore_index=True)
else:
combined_cv_results = timegpt_cv_df if not timegpt_cv_df.empty else combined_cv_results
if not combined_future_forecasts.empty and not forecast_timegpt.empty:
# Make sure we're merging on common columns
combined_future_forecasts = pd.merge(
combined_future_forecasts,
forecast_timegpt,
on=['unique_id', 'ds'],
how='outer'
)
else:
combined_future_forecasts = forecast_timegpt if not forecast_timegpt.empty else combined_future_forecasts
except Exception as e:
return None, None, None, None, None, None, f"Error with TimeGPT: {str(e)}"
# Create plots
if not combined_cv_results.empty:
fig_validation = create_forecast_plot(
combined_cv_results,
df,
f"{eval_strategy} Results"
)
else:
fig_validation = None
if not combined_future_forecasts.empty:
fig_future = create_future_forecast_plot(combined_future_forecasts, df)
else:
fig_future = None
# Export results
export_files = export_results(combined_eval_df, combined_cv_results, combined_future_forecasts)
return combined_eval_df, combined_cv_results, fig_validation, combined_future_forecasts, fig_future, export_files, "Analysis completed successfully!"
except Exception as e:
return None, None, None, None, None, None, f"Error during forecasting: {str(e)}"
# Sample CSV file generation
def download_sample():
sample_data = """unique_id,ds,y
^GSPC,2023-01-03,3824.139892578125
^GSPC,2023-01-04,3852.969970703125
^GSPC,2023-01-05,3808.10009765625
^GSPC,2023-01-06,3895.080078125
^GSPC,2023-01-09,3892.090087890625
^GSPC,2023-01-10,3919.25
^GSPC,2023-01-11,3969.610107421875
^GSPC,2023-01-12,3983.169921875
^GSPC,2023-01-13,3999.090087890625
^GSPC,2023-01-17,3990.969970703125
^GSPC,2023-01-18,3928.860107421875
^GSPC,2023-01-19,3898.85009765625
^GSPC,2023-01-20,3972.610107421875
^GSPC,2023-01-23,4019.81005859375
^GSPC,2023-01-24,4016.949951171875
^GSPC,2023-01-25,4016.219970703125
^GSPC,2023-01-26,4060.429931640625
^GSPC,2023-01-27,4070.56005859375
^GSPC,2023-01-30,4017.77001953125
^GSPC,2023-01-31,4076.60009765625
^GSPC,2023-02-01,4119.2099609375
^GSPC,2023-02-02,4179.759765625
^GSPC,2023-02-03,4136.47998046875
^GSPC,2023-02-06,4111.080078125
^GSPC,2023-02-07,4164
^GSPC,2023-02-08,4117.85986328125
^GSPC,2023-02-09,4081.5
^GSPC,2023-02-10,4090.4599609375
^GSPC,2023-02-13,4137.2900390625
^GSPC,2023-02-14,4136.1298828125
^GSPC,2023-02-15,4147.60009765625
^GSPC,2023-02-16,4090.409912109375
^GSPC,2023-02-17,4079.090087890625
^GSPC,2023-02-21,3997.340087890625
^GSPC,2023-02-22,3991.050048828125
^GSPC,2023-02-23,4012.320068359375
^GSPC,2023-02-24,3970.0400390625
^GSPC,2023-02-27,3982.239990234375
^GSPC,2023-02-28,3970.14990234375
^GSPC,2023-03-01,3951.389892578125
^GSPC,2023-03-02,3981.35009765625
^GSPC,2023-03-03,4045.639892578125
^GSPC,2023-03-06,4048.419921875
^GSPC,2023-03-07,3986.3701171875
^GSPC,2023-03-08,3992.010009765625
^GSPC,2023-03-09,3918.320068359375
^GSPC,2023-03-10,3861.590087890625
^GSPC,2023-03-13,3855.760009765625
^GSPC,2023-03-14,3919.2900390625
^GSPC,2023-03-15,3891.929931640625
^GSPC,2023-03-16,3960.280029296875
^GSPC,2023-03-17,3916.639892578125
^GSPC,2023-03-20,3951.570068359375
^GSPC,2023-03-21,4002.8701171875
^GSPC,2023-03-22,3936.969970703125
^GSPC,2023-03-23,3948.719970703125
^GSPC,2023-03-24,3970.989990234375
^GSPC,2023-03-27,3977.530029296875
^GSPC,2023-03-28,3971.27001953125
^GSPC,2023-03-29,4027.81005859375
^GSPC,2023-03-30,4050.830078125
^GSPC,2023-03-31,4109.31005859375
^GSPC,2023-04-03,4124.509765625
^GSPC,2023-04-04,4100.60009765625
^GSPC,2023-04-05,4090.3798828125
^GSPC,2023-04-06,4105.02001953125
^GSPC,2023-04-10,4109.10986328125
^GSPC,2023-04-11,4108.93994140625
^GSPC,2023-04-12,4091.949951171875
^GSPC,2023-04-13,4146.22021484375
^GSPC,2023-04-14,4137.64013671875
^GSPC,2023-04-17,4151.31982421875
^GSPC,2023-04-18,4154.8701171875
^GSPC,2023-04-19,4154.52001953125
^GSPC,2023-04-20,4129.7900390625
^GSPC,2023-04-21,4133.52001953125
^GSPC,2023-04-24,4137.0400390625
^GSPC,2023-04-25,4071.6298828125
^GSPC,2023-04-26,4055.989990234375
^GSPC,2023-04-27,4135.35009765625
^GSPC,2023-04-28,4169.47998046875
^GSPC,2023-05-01,4167.8701171875
^GSPC,2023-05-02,4119.580078125
^GSPC,2023-05-03,4090.75
^GSPC,2023-05-04,4061.219970703125
^GSPC,2023-05-05,4136.25
^GSPC,2023-05-08,4138.1201171875
^GSPC,2023-05-09,4119.169921875
^GSPC,2023-05-10,4137.64013671875
^GSPC,2023-05-11,4130.6201171875
^GSPC,2023-05-12,4124.080078125
^GSPC,2023-05-15,4136.27978515625
^GSPC,2023-05-16,4109.89990234375
^GSPC,2023-05-17,4158.77001953125
^GSPC,2023-05-18,4198.0498046875
^GSPC,2023-05-19,4191.97998046875
^GSPC,2023-05-22,4192.6298828125
^GSPC,2023-05-23,4145.580078125
^GSPC,2023-05-24,4115.240234375
^GSPC,2023-05-25,4151.27978515625
^GSPC,2023-05-26,4205.4501953125
^GSPC,2023-05-30,4205.52001953125
^GSPC,2023-05-31,4179.830078125
^GSPC,2023-06-01,4221.02001953125
^GSPC,2023-06-02,4282.3701171875
^GSPC,2023-06-05,4273.7900390625
^GSPC,2023-06-06,4283.85009765625
^GSPC,2023-06-07,4267.52001953125
^GSPC,2023-06-08,4293.93017578125
^GSPC,2023-06-09,4298.85986328125
^GSPC,2023-06-12,4338.93017578125
^GSPC,2023-06-13,4369.009765625
^GSPC,2023-06-14,4372.58984375
^GSPC,2023-06-15,4425.83984375
^GSPC,2023-06-16,4409.58984375
^GSPC,2023-06-20,4388.7099609375
^GSPC,2023-06-21,4365.68994140625
^GSPC,2023-06-22,4381.89013671875
^GSPC,2023-06-23,4348.330078125
^GSPC,2023-06-26,4328.81982421875
^GSPC,2023-06-27,4378.41015625
^GSPC,2023-06-28,4376.85986328125
^GSPC,2023-06-29,4396.43994140625
^GSPC,2023-06-30,4450.3798828125
^GSPC,2023-07-03,4455.58984375
^GSPC,2023-07-05,4446.81982421875
^GSPC,2023-07-06,4411.58984375
^GSPC,2023-07-07,4398.9501953125
^GSPC,2023-07-10,4409.52978515625
^GSPC,2023-07-11,4439.259765625
^GSPC,2023-07-12,4472.16015625
^GSPC,2023-07-13,4510.0400390625
^GSPC,2023-07-14,4505.419921875
^GSPC,2023-07-17,4522.7900390625
^GSPC,2023-07-18,4554.97998046875
^GSPC,2023-07-19,4565.72021484375
^GSPC,2023-07-20,4534.8701171875
^GSPC,2023-07-21,4536.33984375
^GSPC,2023-07-24,4554.64013671875
^GSPC,2023-07-25,4567.4599609375
^GSPC,2023-07-26,4566.75
^GSPC,2023-07-27,4537.41015625
^GSPC,2023-07-28,4582.22998046875
^GSPC,2023-07-31,4588.9599609375
^GSPC,2023-08-01,4576.72998046875
^GSPC,2023-08-02,4513.39013671875
^GSPC,2023-08-03,4501.89013671875
^GSPC,2023-08-04,4478.02978515625
^GSPC,2023-08-07,4518.43994140625
^GSPC,2023-08-08,4499.3798828125
^GSPC,2023-08-09,4467.7099609375
^GSPC,2023-08-10,4468.830078125
^GSPC,2023-08-11,4464.0498046875
^GSPC,2023-08-14,4489.72021484375
^GSPC,2023-08-15,4437.85986328125
^GSPC,2023-08-16,4404.330078125
^GSPC,2023-08-17,4370.35986328125
^GSPC,2023-08-18,4369.7099609375
^GSPC,2023-08-21,4399.77001953125
^GSPC,2023-08-22,4387.5498046875
^GSPC,2023-08-23,4436.009765625
^GSPC,2023-08-24,4376.31005859375
^GSPC,2023-08-25,4405.7099609375
^GSPC,2023-08-28,4433.31005859375
^GSPC,2023-08-29,4497.6298828125
^GSPC,2023-08-30,4514.8701171875
^GSPC,2023-08-31,4507.66015625
^GSPC,2023-09-01,4515.77001953125
^GSPC,2023-09-05,4496.830078125
^GSPC,2023-09-06,4465.47998046875
^GSPC,2023-09-07,4451.14013671875
^GSPC,2023-09-08,4457.490234375
^GSPC,2023-09-11,4487.4599609375
^GSPC,2023-09-12,4461.89990234375
^GSPC,2023-09-13,4467.43994140625
^GSPC,2023-09-14,4505.10009765625
^GSPC,2023-09-15,4450.31982421875
^GSPC,2023-09-18,4453.52978515625
^GSPC,2023-09-19,4443.9501953125
^GSPC,2023-09-20,4402.2001953125
^GSPC,2023-09-21,4330
^GSPC,2023-09-22,4320.06005859375
^GSPC,2023-09-25,4337.43994140625
^GSPC,2023-09-26,4273.52978515625
^GSPC,2023-09-27,4274.509765625
^GSPC,2023-09-28,4299.7001953125
^GSPC,2023-09-29,4288.0498046875
^GSPC,2023-10-02,4288.39013671875
^GSPC,2023-10-03,4229.4501953125
^GSPC,2023-10-04,4263.75
^GSPC,2023-10-05,4258.18994140625
^GSPC,2023-10-06,4308.5
^GSPC,2023-10-09,4335.66015625
^GSPC,2023-10-10,4358.240234375
^GSPC,2023-10-11,4376.9501953125
^GSPC,2023-10-12,4349.60986328125
^GSPC,2023-10-13,4327.77978515625
^GSPC,2023-10-16,4373.6298828125
^GSPC,2023-10-17,4373.2001953125
^GSPC,2023-10-18,4314.60009765625
^GSPC,2023-10-19,4278
^GSPC,2023-10-20,4224.16015625
^GSPC,2023-10-23,4217.0400390625
^GSPC,2023-10-24,4247.68017578125
^GSPC,2023-10-25,4186.77001953125
^GSPC,2023-10-26,4137.22998046875
^GSPC,2023-10-27,4117.3701171875
^GSPC,2023-10-30,4166.81982421875
^GSPC,2023-10-31,4193.7998046875
^GSPC,2023-11-01,4237.85986328125
^GSPC,2023-11-02,4317.77978515625
^GSPC,2023-11-03,4358.33984375
^GSPC,2023-11-06,4365.97998046875
^GSPC,2023-11-07,4378.3798828125
^GSPC,2023-11-08,4382.77978515625
^GSPC,2023-11-09,4347.35009765625
^GSPC,2023-11-10,4415.240234375
^GSPC,2023-11-13,4411.5498046875
^GSPC,2023-11-14,4495.7001953125
^GSPC,2023-11-15,4502.8798828125
^GSPC,2023-11-16,4508.240234375
^GSPC,2023-11-17,4514.02001953125
^GSPC,2023-11-20,4547.3798828125
^GSPC,2023-11-21,4538.18994140625
^GSPC,2023-11-22,4556.6201171875
^GSPC,2023-11-24,4559.33984375
^GSPC,2023-11-27,4550.43017578125
^GSPC,2023-11-28,4554.89013671875
^GSPC,2023-11-29,4550.580078125
^GSPC,2023-11-30,4567.7998046875
^GSPC,2023-12-01,4594.6298828125
^GSPC,2023-12-04,4569.77978515625
^GSPC,2023-12-05,4567.18017578125
^GSPC,2023-12-06,4549.33984375
^GSPC,2023-12-07,4585.58984375
^GSPC,2023-12-08,4604.3701171875
^GSPC,2023-12-11,4622.43994140625
^GSPC,2023-12-12,4643.7001953125
^GSPC,2023-12-13,4707.08984375
^GSPC,2023-12-14,4719.5498046875
^GSPC,2023-12-15,4719.18994140625
^GSPC,2023-12-18,4740.56005859375
^GSPC,2023-12-19,4768.3701171875
^GSPC,2023-12-20,4698.35009765625
^GSPC,2023-12-21,4746.75
^GSPC,2023-12-22,4754.6298828125
^GSPC,2023-12-26,4774.75
^GSPC,2023-12-27,4781.580078125
^GSPC,2023-12-28,4783.35009765625
^GSPC,2023-12-29,4769.830078125
^GSPC,2024-01-02,4742.830078125
^GSPC,2024-01-03,4704.81005859375
^GSPC,2024-01-04,4688.68017578125
^GSPC,2024-01-05,4697.240234375
^GSPC,2024-01-08,4763.5400390625
^GSPC,2024-01-09,4756.5
^GSPC,2024-01-10,4783.4501953125
^GSPC,2024-01-11,4780.240234375
^GSPC,2024-01-12,4783.830078125
^GSPC,2024-01-16,4765.97998046875
^GSPC,2024-01-17,4739.2099609375
^GSPC,2024-01-18,4780.93994140625
^GSPC,2024-01-19,4839.81005859375
^GSPC,2024-01-22,4850.43017578125
^GSPC,2024-01-23,4864.60009765625
^GSPC,2024-01-24,4868.5498046875
^GSPC,2024-01-25,4894.16015625
^GSPC,2024-01-26,4890.97021484375
^GSPC,2024-01-29,4927.93017578125
^GSPC,2024-01-30,4924.97021484375
^GSPC,2024-01-31,4845.64990234375
^GSPC,2024-02-01,4906.18994140625
^GSPC,2024-02-02,4958.60986328125
^GSPC,2024-02-05,4942.81005859375
^GSPC,2024-02-06,4954.22998046875
^GSPC,2024-02-07,4995.06005859375
^GSPC,2024-02-08,4997.91015625
^GSPC,2024-02-09,5026.60986328125
^GSPC,2024-02-12,5021.83984375
^GSPC,2024-02-13,4953.169921875
^GSPC,2024-02-14,5000.6201171875
^GSPC,2024-02-15,5029.72998046875
^GSPC,2024-02-16,5005.56982421875
^GSPC,2024-02-20,4975.509765625
^GSPC,2024-02-21,4981.7998046875
^GSPC,2024-02-22,5087.02978515625
^GSPC,2024-02-23,5088.7998046875
^GSPC,2024-02-26,5069.52978515625
^GSPC,2024-02-27,5078.18017578125
^GSPC,2024-02-28,5069.759765625
^GSPC,2024-02-29,5096.27001953125
^GSPC,2024-03-01,5137.080078125
^GSPC,2024-03-04,5130.9501953125
^GSPC,2024-03-05,5078.64990234375
^GSPC,2024-03-06,5104.759765625
^GSPC,2024-03-07,5157.35986328125
^GSPC,2024-03-08,5123.68994140625
^GSPC,2024-03-11,5117.93994140625
^GSPC,2024-03-12,5175.27001953125
^GSPC,2024-03-13,5165.31005859375
^GSPC,2024-03-14,5150.47998046875
^GSPC,2024-03-15,5117.08984375
^GSPC,2024-03-18,5149.419921875
^GSPC,2024-03-19,5178.509765625
^GSPC,2024-03-20,5224.6201171875
^GSPC,2024-03-21,5241.52978515625
^GSPC,2024-03-22,5234.18017578125
^GSPC,2024-03-25,5218.18994140625
^GSPC,2024-03-26,5203.580078125
^GSPC,2024-03-27,5248.490234375
^GSPC,2024-03-28,5254.35009765625
^GSPC,2024-04-01,5243.77001953125
^GSPC,2024-04-02,5205.81005859375
^GSPC,2024-04-03,5211.490234375
^GSPC,2024-04-04,5147.2099609375
^GSPC,2024-04-05,5204.33984375
^GSPC,2024-04-08,5202.39013671875
^GSPC,2024-04-09,5209.91015625
^GSPC,2024-04-10,5160.64013671875
^GSPC,2024-04-11,5199.06005859375
^GSPC,2024-04-12,5123.41015625
^GSPC,2024-04-15,5061.81982421875
^GSPC,2024-04-16,5051.41015625
^GSPC,2024-04-17,5022.2099609375
^GSPC,2024-04-18,5011.1201171875
^GSPC,2024-04-19,4967.22998046875
^GSPC,2024-04-22,5010.60009765625
^GSPC,2024-04-23,5070.5498046875
^GSPC,2024-04-24,5071.6298828125
^GSPC,2024-04-25,5048.419921875
^GSPC,2024-04-26,5099.9599609375
^GSPC,2024-04-29,5116.169921875
^GSPC,2024-04-30,5035.68994140625
^GSPC,2024-05-01,5018.39013671875
^GSPC,2024-05-02,5064.2001953125
^GSPC,2024-05-03,5127.7900390625
^GSPC,2024-05-06,5180.740234375
^GSPC,2024-05-07,5187.7001953125
^GSPC,2024-05-08,5187.669921875
^GSPC,2024-05-09,5214.080078125
^GSPC,2024-05-10,5222.68017578125
^GSPC,2024-05-13,5221.419921875
^GSPC,2024-05-14,5246.68017578125
^GSPC,2024-05-15,5308.14990234375
^GSPC,2024-05-16,5297.10009765625
^GSPC,2024-05-17,5303.27001953125
^GSPC,2024-05-20,5308.1298828125
^GSPC,2024-05-21,5321.41015625
^GSPC,2024-05-22,5307.009765625
^GSPC,2024-05-23,5267.83984375
^GSPC,2024-05-24,5304.72021484375
^GSPC,2024-05-28,5306.0400390625
^GSPC,2024-05-29,5266.9501953125
^GSPC,2024-05-30,5235.47998046875
^GSPC,2024-05-31,5277.509765625
^GSPC,2024-06-03,5283.39990234375
^GSPC,2024-06-04,5291.33984375
^GSPC,2024-06-05,5354.02978515625
^GSPC,2024-06-06,5352.9599609375
^GSPC,2024-06-07,5346.990234375
^GSPC,2024-06-10,5360.7900390625
^GSPC,2024-06-11,5375.31982421875
^GSPC,2024-06-12,5421.02978515625
^GSPC,2024-06-13,5433.740234375
^GSPC,2024-06-14,5431.60009765625
^GSPC,2024-06-17,5473.22998046875
^GSPC,2024-06-18,5487.02978515625
^GSPC,2024-06-20,5473.169921875
^GSPC,2024-06-21,5464.6201171875
^GSPC,2024-06-24,5447.8701171875
^GSPC,2024-06-25,5469.2998046875
^GSPC,2024-06-26,5477.89990234375
^GSPC,2024-06-27,5482.8701171875
^GSPC,2024-06-28,5460.47998046875
^GSPC,2024-07-01,5475.08984375
^GSPC,2024-07-02,5509.009765625
^GSPC,2024-07-03,5537.02001953125
^GSPC,2024-07-05,5567.18994140625
^GSPC,2024-07-08,5572.85009765625
^GSPC,2024-07-09,5576.97998046875
^GSPC,2024-07-10,5633.91015625
^GSPC,2024-07-11,5584.5400390625
^GSPC,2024-07-12,5615.35009765625
^GSPC,2024-07-15,5631.22021484375
^GSPC,2024-07-16,5667.2001953125
^GSPC,2024-07-17,5588.27001953125
^GSPC,2024-07-18,5544.58984375
^GSPC,2024-07-19,5505
^GSPC,2024-07-22,5564.41015625
^GSPC,2024-07-23,5555.740234375
^GSPC,2024-07-24,5427.1298828125
^GSPC,2024-07-25,5399.22021484375
^GSPC,2024-07-26,5459.10009765625
^GSPC,2024-07-29,5463.5400390625
^GSPC,2024-07-30,5436.43994140625
^GSPC,2024-07-31,5522.2998046875
^GSPC,2024-08-01,5446.68017578125
^GSPC,2024-08-02,5346.56005859375
^GSPC,2024-08-05,5186.330078125
^GSPC,2024-08-06,5240.02978515625
^GSPC,2024-08-07,5199.5
^GSPC,2024-08-08,5319.31005859375
^GSPC,2024-08-09,5344.16015625
^GSPC,2024-08-12,5344.39013671875
^GSPC,2024-08-13,5434.43017578125
^GSPC,2024-08-14,5455.2099609375
^GSPC,2024-08-15,5543.22021484375
^GSPC,2024-08-16,5554.25
^GSPC,2024-08-19,5608.25
^GSPC,2024-08-20,5597.1201171875
^GSPC,2024-08-21,5620.85009765625
^GSPC,2024-08-22,5570.64013671875
^GSPC,2024-08-23,5634.60986328125
^GSPC,2024-08-26,5616.83984375
^GSPC,2024-08-27,5625.7998046875
^GSPC,2024-08-28,5592.18017578125
^GSPC,2024-08-29,5591.9599609375
^GSPC,2024-08-30,5648.39990234375
^GSPC,2024-09-03,5528.93017578125
^GSPC,2024-09-04,5520.06982421875
^GSPC,2024-09-05,5503.41015625
^GSPC,2024-09-06,5408.419921875
^GSPC,2024-09-09,5471.0498046875
^GSPC,2024-09-10,5495.52001953125
^GSPC,2024-09-11,5554.1298828125
^GSPC,2024-09-12,5595.759765625
^GSPC,2024-09-13,5626.02001953125
^GSPC,2024-09-16,5633.08984375
^GSPC,2024-09-17,5634.580078125
^GSPC,2024-09-18,5618.259765625
^GSPC,2024-09-19,5713.64013671875
^GSPC,2024-09-20,5702.5498046875
^GSPC,2024-09-23,5718.56982421875
^GSPC,2024-09-24,5732.93017578125
^GSPC,2024-09-25,5722.259765625
^GSPC,2024-09-26,5745.3701171875
^GSPC,2024-09-27,5738.169921875
^GSPC,2024-09-30,5762.47998046875
^GSPC,2024-10-01,5708.75
^GSPC,2024-10-02,5709.5400390625
^GSPC,2024-10-03,5699.93994140625
^GSPC,2024-10-04,5751.06982421875
^GSPC,2024-10-07,5695.93994140625
^GSPC,2024-10-08,5751.1298828125
^GSPC,2024-10-09,5792.0400390625
^GSPC,2024-10-10,5780.0498046875
^GSPC,2024-10-11,5815.02978515625
^GSPC,2024-10-14,5859.85009765625
^GSPC,2024-10-15,5815.259765625
^GSPC,2024-10-16,5842.47021484375
^GSPC,2024-10-17,5841.47021484375
^GSPC,2024-10-18,5864.669921875
^GSPC,2024-10-21,5853.97998046875
^GSPC,2024-10-22,5851.2001953125
^GSPC,2024-10-23,5797.419921875
^GSPC,2024-10-24,5809.85986328125
^GSPC,2024-10-25,5808.1201171875
^GSPC,2024-10-28,5823.52001953125
^GSPC,2024-10-29,5832.919921875
^GSPC,2024-10-30,5813.669921875
^GSPC,2024-10-31,5705.4501953125
^GSPC,2024-11-01,5728.7998046875
^GSPC,2024-11-04,5712.68994140625
^GSPC,2024-11-05,5782.759765625
^GSPC,2024-11-06,5929.0400390625
^GSPC,2024-11-07,5973.10009765625
^GSPC,2024-11-08,5995.5400390625
^GSPC,2024-11-11,6001.35009765625
^GSPC,2024-11-12,5983.990234375
^GSPC,2024-11-13,5985.3798828125
^GSPC,2024-11-14,5949.169921875
^GSPC,2024-11-15,5870.6201171875
^GSPC,2024-11-18,5893.6201171875
^GSPC,2024-11-19,5916.97998046875
^GSPC,2024-11-20,5917.10986328125
^GSPC,2024-11-21,5948.7099609375
^GSPC,2024-11-22,5969.33984375
^GSPC,2024-11-25,5987.3701171875
^GSPC,2024-11-26,6021.6298828125
^GSPC,2024-11-27,5998.740234375
^GSPC,2024-11-29,6032.3798828125
^GSPC,2024-12-02,6047.14990234375
^GSPC,2024-12-03,6049.8798828125
^GSPC,2024-12-04,6086.490234375
^GSPC,2024-12-05,6075.10986328125
^GSPC,2024-12-06,6090.27001953125
^GSPC,2024-12-09,6052.85009765625
^GSPC,2024-12-10,6034.91015625
^GSPC,2024-12-11,6084.18994140625
^GSPC,2024-12-12,6051.25
^GSPC,2024-12-13,6051.08984375
^GSPC,2024-12-16,6074.080078125
^GSPC,2024-12-17,6050.60986328125
^GSPC,2024-12-18,5872.16015625
^GSPC,2024-12-19,5867.080078125
^GSPC,2024-12-20,5930.85009765625
^GSPC,2024-12-23,5974.06982421875
^GSPC,2024-12-24,6040.0400390625
^GSPC,2024-12-26,6037.58984375
^GSPC,2024-12-27,5970.83984375
^GSPC,2024-12-30,5906.93994140625
^GSPC,2024-12-31,5881.6298828125
^GSPC,2025-01-02,5868.5498046875
^GSPC,2025-01-03,5942.47021484375
^GSPC,2025-01-06,5975.3798828125
^GSPC,2025-01-07,5909.02978515625
^GSPC,2025-01-08,5918.25
^GSPC,2025-01-10,5827.0400390625
^GSPC,2025-01-13,5836.22021484375
^GSPC,2025-01-14,5842.91015625
^GSPC,2025-01-15,5949.91015625
^GSPC,2025-01-16,5937.33984375
^GSPC,2025-01-17,5996.66015625
^GSPC,2025-01-21,6049.240234375
^GSPC,2025-01-22,6086.3701171875
^GSPC,2025-01-23,6118.7099609375
^GSPC,2025-01-24,6101.240234375
^GSPC,2025-01-27,6012.27978515625
^GSPC,2025-01-28,6067.7001953125
^GSPC,2025-01-29,6039.31005859375
^GSPC,2025-01-30,6071.169921875
^GSPC,2025-01-31,6040.52978515625
^GSPC,2025-02-03,5994.56982421875
^GSPC,2025-02-04,6037.8798828125
^GSPC,2025-02-05,6061.47998046875
^GSPC,2025-02-06,6083.56982421875
^GSPC,2025-02-07,6025.990234375
^GSPC,2025-02-10,6066.43994140625
^GSPC,2025-02-11,6068.5
^GSPC,2025-02-12,6051.97021484375
^GSPC,2025-02-13,6115.06982421875
^GSPC,2025-02-14,6114.6298828125
^GSPC,2025-02-18,6129.580078125
^GSPC,2025-02-19,6144.14990234375
^GSPC,2025-02-20,6117.52001953125
^GSPC,2025-02-21,6013.1298828125
^GSPC,2025-02-24,5983.25
^GSPC,2025-02-25,5955.25
^GSPC,2025-02-26,5956.06005859375
^GSPC,2025-02-27,5861.56982421875
^GSPC,2025-02-28,5954.5
^GSPC,2025-03-03,5849.72021484375
^GSPC,2025-03-04,5778.14990234375
^GSPC,2025-03-05,5842.6298828125
^GSPC,2025-03-06,5738.52001953125
^GSPC,2025-03-07,5770.2001953125
^GSPC,2025-03-10,5614.56005859375
^GSPC,2025-03-11,5572.06982421875
^GSPC,2025-03-12,5599.2998046875
^GSPC,2025-03-13,5521.52001953125
^GSPC,2025-03-14,5638.93994140625
^GSPC,2025-03-17,5675.1201171875
^GSPC,2025-03-18,5614.66015625
^GSPC,2025-03-19,5675.2900390625
^GSPC,2025-03-20,5662.89013671875
^GSPC,2025-03-21,5667.56005859375
^GSPC,2025-03-24,5767.56982421875
^GSPC,2025-03-25,5776.64990234375
^GSPC,2025-03-26,5712.2001953125
^GSPC,2025-03-27,5693.31005859375
^GSPC,2025-03-28,5580.93994140625
^GSPC,2025-03-31,5611.85009765625
"""
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w', newline='')
temp.write(sample_data)
temp.close()
return temp.name
# Global theme
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="indigo",
neutral_hue="gray"
)
# Gradio interface
with gr.Blocks(title="Time Series Forecasting App", theme=theme) as app:
gr.Markdown("# 📈 Time Series Forecasting App")
gr.Markdown("Upload a CSV with `unique_id`, `ds`, and `y` columns to apply forecasting models.")
# Disclaimer about external predictors
with gr.Accordion("Disclaimer", open=True):
gr.Markdown("""
**Disclaimer:** For simplicity, this app does not allow the use of external predictors.
However, they can be easily included in the underlying statsforecast (for AutoARIMA)
and the TimeGPT implementation by Nixtla. To use external predictors, you would need to modify
the code to include them in your forecasting models.
""")
with gr.Row():
with gr.Column(scale=2):
file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
download_btn = gr.Button("Download Sample Data", variant="secondary")
download_output = gr.File(label="Click to download", visible=True)
download_btn.click(fn=download_sample, outputs=download_output)
with gr.Accordion("Data & Validation Settings", open=True):
frequency = gr.Dropdown(
choices=[
("Hourly", "H"),
("Business Day", "B"),
("Daily", "D"),
("Weekly", "WS"),
("Monthly", "MS"),
("Quarterly", "QS"),
("Yearly", "YS")
],
label="Data Frequency",
value="D"
)
# Evaluation Strategy
eval_strategy = gr.Radio(
choices=["Fixed Window", "Cross Validation"],
label="Evaluation Strategy",
value="Cross Validation"
)
# Fixed Window settings
with gr.Group(visible=True) as fixed_window_box:
gr.Markdown("### Fixed Window Settings")
horizon = gr.Slider(1, 100, value=10, step=1, label="Validation Horizon (steps ahead to predict)")
# Cross Validation settings
with gr.Group(visible=True) as cv_box:
gr.Markdown("### Cross Validation Settings")
with gr.Row():
step_size = gr.Slider(1, 50, value=10, step=1, label="Step Size (distance between windows)")
num_windows = gr.Slider(1, 20, value=5, step=1, label="Number of Windows")
# Future forecast settings (always visible)
with gr.Group():
gr.Markdown("### Future Forecast Settings")
future_horizon = gr.Slider(1, 100, value=10, step=1, label="Future Forecast Horizon (steps to predict)")
with gr.Accordion("Model Configuration", open=True):
with gr.Tabs() as model_tabs:
# Traditional Statistical Models Tab
with gr.TabItem("Statistical Models"):
gr.Markdown("## Basic Models")
with gr.Row():
use_historical_avg = gr.Checkbox(label="Historical Average", value=True)
use_naive = gr.Checkbox(label="Naive", value=True)
# Common seasonality parameter at the top level
with gr.Group():
gr.Markdown("### Seasonality Configuration")
gr.Markdown("This seasonality period affects Seasonal Naive, Seasonal Window Average, AutoETS, and AutoARIMA models")
seasonality = gr.Number(label="Seasonality Period", value=5)
gr.Markdown("### Seasonal Models")
with gr.Row():
use_seasonal_naive = gr.Checkbox(label="Seasonal Naive", value=True)
gr.Markdown("### Window-based Models")
with gr.Row():
use_window_avg = gr.Checkbox(label="Window Average", value=False)
window_size = gr.Number(label="Window Size", value=10)
with gr.Row():
use_seasonal_window_avg = gr.Checkbox(label="Seasonal Window Average", value=False)
seasonal_window_size = gr.Number(label="Seasonal Window Size", value=2)
gr.Markdown("### Advanced Models (use seasonality from above)")
with gr.Row():
use_autoets = gr.Checkbox(label="AutoETS (Exponential Smoothing)", value=False)
use_autoarima = gr.Checkbox(label="AutoARIMA", value=False)
# Transformer Models Tab (TimeGPT)
with gr.TabItem("Transformer Models"):
gr.Markdown("## TimeGPT Model")
gr.Markdown("TimeGPT uses a transformer architecture for state-of-the-art time series forecasting")
with gr.Row():
use_timegpt = gr.Checkbox(label="Use TimeGPT", value=True)
with gr.Group():
gr.Markdown("### TimeGPT Configuration")
with gr.Row():
finetune_loss = gr.Dropdown(
choices=["mape", "mae", "rmse", "smape"],
label="Finetune Loss Metric",
value="mape"
)
confidence_level = gr.Slider(50, 99, value=95, step=1, label="Confidence Level (%)")
gr.Markdown("""
**Note:** Using TimeGPT requires a valid API key. The API key should
be set as an environment variable named `NIXTLA_API_KEY`. This space uses a trial key, which is rate limited.
""")
with gr.Column(scale=3):
message_output = gr.Textbox(label="Status Message")
with gr.Tabs() as tabs:
with gr.TabItem("Validation Results"):
eval_output = gr.Dataframe(label="Evaluation Metrics")
validation_plot = gr.Plot(label="Validation Plot")
validation_output = gr.Dataframe(label="Validation Data", visible=False)
with gr.Row():
show_data_btn = gr.Button("Show Validation Data")
hide_data_btn = gr.Button("Hide Validation Data", visible=False)
def show_data():
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
def hide_data():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
show_data_btn.click(
fn=show_data,
outputs=[validation_output, hide_data_btn, show_data_btn]
)
hide_data_btn.click(
fn=hide_data,
outputs=[validation_output, hide_data_btn, show_data_btn]
)
with gr.TabItem("Future Forecast"):
forecast_plot = gr.Plot(label="Future Forecast Plot")
forecast_output = gr.Dataframe(label="Future Forecast Data", visible=False)
with gr.Row():
show_forecast_btn = gr.Button("Show Forecast Data")
hide_forecast_btn = gr.Button("Hide Forecast Data", visible=False)
def show_forecast():
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
def hide_forecast():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
show_forecast_btn.click(
fn=show_forecast,
outputs=[forecast_output, hide_forecast_btn, show_forecast_btn]
)
hide_forecast_btn.click(
fn=hide_forecast,
outputs=[forecast_output, hide_forecast_btn, show_forecast_btn]
)
with gr.TabItem("Export Results"):
export_files = gr.Files(label="Download Results")
with gr.Row(visible=True) as run_row:
submit_btn = gr.Button("Run Validation and Forecast", variant="primary", size="lg")
# Update visibility of the appropriate box based on evaluation strategy
def update_eval_boxes(strategy):
return (gr.update(visible=strategy == "Fixed Window"),
gr.update(visible=strategy == "Cross Validation"))
eval_strategy.change(
fn=update_eval_boxes,
inputs=[eval_strategy],
outputs=[fixed_window_box, cv_box]
)
# Run forecast when button is clicked
submit_btn.click(
fn=run_forecast,
inputs=[
file_input, frequency, eval_strategy, horizon, step_size, num_windows,
use_historical_avg, use_naive, use_seasonal_naive, seasonality,
use_window_avg, window_size, use_seasonal_window_avg, seasonal_window_size,
use_autoets, use_autoarima, use_timegpt, finetune_loss, confidence_level,
future_horizon
],
outputs=[
eval_output,
validation_output,
validation_plot,
forecast_output,
forecast_plot,
export_files,
message_output]
)
if __name__ == "__main__":
app.launch(share=False)