LeonceNsh commited on
Commit
e888737
·
verified ·
1 Parent(s): aeb4612

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. .gradio/certificate.pem +31 -0
  2. ETTh1.csv +0 -0
  3. ETTh2.csv +0 -0
  4. ETTm1.csv +0 -0
  5. ETTm2.csv +0 -0
  6. README.md +2 -8
  7. gradio_modal.py +410 -0
  8. inference_tutorial.ipynb +0 -0
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
ETTh1.csv ADDED
The diff for this file is too large to render. See raw diff
 
ETTh2.csv ADDED
The diff for this file is too large to render. See raw diff
 
ETTm1.csv ADDED
The diff for this file is too large to render. See raw diff
 
ETTm2.csv ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Timeseries
3
- emoji: 👁
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.31.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: timeseries
3
+ app_file: gradio_modal.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.31.0
 
 
6
  ---
 
 
gradio_modal.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from datetime import datetime, timedelta
8
+ import io
9
+ import base64
10
+ from typing import Optional, Tuple, Dict, Any
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+ # Mock implementations for the original imports
15
+ # In actual deployment, you'd import these from the original modules
16
+ class MaskedTimeseries:
17
+ def __init__(self, series, padding_mask, id_mask, timestamp_seconds, time_interval_seconds):
18
+ self.series = series
19
+ self.padding_mask = padding_mask
20
+ self.id_mask = id_mask
21
+ self.timestamp_seconds = timestamp_seconds
22
+ self.time_interval_seconds = time_interval_seconds
23
+
24
+ class MockToto:
25
+ """Mock Toto model for demonstration"""
26
+ def __init__(self):
27
+ self.model = self
28
+
29
+ @classmethod
30
+ def from_pretrained(cls, model_name):
31
+ return cls()
32
+
33
+ def to(self, device):
34
+ return self
35
+
36
+ def compile(self):
37
+ return self
38
+
39
+ class MockForecaster:
40
+ """Mock forecaster for demonstration"""
41
+ def __init__(self, model):
42
+ self.model = model
43
+
44
+ def forecast(self, inputs, prediction_length, num_samples, samples_per_batch, use_kv_cache=True):
45
+ # Generate mock forecast data
46
+ n_variates, context_length = inputs.series.shape
47
+
48
+ # Create realistic-looking synthetic forecasts
49
+ samples = []
50
+ for _ in range(num_samples):
51
+ # Use last values as starting point and add some trend/noise
52
+ last_values = inputs.series[:, -1:]
53
+ forecast_sample = []
54
+
55
+ for t in range(prediction_length):
56
+ # Add some trend and noise
57
+ trend = torch.randn(n_variates, 1) * 0.1
58
+ noise = torch.randn(n_variates, 1) * 0.5
59
+ next_val = last_values + trend + noise
60
+ forecast_sample.append(next_val)
61
+ last_values = next_val
62
+
63
+ sample = torch.cat(forecast_sample, dim=1)
64
+ samples.append(sample)
65
+
66
+ # Stack samples along a new dimension
67
+ forecast_tensor = torch.stack(samples, dim=-1) # shape: (n_variates, prediction_length, num_samples)
68
+
69
+ class MockForecast:
70
+ def __init__(self, samples):
71
+ self.samples = MockSamples(samples)
72
+
73
+ class MockSamples:
74
+ def __init__(self, tensor):
75
+ self.tensor = tensor
76
+
77
+ def squeeze(self):
78
+ return self.tensor
79
+
80
+ def cpu(self):
81
+ return self.tensor
82
+
83
+ def quantile(self, q, dim):
84
+ # Calculate quantiles along the specified dimension
85
+ sorted_tensor = torch.sort(self.tensor, dim=dim)[0]
86
+ indices = (q.unsqueeze(0).unsqueeze(0) * (self.tensor.shape[dim] - 1)).long()
87
+ return torch.gather(sorted_tensor, dim, indices.expand(sorted_tensor.shape[0], sorted_tensor.shape[1], -1).permute(2, 0, 1))
88
+
89
+ return MockForecast(forecast_tensor)
90
+
91
+ # Global variables
92
+ toto_model = None
93
+ forecaster = None
94
+
95
+ def initialize_model():
96
+ """Initialize the Toto model"""
97
+ global toto_model, forecaster
98
+
99
+ if toto_model is None:
100
+ # In production, replace with: toto_model = Toto.from_pretrained('Datadog/Toto-Open-Base-1.0')
101
+ toto_model = MockToto()
102
+ toto_model.to("cpu") # Use CPU for broader compatibility
103
+ toto_model.compile()
104
+
105
+ forecaster = MockForecaster(toto_model.model)
106
+
107
+ return toto_model, forecaster
108
+
109
+ def load_sample_data():
110
+ """Load sample ETT data for demonstration"""
111
+ # Generate synthetic ETT-like data
112
+ dates = pd.date_range(start='2020-01-01', end='2020-12-31 23:45:00', freq='15T')
113
+ n_points = len(dates)
114
+
115
+ # Create synthetic multivariate time series
116
+ t = np.arange(n_points)
117
+
118
+ # Base patterns with different frequencies and amplitudes
119
+ hufl = 5 + 2 * np.sin(2 * np.pi * t / (24 * 4)) + 0.5 * np.sin(2 * np.pi * t / (24 * 4 * 7)) + np.random.normal(0, 0.3, n_points)
120
+ hull = 4 + 1.5 * np.cos(2 * np.pi * t / (24 * 4)) + 0.3 * np.sin(2 * np.pi * t / (24 * 4 * 30)) + np.random.normal(0, 0.25, n_points)
121
+ mufl = 6 + 1.8 * np.sin(2 * np.pi * t / (24 * 4)) + 0.4 * np.cos(2 * np.pi * t / (24 * 4 * 7)) + np.random.normal(0, 0.35, n_points)
122
+ mull = 5.5 + 1.2 * np.cos(2 * np.pi * t / (24 * 4)) + 0.6 * np.sin(2 * np.pi * t / (24 * 4 * 14)) + np.random.normal(0, 0.28, n_points)
123
+ lufl = 3.5 + 2.2 * np.sin(2 * np.pi * t / (24 * 4)) + 0.8 * np.cos(2 * np.pi * t / (24 * 4 * 21)) + np.random.normal(0, 0.32, n_points)
124
+ lull = 4.2 + 1.6 * np.cos(2 * np.pi * t / (24 * 4)) + 0.5 * np.sin(2 * np.pi * t / (24 * 4 * 10)) + np.random.normal(0, 0.27, n_points)
125
+ ot = 25 + 8 * np.sin(2 * np.pi * t / (24 * 4)) + 3 * np.cos(2 * np.pi * t / (24 * 4 * 365)) + np.random.normal(0, 1.2, n_points)
126
+
127
+ df = pd.DataFrame({
128
+ 'date': dates,
129
+ 'HUFL': hufl,
130
+ 'HULL': hull,
131
+ 'MUFL': mufl,
132
+ 'MULL': mull,
133
+ 'LUFL': lufl,
134
+ 'LULL': lull,
135
+ 'OT': ot
136
+ })
137
+
138
+ df['timestamp_seconds'] = (df['date'] - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')
139
+
140
+ return df
141
+
142
+ def prepare_data(df: pd.DataFrame, context_length: int, prediction_length: int) -> Tuple[MaskedTimeseries, pd.DataFrame, pd.DataFrame]:
143
+ """Prepare data for Toto model"""
144
+ feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
145
+ n_variates = len(feature_columns)
146
+ interval = 60 * 15 # 15-min intervals
147
+
148
+ # Ensure we have enough data
149
+ if len(df) < (context_length + prediction_length):
150
+ raise ValueError(f"Dataset too small. Need at least {context_length + prediction_length} points, got {len(df)}")
151
+
152
+ input_df = df.iloc[-(context_length + prediction_length):-prediction_length].copy()
153
+ target_df = df.iloc[-prediction_length:].copy()
154
+
155
+ input_series = torch.from_numpy(input_df[feature_columns].values.T).to(torch.float)
156
+ timestamp_seconds = torch.from_numpy(input_df.timestamp_seconds.values).expand((n_variates, context_length))
157
+ time_interval_seconds = torch.full((n_variates,), interval)
158
+
159
+ inputs = MaskedTimeseries(
160
+ series=input_series,
161
+ padding_mask=torch.full_like(input_series, True, dtype=torch.bool),
162
+ id_mask=torch.zeros_like(input_series),
163
+ timestamp_seconds=timestamp_seconds,
164
+ time_interval_seconds=time_interval_seconds,
165
+ )
166
+
167
+ return inputs, input_df, target_df
168
+
169
+ def create_forecast_plot(input_df: pd.DataFrame, target_df: pd.DataFrame, forecast, feature_columns: list) -> plt.Figure:
170
+ """Create forecast visualization"""
171
+ DARK_GREY = "#1c2b34"
172
+ BLUE = "#3598ec"
173
+ PURPLE = "#7463e1"
174
+ LIGHT_PURPLE = "#d7c3ff"
175
+ PINK = "#ff0099"
176
+
177
+ fig = plt.figure(figsize=(16, 12), dpi=100)
178
+ fig.suptitle("Toto Time Series Forecasts", fontsize=16, fontweight='bold')
179
+
180
+ n_variates = len(feature_columns)
181
+
182
+ for i, feature in enumerate(feature_columns):
183
+ plt.subplot(n_variates, 1, i + 1)
184
+
185
+ if i != n_variates - 1:
186
+ plt.gca().set_xticklabels([])
187
+
188
+ plt.gca().tick_params(axis="x", color=DARK_GREY, labelcolor=DARK_GREY)
189
+ plt.gca().tick_params(axis="y", color=DARK_GREY, labelcolor=DARK_GREY)
190
+ plt.ylabel(feature, rotation=0, ha='right', va='center')
191
+
192
+ # Set x-axis limits
193
+ context_points = min(960, len(input_df))
194
+ plt.xlim(input_df.date.iloc[-context_points], target_df.date.iloc[-1])
195
+
196
+ # Vertical line separating context and forecast
197
+ plt.axvline(target_df.date.iloc[0], color=PINK, linestyle=":", alpha=0.8, linewidth=2)
198
+
199
+ # Plot historical data
200
+ plt.plot(input_df["date"].iloc[-context_points:], input_df[feature].iloc[-context_points:],
201
+ color=BLUE, linewidth=1.5, label='Historical' if i == 0 else None)
202
+
203
+ # Plot ground truth in forecast period
204
+ plt.plot(target_df["date"], target_df[feature], color=BLUE, linewidth=1.5, alpha=0.7,
205
+ label='Actual' if i == 0 else None)
206
+
207
+ # Plot median forecast
208
+ forecast_median = np.median(forecast.samples.squeeze()[i].cpu().numpy(), axis=-1)
209
+ plt.plot(target_df["date"], forecast_median, color=PURPLE, linestyle="--", linewidth=2,
210
+ label='Forecast' if i == 0 else None)
211
+
212
+ # Plot confidence intervals
213
+ alpha = 0.05
214
+ device = torch.device('cpu')
215
+ qs = forecast.samples.quantile(q=torch.tensor([alpha, 1 - alpha], device=device), dim=-1)
216
+
217
+ plt.fill_between(
218
+ target_df["date"],
219
+ qs[0].squeeze()[i].cpu().numpy(),
220
+ qs[1].squeeze()[i].cpu().numpy(),
221
+ color=LIGHT_PURPLE,
222
+ alpha=0.6,
223
+ label=f'{int((1-2*alpha)*100)}% CI' if i == 0 else None
224
+ )
225
+
226
+ if i == 0:
227
+ plt.legend(loc='upper left', frameon=True, fancybox=True, shadow=True)
228
+
229
+ plt.tight_layout()
230
+ return fig
231
+
232
+ def run_forecast(context_length: int, prediction_length: int, num_samples: int,
233
+ samples_per_batch: int, use_kv_cache: bool, progress=gr.Progress()) -> Tuple[plt.Figure, str]:
234
+ """Run forecasting with given parameters"""
235
+ try:
236
+ progress(0.1, desc="Initializing model...")
237
+ model, forecaster = initialize_model()
238
+
239
+ progress(0.2, desc="Loading data...")
240
+ df = load_sample_data()
241
+
242
+ progress(0.3, desc="Preparing data...")
243
+ inputs, input_df, target_df = prepare_data(df, context_length, prediction_length)
244
+
245
+ progress(0.5, desc="Running forecast...")
246
+ forecast = forecaster.forecast(
247
+ inputs,
248
+ prediction_length=prediction_length,
249
+ num_samples=num_samples,
250
+ samples_per_batch=min(samples_per_batch, num_samples),
251
+ use_kv_cache=use_kv_cache,
252
+ )
253
+
254
+ progress(0.8, desc="Creating visualization...")
255
+ feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
256
+ fig = create_forecast_plot(input_df, target_df, forecast, feature_columns)
257
+
258
+ progress(1.0, desc="Complete!")
259
+
260
+ # Generate summary statistics
261
+ forecast_data = forecast.samples.squeeze().cpu().numpy()
262
+ summary = f"""
263
+ ## Forecast Summary
264
+
265
+ **Parameters Used:**
266
+ - Context Length: {context_length} time steps
267
+ - Prediction Length: {prediction_length} time steps
268
+ - Number of Samples: {num_samples}
269
+ - Samples per Batch: {samples_per_batch}
270
+ - KV Cache: {'Enabled' if use_kv_cache else 'Disabled'}
271
+
272
+ **Results:**
273
+ - Variables Forecasted: {len(feature_columns)}
274
+ - Forecast Shape: {forecast_data.shape}
275
+ - Mean Absolute Forecast Range: {np.mean(np.max(forecast_data, axis=1) - np.min(forecast_data, axis=1)):.3f}
276
+
277
+ The plot shows historical data in blue, actual values in the forecast period in light blue,
278
+ median forecasts as purple dashed lines, and 95% confidence intervals in light purple.
279
+ """
280
+
281
+ return fig, summary
282
+
283
+ except Exception as e:
284
+ error_msg = f"Error during forecasting: {str(e)}"
285
+ fig = plt.figure(figsize=(10, 6))
286
+ plt.text(0.5, 0.5, error_msg, ha='center', va='center', fontsize=12, color='red')
287
+ plt.axis('off')
288
+ return fig, error_msg
289
+
290
+ # Create Gradio interface
291
+ def create_interface():
292
+ with gr.Blocks(title="Toto Time Series Forecasting", theme=gr.themes.Soft()) as demo:
293
+ gr.Markdown("""
294
+ # 🔮 Toto Time Series Forecasting
295
+
296
+ This app demonstrates zero-shot time series forecasting using the Toto foundation model.
297
+ Adjust the parameters below to customize your forecast and see how different settings affect the predictions.
298
+
299
+ **Note:** This demo uses synthetic ETT-like data for illustration purposes.
300
+ """)
301
+
302
+ with gr.Row():
303
+ with gr.Column(scale=1):
304
+ gr.Markdown("### Forecasting Parameters")
305
+
306
+ context_length = gr.Slider(
307
+ minimum=96, maximum=2048, value=512, step=32,
308
+ label="Context Length",
309
+ info="Number of historical time steps to use as input"
310
+ )
311
+
312
+ prediction_length = gr.Slider(
313
+ minimum=24, maximum=720, value=96, step=24,
314
+ label="Prediction Length",
315
+ info="Number of time steps to forecast into the future"
316
+ )
317
+
318
+ num_samples = gr.Slider(
319
+ minimum=8, maximum=512, value=64, step=8,
320
+ label="Number of Samples",
321
+ info="More samples = more stable predictions but slower inference"
322
+ )
323
+
324
+ samples_per_batch = gr.Slider(
325
+ minimum=8, maximum=256, value=32, step=8,
326
+ label="Samples per Batch",
327
+ info="Batch size for sample generation (affects memory usage)"
328
+ )
329
+
330
+ use_kv_cache = gr.Checkbox(
331
+ value=True,
332
+ label="Use KV Cache",
333
+ info="Enable key-value caching for faster inference"
334
+ )
335
+
336
+ forecast_btn = gr.Button("🚀 Run Forecast", variant="primary", size="lg")
337
+
338
+ with gr.Column(scale=2):
339
+ gr.Markdown("### Forecast Results")
340
+ forecast_plot = gr.Plot()
341
+ forecast_summary = gr.Markdown()
342
+
343
+ # Event handlers
344
+ forecast_btn.click(
345
+ fn=run_forecast,
346
+ inputs=[context_length, prediction_length, num_samples, samples_per_batch, use_kv_cache],
347
+ outputs=[forecast_plot, forecast_summary]
348
+ )
349
+
350
+ # Load initial forecast
351
+ demo.load(
352
+ fn=lambda: run_forecast(512, 96, 64, 32, True),
353
+ outputs=[forecast_plot, forecast_summary]
354
+ )
355
+
356
+ return demo
357
+
358
+ # For deployment
359
+ if __name__ == "__main__":
360
+ # Create and launch the interface
361
+ demo = create_interface()
362
+
363
+ # For local development
364
+ if os.getenv("GRADIO_DEV"):
365
+ demo.launch(debug=True, share=False)
366
+ else:
367
+ # For production deployment
368
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
369
+
370
+ # For Modal.com deployment, add this:
371
+ """
372
+ # modal_app.py
373
+ import modal
374
+
375
+ image = modal.Image.debian_slim().pip_install([
376
+ "gradio",
377
+ "torch",
378
+ "numpy",
379
+ "pandas",
380
+ "matplotlib",
381
+ "transformers",
382
+ # Add other required packages
383
+ ])
384
+
385
+ app = modal.App("toto-forecasting")
386
+
387
+ @app.function(image=image, gpu="T4")
388
+ def run_gradio():
389
+ from main import create_interface
390
+ demo = create_interface()
391
+ demo.launch(server_name="0.0.0.0", server_port=8000, share=False)
392
+
393
+ if __name__ == "__main__":
394
+ with app.run():
395
+ run_gradio()
396
+ """
397
+
398
+ # For Hugging Face Spaces deployment:
399
+ """
400
+ Create these files:
401
+ 1. app.py (this file)
402
+ 2. requirements.txt:
403
+ gradio
404
+ torch
405
+ numpy
406
+ pandas
407
+ matplotlib
408
+ transformers
409
+ 3. README.md with your Space description
410
+ """
inference_tutorial.ipynb ADDED
The diff for this file is too large to render. See raw diff