Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gradio/certificate.pem +31 -0
- ETTh1.csv +0 -0
- ETTh2.csv +0 -0
- ETTm1.csv +0 -0
- ETTm2.csv +0 -0
- README.md +2 -8
- gradio_modal.py +410 -0
- inference_tutorial.ipynb +0 -0
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
ETTh1.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ETTh2.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ETTm1.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ETTm2.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.31.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: timeseries
|
3 |
+
app_file: gradio_modal.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.31.0
|
|
|
|
|
6 |
---
|
|
|
|
gradio_modal.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from datetime import datetime, timedelta
|
8 |
+
import io
|
9 |
+
import base64
|
10 |
+
from typing import Optional, Tuple, Dict, Any
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings('ignore')
|
13 |
+
|
14 |
+
# Mock implementations for the original imports
|
15 |
+
# In actual deployment, you'd import these from the original modules
|
16 |
+
class MaskedTimeseries:
|
17 |
+
def __init__(self, series, padding_mask, id_mask, timestamp_seconds, time_interval_seconds):
|
18 |
+
self.series = series
|
19 |
+
self.padding_mask = padding_mask
|
20 |
+
self.id_mask = id_mask
|
21 |
+
self.timestamp_seconds = timestamp_seconds
|
22 |
+
self.time_interval_seconds = time_interval_seconds
|
23 |
+
|
24 |
+
class MockToto:
|
25 |
+
"""Mock Toto model for demonstration"""
|
26 |
+
def __init__(self):
|
27 |
+
self.model = self
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_pretrained(cls, model_name):
|
31 |
+
return cls()
|
32 |
+
|
33 |
+
def to(self, device):
|
34 |
+
return self
|
35 |
+
|
36 |
+
def compile(self):
|
37 |
+
return self
|
38 |
+
|
39 |
+
class MockForecaster:
|
40 |
+
"""Mock forecaster for demonstration"""
|
41 |
+
def __init__(self, model):
|
42 |
+
self.model = model
|
43 |
+
|
44 |
+
def forecast(self, inputs, prediction_length, num_samples, samples_per_batch, use_kv_cache=True):
|
45 |
+
# Generate mock forecast data
|
46 |
+
n_variates, context_length = inputs.series.shape
|
47 |
+
|
48 |
+
# Create realistic-looking synthetic forecasts
|
49 |
+
samples = []
|
50 |
+
for _ in range(num_samples):
|
51 |
+
# Use last values as starting point and add some trend/noise
|
52 |
+
last_values = inputs.series[:, -1:]
|
53 |
+
forecast_sample = []
|
54 |
+
|
55 |
+
for t in range(prediction_length):
|
56 |
+
# Add some trend and noise
|
57 |
+
trend = torch.randn(n_variates, 1) * 0.1
|
58 |
+
noise = torch.randn(n_variates, 1) * 0.5
|
59 |
+
next_val = last_values + trend + noise
|
60 |
+
forecast_sample.append(next_val)
|
61 |
+
last_values = next_val
|
62 |
+
|
63 |
+
sample = torch.cat(forecast_sample, dim=1)
|
64 |
+
samples.append(sample)
|
65 |
+
|
66 |
+
# Stack samples along a new dimension
|
67 |
+
forecast_tensor = torch.stack(samples, dim=-1) # shape: (n_variates, prediction_length, num_samples)
|
68 |
+
|
69 |
+
class MockForecast:
|
70 |
+
def __init__(self, samples):
|
71 |
+
self.samples = MockSamples(samples)
|
72 |
+
|
73 |
+
class MockSamples:
|
74 |
+
def __init__(self, tensor):
|
75 |
+
self.tensor = tensor
|
76 |
+
|
77 |
+
def squeeze(self):
|
78 |
+
return self.tensor
|
79 |
+
|
80 |
+
def cpu(self):
|
81 |
+
return self.tensor
|
82 |
+
|
83 |
+
def quantile(self, q, dim):
|
84 |
+
# Calculate quantiles along the specified dimension
|
85 |
+
sorted_tensor = torch.sort(self.tensor, dim=dim)[0]
|
86 |
+
indices = (q.unsqueeze(0).unsqueeze(0) * (self.tensor.shape[dim] - 1)).long()
|
87 |
+
return torch.gather(sorted_tensor, dim, indices.expand(sorted_tensor.shape[0], sorted_tensor.shape[1], -1).permute(2, 0, 1))
|
88 |
+
|
89 |
+
return MockForecast(forecast_tensor)
|
90 |
+
|
91 |
+
# Global variables
|
92 |
+
toto_model = None
|
93 |
+
forecaster = None
|
94 |
+
|
95 |
+
def initialize_model():
|
96 |
+
"""Initialize the Toto model"""
|
97 |
+
global toto_model, forecaster
|
98 |
+
|
99 |
+
if toto_model is None:
|
100 |
+
# In production, replace with: toto_model = Toto.from_pretrained('Datadog/Toto-Open-Base-1.0')
|
101 |
+
toto_model = MockToto()
|
102 |
+
toto_model.to("cpu") # Use CPU for broader compatibility
|
103 |
+
toto_model.compile()
|
104 |
+
|
105 |
+
forecaster = MockForecaster(toto_model.model)
|
106 |
+
|
107 |
+
return toto_model, forecaster
|
108 |
+
|
109 |
+
def load_sample_data():
|
110 |
+
"""Load sample ETT data for demonstration"""
|
111 |
+
# Generate synthetic ETT-like data
|
112 |
+
dates = pd.date_range(start='2020-01-01', end='2020-12-31 23:45:00', freq='15T')
|
113 |
+
n_points = len(dates)
|
114 |
+
|
115 |
+
# Create synthetic multivariate time series
|
116 |
+
t = np.arange(n_points)
|
117 |
+
|
118 |
+
# Base patterns with different frequencies and amplitudes
|
119 |
+
hufl = 5 + 2 * np.sin(2 * np.pi * t / (24 * 4)) + 0.5 * np.sin(2 * np.pi * t / (24 * 4 * 7)) + np.random.normal(0, 0.3, n_points)
|
120 |
+
hull = 4 + 1.5 * np.cos(2 * np.pi * t / (24 * 4)) + 0.3 * np.sin(2 * np.pi * t / (24 * 4 * 30)) + np.random.normal(0, 0.25, n_points)
|
121 |
+
mufl = 6 + 1.8 * np.sin(2 * np.pi * t / (24 * 4)) + 0.4 * np.cos(2 * np.pi * t / (24 * 4 * 7)) + np.random.normal(0, 0.35, n_points)
|
122 |
+
mull = 5.5 + 1.2 * np.cos(2 * np.pi * t / (24 * 4)) + 0.6 * np.sin(2 * np.pi * t / (24 * 4 * 14)) + np.random.normal(0, 0.28, n_points)
|
123 |
+
lufl = 3.5 + 2.2 * np.sin(2 * np.pi * t / (24 * 4)) + 0.8 * np.cos(2 * np.pi * t / (24 * 4 * 21)) + np.random.normal(0, 0.32, n_points)
|
124 |
+
lull = 4.2 + 1.6 * np.cos(2 * np.pi * t / (24 * 4)) + 0.5 * np.sin(2 * np.pi * t / (24 * 4 * 10)) + np.random.normal(0, 0.27, n_points)
|
125 |
+
ot = 25 + 8 * np.sin(2 * np.pi * t / (24 * 4)) + 3 * np.cos(2 * np.pi * t / (24 * 4 * 365)) + np.random.normal(0, 1.2, n_points)
|
126 |
+
|
127 |
+
df = pd.DataFrame({
|
128 |
+
'date': dates,
|
129 |
+
'HUFL': hufl,
|
130 |
+
'HULL': hull,
|
131 |
+
'MUFL': mufl,
|
132 |
+
'MULL': mull,
|
133 |
+
'LUFL': lufl,
|
134 |
+
'LULL': lull,
|
135 |
+
'OT': ot
|
136 |
+
})
|
137 |
+
|
138 |
+
df['timestamp_seconds'] = (df['date'] - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')
|
139 |
+
|
140 |
+
return df
|
141 |
+
|
142 |
+
def prepare_data(df: pd.DataFrame, context_length: int, prediction_length: int) -> Tuple[MaskedTimeseries, pd.DataFrame, pd.DataFrame]:
|
143 |
+
"""Prepare data for Toto model"""
|
144 |
+
feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
|
145 |
+
n_variates = len(feature_columns)
|
146 |
+
interval = 60 * 15 # 15-min intervals
|
147 |
+
|
148 |
+
# Ensure we have enough data
|
149 |
+
if len(df) < (context_length + prediction_length):
|
150 |
+
raise ValueError(f"Dataset too small. Need at least {context_length + prediction_length} points, got {len(df)}")
|
151 |
+
|
152 |
+
input_df = df.iloc[-(context_length + prediction_length):-prediction_length].copy()
|
153 |
+
target_df = df.iloc[-prediction_length:].copy()
|
154 |
+
|
155 |
+
input_series = torch.from_numpy(input_df[feature_columns].values.T).to(torch.float)
|
156 |
+
timestamp_seconds = torch.from_numpy(input_df.timestamp_seconds.values).expand((n_variates, context_length))
|
157 |
+
time_interval_seconds = torch.full((n_variates,), interval)
|
158 |
+
|
159 |
+
inputs = MaskedTimeseries(
|
160 |
+
series=input_series,
|
161 |
+
padding_mask=torch.full_like(input_series, True, dtype=torch.bool),
|
162 |
+
id_mask=torch.zeros_like(input_series),
|
163 |
+
timestamp_seconds=timestamp_seconds,
|
164 |
+
time_interval_seconds=time_interval_seconds,
|
165 |
+
)
|
166 |
+
|
167 |
+
return inputs, input_df, target_df
|
168 |
+
|
169 |
+
def create_forecast_plot(input_df: pd.DataFrame, target_df: pd.DataFrame, forecast, feature_columns: list) -> plt.Figure:
|
170 |
+
"""Create forecast visualization"""
|
171 |
+
DARK_GREY = "#1c2b34"
|
172 |
+
BLUE = "#3598ec"
|
173 |
+
PURPLE = "#7463e1"
|
174 |
+
LIGHT_PURPLE = "#d7c3ff"
|
175 |
+
PINK = "#ff0099"
|
176 |
+
|
177 |
+
fig = plt.figure(figsize=(16, 12), dpi=100)
|
178 |
+
fig.suptitle("Toto Time Series Forecasts", fontsize=16, fontweight='bold')
|
179 |
+
|
180 |
+
n_variates = len(feature_columns)
|
181 |
+
|
182 |
+
for i, feature in enumerate(feature_columns):
|
183 |
+
plt.subplot(n_variates, 1, i + 1)
|
184 |
+
|
185 |
+
if i != n_variates - 1:
|
186 |
+
plt.gca().set_xticklabels([])
|
187 |
+
|
188 |
+
plt.gca().tick_params(axis="x", color=DARK_GREY, labelcolor=DARK_GREY)
|
189 |
+
plt.gca().tick_params(axis="y", color=DARK_GREY, labelcolor=DARK_GREY)
|
190 |
+
plt.ylabel(feature, rotation=0, ha='right', va='center')
|
191 |
+
|
192 |
+
# Set x-axis limits
|
193 |
+
context_points = min(960, len(input_df))
|
194 |
+
plt.xlim(input_df.date.iloc[-context_points], target_df.date.iloc[-1])
|
195 |
+
|
196 |
+
# Vertical line separating context and forecast
|
197 |
+
plt.axvline(target_df.date.iloc[0], color=PINK, linestyle=":", alpha=0.8, linewidth=2)
|
198 |
+
|
199 |
+
# Plot historical data
|
200 |
+
plt.plot(input_df["date"].iloc[-context_points:], input_df[feature].iloc[-context_points:],
|
201 |
+
color=BLUE, linewidth=1.5, label='Historical' if i == 0 else None)
|
202 |
+
|
203 |
+
# Plot ground truth in forecast period
|
204 |
+
plt.plot(target_df["date"], target_df[feature], color=BLUE, linewidth=1.5, alpha=0.7,
|
205 |
+
label='Actual' if i == 0 else None)
|
206 |
+
|
207 |
+
# Plot median forecast
|
208 |
+
forecast_median = np.median(forecast.samples.squeeze()[i].cpu().numpy(), axis=-1)
|
209 |
+
plt.plot(target_df["date"], forecast_median, color=PURPLE, linestyle="--", linewidth=2,
|
210 |
+
label='Forecast' if i == 0 else None)
|
211 |
+
|
212 |
+
# Plot confidence intervals
|
213 |
+
alpha = 0.05
|
214 |
+
device = torch.device('cpu')
|
215 |
+
qs = forecast.samples.quantile(q=torch.tensor([alpha, 1 - alpha], device=device), dim=-1)
|
216 |
+
|
217 |
+
plt.fill_between(
|
218 |
+
target_df["date"],
|
219 |
+
qs[0].squeeze()[i].cpu().numpy(),
|
220 |
+
qs[1].squeeze()[i].cpu().numpy(),
|
221 |
+
color=LIGHT_PURPLE,
|
222 |
+
alpha=0.6,
|
223 |
+
label=f'{int((1-2*alpha)*100)}% CI' if i == 0 else None
|
224 |
+
)
|
225 |
+
|
226 |
+
if i == 0:
|
227 |
+
plt.legend(loc='upper left', frameon=True, fancybox=True, shadow=True)
|
228 |
+
|
229 |
+
plt.tight_layout()
|
230 |
+
return fig
|
231 |
+
|
232 |
+
def run_forecast(context_length: int, prediction_length: int, num_samples: int,
|
233 |
+
samples_per_batch: int, use_kv_cache: bool, progress=gr.Progress()) -> Tuple[plt.Figure, str]:
|
234 |
+
"""Run forecasting with given parameters"""
|
235 |
+
try:
|
236 |
+
progress(0.1, desc="Initializing model...")
|
237 |
+
model, forecaster = initialize_model()
|
238 |
+
|
239 |
+
progress(0.2, desc="Loading data...")
|
240 |
+
df = load_sample_data()
|
241 |
+
|
242 |
+
progress(0.3, desc="Preparing data...")
|
243 |
+
inputs, input_df, target_df = prepare_data(df, context_length, prediction_length)
|
244 |
+
|
245 |
+
progress(0.5, desc="Running forecast...")
|
246 |
+
forecast = forecaster.forecast(
|
247 |
+
inputs,
|
248 |
+
prediction_length=prediction_length,
|
249 |
+
num_samples=num_samples,
|
250 |
+
samples_per_batch=min(samples_per_batch, num_samples),
|
251 |
+
use_kv_cache=use_kv_cache,
|
252 |
+
)
|
253 |
+
|
254 |
+
progress(0.8, desc="Creating visualization...")
|
255 |
+
feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
|
256 |
+
fig = create_forecast_plot(input_df, target_df, forecast, feature_columns)
|
257 |
+
|
258 |
+
progress(1.0, desc="Complete!")
|
259 |
+
|
260 |
+
# Generate summary statistics
|
261 |
+
forecast_data = forecast.samples.squeeze().cpu().numpy()
|
262 |
+
summary = f"""
|
263 |
+
## Forecast Summary
|
264 |
+
|
265 |
+
**Parameters Used:**
|
266 |
+
- Context Length: {context_length} time steps
|
267 |
+
- Prediction Length: {prediction_length} time steps
|
268 |
+
- Number of Samples: {num_samples}
|
269 |
+
- Samples per Batch: {samples_per_batch}
|
270 |
+
- KV Cache: {'Enabled' if use_kv_cache else 'Disabled'}
|
271 |
+
|
272 |
+
**Results:**
|
273 |
+
- Variables Forecasted: {len(feature_columns)}
|
274 |
+
- Forecast Shape: {forecast_data.shape}
|
275 |
+
- Mean Absolute Forecast Range: {np.mean(np.max(forecast_data, axis=1) - np.min(forecast_data, axis=1)):.3f}
|
276 |
+
|
277 |
+
The plot shows historical data in blue, actual values in the forecast period in light blue,
|
278 |
+
median forecasts as purple dashed lines, and 95% confidence intervals in light purple.
|
279 |
+
"""
|
280 |
+
|
281 |
+
return fig, summary
|
282 |
+
|
283 |
+
except Exception as e:
|
284 |
+
error_msg = f"Error during forecasting: {str(e)}"
|
285 |
+
fig = plt.figure(figsize=(10, 6))
|
286 |
+
plt.text(0.5, 0.5, error_msg, ha='center', va='center', fontsize=12, color='red')
|
287 |
+
plt.axis('off')
|
288 |
+
return fig, error_msg
|
289 |
+
|
290 |
+
# Create Gradio interface
|
291 |
+
def create_interface():
|
292 |
+
with gr.Blocks(title="Toto Time Series Forecasting", theme=gr.themes.Soft()) as demo:
|
293 |
+
gr.Markdown("""
|
294 |
+
# 🔮 Toto Time Series Forecasting
|
295 |
+
|
296 |
+
This app demonstrates zero-shot time series forecasting using the Toto foundation model.
|
297 |
+
Adjust the parameters below to customize your forecast and see how different settings affect the predictions.
|
298 |
+
|
299 |
+
**Note:** This demo uses synthetic ETT-like data for illustration purposes.
|
300 |
+
""")
|
301 |
+
|
302 |
+
with gr.Row():
|
303 |
+
with gr.Column(scale=1):
|
304 |
+
gr.Markdown("### Forecasting Parameters")
|
305 |
+
|
306 |
+
context_length = gr.Slider(
|
307 |
+
minimum=96, maximum=2048, value=512, step=32,
|
308 |
+
label="Context Length",
|
309 |
+
info="Number of historical time steps to use as input"
|
310 |
+
)
|
311 |
+
|
312 |
+
prediction_length = gr.Slider(
|
313 |
+
minimum=24, maximum=720, value=96, step=24,
|
314 |
+
label="Prediction Length",
|
315 |
+
info="Number of time steps to forecast into the future"
|
316 |
+
)
|
317 |
+
|
318 |
+
num_samples = gr.Slider(
|
319 |
+
minimum=8, maximum=512, value=64, step=8,
|
320 |
+
label="Number of Samples",
|
321 |
+
info="More samples = more stable predictions but slower inference"
|
322 |
+
)
|
323 |
+
|
324 |
+
samples_per_batch = gr.Slider(
|
325 |
+
minimum=8, maximum=256, value=32, step=8,
|
326 |
+
label="Samples per Batch",
|
327 |
+
info="Batch size for sample generation (affects memory usage)"
|
328 |
+
)
|
329 |
+
|
330 |
+
use_kv_cache = gr.Checkbox(
|
331 |
+
value=True,
|
332 |
+
label="Use KV Cache",
|
333 |
+
info="Enable key-value caching for faster inference"
|
334 |
+
)
|
335 |
+
|
336 |
+
forecast_btn = gr.Button("🚀 Run Forecast", variant="primary", size="lg")
|
337 |
+
|
338 |
+
with gr.Column(scale=2):
|
339 |
+
gr.Markdown("### Forecast Results")
|
340 |
+
forecast_plot = gr.Plot()
|
341 |
+
forecast_summary = gr.Markdown()
|
342 |
+
|
343 |
+
# Event handlers
|
344 |
+
forecast_btn.click(
|
345 |
+
fn=run_forecast,
|
346 |
+
inputs=[context_length, prediction_length, num_samples, samples_per_batch, use_kv_cache],
|
347 |
+
outputs=[forecast_plot, forecast_summary]
|
348 |
+
)
|
349 |
+
|
350 |
+
# Load initial forecast
|
351 |
+
demo.load(
|
352 |
+
fn=lambda: run_forecast(512, 96, 64, 32, True),
|
353 |
+
outputs=[forecast_plot, forecast_summary]
|
354 |
+
)
|
355 |
+
|
356 |
+
return demo
|
357 |
+
|
358 |
+
# For deployment
|
359 |
+
if __name__ == "__main__":
|
360 |
+
# Create and launch the interface
|
361 |
+
demo = create_interface()
|
362 |
+
|
363 |
+
# For local development
|
364 |
+
if os.getenv("GRADIO_DEV"):
|
365 |
+
demo.launch(debug=True, share=False)
|
366 |
+
else:
|
367 |
+
# For production deployment
|
368 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
369 |
+
|
370 |
+
# For Modal.com deployment, add this:
|
371 |
+
"""
|
372 |
+
# modal_app.py
|
373 |
+
import modal
|
374 |
+
|
375 |
+
image = modal.Image.debian_slim().pip_install([
|
376 |
+
"gradio",
|
377 |
+
"torch",
|
378 |
+
"numpy",
|
379 |
+
"pandas",
|
380 |
+
"matplotlib",
|
381 |
+
"transformers",
|
382 |
+
# Add other required packages
|
383 |
+
])
|
384 |
+
|
385 |
+
app = modal.App("toto-forecasting")
|
386 |
+
|
387 |
+
@app.function(image=image, gpu="T4")
|
388 |
+
def run_gradio():
|
389 |
+
from main import create_interface
|
390 |
+
demo = create_interface()
|
391 |
+
demo.launch(server_name="0.0.0.0", server_port=8000, share=False)
|
392 |
+
|
393 |
+
if __name__ == "__main__":
|
394 |
+
with app.run():
|
395 |
+
run_gradio()
|
396 |
+
"""
|
397 |
+
|
398 |
+
# For Hugging Face Spaces deployment:
|
399 |
+
"""
|
400 |
+
Create these files:
|
401 |
+
1. app.py (this file)
|
402 |
+
2. requirements.txt:
|
403 |
+
gradio
|
404 |
+
torch
|
405 |
+
numpy
|
406 |
+
pandas
|
407 |
+
matplotlib
|
408 |
+
transformers
|
409 |
+
3. README.md with your Space description
|
410 |
+
"""
|
inference_tutorial.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|