Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
from datetime import datetime, timedelta | |
import io | |
import base64 | |
from typing import Optional, Tuple, Dict, Any | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Mock implementations for the original imports | |
# In actual deployment, you'd import these from the original modules | |
class MaskedTimeseries: | |
def __init__(self, series, padding_mask, id_mask, timestamp_seconds, time_interval_seconds): | |
self.series = series | |
self.padding_mask = padding_mask | |
self.id_mask = id_mask | |
self.timestamp_seconds = timestamp_seconds | |
self.time_interval_seconds = time_interval_seconds | |
class MockToto: | |
"""Mock Toto model for demonstration""" | |
def __init__(self): | |
self.model = self | |
def from_pretrained(cls, model_name): | |
return cls() | |
def to(self, device): | |
return self | |
def compile(self): | |
return self | |
class MockForecaster: | |
"""Mock forecaster for demonstration""" | |
def __init__(self, model): | |
self.model = model | |
def forecast(self, inputs, prediction_length, num_samples, samples_per_batch, use_kv_cache=True): | |
# Generate mock forecast data | |
n_variates, context_length = inputs.series.shape | |
# Create realistic-looking synthetic forecasts | |
samples = [] | |
for _ in range(num_samples): | |
# Use last values as starting point and add some trend/noise | |
last_values = inputs.series[:, -1:] | |
forecast_sample = [] | |
for t in range(prediction_length): | |
# Add some trend and noise | |
trend = torch.randn(n_variates, 1) * 0.1 | |
noise = torch.randn(n_variates, 1) * 0.5 | |
next_val = last_values + trend + noise | |
forecast_sample.append(next_val) | |
last_values = next_val | |
sample = torch.cat(forecast_sample, dim=1) | |
samples.append(sample) | |
# Stack samples along a new dimension | |
forecast_tensor = torch.stack(samples, dim=-1) # shape: (n_variates, prediction_length, num_samples) | |
class MockForecast: | |
def __init__(self, samples): | |
self.samples = MockSamples(samples) | |
class MockSamples: | |
def __init__(self, tensor): | |
self.tensor = tensor | |
def squeeze(self): | |
return self.tensor | |
def cpu(self): | |
return self.tensor | |
def quantile(self, q, dim): | |
# Calculate quantiles along the specified dimension | |
sorted_tensor = torch.sort(self.tensor, dim=dim)[0] | |
indices = (q.unsqueeze(0).unsqueeze(0) * (self.tensor.shape[dim] - 1)).long() | |
return torch.gather(sorted_tensor, dim, indices.expand(sorted_tensor.shape[0], sorted_tensor.shape[1], -1).permute(2, 0, 1)) | |
return MockForecast(forecast_tensor) | |
# Global variables | |
toto_model = None | |
forecaster = None | |
def initialize_model(): | |
"""Initialize the Toto model""" | |
global toto_model, forecaster | |
if toto_model is None: | |
# In production, replace with: toto_model = Toto.from_pretrained('Datadog/Toto-Open-Base-1.0') | |
toto_model = MockToto() | |
toto_model.to("cpu") # Use CPU for broader compatibility | |
toto_model.compile() | |
forecaster = MockForecaster(toto_model.model) | |
return toto_model, forecaster | |
def load_sample_data(): | |
"""Load sample ETT data for demonstration""" | |
# Generate synthetic ETT-like data | |
dates = pd.date_range(start='2020-01-01', end='2020-12-31 23:45:00', freq='15T') | |
n_points = len(dates) | |
# Create synthetic multivariate time series | |
t = np.arange(n_points) | |
# Base patterns with different frequencies and amplitudes | |
hufl = 5 + 2 * np.sin(2 * np.pi * t / (24 * 4)) + 0.5 * np.sin(2 * np.pi * t / (24 * 4 * 7)) + np.random.normal(0, 0.3, n_points) | |
hull = 4 + 1.5 * np.cos(2 * np.pi * t / (24 * 4)) + 0.3 * np.sin(2 * np.pi * t / (24 * 4 * 30)) + np.random.normal(0, 0.25, n_points) | |
mufl = 6 + 1.8 * np.sin(2 * np.pi * t / (24 * 4)) + 0.4 * np.cos(2 * np.pi * t / (24 * 4 * 7)) + np.random.normal(0, 0.35, n_points) | |
mull = 5.5 + 1.2 * np.cos(2 * np.pi * t / (24 * 4)) + 0.6 * np.sin(2 * np.pi * t / (24 * 4 * 14)) + np.random.normal(0, 0.28, n_points) | |
lufl = 3.5 + 2.2 * np.sin(2 * np.pi * t / (24 * 4)) + 0.8 * np.cos(2 * np.pi * t / (24 * 4 * 21)) + np.random.normal(0, 0.32, n_points) | |
lull = 4.2 + 1.6 * np.cos(2 * np.pi * t / (24 * 4)) + 0.5 * np.sin(2 * np.pi * t / (24 * 4 * 10)) + np.random.normal(0, 0.27, n_points) | |
ot = 25 + 8 * np.sin(2 * np.pi * t / (24 * 4)) + 3 * np.cos(2 * np.pi * t / (24 * 4 * 365)) + np.random.normal(0, 1.2, n_points) | |
df = pd.DataFrame({ | |
'date': dates, | |
'HUFL': hufl, | |
'HULL': hull, | |
'MUFL': mufl, | |
'MULL': mull, | |
'LUFL': lufl, | |
'LULL': lull, | |
'OT': ot | |
}) | |
df['timestamp_seconds'] = (df['date'] - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s') | |
return df | |
def prepare_data(df: pd.DataFrame, context_length: int, prediction_length: int) -> Tuple[MaskedTimeseries, pd.DataFrame, pd.DataFrame]: | |
"""Prepare data for Toto model""" | |
feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"] | |
n_variates = len(feature_columns) | |
interval = 60 * 15 # 15-min intervals | |
# Ensure we have enough data | |
if len(df) < (context_length + prediction_length): | |
raise ValueError(f"Dataset too small. Need at least {context_length + prediction_length} points, got {len(df)}") | |
input_df = df.iloc[-(context_length + prediction_length):-prediction_length].copy() | |
target_df = df.iloc[-prediction_length:].copy() | |
input_series = torch.from_numpy(input_df[feature_columns].values.T).to(torch.float) | |
timestamp_seconds = torch.from_numpy(input_df.timestamp_seconds.values).expand((n_variates, context_length)) | |
time_interval_seconds = torch.full((n_variates,), interval) | |
inputs = MaskedTimeseries( | |
series=input_series, | |
padding_mask=torch.full_like(input_series, True, dtype=torch.bool), | |
id_mask=torch.zeros_like(input_series), | |
timestamp_seconds=timestamp_seconds, | |
time_interval_seconds=time_interval_seconds, | |
) | |
return inputs, input_df, target_df | |
def create_forecast_plot(input_df: pd.DataFrame, target_df: pd.DataFrame, forecast, feature_columns: list) -> plt.Figure: | |
"""Create forecast visualization""" | |
DARK_GREY = "#1c2b34" | |
BLUE = "#3598ec" | |
PURPLE = "#7463e1" | |
LIGHT_PURPLE = "#d7c3ff" | |
PINK = "#ff0099" | |
fig = plt.figure(figsize=(16, 12), dpi=100) | |
fig.suptitle("Toto Time Series Forecasts", fontsize=16, fontweight='bold') | |
n_variates = len(feature_columns) | |
for i, feature in enumerate(feature_columns): | |
plt.subplot(n_variates, 1, i + 1) | |
if i != n_variates - 1: | |
plt.gca().set_xticklabels([]) | |
plt.gca().tick_params(axis="x", color=DARK_GREY, labelcolor=DARK_GREY) | |
plt.gca().tick_params(axis="y", color=DARK_GREY, labelcolor=DARK_GREY) | |
plt.ylabel(feature, rotation=0, ha='right', va='center') | |
# Set x-axis limits | |
context_points = min(960, len(input_df)) | |
plt.xlim(input_df.date.iloc[-context_points], target_df.date.iloc[-1]) | |
# Vertical line separating context and forecast | |
plt.axvline(target_df.date.iloc[0], color=PINK, linestyle=":", alpha=0.8, linewidth=2) | |
# Plot historical data | |
plt.plot(input_df["date"].iloc[-context_points:], input_df[feature].iloc[-context_points:], | |
color=BLUE, linewidth=1.5, label='Historical' if i == 0 else None) | |
# Plot ground truth in forecast period | |
plt.plot(target_df["date"], target_df[feature], color=BLUE, linewidth=1.5, alpha=0.7, | |
label='Actual' if i == 0 else None) | |
# Plot median forecast | |
forecast_median = np.median(forecast.samples.squeeze()[i].cpu().numpy(), axis=-1) | |
plt.plot(target_df["date"], forecast_median, color=PURPLE, linestyle="--", linewidth=2, | |
label='Forecast' if i == 0 else None) | |
# Plot confidence intervals | |
alpha = 0.05 | |
device = torch.device('cpu') | |
qs = forecast.samples.quantile(q=torch.tensor([alpha, 1 - alpha], device=device), dim=-1) | |
plt.fill_between( | |
target_df["date"], | |
qs[0].squeeze()[i].cpu().numpy(), | |
qs[1].squeeze()[i].cpu().numpy(), | |
color=LIGHT_PURPLE, | |
alpha=0.6, | |
label=f'{int((1-2*alpha)*100)}% CI' if i == 0 else None | |
) | |
if i == 0: | |
plt.legend(loc='upper left', frameon=True, fancybox=True, shadow=True) | |
plt.tight_layout() | |
return fig | |
def run_forecast(context_length: int, prediction_length: int, num_samples: int, | |
samples_per_batch: int, use_kv_cache: bool, progress=gr.Progress()) -> Tuple[plt.Figure, str]: | |
"""Run forecasting with given parameters""" | |
try: | |
progress(0.1, desc="Initializing model...") | |
model, forecaster = initialize_model() | |
progress(0.2, desc="Loading data...") | |
df = load_sample_data() | |
progress(0.3, desc="Preparing data...") | |
inputs, input_df, target_df = prepare_data(df, context_length, prediction_length) | |
progress(0.5, desc="Running forecast...") | |
forecast = forecaster.forecast( | |
inputs, | |
prediction_length=prediction_length, | |
num_samples=num_samples, | |
samples_per_batch=min(samples_per_batch, num_samples), | |
use_kv_cache=use_kv_cache, | |
) | |
progress(0.8, desc="Creating visualization...") | |
feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"] | |
fig = create_forecast_plot(input_df, target_df, forecast, feature_columns) | |
progress(1.0, desc="Complete!") | |
# Generate summary statistics | |
forecast_data = forecast.samples.squeeze().cpu().numpy() | |
summary = f""" | |
## Forecast Summary | |
**Parameters Used:** | |
- Context Length: {context_length} time steps | |
- Prediction Length: {prediction_length} time steps | |
- Number of Samples: {num_samples} | |
- Samples per Batch: {samples_per_batch} | |
- KV Cache: {'Enabled' if use_kv_cache else 'Disabled'} | |
**Results:** | |
- Variables Forecasted: {len(feature_columns)} | |
- Forecast Shape: {forecast_data.shape} | |
- Mean Absolute Forecast Range: {np.mean(np.max(forecast_data, axis=1) - np.min(forecast_data, axis=1)):.3f} | |
The plot shows historical data in blue, actual values in the forecast period in light blue, | |
median forecasts as purple dashed lines, and 95% confidence intervals in light purple. | |
""" | |
return fig, summary | |
except Exception as e: | |
error_msg = f"Error during forecasting: {str(e)}" | |
fig = plt.figure(figsize=(10, 6)) | |
plt.text(0.5, 0.5, error_msg, ha='center', va='center', fontsize=12, color='red') | |
plt.axis('off') | |
return fig, error_msg | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="Toto Time Series Forecasting", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🔮 Toto Time Series Forecasting | |
This app demonstrates zero-shot time series forecasting using the Toto foundation model. | |
Adjust the parameters below to customize your forecast and see how different settings affect the predictions. | |
**Note:** This demo uses synthetic ETT-like data for illustration purposes. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Forecasting Parameters") | |
context_length = gr.Slider( | |
minimum=96, maximum=2048, value=512, step=32, | |
label="Context Length", | |
info="Number of historical time steps to use as input" | |
) | |
prediction_length = gr.Slider( | |
minimum=24, maximum=720, value=96, step=24, | |
label="Prediction Length", | |
info="Number of time steps to forecast into the future" | |
) | |
num_samples = gr.Slider( | |
minimum=8, maximum=512, value=64, step=8, | |
label="Number of Samples", | |
info="More samples = more stable predictions but slower inference" | |
) | |
samples_per_batch = gr.Slider( | |
minimum=8, maximum=256, value=32, step=8, | |
label="Samples per Batch", | |
info="Batch size for sample generation (affects memory usage)" | |
) | |
use_kv_cache = gr.Checkbox( | |
value=True, | |
label="Use KV Cache", | |
info="Enable key-value caching for faster inference" | |
) | |
forecast_btn = gr.Button("🚀 Run Forecast", variant="primary", size="lg") | |
with gr.Column(scale=2): | |
gr.Markdown("### Forecast Results") | |
forecast_plot = gr.Plot() | |
forecast_summary = gr.Markdown() | |
# Event handlers | |
forecast_btn.click( | |
fn=run_forecast, | |
inputs=[context_length, prediction_length, num_samples, samples_per_batch, use_kv_cache], | |
outputs=[forecast_plot, forecast_summary] | |
) | |
# Load initial forecast | |
demo.load( | |
fn=lambda: run_forecast(512, 96, 64, 32, True), | |
outputs=[forecast_plot, forecast_summary] | |
) | |
return demo | |
# For deployment | |
if __name__ == "__main__": | |
# Create and launch the interface | |
demo = create_interface() | |
# For local development | |
if os.getenv("GRADIO_DEV"): | |
demo.launch(debug=True, share=False) | |
else: | |
# For production deployment | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
# For Modal.com deployment, add this: | |
""" | |
# modal_app.py | |
import modal | |
image = modal.Image.debian_slim().pip_install([ | |
"gradio", | |
"torch", | |
"numpy", | |
"pandas", | |
"matplotlib", | |
"transformers", | |
# Add other required packages | |
]) | |
app = modal.App("toto-forecasting") | |
@app.function(image=image, gpu="T4") | |
def run_gradio(): | |
from main import create_interface | |
demo = create_interface() | |
demo.launch(server_name="0.0.0.0", server_port=8000, share=False) | |
if __name__ == "__main__": | |
with app.run(): | |
run_gradio() | |
""" | |
# For Hugging Face Spaces deployment: | |
""" | |
Create these files: | |
1. app.py (this file) | |
2. requirements.txt: | |
gradio | |
torch | |
numpy | |
pandas | |
matplotlib | |
transformers | |
3. README.md with your Space description | |
""" |