import os import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from datetime import datetime, timedelta import io import base64 from typing import Optional, Tuple, Dict, Any import warnings warnings.filterwarnings('ignore') # Mock implementations for the original imports # In actual deployment, you'd import these from the original modules class MaskedTimeseries: def __init__(self, series, padding_mask, id_mask, timestamp_seconds, time_interval_seconds): self.series = series self.padding_mask = padding_mask self.id_mask = id_mask self.timestamp_seconds = timestamp_seconds self.time_interval_seconds = time_interval_seconds class MockToto: """Mock Toto model for demonstration""" def __init__(self): self.model = self @classmethod def from_pretrained(cls, model_name): return cls() def to(self, device): return self def compile(self): return self class MockForecaster: """Mock forecaster for demonstration""" def __init__(self, model): self.model = model def forecast(self, inputs, prediction_length, num_samples, samples_per_batch, use_kv_cache=True): # Generate mock forecast data n_variates, context_length = inputs.series.shape # Create realistic-looking synthetic forecasts samples = [] for _ in range(num_samples): # Use last values as starting point and add some trend/noise last_values = inputs.series[:, -1:] forecast_sample = [] for t in range(prediction_length): # Add some trend and noise trend = torch.randn(n_variates, 1) * 0.1 noise = torch.randn(n_variates, 1) * 0.5 next_val = last_values + trend + noise forecast_sample.append(next_val) last_values = next_val sample = torch.cat(forecast_sample, dim=1) samples.append(sample) # Stack samples along a new dimension forecast_tensor = torch.stack(samples, dim=-1) # shape: (n_variates, prediction_length, num_samples) class MockForecast: def __init__(self, samples): self.samples = MockSamples(samples) class MockSamples: def __init__(self, tensor): self.tensor = tensor def squeeze(self): return self.tensor def cpu(self): return self.tensor def quantile(self, q, dim): # Calculate quantiles along the specified dimension sorted_tensor = torch.sort(self.tensor, dim=dim)[0] indices = (q.unsqueeze(0).unsqueeze(0) * (self.tensor.shape[dim] - 1)).long() return torch.gather(sorted_tensor, dim, indices.expand(sorted_tensor.shape[0], sorted_tensor.shape[1], -1).permute(2, 0, 1)) return MockForecast(forecast_tensor) # Global variables toto_model = None forecaster = None def initialize_model(): """Initialize the Toto model""" global toto_model, forecaster if toto_model is None: # In production, replace with: toto_model = Toto.from_pretrained('Datadog/Toto-Open-Base-1.0') toto_model = MockToto() toto_model.to("cpu") # Use CPU for broader compatibility toto_model.compile() forecaster = MockForecaster(toto_model.model) return toto_model, forecaster def load_sample_data(): """Load sample ETT data for demonstration""" # Generate synthetic ETT-like data dates = pd.date_range(start='2020-01-01', end='2020-12-31 23:45:00', freq='15T') n_points = len(dates) # Create synthetic multivariate time series t = np.arange(n_points) # Base patterns with different frequencies and amplitudes hufl = 5 + 2 * np.sin(2 * np.pi * t / (24 * 4)) + 0.5 * np.sin(2 * np.pi * t / (24 * 4 * 7)) + np.random.normal(0, 0.3, n_points) hull = 4 + 1.5 * np.cos(2 * np.pi * t / (24 * 4)) + 0.3 * np.sin(2 * np.pi * t / (24 * 4 * 30)) + np.random.normal(0, 0.25, n_points) mufl = 6 + 1.8 * np.sin(2 * np.pi * t / (24 * 4)) + 0.4 * np.cos(2 * np.pi * t / (24 * 4 * 7)) + np.random.normal(0, 0.35, n_points) mull = 5.5 + 1.2 * np.cos(2 * np.pi * t / (24 * 4)) + 0.6 * np.sin(2 * np.pi * t / (24 * 4 * 14)) + np.random.normal(0, 0.28, n_points) lufl = 3.5 + 2.2 * np.sin(2 * np.pi * t / (24 * 4)) + 0.8 * np.cos(2 * np.pi * t / (24 * 4 * 21)) + np.random.normal(0, 0.32, n_points) lull = 4.2 + 1.6 * np.cos(2 * np.pi * t / (24 * 4)) + 0.5 * np.sin(2 * np.pi * t / (24 * 4 * 10)) + np.random.normal(0, 0.27, n_points) ot = 25 + 8 * np.sin(2 * np.pi * t / (24 * 4)) + 3 * np.cos(2 * np.pi * t / (24 * 4 * 365)) + np.random.normal(0, 1.2, n_points) df = pd.DataFrame({ 'date': dates, 'HUFL': hufl, 'HULL': hull, 'MUFL': mufl, 'MULL': mull, 'LUFL': lufl, 'LULL': lull, 'OT': ot }) df['timestamp_seconds'] = (df['date'] - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s') return df def prepare_data(df: pd.DataFrame, context_length: int, prediction_length: int) -> Tuple[MaskedTimeseries, pd.DataFrame, pd.DataFrame]: """Prepare data for Toto model""" feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"] n_variates = len(feature_columns) interval = 60 * 15 # 15-min intervals # Ensure we have enough data if len(df) < (context_length + prediction_length): raise ValueError(f"Dataset too small. Need at least {context_length + prediction_length} points, got {len(df)}") input_df = df.iloc[-(context_length + prediction_length):-prediction_length].copy() target_df = df.iloc[-prediction_length:].copy() input_series = torch.from_numpy(input_df[feature_columns].values.T).to(torch.float) timestamp_seconds = torch.from_numpy(input_df.timestamp_seconds.values).expand((n_variates, context_length)) time_interval_seconds = torch.full((n_variates,), interval) inputs = MaskedTimeseries( series=input_series, padding_mask=torch.full_like(input_series, True, dtype=torch.bool), id_mask=torch.zeros_like(input_series), timestamp_seconds=timestamp_seconds, time_interval_seconds=time_interval_seconds, ) return inputs, input_df, target_df def create_forecast_plot(input_df: pd.DataFrame, target_df: pd.DataFrame, forecast, feature_columns: list) -> plt.Figure: """Create forecast visualization""" DARK_GREY = "#1c2b34" BLUE = "#3598ec" PURPLE = "#7463e1" LIGHT_PURPLE = "#d7c3ff" PINK = "#ff0099" fig = plt.figure(figsize=(16, 12), dpi=100) fig.suptitle("Toto Time Series Forecasts", fontsize=16, fontweight='bold') n_variates = len(feature_columns) for i, feature in enumerate(feature_columns): plt.subplot(n_variates, 1, i + 1) if i != n_variates - 1: plt.gca().set_xticklabels([]) plt.gca().tick_params(axis="x", color=DARK_GREY, labelcolor=DARK_GREY) plt.gca().tick_params(axis="y", color=DARK_GREY, labelcolor=DARK_GREY) plt.ylabel(feature, rotation=0, ha='right', va='center') # Set x-axis limits context_points = min(960, len(input_df)) plt.xlim(input_df.date.iloc[-context_points], target_df.date.iloc[-1]) # Vertical line separating context and forecast plt.axvline(target_df.date.iloc[0], color=PINK, linestyle=":", alpha=0.8, linewidth=2) # Plot historical data plt.plot(input_df["date"].iloc[-context_points:], input_df[feature].iloc[-context_points:], color=BLUE, linewidth=1.5, label='Historical' if i == 0 else None) # Plot ground truth in forecast period plt.plot(target_df["date"], target_df[feature], color=BLUE, linewidth=1.5, alpha=0.7, label='Actual' if i == 0 else None) # Plot median forecast forecast_median = np.median(forecast.samples.squeeze()[i].cpu().numpy(), axis=-1) plt.plot(target_df["date"], forecast_median, color=PURPLE, linestyle="--", linewidth=2, label='Forecast' if i == 0 else None) # Plot confidence intervals alpha = 0.05 device = torch.device('cpu') qs = forecast.samples.quantile(q=torch.tensor([alpha, 1 - alpha], device=device), dim=-1) plt.fill_between( target_df["date"], qs[0].squeeze()[i].cpu().numpy(), qs[1].squeeze()[i].cpu().numpy(), color=LIGHT_PURPLE, alpha=0.6, label=f'{int((1-2*alpha)*100)}% CI' if i == 0 else None ) if i == 0: plt.legend(loc='upper left', frameon=True, fancybox=True, shadow=True) plt.tight_layout() return fig def run_forecast(context_length: int, prediction_length: int, num_samples: int, samples_per_batch: int, use_kv_cache: bool, progress=gr.Progress()) -> Tuple[plt.Figure, str]: """Run forecasting with given parameters""" try: progress(0.1, desc="Initializing model...") model, forecaster = initialize_model() progress(0.2, desc="Loading data...") df = load_sample_data() progress(0.3, desc="Preparing data...") inputs, input_df, target_df = prepare_data(df, context_length, prediction_length) progress(0.5, desc="Running forecast...") forecast = forecaster.forecast( inputs, prediction_length=prediction_length, num_samples=num_samples, samples_per_batch=min(samples_per_batch, num_samples), use_kv_cache=use_kv_cache, ) progress(0.8, desc="Creating visualization...") feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"] fig = create_forecast_plot(input_df, target_df, forecast, feature_columns) progress(1.0, desc="Complete!") # Generate summary statistics forecast_data = forecast.samples.squeeze().cpu().numpy() summary = f""" ## Forecast Summary **Parameters Used:** - Context Length: {context_length} time steps - Prediction Length: {prediction_length} time steps - Number of Samples: {num_samples} - Samples per Batch: {samples_per_batch} - KV Cache: {'Enabled' if use_kv_cache else 'Disabled'} **Results:** - Variables Forecasted: {len(feature_columns)} - Forecast Shape: {forecast_data.shape} - Mean Absolute Forecast Range: {np.mean(np.max(forecast_data, axis=1) - np.min(forecast_data, axis=1)):.3f} The plot shows historical data in blue, actual values in the forecast period in light blue, median forecasts as purple dashed lines, and 95% confidence intervals in light purple. """ return fig, summary except Exception as e: error_msg = f"Error during forecasting: {str(e)}" fig = plt.figure(figsize=(10, 6)) plt.text(0.5, 0.5, error_msg, ha='center', va='center', fontsize=12, color='red') plt.axis('off') return fig, error_msg # Create Gradio interface def create_interface(): with gr.Blocks(title="Toto Time Series Forecasting", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🔮 Toto Time Series Forecasting This app demonstrates zero-shot time series forecasting using the Toto foundation model. Adjust the parameters below to customize your forecast and see how different settings affect the predictions. **Note:** This demo uses synthetic ETT-like data for illustration purposes. """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Forecasting Parameters") context_length = gr.Slider( minimum=96, maximum=2048, value=512, step=32, label="Context Length", info="Number of historical time steps to use as input" ) prediction_length = gr.Slider( minimum=24, maximum=720, value=96, step=24, label="Prediction Length", info="Number of time steps to forecast into the future" ) num_samples = gr.Slider( minimum=8, maximum=512, value=64, step=8, label="Number of Samples", info="More samples = more stable predictions but slower inference" ) samples_per_batch = gr.Slider( minimum=8, maximum=256, value=32, step=8, label="Samples per Batch", info="Batch size for sample generation (affects memory usage)" ) use_kv_cache = gr.Checkbox( value=True, label="Use KV Cache", info="Enable key-value caching for faster inference" ) forecast_btn = gr.Button("🚀 Run Forecast", variant="primary", size="lg") with gr.Column(scale=2): gr.Markdown("### Forecast Results") forecast_plot = gr.Plot() forecast_summary = gr.Markdown() # Event handlers forecast_btn.click( fn=run_forecast, inputs=[context_length, prediction_length, num_samples, samples_per_batch, use_kv_cache], outputs=[forecast_plot, forecast_summary] ) # Load initial forecast demo.load( fn=lambda: run_forecast(512, 96, 64, 32, True), outputs=[forecast_plot, forecast_summary] ) return demo # For deployment if __name__ == "__main__": # Create and launch the interface demo = create_interface() # For local development if os.getenv("GRADIO_DEV"): demo.launch(debug=True, share=False) else: # For production deployment demo.launch(server_name="0.0.0.0", server_port=7860, share=True) # For Modal.com deployment, add this: """ # modal_app.py import modal image = modal.Image.debian_slim().pip_install([ "gradio", "torch", "numpy", "pandas", "matplotlib", "transformers", # Add other required packages ]) app = modal.App("toto-forecasting") @app.function(image=image, gpu="T4") def run_gradio(): from main import create_interface demo = create_interface() demo.launch(server_name="0.0.0.0", server_port=8000, share=False) if __name__ == "__main__": with app.run(): run_gradio() """ # For Hugging Face Spaces deployment: """ Create these files: 1. app.py (this file) 2. requirements.txt: gradio torch numpy pandas matplotlib transformers 3. README.md with your Space description """