Spaces:
Sleeping
Sleeping
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +168 -0
- Infer Demo.ipynb +169 -0
- LICENSE +21 -0
- README.md +128 -13
- Train Demo.ipynb +0 -0
- app.py +1121 -0
- config/backup/revenue-1.yaml +76 -0
- config/backup/revenue-2.yaml +77 -0
- config/backup/revenue-3xl.yaml +77 -0
- config/backup/revenue-baseline.yaml +82 -0
- config/backup/revenue-test.yaml +82 -0
- config/backup/revenue.yaml +76 -0
- config/config.yaml +40 -0
- config/control/revenue-baseline-180.yaml +82 -0
- config/control/revenue-baseline-365-ma.yaml +83 -0
- config/control/revenue-baseline-365.yaml +82 -0
- config/control/revenue-baseline-sine.yaml +82 -0
- config/control/revenue-extend.yaml +83 -0
- config/csdi/energy.yaml +75 -0
- config/csdi/fmri.yaml +74 -0
- config/csdi/revenue-baseline-365.yaml +82 -0
- config/csdi/sines.yaml +73 -0
- config/energy.yaml +74 -0
- config/etth.yaml +74 -0
- config/fmri.yaml +74 -0
- config/modified/192/energy.yaml +74 -0
- config/modified/192/fmri.yaml +74 -0
- config/modified/192/revenue.yaml +82 -0
- config/modified/192/sines.yaml +73 -0
- config/modified/384/energy.yaml +74 -0
- config/modified/384/fmri.yaml +74 -0
- config/modified/384/revenue.yaml +82 -0
- config/modified/384/sines.yaml +73 -0
- config/modified/96/energy.yaml +74 -0
- config/modified/96/fmri.yaml +74 -0
- config/modified/96/revenue.yaml +82 -0
- config/modified/96/sines.yaml +73 -0
- config/modified/energy.yaml +74 -0
- config/modified/fmri.yaml +74 -0
- config/modified/revenue-baseline-365.yaml +82 -0
- config/modified/revenue.yaml +82 -0
- config/modified/sines.yaml +73 -0
- config/mujoco.yaml +72 -0
- config/mujoco_sssd.yaml +40 -0
- config/sines.yaml +72 -0
- config/solar.yaml +40 -0
- config/solar_update.yaml +40 -0
- config/stocks.yaml +74 -0
- efficiency.py +319 -0
- engine/logger.py +71 -0
.gitignore
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
toy_exp/
|
165 |
+
Checkpoints*/
|
166 |
+
Data/datasets/
|
167 |
+
wandb/
|
168 |
+
data/
|
Infer Demo.ipynb
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import torch\n",
|
11 |
+
"\n",
|
12 |
+
"torch.cuda.is_available()\n",
|
13 |
+
"os.environ[\"WANDB_ENABLED\"] = \"false\"\n",
|
14 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": null,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"device = torch.device(f\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
24 |
+
"from models.Tiffusion import tiffusion\n",
|
25 |
+
"# from models.CSDI import tiffusion\n",
|
26 |
+
"\n",
|
27 |
+
"model = tiffusion.Tiffusion(\n",
|
28 |
+
" seq_length=365,\n",
|
29 |
+
" feature_size=3,\n",
|
30 |
+
" n_layer_enc=6,\n",
|
31 |
+
" n_layer_dec=4,\n",
|
32 |
+
" d_model=128,\n",
|
33 |
+
" timesteps=500,\n",
|
34 |
+
" sampling_timesteps=200,\n",
|
35 |
+
" loss_type='l1',\n",
|
36 |
+
" beta_schedule='cosine',\n",
|
37 |
+
" n_heads=8,\n",
|
38 |
+
" mlp_hidden_times=4,\n",
|
39 |
+
" attn_pd=0.0,\n",
|
40 |
+
" resid_pd=0.0,\n",
|
41 |
+
" kernel_size=1,\n",
|
42 |
+
" padding_size=0,\n",
|
43 |
+
" control_signal=[]\n",
|
44 |
+
").to(device)\n",
|
45 |
+
"\n",
|
46 |
+
"model.load_state_dict(torch.load(\"./weight/checkpoint-10.pt\", map_location='cpu', weights_only=True)[\"model\"])\n",
|
47 |
+
"# model.load_state_dict(torch.load(\"../../../data/CSDI/ckpt_baseline_365/checkpoint-10.pt\", map_location='cpu', weights_only=True)[\"model\"])\n",
|
48 |
+
"\n",
|
49 |
+
"\n",
|
50 |
+
"coef = 1.0e-2\n",
|
51 |
+
"stepsize = 5.0e-2\n",
|
52 |
+
"sampling_steps = 100 # 这个可以调整 100-500都行 快慢和准度 tradeoff\n",
|
53 |
+
"seq_length = 365\n",
|
54 |
+
"feature_dim = 3\n",
|
55 |
+
"print(f\"seq_length: {seq_length}, feature_dim: {feature_dim}\")"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "markdown",
|
60 |
+
"metadata": {},
|
61 |
+
"source": [
|
62 |
+
"## Sampling"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"\n",
|
72 |
+
"anchor_value = [\n",
|
73 |
+
" # (time, feature_id, y-value, confidence)\n",
|
74 |
+
" (0, 0, 0.04, 1.0),\n",
|
75 |
+
" (2, 0, 0.58, 1.0),\n",
|
76 |
+
" # (6, 0, 0.27, 0.5),\n",
|
77 |
+
" # (10, 0, 0.04, 1.0),\n",
|
78 |
+
" # (12, 0, 0.58, 0.001),\n",
|
79 |
+
" # (16, 0, 0.27, 0.5),\n",
|
80 |
+
" # (20, 0, 0.04, 1.0),\n",
|
81 |
+
" # (22, 0, 0.58, 0.001),\n",
|
82 |
+
" # (26, 0, 0.27, 0.5),\n",
|
83 |
+
" # (30, 0, 0.04, 1.0),\n",
|
84 |
+
" # (32, 0, 0.58, 0.001),\n",
|
85 |
+
" # (36, 0, 0.27, 0.5),\n",
|
86 |
+
" # (40, 0, 0.04, 1.0),\n",
|
87 |
+
" # (42, 0, 0.58, 0.001),\n",
|
88 |
+
" # (46, 0, 0.27, 0.5),\n",
|
89 |
+
" # (50, 0, 0.04, 1.0),\n",
|
90 |
+
" # (52, 0, 0.58, 0.001),\n",
|
91 |
+
" # (56, 0, 0.27, 0.5),\n",
|
92 |
+
" # (60, 0, 0.04, 1.0),\n",
|
93 |
+
" # (62, 0, 0.58, 0.001),\n",
|
94 |
+
" # (66, 0, 0.27, 0.5),\n",
|
95 |
+
"]\n",
|
96 |
+
"\n",
|
97 |
+
"observed_points = torch.zeros((seq_length, feature_dim)).to(device)\n",
|
98 |
+
"observed_mask = torch.zeros((seq_length, feature_dim)).to(device)\n",
|
99 |
+
"\n",
|
100 |
+
"for time, feature_id, y_value, confidence in anchor_value:\n",
|
101 |
+
" observed_points[time, feature_id] = y_value\n",
|
102 |
+
" observed_mask[time, feature_id] = confidence\n",
|
103 |
+
"\n",
|
104 |
+
"auc = -10\n",
|
105 |
+
"auc_weight = 10.0\n",
|
106 |
+
"with torch.no_grad():\n",
|
107 |
+
" results = model.predict_weighted_points(\n",
|
108 |
+
" observed_points, # (seq_length, feature_dim)\n",
|
109 |
+
" observed_mask, # (seq_length, feature_dim)\n",
|
110 |
+
" coef, # fixed\n",
|
111 |
+
" stepsize, # fixed\n",
|
112 |
+
" sampling_steps, # fixed\n",
|
113 |
+
" # model_control_signal=model_control_signal,\n",
|
114 |
+
" gradient_control_signal={\n",
|
115 |
+
" \"auc\": auc, \"auc_weight\": auc_weight,\n",
|
116 |
+
" },\n",
|
117 |
+
" )\n"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"results.shape"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"execution_count": null,
|
132 |
+
"metadata": {},
|
133 |
+
"outputs": [],
|
134 |
+
"source": [
|
135 |
+
"import matplotlib.pyplot as plt\n",
|
136 |
+
"\n",
|
137 |
+
"plt.plot(results[:,0], label=\"Predicted Feature 0\")\n",
|
138 |
+
"for time, feature_id, y_value, confidence in anchor_value:\n",
|
139 |
+
" plt.scatter(time, y_value, c='r')\n",
|
140 |
+
"plt.show()\n",
|
141 |
+
"plt.plot(results[:,1], label=\"Predicted Feature 1\")\n",
|
142 |
+
"plt.show()\n",
|
143 |
+
"plt.plot(results[:,2], label=\"Predicted Feature 2\")\n",
|
144 |
+
"plt.show()"
|
145 |
+
]
|
146 |
+
}
|
147 |
+
],
|
148 |
+
"metadata": {
|
149 |
+
"kernelspec": {
|
150 |
+
"display_name": "rag",
|
151 |
+
"language": "python",
|
152 |
+
"name": "python3"
|
153 |
+
},
|
154 |
+
"language_info": {
|
155 |
+
"codemirror_mode": {
|
156 |
+
"name": "ipython",
|
157 |
+
"version": 3
|
158 |
+
},
|
159 |
+
"file_extension": ".py",
|
160 |
+
"mimetype": "text/x-python",
|
161 |
+
"name": "python",
|
162 |
+
"nbconvert_exporter": "python",
|
163 |
+
"pygments_lexer": "ipython3",
|
164 |
+
"version": "3.10.14"
|
165 |
+
}
|
166 |
+
},
|
167 |
+
"nbformat": 4,
|
168 |
+
"nbformat_minor": 2
|
169 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 XXX
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,128 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Demo for TSEditor
|
2 |
+
<!--
|
3 |
+
[](https://github.com/Y-debug-sys/Diffusion-TS/stargazers)
|
4 |
+
[](https://github.com/Y-debug-sys/Diffusion-TS/network)
|
5 |
+
[](https://github.com/Y-debug-sys/Diffusion-TS/blob/main/LICENSE)
|
6 |
+
<img src="https://img.shields.io/badge/python-3.8-blue">
|
7 |
+
<img src="https://img.shields.io/badge/pytorch-2.0-orange">
|
8 |
+
|
9 |
+
> **Abstract:** Denoising diffusion probabilistic models (DDPMs) are becoming the leading paradigm for generative models. It has recently shown breakthroughs in audio synthesis, time series imputation and forecasting. In this paper, we propose Diffusion-TS, a novel diffusion-based framework that generates multivariate time series samples of high quality by using an encoder-decoder transformer with disentangled temporal representations, in which the decomposition technique guides Diffusion-TS to capture the semantic meaning of time series while transformers mine detailed sequential information from the noisy model input. Different from existing diffusion-based approaches, we train the model to directly reconstruct the sample instead of the noise in each diffusion step, combining a Fourier-based loss term. Diffusion-TS is expected to generate time series satisfying both interpretablity and realness. In addition, it is shown that the proposed Diffusion-TS can be easily extended to conditional generation tasks, such as forecasting and imputation, without any model changes. This also motivates us to further explore the performance of Diffusion-TS under irregular settings. Finally, through qualitative and quantitative experiments, results show that Diffusion-TS achieves the state-of-the-art results on various realistic analyses of time series.
|
10 |
+
|
11 |
+
Diffusion-TS is a diffusion-based framework that generates general time series samples both conditionally and unconditionally. As shown in Figure 1, the framework contains two parts: a sequence encoder and an interpretable decoder which decomposes the time series into seasonal part and trend part. The trend part contains the polynomial regressor and extracted mean of each block output. For seasonal part, we reuse trigonometric representations based on Fourier series. Regarding training, sampling and more details, please refer to [our paper](https://openreview.net/pdf?id=4h1apFjO99) in ICLR 2024.
|
12 |
+
|
13 |
+
<p align="center">
|
14 |
+
<img src="figures/fig1.jpg" alt="">
|
15 |
+
<br>
|
16 |
+
<b>Figure 1</b>: Overall Architecture of Diffusion-TS.
|
17 |
+
</p>
|
18 |
+
|
19 |
+
|
20 |
+
## Dataset Preparation
|
21 |
+
|
22 |
+
All the four real-world datasets (Stocks, ETTh1, Energy and fMRI) can be obtained from [Google Drive](https://drive.google.com/file/d/11DI22zKWtHjXMnNGPWNUbyGz-JiEtZy6/view?usp=sharing). Please download **dataset.zip**, then unzip and copy it to the folder `./Data` in our repository.
|
23 |
+
|
24 |
+
|
25 |
+
## Running the Code
|
26 |
+
|
27 |
+
The code requires conda3 (or miniconda3), and one CUDA capable GPU. The instructions below guide you regarding running the codes in this repository.
|
28 |
+
|
29 |
+
### Environment & Libraries
|
30 |
+
|
31 |
+
The full libraries list is provided as a `requirements.txt` in this repo. Please create a virtual environment with `conda` or `venv` and run
|
32 |
+
|
33 |
+
~~~bash
|
34 |
+
(myenv) $ pip install -r requirements.txt
|
35 |
+
~~~
|
36 |
+
|
37 |
+
### Training & Sampling
|
38 |
+
|
39 |
+
For training, you can reproduce the experimental results of all benchmarks by runing
|
40 |
+
|
41 |
+
~~~bash
|
42 |
+
(myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --train
|
43 |
+
~~~
|
44 |
+
|
45 |
+
**Note:** We also provided the corresponding `.yml` files (only stocks, sines, mujoco, etth, energy and fmri) under the folder `./Config` where all possible option can be altered. You may need to change some parameters in the model for different scenarios. For example, we use the whole data to train model for unconditional evaluation, then *training_ratio* is set to 1 by default. As for conditional generation, we need to divide data set thus it should be changed to a value < 1.
|
46 |
+
|
47 |
+
While training, the script will save check points to the *results* folder after a fixed number of epochs. Once trained, please use the saved model for sampling by running
|
48 |
+
|
49 |
+
#### Unconstrained
|
50 |
+
```bash
|
51 |
+
(myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --sample 0 --milestone {checkpoint_number}
|
52 |
+
```
|
53 |
+
|
54 |
+
#### Imputation
|
55 |
+
```bash
|
56 |
+
(myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --sample 1 --milestone {checkpoint_number} --mode infill --missing_ratio {missing_ratio}
|
57 |
+
```
|
58 |
+
|
59 |
+
#### Forecasting
|
60 |
+
```bash
|
61 |
+
(myenv) $ python main.py --name {dataset_name} --config_file {config.yaml} --gpu 0 --sample 1 --milestone {checkpoint_number} --mode predict --pred_len {pred_len}
|
62 |
+
```
|
63 |
+
|
64 |
+
|
65 |
+
## Visualization and Evaluation
|
66 |
+
|
67 |
+
After sampling, synthetic data and orginal data are stored in `.npy` file format under the *output* folder, which can be directly read to calculate quantitative metrics such as discriminative, predictive, correlational and context-FID score. You can also reproduce the visualization results using t-SNE or kernel plotting, and all of these evaluational codes can be found in the folder `./Utils`. Please refer to `.ipynb` tutorial files in this repo for more detailed implementations.
|
68 |
+
|
69 |
+
**Note:** All the metrics can be found in the `./Experiments` folder. Additionally, by default, for datasets other than the Sine dataset (because it do not need normalization), their normalized forms are saved in `{...}_norm_truth.npy`. Therefore, when you run the Jupternotebook for dataset other than Sine, just uncomment and rewrite the corresponding code written at the beginning.
|
70 |
+
|
71 |
+
### Main Results
|
72 |
+
|
73 |
+
#### Standard TS Generation
|
74 |
+
<p align="center">
|
75 |
+
<b>Table 1</b>: Results of 24-length Time-series Generation.
|
76 |
+
<br>
|
77 |
+
<img src="figures/fig2.jpg" alt="">
|
78 |
+
</p>
|
79 |
+
|
80 |
+
#### Long-term TS Generation
|
81 |
+
<p align="center">
|
82 |
+
<b>Table 2</b>: Results of Long-term Time-series Generation.
|
83 |
+
<br>
|
84 |
+
<img src="figures/fig3.jpg" alt="">
|
85 |
+
</p>
|
86 |
+
|
87 |
+
#### Conditional TS Generation
|
88 |
+
<p align="center">
|
89 |
+
<img src="figures/fig4.jpg" alt="">
|
90 |
+
<br>
|
91 |
+
<b>Figure 2</b>: Visualizations of Time-series Imputation and Forecasting.
|
92 |
+
</p>
|
93 |
+
|
94 |
+
|
95 |
+
## Authors
|
96 |
+
|
97 |
+
* Paper Authors : Xinyu Yuan, Yan Qiao
|
98 |
+
|
99 |
+
* Code Author : Xinyu Yuan
|
100 |
+
|
101 |
+
* Contact : yxy5315@gmail.com
|
102 |
+
-->
|
103 |
+
|
104 |
+
## Citation
|
105 |
+
If you find this repo useful, please cite our paper via
|
106 |
+
```bibtex
|
107 |
+
```
|
108 |
+
|
109 |
+
|
110 |
+
## Acknowledgement
|
111 |
+
|
112 |
+
We appreciate the following github repos a lot for their valuable code base:
|
113 |
+
|
114 |
+
https://github.com/Y-debug-sys/Diffusion-TS
|
115 |
+
|
116 |
+
https://github.com/lucidrains/denoising-diffusion-pytorch
|
117 |
+
|
118 |
+
https://github.com/cientgu/VQ-Diffusion
|
119 |
+
|
120 |
+
https://github.com/XiangLi1999/Diffusion-LM
|
121 |
+
|
122 |
+
https://github.com/philipperemy/n-beats
|
123 |
+
|
124 |
+
https://github.com/salesforce/ETSformer
|
125 |
+
|
126 |
+
https://github.com/ermongroup/CSDI
|
127 |
+
|
128 |
+
https://github.com/jsyoon0823/TimeGAN
|
Train Demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,1121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from typing import Dict, List, Tuple
|
7 |
+
import re
|
8 |
+
from typing import Callable, Union, Dict
|
9 |
+
|
10 |
+
|
11 |
+
class TimeSeriesEditor:
|
12 |
+
def __init__(self, seq_length: int, feature_dim: int, trainer):
|
13 |
+
# Existing initialization
|
14 |
+
self.seq_length = seq_length
|
15 |
+
self.feature_dim = feature_dim
|
16 |
+
self.trainer = trainer
|
17 |
+
self.coef = None
|
18 |
+
self.stepsize = None
|
19 |
+
self.sampling_steps = None
|
20 |
+
self.feature_names = ["revenue", "download", "daily active user"]# * 20
|
21 |
+
# self.feature_names = [f"Feature {i}" for i in range(self.feature_dim)]
|
22 |
+
|
23 |
+
# Store the latest model output
|
24 |
+
self.latest_sample = None
|
25 |
+
self.latest_observed_points = None
|
26 |
+
self.latest_observed_mask = None
|
27 |
+
self.latest_gradient_control_signal = None
|
28 |
+
self.latest_model_control_signal = None
|
29 |
+
# self.latest_metrics
|
30 |
+
# Define scales for each feature
|
31 |
+
self.feature_scales = {
|
32 |
+
0: 1000000, # Revenue: $1M per 0.1
|
33 |
+
1: 100000, # Download: 100K downloads per 0.1
|
34 |
+
2: 10000 # AU: 10K active users per 0.1
|
35 |
+
}
|
36 |
+
self.feature_units = {
|
37 |
+
0: "$", # Revenue
|
38 |
+
1: "downloads", # Download
|
39 |
+
2: "users" # AU
|
40 |
+
}
|
41 |
+
self.show_normalized = True
|
42 |
+
|
43 |
+
# Add frequency band multipliers
|
44 |
+
self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0
|
45 |
+
self.function_parser = FunctionParser()
|
46 |
+
self.trending_controls = [] # Store trending controls
|
47 |
+
|
48 |
+
def format_value(self, value: float, feature_idx: int) -> str:
|
49 |
+
"""Format value with appropriate units and notation"""
|
50 |
+
if self.show_normalized:
|
51 |
+
return f"{value:.4f}"
|
52 |
+
else:
|
53 |
+
if feature_idx == 0: # Revenue
|
54 |
+
return f"{self.feature_units[feature_idx]}{value:,.2f}"
|
55 |
+
else: # Downloads and AU
|
56 |
+
return f"{value:,.0f} {self.feature_units[feature_idx]}"
|
57 |
+
|
58 |
+
def create_plot(self, sample: np.ndarray, observed_points: torch.Tensor,
|
59 |
+
observed_mask: torch.Tensor,
|
60 |
+
gradient_control_signal: Dict, metrics: Dict) -> List[go.Figure]:
|
61 |
+
figures = []
|
62 |
+
# Get weights from model_control_signal (will be all 1s if not provided)
|
63 |
+
weights = observed_mask
|
64 |
+
|
65 |
+
for feat_idx in range(self.feature_dim):
|
66 |
+
fig = go.Figure()
|
67 |
+
|
68 |
+
# Scale values if needed
|
69 |
+
scale_factor = self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1
|
70 |
+
|
71 |
+
# Plot predicted line
|
72 |
+
predicted_values = sample[:, feat_idx] * scale_factor
|
73 |
+
fig.add_trace(go.Scatter(
|
74 |
+
x=np.arange(self.seq_length),
|
75 |
+
y=predicted_values,
|
76 |
+
mode='lines',
|
77 |
+
name='Predicted',
|
78 |
+
line=dict(color='green', width=2),
|
79 |
+
showlegend=True
|
80 |
+
))
|
81 |
+
|
82 |
+
# Calculate and plot confidence bands based on weights
|
83 |
+
# Lower weights = larger uncertainty bands
|
84 |
+
mask = observed_points[:, feat_idx] > 0
|
85 |
+
ox = np.arange(0, self.seq_length)[mask]
|
86 |
+
oy = observed_points[mask, feat_idx].numpy() * scale_factor
|
87 |
+
weights_masked = 1 - weights[mask, feat_idx].numpy()
|
88 |
+
|
89 |
+
# Calculate error bars - inverse relationship with weight
|
90 |
+
# Weight of 1.0 gives minimal uncertainty (0.02)
|
91 |
+
# Weight of 0.1 gives larger uncertainty (0.2)
|
92 |
+
# error_y = 0.02 / weights_masked
|
93 |
+
error_y = weights_masked / 5
|
94 |
+
|
95 |
+
# Plot observed points with error bars - changed symbol to 'cross'
|
96 |
+
fig.add_trace(go.Scatter(
|
97 |
+
x=ox,
|
98 |
+
y=oy,
|
99 |
+
mode='markers',
|
100 |
+
name='Observed',
|
101 |
+
marker=dict(
|
102 |
+
# special red
|
103 |
+
color='rgba(255, 0, 0, 0.5)',
|
104 |
+
# size=10,
|
105 |
+
symbol='x', # Changed from 'circle' to 'x' for cross symbol
|
106 |
+
),
|
107 |
+
error_y=dict(
|
108 |
+
type='data',
|
109 |
+
array=error_y * scale_factor,
|
110 |
+
visible=True,
|
111 |
+
thickness=0.5,
|
112 |
+
width=2,
|
113 |
+
color='blue'
|
114 |
+
),
|
115 |
+
showlegend=True
|
116 |
+
))
|
117 |
+
|
118 |
+
# Add shaded confidence bands around the predicted line
|
119 |
+
# This shows the general uncertainty in the prediction
|
120 |
+
uncertainty = 0.05 # Base uncertainty level
|
121 |
+
upper_bound = predicted_values + uncertainty * scale_factor
|
122 |
+
lower_bound = predicted_values - uncertainty * scale_factor
|
123 |
+
|
124 |
+
fig.add_trace(go.Scatter(
|
125 |
+
x=np.concatenate([np.arange(self.seq_length), np.arange(self.seq_length)[::-1]]),
|
126 |
+
y=np.concatenate([upper_bound, lower_bound[::-1]]),
|
127 |
+
# fill='toself',
|
128 |
+
# fillcolor='rgba(0,100,0,0.1)',
|
129 |
+
line=dict(color='rgba(255,255,255,0)'),
|
130 |
+
name='Prediction Interval',
|
131 |
+
showlegend=True
|
132 |
+
))
|
133 |
+
|
134 |
+
# Add vertical lines for peak points
|
135 |
+
if gradient_control_signal.get("peak_points"):
|
136 |
+
for peak_point in gradient_control_signal["peak_points"]:
|
137 |
+
fig.add_vline(x=peak_point, line_dash="dash", line_color="red")
|
138 |
+
|
139 |
+
# Add metrics annotations
|
140 |
+
total_value = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
|
141 |
+
annotations = [dict(
|
142 |
+
x=0.02,
|
143 |
+
y=1.1,
|
144 |
+
xref="paper",
|
145 |
+
yref="paper",
|
146 |
+
text=f"Total {self.feature_names[feat_idx]}: {self.format_value(total_value, feat_idx)}",
|
147 |
+
showarrow=False
|
148 |
+
)]
|
149 |
+
|
150 |
+
# Update y-axis title based on feature and scaling
|
151 |
+
if self.show_normalized:
|
152 |
+
y_title = f'{self.feature_names[feat_idx]} (Normalized)'
|
153 |
+
else:
|
154 |
+
unit = self.feature_units[feat_idx]
|
155 |
+
y_title = f'{self.feature_names[feat_idx]} ({unit})'
|
156 |
+
|
157 |
+
# Create a more informative legend for uncertainty
|
158 |
+
legend_text = (
|
159 |
+
"Prediction with Confidence Bands<br>"
|
160 |
+
"• Blue points: Observed values with uncertainty<br>"
|
161 |
+
"• Green line: Predicted values<br>"
|
162 |
+
# "• Shaded area: Prediction uncertainty<br>"
|
163 |
+
"• Error bars: Observation uncertainty (larger = lower weight)"
|
164 |
+
)
|
165 |
+
|
166 |
+
fig.update_layout(
|
167 |
+
title=dict(
|
168 |
+
text=f'Feature: {self.feature_names[feat_idx]}',
|
169 |
+
x=0.5,
|
170 |
+
y=0.95
|
171 |
+
),
|
172 |
+
xaxis_title='Time',
|
173 |
+
yaxis_title=y_title,
|
174 |
+
height=400,
|
175 |
+
showlegend=True,
|
176 |
+
dragmode='select',
|
177 |
+
annotations=[
|
178 |
+
*annotations,
|
179 |
+
# dict(
|
180 |
+
# x=1.15,
|
181 |
+
# y=0.5,
|
182 |
+
# xref="paper",
|
183 |
+
# yref="paper",
|
184 |
+
# text=legend_text,
|
185 |
+
# showarrow=False,
|
186 |
+
# align="left",
|
187 |
+
# bordercolor="black",
|
188 |
+
# borderwidth=1,
|
189 |
+
# borderpad=4,
|
190 |
+
# bgcolor="white",
|
191 |
+
# )
|
192 |
+
],
|
193 |
+
margin=dict(r=200) # Add right margin for legend
|
194 |
+
)
|
195 |
+
|
196 |
+
figures.append(fig)
|
197 |
+
|
198 |
+
return figures
|
199 |
+
|
200 |
+
def update_scaling(self,
|
201 |
+
revenue_scale: float,
|
202 |
+
download_scale: float,
|
203 |
+
au_scale: float,
|
204 |
+
show_normalized: bool) -> Tuple[List[go.Figure], Dict]:
|
205 |
+
"""Update the scaling parameters and redraw plots"""
|
206 |
+
if self.latest_sample is None:
|
207 |
+
return [], {}
|
208 |
+
|
209 |
+
# Update scales
|
210 |
+
self.feature_scales = {
|
211 |
+
0: revenue_scale,
|
212 |
+
1: download_scale,
|
213 |
+
2: au_scale
|
214 |
+
}
|
215 |
+
self.show_normalized = show_normalized
|
216 |
+
|
217 |
+
# Calculate metrics
|
218 |
+
metrics = {
|
219 |
+
'show_normalized': self.show_normalized
|
220 |
+
}
|
221 |
+
for feat_idx in range(self.feature_dim):
|
222 |
+
total = np.sum(self.latest_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
|
223 |
+
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx)
|
224 |
+
|
225 |
+
# Update plots
|
226 |
+
figures = self.create_plot(
|
227 |
+
self.latest_sample,
|
228 |
+
self.latest_observed_points,
|
229 |
+
self.latest_observed_mask,
|
230 |
+
self.latest_gradient_control_signal,
|
231 |
+
metrics
|
232 |
+
)
|
233 |
+
|
234 |
+
return figures, metrics
|
235 |
+
|
236 |
+
def parse_data_points(self, df) -> Dict:
|
237 |
+
"""Parse data points from DataFrame with columns: time,feature,value"""
|
238 |
+
data_dict = {}
|
239 |
+
if df is None or df.empty:
|
240 |
+
return data_dict
|
241 |
+
|
242 |
+
for _, row in df.iterrows():
|
243 |
+
# Skip if any required value is NaN
|
244 |
+
if pd.isna(row['time']) or pd.isna(row['feature']) or pd.isna(row['value']):
|
245 |
+
continue
|
246 |
+
try:
|
247 |
+
time_idx = int(row['time'])
|
248 |
+
feature_idx = int(row['feature'])
|
249 |
+
value = float(row['value'])
|
250 |
+
|
251 |
+
if time_idx not in data_dict:
|
252 |
+
data_dict[time_idx] = {}
|
253 |
+
data_dict[time_idx][feature_idx] = (value, 1.0)
|
254 |
+
except (ValueError, TypeError):
|
255 |
+
continue
|
256 |
+
return data_dict
|
257 |
+
|
258 |
+
def parse_point_groups(self, df) -> Dict:
|
259 |
+
"""Parse point groups from DataFrame with columns: start,end,interval,feature,value,weight"""
|
260 |
+
data_dict = {}
|
261 |
+
if df is None or df.empty:
|
262 |
+
return data_dict
|
263 |
+
|
264 |
+
for _, row in df.iterrows():
|
265 |
+
# Skip if any required value is NaN
|
266 |
+
if pd.isna(row['start']) or pd.isna(row['end']) or pd.isna(row['interval']) or \
|
267 |
+
pd.isna(row['feature']) or pd.isna(row['value']):
|
268 |
+
continue
|
269 |
+
|
270 |
+
try:
|
271 |
+
start = int(row['start'])
|
272 |
+
end = int(row['end'])
|
273 |
+
interval = int(row['interval'])
|
274 |
+
feature = int(row['feature'])
|
275 |
+
value = float(row['value'])
|
276 |
+
weight = float(row.get('weight', 1.0)) if not pd.isna(row.get('weight')) else 1.0
|
277 |
+
|
278 |
+
for t in range(start, end + 1, interval):
|
279 |
+
if 0 <= t < self.seq_length:
|
280 |
+
if t not in data_dict:
|
281 |
+
data_dict[t] = {}
|
282 |
+
data_dict[t][feature] = (value, weight)
|
283 |
+
except (ValueError, TypeError):
|
284 |
+
continue
|
285 |
+
|
286 |
+
return data_dict
|
287 |
+
|
288 |
+
def to_tensor(self, observed_points_dict, seq_length, feature_dim):
|
289 |
+
observed_points = torch.zeros((seq_length, feature_dim))
|
290 |
+
observed_weights = torch.zeros((seq_length, feature_dim))
|
291 |
+
|
292 |
+
for seq, feature_dict in observed_points_dict.items():
|
293 |
+
for feature, (value, weight) in feature_dict.items():
|
294 |
+
observed_points[seq, feature] = value
|
295 |
+
observed_weights[seq, feature] = weight
|
296 |
+
|
297 |
+
return observed_points, observed_weights
|
298 |
+
|
299 |
+
def apply_direct_edits(self, sample: np.ndarray, edit_params: Dict) -> np.ndarray:
|
300 |
+
"""Apply direct edits to the sample array"""
|
301 |
+
edited_sample = sample.copy()
|
302 |
+
|
303 |
+
if edit_params.get("enable_direct_area"):
|
304 |
+
areas = self.parse_area_selections(edit_params["direct_areas"])
|
305 |
+
for area in areas:
|
306 |
+
start, end, feat_idx, target = area
|
307 |
+
edited_sample[start:end, feat_idx] += target
|
308 |
+
edited_sample = np.clip(edited_sample, 0, 1)
|
309 |
+
return edited_sample
|
310 |
+
|
311 |
+
def parse_area_selections(self, area_text: str) -> List[Tuple]:
|
312 |
+
"""Parse area selection text into (start, end, feature, target) tuples"""
|
313 |
+
areas = []
|
314 |
+
if not area_text.strip():
|
315 |
+
return areas
|
316 |
+
|
317 |
+
area_text = area_text.replace('\n', ';')
|
318 |
+
|
319 |
+
for line in area_text.strip().split(';'):
|
320 |
+
if not line.strip():
|
321 |
+
continue
|
322 |
+
try:
|
323 |
+
start, end, feat, target = map(float, line.strip().split(','))
|
324 |
+
areas.append((int(start), int(end), int(feat), target))
|
325 |
+
except (ValueError, IndexError):
|
326 |
+
continue
|
327 |
+
return areas
|
328 |
+
|
329 |
+
def apply_trending_mask(self, points: torch.Tensor, mask: torch.Tensor, consider_last_generated=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
330 |
+
"""Apply trending functions as soft constraints through masks"""
|
331 |
+
if not self.trending_controls or self.latest_sample is None:
|
332 |
+
return points, mask
|
333 |
+
|
334 |
+
for start, end, feat_idx, func, confidence in self.trending_controls:
|
335 |
+
if start < 0 or end > self.seq_length or start >= end:
|
336 |
+
continue
|
337 |
+
|
338 |
+
# Generate x values normalized between 0 and 1 for the segment
|
339 |
+
x = np.linspace(0, 1, end - start)
|
340 |
+
|
341 |
+
try:
|
342 |
+
# Calculate the function values
|
343 |
+
y = func(x)
|
344 |
+
|
345 |
+
# Scale the function output to 0-1 range
|
346 |
+
y = (y - np.min(y)) / (np.max(y) - np.min(y))
|
347 |
+
# points[start:end, feat_idx] = torch.tensor(y, dtype=points.dtype)
|
348 |
+
# mask[start:end, feat_idx] = max(mask[start:end, feat_idx], min(1.0, confidence * abs(
|
349 |
+
# self.latest_sample[start:end, feat_idx] - y
|
350 |
+
# ))) # Use lower weight for trending constraints
|
351 |
+
|
352 |
+
except Exception as e:
|
353 |
+
print(f"Error applying function: {e}")
|
354 |
+
continue
|
355 |
+
|
356 |
+
# Apply the trend as soft constraints
|
357 |
+
mask_zero = (mask[start:end, feat_idx] == 0)
|
358 |
+
points[start:end, feat_idx][mask_zero] = torch.tensor(y, dtype=points.dtype)[mask_zero]
|
359 |
+
mask[start:end, feat_idx][mask_zero] = torch.tensor(confidence * np.ones_like(y), dtype=mask.dtype)[mask_zero]
|
360 |
+
|
361 |
+
# mask[start:end, feat_idx][mask_zero] = torch.tensor((confidence * np.abs(self.latest_sample[start:end, feat_idx] - y)), dtype=mask.dtype)[mask_zero]
|
362 |
+
mask = mask.clamp(0, 1)
|
363 |
+
|
364 |
+
return points, mask
|
365 |
+
|
366 |
+
|
367 |
+
def update_model(self,
|
368 |
+
figures: List[go.Figure],
|
369 |
+
data_points: str,
|
370 |
+
point_groups: str,
|
371 |
+
enable_area_control: bool,
|
372 |
+
area_selections: str,
|
373 |
+
enable_auc: bool,
|
374 |
+
auc_value: float,
|
375 |
+
enable_peaks: bool,
|
376 |
+
peak_points: str,
|
377 |
+
peak_alpha: float,
|
378 |
+
auc_weight: float,
|
379 |
+
peak_weight: float,
|
380 |
+
enable_trending: bool = False,
|
381 |
+
enable_trending_with_diff: bool = False,
|
382 |
+
trending_params: str = ""
|
383 |
+
) -> Tuple[List[go.Figure], str, str, Dict]:
|
384 |
+
|
385 |
+
# Parse both point groups and individual data points
|
386 |
+
individual_points_dict = self.parse_data_points(data_points)
|
387 |
+
group_points_dict = self.parse_point_groups(point_groups)
|
388 |
+
|
389 |
+
# Merge dictionaries, giving preference to individual points
|
390 |
+
combined_points_dict = group_points_dict.copy()
|
391 |
+
for t, feat_dict in individual_points_dict.items():
|
392 |
+
if t not in combined_points_dict:
|
393 |
+
combined_points_dict[t] = {}
|
394 |
+
for f, v in feat_dict.items():
|
395 |
+
combined_points_dict[t][f] = v
|
396 |
+
|
397 |
+
# Convert to tensor
|
398 |
+
observed_points, observed_weights = self.to_tensor(
|
399 |
+
combined_points_dict,
|
400 |
+
self.seq_length,
|
401 |
+
self.feature_dim
|
402 |
+
)
|
403 |
+
observed_mask = observed_weights
|
404 |
+
|
405 |
+
# Parse peak points
|
406 |
+
peak_points_list = []
|
407 |
+
if enable_peaks and peak_points:
|
408 |
+
try:
|
409 |
+
peak_points_list = [int(x.strip()) for x in peak_points.split(',') if x.strip()]
|
410 |
+
except ValueError:
|
411 |
+
peak_points_list = []
|
412 |
+
|
413 |
+
# Apply trending control if enabled
|
414 |
+
if enable_trending and trending_params:
|
415 |
+
self.parse_trending_parameters(trending_params)
|
416 |
+
observed_points, observed_mask = self.apply_trending_mask(observed_points, observed_mask, consider_last_generated=enable_trending_with_diff)
|
417 |
+
|
418 |
+
# Build gradient control signal
|
419 |
+
# IMPORTANT
|
420 |
+
gradient_control_signal = {}
|
421 |
+
if enable_auc:
|
422 |
+
gradient_control_signal["auc"] = auc_value
|
423 |
+
gradient_control_signal["auc_weight"] = auc_weight
|
424 |
+
if enable_peaks:
|
425 |
+
gradient_control_signal.update({
|
426 |
+
"peak_points": peak_points_list,
|
427 |
+
"peak_alpha": peak_alpha,
|
428 |
+
"peak_weight": peak_weight
|
429 |
+
})
|
430 |
+
|
431 |
+
# Build model control signal
|
432 |
+
model_control_signal = {}
|
433 |
+
# if enable_area_control and area_selections:
|
434 |
+
# areas = self.parse_area_selections(area_selections)
|
435 |
+
# if areas:
|
436 |
+
# model_control_signal["selected_areas"] = areas
|
437 |
+
|
438 |
+
# Run prediction
|
439 |
+
sample = self.trainer.predict_weighted_points(
|
440 |
+
observed_points, # (seq_length, feature_dim)
|
441 |
+
observed_mask, # (seq_length, feature_dim)
|
442 |
+
self.coef, # fixed
|
443 |
+
self.stepsize, # fixed
|
444 |
+
self.sampling_steps, # fixed
|
445 |
+
# model_control_signal=model_control_signal,
|
446 |
+
gradient_control_signal=gradient_control_signal
|
447 |
+
)
|
448 |
+
|
449 |
+
# Store latest results
|
450 |
+
self.latest_sample = sample
|
451 |
+
self.latest_observed_points = observed_points
|
452 |
+
self.latest_observed_mask = observed_mask
|
453 |
+
self.latest_gradient_control_signal = gradient_control_signal
|
454 |
+
self.latest_model_control_signal = model_control_signal
|
455 |
+
|
456 |
+
# Calculate metrics
|
457 |
+
metrics = {
|
458 |
+
'show_normalized': self.show_normalized
|
459 |
+
}
|
460 |
+
for feat_idx in range(self.feature_dim):
|
461 |
+
total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
|
462 |
+
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx)
|
463 |
+
|
464 |
+
# Update plots
|
465 |
+
figures = self.create_plot(sample, observed_points, observed_mask, gradient_control_signal, metrics)
|
466 |
+
|
467 |
+
return figures, data_points, point_groups, metrics
|
468 |
+
|
469 |
+
|
470 |
+
def update_additional_edit(
|
471 |
+
self,
|
472 |
+
enable_direct_area: bool,
|
473 |
+
direct_areas: str):
|
474 |
+
# Apply direct edits if enabled
|
475 |
+
if enable_direct_area:
|
476 |
+
sample = self.apply_direct_edits(self.latest_sample, {
|
477 |
+
"enable_direct_area": enable_direct_area,
|
478 |
+
"direct_areas": direct_areas
|
479 |
+
})
|
480 |
+
else:
|
481 |
+
sample = self.latest_sample
|
482 |
+
|
483 |
+
# Calculate metrics
|
484 |
+
metrics = {
|
485 |
+
'show_normalized': self.show_normalized
|
486 |
+
}
|
487 |
+
for feat_idx in range(self.feature_dim):
|
488 |
+
total = np.sum(sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
|
489 |
+
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx)
|
490 |
+
|
491 |
+
# Update plots
|
492 |
+
figures = self.create_plot(
|
493 |
+
sample,
|
494 |
+
self.latest_observed_points,
|
495 |
+
self.latest_observed_mask,
|
496 |
+
self.latest_gradient_control_signal,
|
497 |
+
metrics
|
498 |
+
)
|
499 |
+
|
500 |
+
return figures, metrics
|
501 |
+
|
502 |
+
|
503 |
+
def apply_frequency_filter(self, signal: np.ndarray) -> np.ndarray:
|
504 |
+
"""Apply FFT-based frequency filtering using the current band multipliers"""
|
505 |
+
# Get FFT of the signal
|
506 |
+
fft = np.fft.fft(signal)
|
507 |
+
freqs = np.fft.fftfreq(len(signal))
|
508 |
+
|
509 |
+
# Split frequencies into 5 bands
|
510 |
+
# Exclude DC component (0 frequency) from bands
|
511 |
+
pos_freqs = freqs[1:len(freqs)//2]
|
512 |
+
freq_ranges = np.array_split(pos_freqs, 5)
|
513 |
+
|
514 |
+
# Apply band multipliers
|
515 |
+
filtered_fft = fft.copy()
|
516 |
+
|
517 |
+
# Handle DC component separately (lowest frequency)
|
518 |
+
filtered_fft[0] *= self.freq_bands[4] # Apply very low freq multiplier to DC
|
519 |
+
|
520 |
+
# Apply multipliers to each frequency band
|
521 |
+
for i, freq_range in enumerate(freq_ranges):
|
522 |
+
# Get indices for this frequency band
|
523 |
+
band_mask = np.logical_and(
|
524 |
+
freqs >= freq_range[0],
|
525 |
+
freqs <= freq_range[-1]
|
526 |
+
)
|
527 |
+
|
528 |
+
# Apply multiplier to positive and negative frequencies symmetrically
|
529 |
+
filtered_fft[band_mask] *= self.freq_bands[4-i]
|
530 |
+
filtered_fft[np.flip(band_mask)] *= self.freq_bands[4-i]
|
531 |
+
|
532 |
+
# Convert back to time domain
|
533 |
+
filtered_signal = np.real(np.fft.ifft(filtered_fft))
|
534 |
+
|
535 |
+
return filtered_signal
|
536 |
+
|
537 |
+
|
538 |
+
def update_frequency_bands(self, band_idx: int, value: float) -> Tuple[List[go.Figure], Dict]:
|
539 |
+
"""Update a frequency band multiplier and recompute the filtered signal"""
|
540 |
+
if self.latest_sample is None:
|
541 |
+
return [], {}
|
542 |
+
|
543 |
+
# Update the specified band multiplier
|
544 |
+
self.freq_bands[band_idx] = value
|
545 |
+
|
546 |
+
# Apply frequency filtering to each feature
|
547 |
+
filtered_sample = self.latest_sample.copy()
|
548 |
+
for feat_idx in range(self.feature_dim):
|
549 |
+
filtered_sample[:, feat_idx] = self.apply_frequency_filter(
|
550 |
+
self.latest_sample[:, feat_idx]
|
551 |
+
)
|
552 |
+
|
553 |
+
# Ensure values remain in valid range
|
554 |
+
filtered_sample = np.clip(filtered_sample, 0, 1)
|
555 |
+
|
556 |
+
# Calculate metrics
|
557 |
+
metrics = {
|
558 |
+
'show_normalized': self.show_normalized,
|
559 |
+
'frequency_bands': self.freq_bands.tolist()
|
560 |
+
}
|
561 |
+
for feat_idx in range(self.feature_dim):
|
562 |
+
total = np.sum(filtered_sample[:, feat_idx]) * (self.feature_scales[feat_idx] * 10 if not self.show_normalized else 1)
|
563 |
+
metrics[f'total_{self.feature_names[feat_idx]}'] = self.format_value(total, feat_idx)
|
564 |
+
|
565 |
+
# Update plots
|
566 |
+
figures = self.create_plot(
|
567 |
+
filtered_sample,
|
568 |
+
self.latest_observed_points,
|
569 |
+
self.latest_observed_mask,
|
570 |
+
self.latest_gradient_control_signal,
|
571 |
+
metrics
|
572 |
+
)
|
573 |
+
|
574 |
+
return figures, metrics
|
575 |
+
|
576 |
+
def parse_trending_parameters(self, trending_text: str) -> List[Tuple]:
|
577 |
+
"""Parse trending control parameters into (start, end, feature, function) tuples"""
|
578 |
+
trending_params = []
|
579 |
+
if not trending_text.strip():
|
580 |
+
return trending_params
|
581 |
+
|
582 |
+
trending_text = trending_text.replace('\n', ';')
|
583 |
+
|
584 |
+
for line in trending_text.strip().split(';'):
|
585 |
+
if not line.strip():
|
586 |
+
continue
|
587 |
+
try:
|
588 |
+
# Split by comma and handle the function part separately
|
589 |
+
parts = line.strip().split(',', 4)
|
590 |
+
if len(parts) != 5:
|
591 |
+
continue
|
592 |
+
|
593 |
+
start, end, feat = map(int, parts[:3])
|
594 |
+
function_str = parts[3].strip()
|
595 |
+
confidence = float(parts[4])
|
596 |
+
# Convert the function string to a callable
|
597 |
+
try:
|
598 |
+
func = self.function_parser.string_to_function(function_str)
|
599 |
+
trending_params.append((start, end, feat, func, confidence))
|
600 |
+
except ValueError as e:
|
601 |
+
print(f"Error parsing function '{function_str}': {e}")
|
602 |
+
continue
|
603 |
+
|
604 |
+
except (ValueError, IndexError):
|
605 |
+
continue
|
606 |
+
self.trending_controls = trending_params # Store the parsed parameters
|
607 |
+
return trending_params
|
608 |
+
|
609 |
+
|
610 |
+
def create_gradio_interface(editor: TimeSeriesEditor):
|
611 |
+
with gr.Blocks() as app:
|
612 |
+
gr.Markdown("# Time Series Editor")
|
613 |
+
gr.Markdown("## Instruction: Scroll Down + Click `Update Figure` [~20s]")
|
614 |
+
|
615 |
+
metrics_display = gr.JSON(label="Metrics", value={})
|
616 |
+
|
617 |
+
with gr.Row():
|
618 |
+
with gr.Column(scale=1):
|
619 |
+
# with Tab():
|
620 |
+
# Scaling Parameters Section
|
621 |
+
# with gr.Group():
|
622 |
+
|
623 |
+
gr.Markdown("## Scaling Parameters")
|
624 |
+
with gr.Accordion("Open for More Detail", open=False):
|
625 |
+
revenue_scale = gr.Number(
|
626 |
+
label="Revenue Scale ($ per 0.1 in model)",
|
627 |
+
value=1000000
|
628 |
+
)
|
629 |
+
download_scale = gr.Number(
|
630 |
+
label="Download Scale (downloads per 0.1 in model)",
|
631 |
+
value=100000
|
632 |
+
)
|
633 |
+
au_scale = gr.Number(
|
634 |
+
label="Active Users Scale (users per 0.1 in model)",
|
635 |
+
value=10000
|
636 |
+
)
|
637 |
+
show_normalized = gr.Checkbox(
|
638 |
+
label="Show Normalized Values (0-1 scale)",
|
639 |
+
value=True
|
640 |
+
)
|
641 |
+
update_scaling_btn = gr.Button("Update Scaling")
|
642 |
+
|
643 |
+
# TS Section
|
644 |
+
gr.Markdown("## Time Series Control Panel")
|
645 |
+
with gr.Accordion("Open for More Detail"):
|
646 |
+
with gr.Group():
|
647 |
+
gr.Markdown("### Anchor Point Control")
|
648 |
+
data_points_df = gr.Dataframe(
|
649 |
+
headers=["time", "feature", "value"],
|
650 |
+
datatype=["number", "number", "number"],
|
651 |
+
# label="Anchor Point Control",
|
652 |
+
value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 1.0], [60, 0, 0.5]],
|
653 |
+
col_count=(3, "fixed"), # Fix number of columns
|
654 |
+
interactive=True
|
655 |
+
)
|
656 |
+
add_data_point_btn = gr.Button("Add Data Point")
|
657 |
+
|
658 |
+
def add_data_point(df):
|
659 |
+
new_row = pd.DataFrame([[None, 0, None]],
|
660 |
+
columns=["time", "feature", "value"])
|
661 |
+
return pd.concat([df, new_row], ignore_index=True)
|
662 |
+
|
663 |
+
add_data_point_btn.click(
|
664 |
+
fn=add_data_point,
|
665 |
+
inputs=[data_points_df],
|
666 |
+
outputs=[data_points_df]
|
667 |
+
)
|
668 |
+
|
669 |
+
with gr.Group():
|
670 |
+
gr.Markdown("### Group of Anchor Point Control")
|
671 |
+
point_groups_df = gr.Dataframe(
|
672 |
+
headers=["start", "end", "interval", "feature", "value", "weight"],
|
673 |
+
datatype=["number", "number", "number", "number", "number", "number"],
|
674 |
+
# label="Group of Anchor Point Control",
|
675 |
+
value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]],
|
676 |
+
col_count=(6, "fixed"), # Fix number of columns
|
677 |
+
interactive=True
|
678 |
+
)
|
679 |
+
add_point_group_btn = gr.Button("Add Point Group")
|
680 |
+
|
681 |
+
def add_point_group(df):
|
682 |
+
new_row = pd.DataFrame([[None, None, None, 0, None, None]],
|
683 |
+
columns=["start", "end", "interval", "feature", "value", "weight"])
|
684 |
+
return pd.concat([df, new_row], ignore_index=True)
|
685 |
+
|
686 |
+
add_point_group_btn.click(
|
687 |
+
fn=add_point_group,
|
688 |
+
inputs=[point_groups_df],
|
689 |
+
outputs=[point_groups_df]
|
690 |
+
)
|
691 |
+
|
692 |
+
with gr.Group():
|
693 |
+
# with gr.Tab("Trending Control"):
|
694 |
+
gr.Markdown("### Trending Control")
|
695 |
+
gr.Markdown("""
|
696 |
+
Enter trending control parameters in the format:
|
697 |
+
```
|
698 |
+
start_time,end_time,feature,function,confidence
|
699 |
+
```
|
700 |
+
Examples:
|
701 |
+
- Linear trend: `0,100,0,x`
|
702 |
+
- Sine wave: `0,100,0,sin(2*pi*x)`
|
703 |
+
- Exponential: `0,100,0,exp(-x)`
|
704 |
+
|
705 |
+
Separate multiple trends with semicolons.
|
706 |
+
""")
|
707 |
+
enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=True)
|
708 |
+
enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False)
|
709 |
+
trending_control = gr.Textbox(
|
710 |
+
label="Trending Control Parameters",
|
711 |
+
lines=2,
|
712 |
+
placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons",
|
713 |
+
value="200,250,0,sin(2*pi*x),0.2"
|
714 |
+
)
|
715 |
+
|
716 |
+
# Area Control Parameters
|
717 |
+
with gr.Group(visible=False):
|
718 |
+
gr.Markdown("### Area Control")
|
719 |
+
enable_area_control = gr.Checkbox(label="Enable Area Control", value=False)
|
720 |
+
area_selections = gr.Textbox(
|
721 |
+
label="Area Selections (format: start_time,end_time,feature,target_value)",
|
722 |
+
lines=2,
|
723 |
+
placeholder="Enter areas: start,end,feature,target; separated by semicolons",
|
724 |
+
|
725 |
+
)
|
726 |
+
|
727 |
+
# AUC Parameters
|
728 |
+
gr.Markdown("### Statistics Control")
|
729 |
+
enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True)
|
730 |
+
auc_input = gr.Number(label="Target Sum Value", value=-150)
|
731 |
+
auc_weight_input = gr.Number(label="Sum Weight", value=10.0)
|
732 |
+
|
733 |
+
# Peak Parameters
|
734 |
+
with gr.Group(visible=False):
|
735 |
+
gr.Markdown("### Peak Control")
|
736 |
+
enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False)
|
737 |
+
peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200")
|
738 |
+
peak_alpha_input = gr.Number(label="Peak Alpha", value=10)
|
739 |
+
peak_weight_input = gr.Number(label="Peak Weight", value=1.0)
|
740 |
+
|
741 |
+
update_model_btn = gr.Button("Update Figure")
|
742 |
+
|
743 |
+
gr.Markdown("## Extend Edit", visible=False)
|
744 |
+
with gr.Tab("Range Shift", visible=False):
|
745 |
+
# gr.Markdown("### Direct Edit Control")
|
746 |
+
enable_direct_area = gr.Checkbox(label="Enable Direct Edits", value=False) # range shift
|
747 |
+
direct_areas = gr.Textbox(
|
748 |
+
label="Direct Edit Areas (format: start_time,end_time,feature,delta)",
|
749 |
+
lines=2,
|
750 |
+
placeholder="Enter areas: start,end,feature,delta; separated by semicolons",
|
751 |
+
value="150,200,0,-0.1"
|
752 |
+
)
|
753 |
+
|
754 |
+
update_additional_btn = gr.Button("Update Additional Edit")
|
755 |
+
|
756 |
+
# with gr.Tab("Trending Control"):
|
757 |
+
# gr.Markdown("### Trending Control")
|
758 |
+
# gr.Markdown("""
|
759 |
+
# Enter trending control parameters in the format:
|
760 |
+
# ```
|
761 |
+
# start_time,end_time,feature,function
|
762 |
+
# ```
|
763 |
+
# Examples:
|
764 |
+
# - Linear trend: `0,100,0,x`
|
765 |
+
# - Sine wave: `0,100,0,sin(2*pi*x)`
|
766 |
+
# - Exponential: `0,100,0,exp(-x)`
|
767 |
+
|
768 |
+
# Separate multiple trends with semicolons.
|
769 |
+
# """)
|
770 |
+
# enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=False)
|
771 |
+
# enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False)
|
772 |
+
# trending_control = gr.Textbox(
|
773 |
+
# label="Trending Control Parameters",
|
774 |
+
# lines=2,
|
775 |
+
# placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons",
|
776 |
+
# value="0,100,0,sin(2*pi*x),0.3"
|
777 |
+
# )
|
778 |
+
|
779 |
+
# with gr.Tab("Frequency Controls", visible=False):
|
780 |
+
with gr.Group(visible=False):
|
781 |
+
gr.Markdown("Adjust multipliers for different frequency bands (0-2)")
|
782 |
+
freq_bands = [
|
783 |
+
gr.Slider(
|
784 |
+
minimum=0, maximum=2, step=0.1, value=1.0,
|
785 |
+
label=f"Band {i+1}: {'Very High' if i==0 else 'High' if i==1 else 'Mid' if i==2 else 'Low' if i==3 else 'Very Low'} Freq",
|
786 |
+
) for i in range(5)
|
787 |
+
]
|
788 |
+
|
789 |
+
gr.Markdown("### Feature Index Reference:")
|
790 |
+
for idx, name in enumerate(editor.feature_names):
|
791 |
+
gr.Markdown(f"- {idx}: {name}")
|
792 |
+
|
793 |
+
with gr.Column(scale=1.2):
|
794 |
+
gr.Markdown("""
|
795 |
+
### Plot Legend
|
796 |
+
- **Points with Error Bars**: Observed values where:
|
797 |
+
- Point position = observed value
|
798 |
+
- Error bar size = uncertainty (inversely proportional to weight)
|
799 |
+
- **Green Line**: Model prediction
|
800 |
+
- **Vertical Red Lines**: Peak points (if enabled)
|
801 |
+
""")
|
802 |
+
plots = [gr.Plot() for _ in range(editor.feature_dim)]
|
803 |
+
# - **Shaded Area**: General prediction uncertainty
|
804 |
+
|
805 |
+
def update_scaling_callback(revenue_scale, download_scale, au_scale, show_normalized):
|
806 |
+
figs, metrics = editor.update_scaling(
|
807 |
+
revenue_scale,
|
808 |
+
download_scale,
|
809 |
+
au_scale,
|
810 |
+
show_normalized
|
811 |
+
)
|
812 |
+
return [*figs, metrics]
|
813 |
+
|
814 |
+
def update_model_callback(
|
815 |
+
data_points_df,
|
816 |
+
point_groups_df,
|
817 |
+
enable_area_control,
|
818 |
+
area_selections,
|
819 |
+
enable_auc,
|
820 |
+
auc,
|
821 |
+
auc_weight,
|
822 |
+
enable_peaks,
|
823 |
+
peak_points,
|
824 |
+
peak_alpha,
|
825 |
+
peak_weight,
|
826 |
+
enable_trending,
|
827 |
+
enable_trending_with_diff,
|
828 |
+
trending_params
|
829 |
+
):
|
830 |
+
figs, _, _, metrics = editor.update_model(
|
831 |
+
plots,
|
832 |
+
data_points_df,
|
833 |
+
point_groups_df,
|
834 |
+
enable_area_control,
|
835 |
+
area_selections,
|
836 |
+
enable_auc,
|
837 |
+
auc,
|
838 |
+
enable_peaks,
|
839 |
+
peak_points,
|
840 |
+
peak_alpha,
|
841 |
+
auc_weight,
|
842 |
+
peak_weight,
|
843 |
+
enable_trending,
|
844 |
+
enable_trending_with_diff,
|
845 |
+
trending_params
|
846 |
+
)
|
847 |
+
return [*figs, metrics]
|
848 |
+
|
849 |
+
# Update the click handler
|
850 |
+
update_model_btn.click(
|
851 |
+
fn=update_model_callback,
|
852 |
+
inputs=[
|
853 |
+
data_points_df,
|
854 |
+
point_groups_df,
|
855 |
+
enable_area_control,
|
856 |
+
area_selections,
|
857 |
+
enable_auc,
|
858 |
+
auc_input,
|
859 |
+
auc_weight_input,
|
860 |
+
enable_peaks,
|
861 |
+
peak_points_input,
|
862 |
+
peak_alpha_input,
|
863 |
+
peak_weight_input,
|
864 |
+
enable_trending_control,
|
865 |
+
enable_trending_control_with_diff,
|
866 |
+
trending_control
|
867 |
+
],
|
868 |
+
outputs=[*plots, metrics_display]
|
869 |
+
)
|
870 |
+
|
871 |
+
|
872 |
+
def update_additional_callback(enable_direct_area, direct_areas):
|
873 |
+
figs, metrics = editor.update_additional_edit(
|
874 |
+
enable_direct_area,
|
875 |
+
direct_areas
|
876 |
+
)
|
877 |
+
return [*figs, metrics]
|
878 |
+
|
879 |
+
def update_freq_band(band_idx, value):
|
880 |
+
figs, metrics = editor.update_frequency_bands(band_idx, value)
|
881 |
+
return [*figs, metrics]
|
882 |
+
|
883 |
+
update_scaling_btn.click(
|
884 |
+
fn=update_scaling_callback,
|
885 |
+
inputs=[
|
886 |
+
revenue_scale,
|
887 |
+
download_scale,
|
888 |
+
au_scale,
|
889 |
+
show_normalized
|
890 |
+
],
|
891 |
+
outputs=[*plots, metrics_display]
|
892 |
+
)
|
893 |
+
|
894 |
+
update_additional_btn.click(
|
895 |
+
fn=update_additional_callback,
|
896 |
+
inputs=[enable_direct_area, direct_areas],
|
897 |
+
outputs=[*plots, metrics_display]
|
898 |
+
)
|
899 |
+
|
900 |
+
# Add event handlers for frequency band sliders
|
901 |
+
for i, slider in enumerate(freq_bands):
|
902 |
+
slider.change(
|
903 |
+
fn=update_freq_band,
|
904 |
+
inputs=[gr.Number(value=i, visible=False), slider],
|
905 |
+
outputs=[*plots, metrics_display]
|
906 |
+
)
|
907 |
+
|
908 |
+
# app.load(
|
909 |
+
# fn=update_model_callback,
|
910 |
+
# inputs=[
|
911 |
+
# data_points_df,
|
912 |
+
# point_groups_df,
|
913 |
+
# enable_area_control,
|
914 |
+
# area_selections,
|
915 |
+
# enable_auc,
|
916 |
+
# auc_input,
|
917 |
+
# auc_weight_input,
|
918 |
+
# enable_peaks,
|
919 |
+
# peak_points_input,
|
920 |
+
# peak_alpha_input,
|
921 |
+
# peak_weight_input,
|
922 |
+
# enable_trending_control,
|
923 |
+
# enable_trending_control_with_diff,
|
924 |
+
# trending_control
|
925 |
+
# ],
|
926 |
+
# outputs=[*plots, metrics_display]
|
927 |
+
# )
|
928 |
+
|
929 |
+
return app
|
930 |
+
|
931 |
+
|
932 |
+
class FunctionParser:
|
933 |
+
def __init__(self):
|
934 |
+
# Define available mathematical functions and constants
|
935 |
+
self.math_functions = {
|
936 |
+
'sin': np.sin,
|
937 |
+
'cos': np.cos,
|
938 |
+
'tan': np.tan,
|
939 |
+
'exp': np.exp,
|
940 |
+
'log': np.log,
|
941 |
+
'sqrt': np.sqrt,
|
942 |
+
'abs': np.abs,
|
943 |
+
'pow': np.power,
|
944 |
+
'pi': np.pi,
|
945 |
+
'e': np.e,
|
946 |
+
'asin': np.arcsin,
|
947 |
+
'acos': np.arccos,
|
948 |
+
'atan': np.arctan,
|
949 |
+
'sinh': np.sinh,
|
950 |
+
'cosh': np.cosh,
|
951 |
+
'tanh': np.tanh
|
952 |
+
}
|
953 |
+
|
954 |
+
def validate_expression(self, expression: str) -> bool:
|
955 |
+
"""
|
956 |
+
Validate the mathematical expression for basic syntax errors.
|
957 |
+
"""
|
958 |
+
# Check for balanced parentheses
|
959 |
+
if expression.count('(') != expression.count(')'):
|
960 |
+
raise ValueError("Unbalanced parentheses in expression")
|
961 |
+
|
962 |
+
# Check for invalid characters
|
963 |
+
valid_chars = set('0123456789.+-*/()^ xXepi,')
|
964 |
+
valid_chars.update(''.join(self.math_functions.keys()))
|
965 |
+
if not all(c in valid_chars or c.isspace() for c in expression.lower()):
|
966 |
+
raise ValueError("Expression contains invalid characters")
|
967 |
+
|
968 |
+
return True
|
969 |
+
|
970 |
+
def preprocess_expression(self, expression: str) -> str:
|
971 |
+
"""
|
972 |
+
Preprocess the expression to handle various input formats.
|
973 |
+
"""
|
974 |
+
# Remove whitespace
|
975 |
+
expression = expression.replace(' ', '')
|
976 |
+
|
977 |
+
# Convert ^ to ** for exponentiation
|
978 |
+
expression = expression.replace('^', '**')
|
979 |
+
|
980 |
+
# Ensure multiplication is explicit
|
981 |
+
expression = re.sub(r'(\d+)([a-zA-Z])', r'\1*\2', expression)
|
982 |
+
expression = re.sub(r'(\))([\w])', r'\1*\2', expression)
|
983 |
+
|
984 |
+
# Replace X with x for consistency
|
985 |
+
expression = expression.lower()
|
986 |
+
|
987 |
+
return expression
|
988 |
+
|
989 |
+
def string_to_function(self, expression: str) -> Callable[[Union[float, np.ndarray]], Union[float, np.ndarray]]:
|
990 |
+
"""
|
991 |
+
Convert a string mathematical expression to a callable function.
|
992 |
+
|
993 |
+
Args:
|
994 |
+
expression (str): Mathematical expression (e.g., "sin(x) + x^2")
|
995 |
+
|
996 |
+
Returns:
|
997 |
+
Callable: A function that takes x as input and returns the evaluated result
|
998 |
+
|
999 |
+
Example:
|
1000 |
+
>>> f = string_to_function("sin(x) + x^2")
|
1001 |
+
>>> f(0.5)
|
1002 |
+
0.729321...
|
1003 |
+
"""
|
1004 |
+
# Validate and preprocess the expression
|
1005 |
+
self.validate_expression(expression)
|
1006 |
+
processed_expr = self.preprocess_expression(expression)
|
1007 |
+
|
1008 |
+
# Create the function namespace
|
1009 |
+
namespace = self.math_functions.copy()
|
1010 |
+
|
1011 |
+
try:
|
1012 |
+
# Create the lambda function
|
1013 |
+
func = eval(f"lambda x: {processed_expr}", namespace)
|
1014 |
+
|
1015 |
+
# Test the function with a simple input
|
1016 |
+
test_value = 1.0
|
1017 |
+
try:
|
1018 |
+
func(test_value)
|
1019 |
+
except Exception as e:
|
1020 |
+
raise ValueError(f"Invalid function: {str(e)}")
|
1021 |
+
|
1022 |
+
return func
|
1023 |
+
|
1024 |
+
except SyntaxError as e:
|
1025 |
+
raise ValueError(f"Invalid expression syntax: {str(e)}")
|
1026 |
+
except Exception as e:
|
1027 |
+
raise ValueError(f"Error creating function: {str(e)}")
|
1028 |
+
|
1029 |
+
@staticmethod
|
1030 |
+
def demonstrate_usage():
|
1031 |
+
"""
|
1032 |
+
Demonstrate various uses of the function parser.
|
1033 |
+
"""
|
1034 |
+
parser = FunctionParser()
|
1035 |
+
|
1036 |
+
# Test cases
|
1037 |
+
test_expressions = [
|
1038 |
+
"x^2 + 2*x + 1",
|
1039 |
+
"sin(x) + cos(x)",
|
1040 |
+
"exp(-x^2)",
|
1041 |
+
"log(x + 1)",
|
1042 |
+
"sqrt(1 - x^2)",
|
1043 |
+
]
|
1044 |
+
|
1045 |
+
print("Testing various mathematical expressions:")
|
1046 |
+
x_test = 0.5
|
1047 |
+
|
1048 |
+
for expr in test_expressions:
|
1049 |
+
try:
|
1050 |
+
print(f"\nExpression: {expr}")
|
1051 |
+
func = parser.string_to_function(expr)
|
1052 |
+
result = func(x_test)
|
1053 |
+
print(f"f({x_test}) = {result}")
|
1054 |
+
|
1055 |
+
# Test with numpy array
|
1056 |
+
x_array = np.linspace(0, 1, 5)
|
1057 |
+
results = func(x_array)
|
1058 |
+
print(f"f(array) = {results}")
|
1059 |
+
|
1060 |
+
except Exception as e:
|
1061 |
+
print(f"Error: {str(e)}")
|
1062 |
+
|
1063 |
+
# Example usage:
|
1064 |
+
if __name__ == "__main__":
|
1065 |
+
# Initialize with example data points
|
1066 |
+
# example_data_points = "0,0,0.04;2,0,0.58;6,0,0.27;58,0,1.0;-1,0,0.05"
|
1067 |
+
|
1068 |
+
import os
|
1069 |
+
import torch
|
1070 |
+
import numpy as np
|
1071 |
+
from engine.solver import Trainer
|
1072 |
+
from utils.io_utils import load_yaml_config, instantiate_from_config
|
1073 |
+
|
1074 |
+
# assert torch.cuda.is_available(), "CUDA must be available"
|
1075 |
+
class Parameters:
|
1076 |
+
def __init__(self) -> None:
|
1077 |
+
self.gpu = 0
|
1078 |
+
self.config_path = "./config/modified/revenue-baseline-365.yaml"
|
1079 |
+
# self.config_path = "config/modified/96/fmri.yaml"
|
1080 |
+
# self.config_path = "./config/control/revenue-baseline-sine.yaml"
|
1081 |
+
# self.save_dir = (
|
1082 |
+
# "../../../data/" + os.path.basename(self.config_path).split(".")[0]
|
1083 |
+
# )
|
1084 |
+
self.mode = "infill"
|
1085 |
+
self.missing_ratio = 0.95
|
1086 |
+
self.milestone = "10"
|
1087 |
+
# os.makedirs(self.save_dir, exist_ok=True)
|
1088 |
+
|
1089 |
+
os.environ["WANDB_ENABLED"] = "false"
|
1090 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
1091 |
+
# print working directory
|
1092 |
+
print(os.getcwd())
|
1093 |
+
args = Parameters()
|
1094 |
+
configs = load_yaml_config(args.config_path)
|
1095 |
+
# device = torch.device('cpu')
|
1096 |
+
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
|
1097 |
+
|
1098 |
+
# dl_info = build_dataloader_cond(configs, args)
|
1099 |
+
model = instantiate_from_config(configs["model"]).to(device)
|
1100 |
+
trainer = Trainer(config=configs, args=args, model=model, dataloader={
|
1101 |
+
"dataloader": []
|
1102 |
+
})
|
1103 |
+
|
1104 |
+
trainer.load(args.milestone, from_folder="./weight") #, from_folder="../../../data/ckpt_baseline_sine_240"), from_folder="./data/weight_365"
|
1105 |
+
# dataloader, dataset = dl_info["dataloader"], dl_info["dataset"]
|
1106 |
+
coef = configs["dataloader"]["test_dataset"]["coefficient"]
|
1107 |
+
stepsize = configs["dataloader"]["test_dataset"]["step_size"]
|
1108 |
+
sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"]
|
1109 |
+
seq_length = configs["dataloader"]["test_dataset"]["params"]["window"]
|
1110 |
+
feature_dim = 3
|
1111 |
+
print(f"seq_length: {seq_length}, feature_dim: {feature_dim}")
|
1112 |
+
|
1113 |
+
# Initialize your trainer, configs, and dataset here
|
1114 |
+
editor = TimeSeriesEditor(seq_length, feature_dim, trainer)
|
1115 |
+
editor.coef = coef
|
1116 |
+
editor.stepsize = stepsize
|
1117 |
+
editor.sampling_steps = sampling_steps
|
1118 |
+
|
1119 |
+
app = create_gradio_interface(editor)
|
1120 |
+
# app.launch(server_name="0.0.0.0", server_port=8888, share=True)
|
1121 |
+
app.launch(show_api=False)
|
config/backup/revenue-1.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 240
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 8
|
7 |
+
n_layer_dec: 5
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 1000 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l2'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 1385
|
23 |
+
results_folder: ../../../data/Checkpoints_revenue-1
|
24 |
+
gradient_accumulate_every: 1
|
25 |
+
save_cycle: 277 # max_epochs // 5
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 200
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 300
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
45 |
+
params:
|
46 |
+
name: revenue
|
47 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
48 |
+
# data_root: ./Data/datasets/stock_data.csv
|
49 |
+
data_root: ../../../data/daily.csv
|
50 |
+
window: 240 # seq_length
|
51 |
+
save2npy: True
|
52 |
+
neg_one_to_one: True
|
53 |
+
seed: 2024
|
54 |
+
period: train
|
55 |
+
|
56 |
+
test_dataset:
|
57 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
58 |
+
params:
|
59 |
+
name: revenue
|
60 |
+
proportion: 0.8 # rate
|
61 |
+
data_root: ../../../data/daily.csv
|
62 |
+
window: 240 # seq_length
|
63 |
+
save2npy: True
|
64 |
+
neg_one_to_one: True
|
65 |
+
seed: 2024
|
66 |
+
period: test
|
67 |
+
style: separate
|
68 |
+
distribution: geometric
|
69 |
+
|
70 |
+
coefficient: 1.0e-2
|
71 |
+
step_size: 5.0e-2
|
72 |
+
sampling_steps: 200
|
73 |
+
|
74 |
+
batch_size: 64
|
75 |
+
sample_size: 256
|
76 |
+
shuffle: True
|
config/backup/revenue-2.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 240
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 2770
|
23 |
+
results_folder: ../../../data/Checkpoints_revenue-2
|
24 |
+
gradient_accumulate_every: 1
|
25 |
+
save_cycle: 277 # max_epochs // 5
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 200
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 300
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
45 |
+
params:
|
46 |
+
name: revenue
|
47 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
48 |
+
# data_root: ./Data/datasets/stock_data.csv
|
49 |
+
data_root: ../../../data/daily.csv
|
50 |
+
window: 240 # seq_length
|
51 |
+
save2npy: True
|
52 |
+
neg_one_to_one: True
|
53 |
+
seed: 2024
|
54 |
+
period: train
|
55 |
+
|
56 |
+
test_dataset:
|
57 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
58 |
+
params:
|
59 |
+
name: revenue
|
60 |
+
proportion: 0.8 # rate
|
61 |
+
data_root: ../../../data/daily.csv
|
62 |
+
window: 240 # seq_length
|
63 |
+
save2npy: True
|
64 |
+
neg_one_to_one: True
|
65 |
+
seed: 2024
|
66 |
+
period: test
|
67 |
+
style: separate
|
68 |
+
# distribution: geometric
|
69 |
+
distribution: uniform
|
70 |
+
|
71 |
+
coefficient: 1.0e-2
|
72 |
+
step_size: 5.0e-2
|
73 |
+
sampling_steps: 200
|
74 |
+
|
75 |
+
batch_size: 64
|
76 |
+
sample_size: 256
|
77 |
+
shuffle: True
|
config/backup/revenue-3xl.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 480
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 8
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 256 # 4 X 16
|
9 |
+
timesteps: 2000 # diffusion timesteps
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 16
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 5540
|
23 |
+
results_folder: ../../../data/Checkpoints_revenue-3xl
|
24 |
+
gradient_accumulate_every: 1
|
25 |
+
save_cycle: 554 # max_epochs // 5
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 200
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 300
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
45 |
+
params:
|
46 |
+
name: revenue
|
47 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
48 |
+
# data_root: ./Data/datasets/stock_data.csv
|
49 |
+
data_root: ../../../data/daily.csv
|
50 |
+
window: 480 # seq_length
|
51 |
+
save2npy: True
|
52 |
+
neg_one_to_one: True
|
53 |
+
seed: 2024
|
54 |
+
period: train
|
55 |
+
|
56 |
+
test_dataset:
|
57 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
58 |
+
params:
|
59 |
+
name: revenue
|
60 |
+
proportion: 0.8 # rate
|
61 |
+
data_root: ../../../data/daily.csv
|
62 |
+
window: 480 # seq_length
|
63 |
+
save2npy: True
|
64 |
+
neg_one_to_one: True
|
65 |
+
seed: 2024
|
66 |
+
period: test
|
67 |
+
style: separate
|
68 |
+
# distribution: geometric
|
69 |
+
distribution: uniform
|
70 |
+
|
71 |
+
coefficient: 1.0e-2
|
72 |
+
step_size: 5.0e-2
|
73 |
+
sampling_steps: 200
|
74 |
+
|
75 |
+
batch_size: 64
|
76 |
+
sample_size: 256
|
77 |
+
shuffle: True
|
config/backup/revenue-baseline.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 240
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 240 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 240 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/backup/revenue-test.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 240
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 240 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 240 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 500
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/backup/revenue.yaml
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 240
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 5
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 2770
|
23 |
+
results_folder: ../../../data/Checkpoints_revenue
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 554 # max_epochs // 5
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 2000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 300
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
45 |
+
params:
|
46 |
+
name: revenue
|
47 |
+
proportion: 0.7 # Set to rate < 1 if training conditional generation
|
48 |
+
# data_root: ./Data/datasets/stock_data.csv
|
49 |
+
data_root: ../../../data/daily.csv
|
50 |
+
window: 240 # seq_length
|
51 |
+
save2npy: True
|
52 |
+
neg_one_to_one: True
|
53 |
+
seed: 2024
|
54 |
+
period: train
|
55 |
+
|
56 |
+
test_dataset:
|
57 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
58 |
+
params:
|
59 |
+
name: revenue
|
60 |
+
proportion: 0.9 # rate
|
61 |
+
data_root: ../../../data/daily.csv
|
62 |
+
window: 240 # seq_length
|
63 |
+
save2npy: True
|
64 |
+
neg_one_to_one: True
|
65 |
+
seed: 123
|
66 |
+
period: test
|
67 |
+
style: separate
|
68 |
+
distribution: geometric
|
69 |
+
|
70 |
+
coefficient: 1.0e-2
|
71 |
+
step_size: 5.0e-2
|
72 |
+
sampling_steps: 200
|
73 |
+
|
74 |
+
batch_size: 64
|
75 |
+
sample_size: 256
|
76 |
+
shuffle: True
|
config/config.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 160
|
5 |
+
feature_size: 5
|
6 |
+
n_layer_enc: 1
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 200
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.1
|
16 |
+
resid_pd: 0.1
|
17 |
+
kernel_size: 5
|
18 |
+
padding_size: 2
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 1000
|
23 |
+
results_folder: ./Checkpoints_syn
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 100 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.99
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 200
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 100
|
40 |
+
verbose: False
|
config/control/revenue-baseline-180.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 180
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 180 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 180 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/control/revenue-baseline-365-ma.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 365
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
moving_average: True
|
21 |
+
# - classifier-based-sum-control
|
22 |
+
# - classifier-free-sum-control
|
23 |
+
# - range-wise-peak-control
|
24 |
+
|
25 |
+
solver:
|
26 |
+
base_lr: 2.0e-5
|
27 |
+
max_epochs: 2230 # 11150
|
28 |
+
results_folder: ../../../data/ckpt_ma
|
29 |
+
gradient_accumulate_every: 2
|
30 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
31 |
+
ema:
|
32 |
+
decay: 0.995
|
33 |
+
update_interval: 10
|
34 |
+
|
35 |
+
scheduler:
|
36 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
37 |
+
params:
|
38 |
+
factor: 0.65
|
39 |
+
patience: 200
|
40 |
+
min_lr: 1.0e-5
|
41 |
+
threshold: 1.0e-1
|
42 |
+
threshold_mode: rel
|
43 |
+
warmup_lr: 8.0e-4
|
44 |
+
warmup: 100
|
45 |
+
verbose: False
|
46 |
+
|
47 |
+
dataloader:
|
48 |
+
train_dataset:
|
49 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
50 |
+
params:
|
51 |
+
name: revenue
|
52 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
53 |
+
# data_root: ./Data/datasets/stock_data.csv
|
54 |
+
data_root: ../../../data/daily.csv
|
55 |
+
window: 365 # seq_length
|
56 |
+
save2npy: True
|
57 |
+
neg_one_to_one: True
|
58 |
+
seed: 2024
|
59 |
+
period: train
|
60 |
+
|
61 |
+
test_dataset:
|
62 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
63 |
+
params:
|
64 |
+
name: revenue
|
65 |
+
proportion: 0.8 # rate
|
66 |
+
data_root: ../../../data/daily.csv
|
67 |
+
window: 365 # seq_length
|
68 |
+
save2npy: True
|
69 |
+
neg_one_to_one: True
|
70 |
+
seed: 2024
|
71 |
+
period: test
|
72 |
+
style: separate
|
73 |
+
# distribution: geometric
|
74 |
+
distribution: uniform
|
75 |
+
missing_ratio: 0.5
|
76 |
+
|
77 |
+
coefficient: 1.0e-2
|
78 |
+
step_size: 5.0e-2
|
79 |
+
sampling_steps: 100
|
80 |
+
|
81 |
+
batch_size: 64
|
82 |
+
sample_size: 256
|
83 |
+
shuffle: True
|
config/control/revenue-baseline-365.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 365
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 365 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 365 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/control/revenue-baseline-sine.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 240
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 1.0e-5
|
26 |
+
max_epochs: 223 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline_sine
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 20
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 10
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 240 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 240 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/control/revenue-extend.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 240
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
# - regression-based-sum-control
|
24 |
+
|
25 |
+
solver:
|
26 |
+
base_lr: 2.0e-5
|
27 |
+
max_epochs: 11150
|
28 |
+
results_folder: ../../../data/ckpt_baseline_extend
|
29 |
+
gradient_accumulate_every: 2
|
30 |
+
save_cycle: 1115 # max_epochs // 5
|
31 |
+
ema:
|
32 |
+
decay: 0.995
|
33 |
+
update_interval: 10
|
34 |
+
|
35 |
+
scheduler:
|
36 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
37 |
+
params:
|
38 |
+
factor: 0.65
|
39 |
+
patience: 200
|
40 |
+
min_lr: 1.0e-5
|
41 |
+
threshold: 1.0e-1
|
42 |
+
threshold_mode: rel
|
43 |
+
warmup_lr: 8.0e-4
|
44 |
+
warmup: 100
|
45 |
+
verbose: False
|
46 |
+
|
47 |
+
dataloader:
|
48 |
+
train_dataset:
|
49 |
+
target: utils.data_utils.real_datasets.ControlRevenueDataset
|
50 |
+
params:
|
51 |
+
name: revenue
|
52 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
53 |
+
# data_root: ./Data/datasets/stock_data.csv
|
54 |
+
data_root: ../../../data/daily.csv
|
55 |
+
window: 240 # seq_length
|
56 |
+
save2npy: True
|
57 |
+
neg_one_to_one: True
|
58 |
+
seed: 2024
|
59 |
+
period: train
|
60 |
+
|
61 |
+
test_dataset:
|
62 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
63 |
+
params:
|
64 |
+
name: revenue
|
65 |
+
proportion: 0.8 # rate
|
66 |
+
data_root: ../../../data/daily.csv
|
67 |
+
window: 240 # seq_length
|
68 |
+
save2npy: True
|
69 |
+
neg_one_to_one: True
|
70 |
+
seed: 2024
|
71 |
+
period: test
|
72 |
+
style: separate
|
73 |
+
# distribution: geometric
|
74 |
+
distribution: uniform
|
75 |
+
missing_ratio: 0.5
|
76 |
+
|
77 |
+
coefficient: 1.0e-2
|
78 |
+
step_size: 5.0e-2
|
79 |
+
sampling_steps: 100
|
80 |
+
|
81 |
+
batch_size: 64
|
82 |
+
sample_size: 256
|
83 |
+
shuffle: True
|
config/csdi/energy.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.CSDI.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 28
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 3
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
|
21 |
+
solver:
|
22 |
+
base_lr: 1.0e-3
|
23 |
+
max_epochs: 25000
|
24 |
+
results_folder: ../../../data/CSDI/Checkpoints_energy
|
25 |
+
gradient_accumulate_every: 2
|
26 |
+
save_cycle: 2500 # max_epochs // 10
|
27 |
+
ema:
|
28 |
+
decay: 0.995
|
29 |
+
update_interval: 10
|
30 |
+
|
31 |
+
scheduler:
|
32 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
33 |
+
params:
|
34 |
+
factor: 0.5
|
35 |
+
patience: 5000
|
36 |
+
min_lr: 1.0e-5
|
37 |
+
threshold: 1.0e-1
|
38 |
+
threshold_mode: rel
|
39 |
+
warmup_lr: 8.0e-4
|
40 |
+
warmup: 500
|
41 |
+
verbose: False
|
42 |
+
|
43 |
+
dataloader:
|
44 |
+
train_dataset:
|
45 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
46 |
+
params:
|
47 |
+
name: energy
|
48 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
49 |
+
data_root: ./data/energy_data.csv
|
50 |
+
window: 24 # seq_length
|
51 |
+
save2npy: True
|
52 |
+
neg_one_to_one: True
|
53 |
+
seed: 123
|
54 |
+
period: train
|
55 |
+
|
56 |
+
test_dataset:
|
57 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
58 |
+
params:
|
59 |
+
name: energy
|
60 |
+
proportion: 0.9 # rate
|
61 |
+
data_root: ./data/energy_data.csv
|
62 |
+
window: 24 # seq_length
|
63 |
+
save2npy: True
|
64 |
+
neg_one_to_one: True
|
65 |
+
seed: 123
|
66 |
+
period: test
|
67 |
+
style: separate
|
68 |
+
distribution: geometric
|
69 |
+
coefficient: 1.0e-2
|
70 |
+
step_size: 5.0e-2
|
71 |
+
sampling_steps: 250
|
72 |
+
|
73 |
+
batch_size: 64
|
74 |
+
sample_size: 256
|
75 |
+
shuffle: True
|
config/csdi/fmri.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.CSDI.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 50
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 5
|
18 |
+
padding_size: 2
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-3
|
22 |
+
max_epochs: 15000
|
23 |
+
results_folder: ../../../data/CSDI/Checkpoints_fmri
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
45 |
+
params:
|
46 |
+
name: fMRI
|
47 |
+
proportion: 0.9 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/fMRI
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
57 |
+
params:
|
58 |
+
name: fMRI
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/fMRI
|
61 |
+
window: 24 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/csdi/revenue-baseline-365.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.CSDI.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 365
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-3
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/CSDI/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 365 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 365 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/csdi/sines.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.CSDI.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 5
|
6 |
+
n_layer_enc: 1
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
|
21 |
+
solver:
|
22 |
+
base_lr: 1.0e-3
|
23 |
+
max_epochs: 12000
|
24 |
+
results_folder: ../../../data/CSDI/Checkpoints_sine
|
25 |
+
gradient_accumulate_every: 2
|
26 |
+
save_cycle: 1200 # max_epochs // 10
|
27 |
+
ema:
|
28 |
+
decay: 0.995
|
29 |
+
update_interval: 10
|
30 |
+
|
31 |
+
scheduler:
|
32 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
33 |
+
params:
|
34 |
+
factor: 0.5
|
35 |
+
patience: 3000
|
36 |
+
min_lr: 1.0e-5
|
37 |
+
threshold: 1.0e-1
|
38 |
+
threshold_mode: rel
|
39 |
+
warmup_lr: 8.0e-4
|
40 |
+
warmup: 500
|
41 |
+
verbose: False
|
42 |
+
|
43 |
+
dataloader:
|
44 |
+
train_dataset:
|
45 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
46 |
+
params:
|
47 |
+
num: 10000
|
48 |
+
dim: 5
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
57 |
+
params:
|
58 |
+
num: 1000
|
59 |
+
dim: 5
|
60 |
+
window: 24 # seq_length
|
61 |
+
save2npy: True
|
62 |
+
neg_one_to_one: True
|
63 |
+
seed: 123
|
64 |
+
style: separate
|
65 |
+
period: test
|
66 |
+
distribution: geometric
|
67 |
+
coefficient: 1.0e-2
|
68 |
+
step_size: 5.0e-2
|
69 |
+
sampling_steps: 200
|
70 |
+
|
71 |
+
batch_size: 128
|
72 |
+
sample_size: 256
|
73 |
+
shuffle: True
|
config/energy.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 28
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 3
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 25000
|
23 |
+
results_folder: ./Checkpoints_energy
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 2500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 5000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
45 |
+
params:
|
46 |
+
name: energy
|
47 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./Data/datasets/energy_data.csv
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
57 |
+
params:
|
58 |
+
name: energy
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./Data/datasets/energy_data.csv
|
61 |
+
window: 24 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/etth.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 7
|
6 |
+
n_layer_enc: 3
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 9000
|
23 |
+
results_folder: ./Checkpoints_etth
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1800 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 4000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
45 |
+
params:
|
46 |
+
name: etth
|
47 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./Data/datasets/ETTh.csv
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
57 |
+
params:
|
58 |
+
name: etth
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./Data/datasets/ETTh.csv
|
61 |
+
window: 24 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 200
|
71 |
+
|
72 |
+
batch_size: 128
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/fmri.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 50
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 5
|
18 |
+
padding_size: 2
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 15000
|
23 |
+
results_folder: ./Checkpoints_fmri
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
45 |
+
params:
|
46 |
+
name: fMRI
|
47 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./Data/datasets/fMRI
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
57 |
+
params:
|
58 |
+
name: fMRI
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./Data/datasets/fMRI
|
61 |
+
window: 24 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/192/energy.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 192
|
5 |
+
feature_size: 28
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 3
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 25000
|
23 |
+
results_folder: ../../../data/Checkpoints_energy
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 2500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 5000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
45 |
+
params:
|
46 |
+
name: energy
|
47 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/energy_data.csv
|
49 |
+
window: 192 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
57 |
+
params:
|
58 |
+
name: energy
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/energy_data.csv
|
61 |
+
window: 192 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/192/fmri.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 192
|
5 |
+
feature_size: 50
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 5
|
18 |
+
padding_size: 2
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 15000
|
23 |
+
results_folder: ../../../data/Checkpoints_fmri
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
45 |
+
params:
|
46 |
+
name: fMRI
|
47 |
+
proportion: 0.9 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/fMRI
|
49 |
+
window: 192 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
57 |
+
params:
|
58 |
+
name: fMRI
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/fMRI
|
61 |
+
window: 192 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/192/revenue.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 192
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 192 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 192 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/modified/192/sines.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 192
|
5 |
+
feature_size: 5
|
6 |
+
n_layer_enc: 1
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
|
21 |
+
solver:
|
22 |
+
base_lr: 1.0e-5
|
23 |
+
max_epochs: 12000
|
24 |
+
results_folder: ../../../data/Checkpoints_sine
|
25 |
+
gradient_accumulate_every: 2
|
26 |
+
save_cycle: 1200 # max_epochs // 10
|
27 |
+
ema:
|
28 |
+
decay: 0.995
|
29 |
+
update_interval: 10
|
30 |
+
|
31 |
+
scheduler:
|
32 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
33 |
+
params:
|
34 |
+
factor: 0.5
|
35 |
+
patience: 3000
|
36 |
+
min_lr: 1.0e-5
|
37 |
+
threshold: 1.0e-1
|
38 |
+
threshold_mode: rel
|
39 |
+
warmup_lr: 8.0e-4
|
40 |
+
warmup: 500
|
41 |
+
verbose: False
|
42 |
+
|
43 |
+
dataloader:
|
44 |
+
train_dataset:
|
45 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
46 |
+
params:
|
47 |
+
num: 10000
|
48 |
+
dim: 5
|
49 |
+
window: 192 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
57 |
+
params:
|
58 |
+
num: 1000
|
59 |
+
dim: 5
|
60 |
+
window: 192 # seq_length
|
61 |
+
save2npy: True
|
62 |
+
neg_one_to_one: True
|
63 |
+
seed: 123
|
64 |
+
style: separate
|
65 |
+
period: test
|
66 |
+
distribution: geometric
|
67 |
+
coefficient: 1.0e-2
|
68 |
+
step_size: 5.0e-2
|
69 |
+
sampling_steps: 200
|
70 |
+
|
71 |
+
batch_size: 128
|
72 |
+
sample_size: 256
|
73 |
+
shuffle: True
|
config/modified/384/energy.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 384
|
5 |
+
feature_size: 28
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 3
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 25000
|
23 |
+
results_folder: ../../../data/Checkpoints_energy
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 2500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 5000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
45 |
+
params:
|
46 |
+
name: energy
|
47 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/energy_data.csv
|
49 |
+
window: 384 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
57 |
+
params:
|
58 |
+
name: energy
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/energy_data.csv
|
61 |
+
window: 384 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/384/fmri.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 384
|
5 |
+
feature_size: 50
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 5
|
18 |
+
padding_size: 2
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 15000
|
23 |
+
results_folder: ../../../data/Checkpoints_fmri
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
45 |
+
params:
|
46 |
+
name: fMRI
|
47 |
+
proportion: 0.9 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/fMRI
|
49 |
+
window: 384 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
57 |
+
params:
|
58 |
+
name: fMRI
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/fMRI
|
61 |
+
window: 384 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/384/revenue.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 384
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 384 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 384 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/modified/384/sines.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 384
|
5 |
+
feature_size: 5
|
6 |
+
n_layer_enc: 1
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
|
21 |
+
solver:
|
22 |
+
base_lr: 1.0e-5
|
23 |
+
max_epochs: 12000
|
24 |
+
results_folder: ../../../data/Checkpoints_sine
|
25 |
+
gradient_accumulate_every: 2
|
26 |
+
save_cycle: 1200 # max_epochs // 10
|
27 |
+
ema:
|
28 |
+
decay: 0.995
|
29 |
+
update_interval: 10
|
30 |
+
|
31 |
+
scheduler:
|
32 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
33 |
+
params:
|
34 |
+
factor: 0.5
|
35 |
+
patience: 3000
|
36 |
+
min_lr: 1.0e-5
|
37 |
+
threshold: 1.0e-1
|
38 |
+
threshold_mode: rel
|
39 |
+
warmup_lr: 8.0e-4
|
40 |
+
warmup: 500
|
41 |
+
verbose: False
|
42 |
+
|
43 |
+
dataloader:
|
44 |
+
train_dataset:
|
45 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
46 |
+
params:
|
47 |
+
num: 10000
|
48 |
+
dim: 5
|
49 |
+
window: 384 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
57 |
+
params:
|
58 |
+
num: 1000
|
59 |
+
dim: 5
|
60 |
+
window: 384 # seq_length
|
61 |
+
save2npy: True
|
62 |
+
neg_one_to_one: True
|
63 |
+
seed: 123
|
64 |
+
style: separate
|
65 |
+
period: test
|
66 |
+
distribution: geometric
|
67 |
+
coefficient: 1.0e-2
|
68 |
+
step_size: 5.0e-2
|
69 |
+
sampling_steps: 200
|
70 |
+
|
71 |
+
batch_size: 128
|
72 |
+
sample_size: 256
|
73 |
+
shuffle: True
|
config/modified/96/energy.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 96
|
5 |
+
feature_size: 28
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 3
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 25000
|
23 |
+
results_folder: ../../../data/Checkpoints_energy
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 2500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 5000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
45 |
+
params:
|
46 |
+
name: energy
|
47 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/energy_data.csv
|
49 |
+
window: 96 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
57 |
+
params:
|
58 |
+
name: energy
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/energy_data.csv
|
61 |
+
window: 96 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/96/fmri.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 96
|
5 |
+
feature_size: 50
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 5
|
18 |
+
padding_size: 2
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 15000
|
23 |
+
results_folder: ../../../data/Checkpoints_fmri
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
45 |
+
params:
|
46 |
+
name: fMRI
|
47 |
+
proportion: 0.9 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/fMRI
|
49 |
+
window: 96 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
57 |
+
params:
|
58 |
+
name: fMRI
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/fMRI
|
61 |
+
window: 96 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/96/revenue.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 96
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 96 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 96 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/modified/96/sines.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 96
|
5 |
+
feature_size: 5
|
6 |
+
n_layer_enc: 1
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
|
21 |
+
solver:
|
22 |
+
base_lr: 1.0e-5
|
23 |
+
max_epochs: 12000
|
24 |
+
results_folder: ../../../data/Checkpoints_sine
|
25 |
+
gradient_accumulate_every: 2
|
26 |
+
save_cycle: 1200 # max_epochs // 10
|
27 |
+
ema:
|
28 |
+
decay: 0.995
|
29 |
+
update_interval: 10
|
30 |
+
|
31 |
+
scheduler:
|
32 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
33 |
+
params:
|
34 |
+
factor: 0.5
|
35 |
+
patience: 3000
|
36 |
+
min_lr: 1.0e-5
|
37 |
+
threshold: 1.0e-1
|
38 |
+
threshold_mode: rel
|
39 |
+
warmup_lr: 8.0e-4
|
40 |
+
warmup: 500
|
41 |
+
verbose: False
|
42 |
+
|
43 |
+
dataloader:
|
44 |
+
train_dataset:
|
45 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
46 |
+
params:
|
47 |
+
num: 10000
|
48 |
+
dim: 5
|
49 |
+
window: 96 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
57 |
+
params:
|
58 |
+
num: 1000
|
59 |
+
dim: 5
|
60 |
+
window: 96 # seq_length
|
61 |
+
save2npy: True
|
62 |
+
neg_one_to_one: True
|
63 |
+
seed: 123
|
64 |
+
style: separate
|
65 |
+
period: test
|
66 |
+
distribution: geometric
|
67 |
+
coefficient: 1.0e-2
|
68 |
+
step_size: 5.0e-2
|
69 |
+
sampling_steps: 200
|
70 |
+
|
71 |
+
batch_size: 128
|
72 |
+
sample_size: 256
|
73 |
+
shuffle: True
|
config/modified/energy.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 28
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 3
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 25000
|
23 |
+
results_folder: ../../../data/Checkpoints_energy
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 2500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 5000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
45 |
+
params:
|
46 |
+
name: energy
|
47 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/energy_data.csv
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
57 |
+
params:
|
58 |
+
name: energy
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/energy_data.csv
|
61 |
+
window: 24 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/fmri.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 50
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 5
|
18 |
+
padding_size: 2
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 15000
|
23 |
+
results_folder: ../../../data/Checkpoints_fmri
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1500 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
45 |
+
params:
|
46 |
+
name: fMRI
|
47 |
+
proportion: 0.9 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./data/fMRI
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.fMRIDataset
|
57 |
+
params:
|
58 |
+
name: fMRI
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./data/fMRI
|
61 |
+
window: 24 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 250
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
config/modified/revenue-baseline-365.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 365
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 365 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 365 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/modified/revenue.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 3
|
6 |
+
n_layer_enc: 6
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 128 # 4 X 16
|
9 |
+
timesteps: 500 # diffusion timesteps
|
10 |
+
sampling_timesteps: 200
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 8
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
# - classifier-based-sum-control
|
21 |
+
# - classifier-free-sum-control
|
22 |
+
# - range-wise-peak-control
|
23 |
+
|
24 |
+
solver:
|
25 |
+
base_lr: 2.0e-5
|
26 |
+
max_epochs: 2230 # 11150
|
27 |
+
results_folder: ../../../data/ckpt_baseline
|
28 |
+
gradient_accumulate_every: 2
|
29 |
+
save_cycle: 223 # 1115 # max_epochs // 5
|
30 |
+
ema:
|
31 |
+
decay: 0.995
|
32 |
+
update_interval: 10
|
33 |
+
|
34 |
+
scheduler:
|
35 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
36 |
+
params:
|
37 |
+
factor: 0.65
|
38 |
+
patience: 200
|
39 |
+
min_lr: 1.0e-5
|
40 |
+
threshold: 1.0e-1
|
41 |
+
threshold_mode: rel
|
42 |
+
warmup_lr: 8.0e-4
|
43 |
+
warmup: 100
|
44 |
+
verbose: False
|
45 |
+
|
46 |
+
dataloader:
|
47 |
+
train_dataset:
|
48 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
49 |
+
params:
|
50 |
+
name: revenue
|
51 |
+
proportion: 0.8 # Set to rate < 1 if training conditional generation
|
52 |
+
# data_root: ./Data/datasets/stock_data.csv
|
53 |
+
data_root: ../../../data/daily.csv
|
54 |
+
window: 365 # seq_length
|
55 |
+
save2npy: True
|
56 |
+
neg_one_to_one: True
|
57 |
+
seed: 2024
|
58 |
+
period: train
|
59 |
+
|
60 |
+
test_dataset:
|
61 |
+
target: utils.data_utils.real_datasets.RevenueDataset
|
62 |
+
params:
|
63 |
+
name: revenue
|
64 |
+
proportion: 0.8 # rate
|
65 |
+
data_root: ../../../data/daily.csv
|
66 |
+
window: 365 # seq_length
|
67 |
+
save2npy: True
|
68 |
+
neg_one_to_one: True
|
69 |
+
seed: 2024
|
70 |
+
period: test
|
71 |
+
style: separate
|
72 |
+
# distribution: geometric
|
73 |
+
distribution: uniform
|
74 |
+
missing_ratio: 0.5
|
75 |
+
|
76 |
+
coefficient: 1.0e-2
|
77 |
+
step_size: 5.0e-2
|
78 |
+
sampling_steps: 100
|
79 |
+
|
80 |
+
batch_size: 64
|
81 |
+
sample_size: 256
|
82 |
+
shuffle: True
|
config/modified/sines.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: models.Tiffusion.tiffusion.Tiffusion
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 5
|
6 |
+
n_layer_enc: 1
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
control_signal: []
|
20 |
+
|
21 |
+
solver:
|
22 |
+
base_lr: 1.0e-5
|
23 |
+
max_epochs: 12000
|
24 |
+
results_folder: ../../../data/Checkpoints_sine
|
25 |
+
gradient_accumulate_every: 2
|
26 |
+
save_cycle: 1200 # max_epochs // 10
|
27 |
+
ema:
|
28 |
+
decay: 0.995
|
29 |
+
update_interval: 10
|
30 |
+
|
31 |
+
scheduler:
|
32 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
33 |
+
params:
|
34 |
+
factor: 0.5
|
35 |
+
patience: 3000
|
36 |
+
min_lr: 1.0e-5
|
37 |
+
threshold: 1.0e-1
|
38 |
+
threshold_mode: rel
|
39 |
+
warmup_lr: 8.0e-4
|
40 |
+
warmup: 500
|
41 |
+
verbose: False
|
42 |
+
|
43 |
+
dataloader:
|
44 |
+
train_dataset:
|
45 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
46 |
+
params:
|
47 |
+
num: 10000
|
48 |
+
dim: 5
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
57 |
+
params:
|
58 |
+
num: 1000
|
59 |
+
dim: 5
|
60 |
+
window: 24 # seq_length
|
61 |
+
save2npy: True
|
62 |
+
neg_one_to_one: True
|
63 |
+
seed: 123
|
64 |
+
style: separate
|
65 |
+
period: test
|
66 |
+
distribution: geometric
|
67 |
+
coefficient: 1.0e-2
|
68 |
+
step_size: 5.0e-2
|
69 |
+
sampling_steps: 200
|
70 |
+
|
71 |
+
batch_size: 128
|
72 |
+
sample_size: 256
|
73 |
+
shuffle: True
|
config/mujoco.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 14
|
6 |
+
n_layer_enc: 3
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 1000
|
10 |
+
sampling_timesteps: 1000
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 14000
|
23 |
+
results_folder: ./Checkpoints_mujoco
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1400 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.mujoco_dataset.MuJoCoDataset
|
45 |
+
params:
|
46 |
+
num: 10000
|
47 |
+
dim: 14
|
48 |
+
window: 24 # seq_length
|
49 |
+
save2npy: True
|
50 |
+
neg_one_to_one: True
|
51 |
+
seed: 123
|
52 |
+
period: train
|
53 |
+
|
54 |
+
test_dataset:
|
55 |
+
target: utils.data_utils.mujoco_dataset.MuJoCoDataset
|
56 |
+
params:
|
57 |
+
num: 1000
|
58 |
+
dim: 14
|
59 |
+
window: 24 # seq_length
|
60 |
+
save2npy: True
|
61 |
+
neg_one_to_one: True
|
62 |
+
seed: 123
|
63 |
+
style: separate
|
64 |
+
period: test
|
65 |
+
distribution: geometric
|
66 |
+
coefficient: 1.0e-2
|
67 |
+
step_size: 5.0e-2
|
68 |
+
sampling_steps: 250
|
69 |
+
|
70 |
+
batch_size: 128
|
71 |
+
sample_size: 256
|
72 |
+
shuffle: True
|
config/mujoco_sssd.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 100
|
5 |
+
feature_size: 14
|
6 |
+
n_layer_enc: 3
|
7 |
+
n_layer_dec: 3
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0
|
16 |
+
resid_pd: 0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 12000
|
23 |
+
results_folder: ./Checkpoints_mujoco_sssd
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1200 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
config/sines.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 5
|
6 |
+
n_layer_enc: 1
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 12000
|
23 |
+
results_folder: ./Checkpoints_sine
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1200 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 3000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
45 |
+
params:
|
46 |
+
num: 10000
|
47 |
+
dim: 5
|
48 |
+
window: 24 # seq_length
|
49 |
+
save2npy: True
|
50 |
+
neg_one_to_one: True
|
51 |
+
seed: 123
|
52 |
+
period: train
|
53 |
+
|
54 |
+
test_dataset:
|
55 |
+
target: utils.data_utils.sine_dataset.SineDataset
|
56 |
+
params:
|
57 |
+
num: 1000
|
58 |
+
dim: 5
|
59 |
+
window: 24 # seq_length
|
60 |
+
save2npy: True
|
61 |
+
neg_one_to_one: True
|
62 |
+
seed: 123
|
63 |
+
style: separate
|
64 |
+
period: test
|
65 |
+
distribution: geometric
|
66 |
+
coefficient: 1.0e-2
|
67 |
+
step_size: 5.0e-2
|
68 |
+
sampling_steps: 200
|
69 |
+
|
70 |
+
batch_size: 128
|
71 |
+
sample_size: 256
|
72 |
+
shuffle: True
|
config/solar.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 192
|
5 |
+
feature_size: 128
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 1500
|
23 |
+
results_folder: ./Checkpoints_solar
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 150 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 300
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 100
|
40 |
+
verbose: False
|
config/solar_update.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 192
|
5 |
+
feature_size: 137
|
6 |
+
n_layer_enc: 4
|
7 |
+
n_layer_dec: 4
|
8 |
+
d_model: 96 # 4 X 24
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.5
|
16 |
+
resid_pd: 0.5
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 1000
|
23 |
+
results_folder: ./Checkpoints_solar_nips
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 100 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.9
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 300
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 100
|
40 |
+
verbose: False
|
config/stocks.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
|
3 |
+
params:
|
4 |
+
seq_length: 24
|
5 |
+
feature_size: 6
|
6 |
+
n_layer_enc: 2
|
7 |
+
n_layer_dec: 2
|
8 |
+
d_model: 64 # 4 X 16
|
9 |
+
timesteps: 500
|
10 |
+
sampling_timesteps: 500
|
11 |
+
loss_type: 'l1'
|
12 |
+
beta_schedule: 'cosine'
|
13 |
+
n_heads: 4
|
14 |
+
mlp_hidden_times: 4
|
15 |
+
attn_pd: 0.0
|
16 |
+
resid_pd: 0.0
|
17 |
+
kernel_size: 1
|
18 |
+
padding_size: 0
|
19 |
+
|
20 |
+
solver:
|
21 |
+
base_lr: 1.0e-5
|
22 |
+
max_epochs: 10000
|
23 |
+
results_folder: ./Checkpoints_stock
|
24 |
+
gradient_accumulate_every: 2
|
25 |
+
save_cycle: 1000 # max_epochs // 10
|
26 |
+
ema:
|
27 |
+
decay: 0.995
|
28 |
+
update_interval: 10
|
29 |
+
|
30 |
+
scheduler:
|
31 |
+
target: engine.lr_sch.ReduceLROnPlateauWithWarmup
|
32 |
+
params:
|
33 |
+
factor: 0.5
|
34 |
+
patience: 2000
|
35 |
+
min_lr: 1.0e-5
|
36 |
+
threshold: 1.0e-1
|
37 |
+
threshold_mode: rel
|
38 |
+
warmup_lr: 8.0e-4
|
39 |
+
warmup: 500
|
40 |
+
verbose: False
|
41 |
+
|
42 |
+
dataloader:
|
43 |
+
train_dataset:
|
44 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
45 |
+
params:
|
46 |
+
name: stock
|
47 |
+
proportion: 1.0 # Set to rate < 1 if training conditional generation
|
48 |
+
data_root: ./Data/datasets/stock_data.csv
|
49 |
+
window: 24 # seq_length
|
50 |
+
save2npy: True
|
51 |
+
neg_one_to_one: True
|
52 |
+
seed: 123
|
53 |
+
period: train
|
54 |
+
|
55 |
+
test_dataset:
|
56 |
+
target: utils.data_utils.real_datasets.CustomDataset
|
57 |
+
params:
|
58 |
+
name: stock
|
59 |
+
proportion: 0.9 # rate
|
60 |
+
data_root: ./Data/datasets/stock_data.csv
|
61 |
+
window: 24 # seq_length
|
62 |
+
save2npy: True
|
63 |
+
neg_one_to_one: True
|
64 |
+
seed: 123
|
65 |
+
period: test
|
66 |
+
style: separate
|
67 |
+
distribution: geometric
|
68 |
+
coefficient: 1.0e-2
|
69 |
+
step_size: 5.0e-2
|
70 |
+
sampling_steps: 200
|
71 |
+
|
72 |
+
batch_size: 64
|
73 |
+
sample_size: 256
|
74 |
+
shuffle: True
|
efficiency.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
|
5 |
+
os.environ["WANDB_ENABLED"] = "false"
|
6 |
+
|
7 |
+
from engine.solver import Trainer
|
8 |
+
from data.build_dataloader import build_dataloader
|
9 |
+
from data.build_dataloader import build_dataloader_cond
|
10 |
+
|
11 |
+
from utils.io_utils import load_yaml_config, instantiate_from_config
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
warnings.simplefilter("ignore", UserWarning)
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import pickle
|
19 |
+
from pathlib import Path
|
20 |
+
|
21 |
+
|
22 |
+
def load_cached_results(cache_dir):
|
23 |
+
results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}}
|
24 |
+
for cache_file in cache_dir.glob("*.pkl"):
|
25 |
+
with open(cache_file, "rb") as f:
|
26 |
+
key = cache_file.stem
|
27 |
+
# if key=="unconditional":
|
28 |
+
# continue
|
29 |
+
if key == "unconditional":
|
30 |
+
results["unconditional"] = pickle.load(f)
|
31 |
+
elif key.startswith("sum_"):
|
32 |
+
param = key[4:] # Remove 'sum_' prefix
|
33 |
+
results["sum_controlled"][param] = pickle.load(f)
|
34 |
+
elif key.startswith("anchor_"):
|
35 |
+
param = key[7:] # Remove 'anchor_' prefix
|
36 |
+
results["anchor_controlled"][param] = pickle.load(f)
|
37 |
+
return results
|
38 |
+
|
39 |
+
|
40 |
+
def save_result(cache_dir, key, subkey, data):
|
41 |
+
return
|
42 |
+
|
43 |
+
if subkey:
|
44 |
+
filename = f"{key}_{subkey}.pkl"
|
45 |
+
else:
|
46 |
+
filename = f"{key}.pkl"
|
47 |
+
with open(cache_dir / filename, "wb") as f:
|
48 |
+
pickle.dump(data, f)
|
49 |
+
|
50 |
+
|
51 |
+
class Arguments:
|
52 |
+
def __init__(self, config_path, gpu=0) -> None:
|
53 |
+
self.config_path = config_path
|
54 |
+
# self.config_path = "./config/control/revenue-baseline-sine.yaml"
|
55 |
+
self.save_dir = (
|
56 |
+
"../../../data/" + os.path.basename(self.config_path).split(".")[0]
|
57 |
+
)
|
58 |
+
self.gpu = gpu
|
59 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
60 |
+
|
61 |
+
self.mode = "infill"
|
62 |
+
self.missing_ratio = 0.95
|
63 |
+
self.milestone = 10
|
64 |
+
|
65 |
+
|
66 |
+
import argparse
|
67 |
+
|
68 |
+
|
69 |
+
def parse_args():
|
70 |
+
parser = argparse.ArgumentParser(description="Controlled Sampling")
|
71 |
+
parser.add_argument(
|
72 |
+
"--config_path", type=str, default="./config/modified/energy.yaml"
|
73 |
+
)
|
74 |
+
parser.add_argument("--gpu", type=int, default=0)
|
75 |
+
return parser.parse_args()
|
76 |
+
|
77 |
+
|
78 |
+
def run(run_args):
|
79 |
+
|
80 |
+
args = Arguments(run_args.config_path, run_args.gpu)
|
81 |
+
configs = load_yaml_config(args.config_path)
|
82 |
+
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
83 |
+
torch.cuda.set_device(args.gpu)
|
84 |
+
|
85 |
+
dl_info = build_dataloader(configs, args)
|
86 |
+
model = instantiate_from_config(configs["model"]).to(device)
|
87 |
+
trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info)
|
88 |
+
# args.milestone
|
89 |
+
trainer.load("10")
|
90 |
+
dataset = dl_info["dataset"]
|
91 |
+
test_dl_info = build_dataloader_cond(configs, args)
|
92 |
+
test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"]
|
93 |
+
coef = configs["dataloader"]["test_dataset"]["coefficient"]
|
94 |
+
stepsize = configs["dataloader"]["test_dataset"]["step_size"]
|
95 |
+
sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"]
|
96 |
+
seq_length, feature_dim = test_dataset.window, test_dataset.var_num
|
97 |
+
dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0]
|
98 |
+
mapper = {
|
99 |
+
"sines": "sines",
|
100 |
+
"revenue": "revenue",
|
101 |
+
"energy": "energy",
|
102 |
+
"fmri": "fMRI",
|
103 |
+
}
|
104 |
+
gap = seq_length // 5
|
105 |
+
if seq_length in [96, 192, 384]:
|
106 |
+
ori_data = np.load(
|
107 |
+
os.path.join(
|
108 |
+
"../../../data/train/",str(seq_length),
|
109 |
+
dataset_name,
|
110 |
+
"samples",
|
111 |
+
f'{mapper[dataset_name].replace("sines", "sine")}_norm_truth_{seq_length}_train.npy',
|
112 |
+
)
|
113 |
+
)
|
114 |
+
masks = np.load(
|
115 |
+
os.path.join(
|
116 |
+
"../../../data/train/",str(seq_length),
|
117 |
+
dataset_name,
|
118 |
+
"samples",
|
119 |
+
f'{mapper[dataset_name].replace("sines", "sine")}_masking_{seq_length}.npy',
|
120 |
+
)
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
ori_data = np.load(
|
124 |
+
os.path.join(
|
125 |
+
"../../../data/train/",
|
126 |
+
dataset_name,
|
127 |
+
"samples",
|
128 |
+
f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy",
|
129 |
+
)
|
130 |
+
)
|
131 |
+
masks = np.load(
|
132 |
+
os.path.join(
|
133 |
+
"../../../data/train/",
|
134 |
+
dataset_name,
|
135 |
+
"samples",
|
136 |
+
f"{mapper[dataset_name]}_masking_{seq_length}.npy",
|
137 |
+
)
|
138 |
+
)
|
139 |
+
|
140 |
+
sample_num, _, _ = masks.shape
|
141 |
+
# observed = ori_data[:sample_num] * masks
|
142 |
+
ori_data = ori_data[:sample_num]
|
143 |
+
|
144 |
+
sampling_size = min(1000, len(test_dataset), sample_num)
|
145 |
+
batch_size = 500
|
146 |
+
print(f"Sampling size: {sampling_size}, Batch size: {batch_size}")
|
147 |
+
|
148 |
+
### Cache file path
|
149 |
+
cache_dir = Path(f"../../../data/cache/{dataset_name}_{seq_length}")
|
150 |
+
cache_dir.mkdir(exist_ok=True)
|
151 |
+
# results = load_cached_results(cache_dir)
|
152 |
+
results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}}
|
153 |
+
|
154 |
+
def measure_inference_time(func, *args, **kwargs):
|
155 |
+
start_time = time.time()
|
156 |
+
result = func(*args, **kwargs)
|
157 |
+
end_time = time.time()
|
158 |
+
return result, (end_time - start_time)
|
159 |
+
|
160 |
+
timing_results = {}
|
161 |
+
|
162 |
+
### Unconditional sampling
|
163 |
+
if results["unconditional"] is None:
|
164 |
+
print("Generating unconditional data...")
|
165 |
+
results["unconditional"], timing = measure_inference_time(
|
166 |
+
trainer.control_sample,
|
167 |
+
num=sampling_size,
|
168 |
+
size_every=batch_size,
|
169 |
+
shape=[seq_length, feature_dim],
|
170 |
+
model_kwargs={
|
171 |
+
"gradient_control_signal": {},
|
172 |
+
"coef": coef,
|
173 |
+
"learning_rate": stepsize,
|
174 |
+
},
|
175 |
+
)
|
176 |
+
timing_results["unconditional"] = timing / sampling_size
|
177 |
+
save_result(cache_dir, "unconditional", "", results["unconditional"])
|
178 |
+
|
179 |
+
### Different AUC values
|
180 |
+
auc_weights = [10]
|
181 |
+
auc_values = [-100, 20, 50, 150] # -200, -150, -100, -50, 0, 20, 30, 50, 100, 150
|
182 |
+
|
183 |
+
for auc in auc_values:
|
184 |
+
for weight in auc_weights:
|
185 |
+
key = f"auc_{auc}_weight_{weight}"
|
186 |
+
if key not in results["sum_controlled"]:
|
187 |
+
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}")
|
188 |
+
results["sum_controlled"][key], timing = measure_inference_time(
|
189 |
+
trainer.control_sample,
|
190 |
+
num=sampling_size,
|
191 |
+
size_every=batch_size,
|
192 |
+
shape=[seq_length, feature_dim],
|
193 |
+
model_kwargs={
|
194 |
+
"gradient_control_signal": {"auc": auc, "auc_weight": weight},
|
195 |
+
"coef": coef,
|
196 |
+
"learning_rate": stepsize,
|
197 |
+
},
|
198 |
+
)
|
199 |
+
timing_results[f"sum_controlled_{key}"] = timing / sampling_size
|
200 |
+
save_result(cache_dir, "sum", key, results["sum_controlled"][key])
|
201 |
+
|
202 |
+
### Different AUC weights
|
203 |
+
auc_weights = [1, 10, 50, 100]
|
204 |
+
auc_values = [-100]
|
205 |
+
for auc in auc_values:
|
206 |
+
for weight in auc_weights:
|
207 |
+
key = f"auc_{auc}_weight_{weight}"
|
208 |
+
if key not in results["sum_controlled"]:
|
209 |
+
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}")
|
210 |
+
results["sum_controlled"][key], timing = measure_inference_time(
|
211 |
+
trainer.control_sample,
|
212 |
+
num=sampling_size,
|
213 |
+
size_every=batch_size,
|
214 |
+
shape=[seq_length, feature_dim],
|
215 |
+
model_kwargs={
|
216 |
+
"gradient_control_signal": {"auc": auc, "auc_weight": weight},
|
217 |
+
"coef": coef,
|
218 |
+
"learning_rate": stepsize,
|
219 |
+
},
|
220 |
+
)
|
221 |
+
timing_results[f"sum_controlled_{key}"] = timing / (sampling_size)
|
222 |
+
save_result(cache_dir, "sum", key, results["sum_controlled"][key])
|
223 |
+
|
224 |
+
|
225 |
+
### Different AUC segments
|
226 |
+
auc_weights = [10]
|
227 |
+
auc_values = [150]
|
228 |
+
auc_average = 10
|
229 |
+
auc_segments = ((gap, 2 * gap), (2 * gap, 3 * gap), (3 * gap, 4 * gap))
|
230 |
+
# for auc in auc_values:
|
231 |
+
# for weight in auc_weights:
|
232 |
+
# for segment in auc_segments:
|
233 |
+
auc = auc_values[0]
|
234 |
+
weight = auc_weights[0]
|
235 |
+
# segment = auc_segments[0]
|
236 |
+
for segment in auc_segments:
|
237 |
+
key = f"auc_{auc}_weight_{weight}_segment_{segment[0]}_{segment[1]}"
|
238 |
+
if key not in results["sum_controlled"]:
|
239 |
+
print(
|
240 |
+
f"Generating sum controlled data - AUC: {auc}, Weight: {weight}, Segment: {segment}"
|
241 |
+
)
|
242 |
+
results["sum_controlled"][key], timing = measure_inference_time(
|
243 |
+
trainer.control_sample,
|
244 |
+
num=sampling_size,
|
245 |
+
size_every=batch_size,
|
246 |
+
shape=[seq_length, feature_dim],
|
247 |
+
model_kwargs={
|
248 |
+
"gradient_control_signal": {
|
249 |
+
"auc": auc_average * (segment[1] - segment[0]), # / seq_length,
|
250 |
+
"auc_weight": weight,
|
251 |
+
"segment": [segment],
|
252 |
+
},
|
253 |
+
"coef": coef,
|
254 |
+
"learning_rate": stepsize,
|
255 |
+
},
|
256 |
+
)
|
257 |
+
timing_results[f"sum_controlled_{key}"] = timing / sampling_size
|
258 |
+
save_result(cache_dir, "sum", key, results["sum_controlled"][key])
|
259 |
+
|
260 |
+
# Different anchors
|
261 |
+
anchor_values = [-0.8, 0.6, 1.0]
|
262 |
+
anchor_weights = [0.01, 0.01, 0.5, 1.0]
|
263 |
+
for peak in anchor_values:
|
264 |
+
for weight in anchor_weights:
|
265 |
+
key = f"peak_{peak}_weight_{weight}"
|
266 |
+
if key not in results["anchor_controlled"]:
|
267 |
+
mask = np.zeros((seq_length, feature_dim), dtype=np.float32)
|
268 |
+
mask[gap // 2 :: gap, 0] = weight
|
269 |
+
target = np.zeros((seq_length, feature_dim), dtype=np.float32)
|
270 |
+
target[gap // 2 :: gap, 0] = peak
|
271 |
+
|
272 |
+
print(f"Anchor controlled data - Peak: {peak}, Weight: {weight}")
|
273 |
+
results["anchor_controlled"][key], timing = measure_inference_time(
|
274 |
+
trainer.control_sample,
|
275 |
+
num=sampling_size,
|
276 |
+
size_every=batch_size,
|
277 |
+
shape=[seq_length, feature_dim],
|
278 |
+
model_kwargs={
|
279 |
+
"gradient_control_signal": {}, # "auc": -50, "auc_weight": 10.0},
|
280 |
+
"coef": coef,
|
281 |
+
"learning_rate": stepsize,
|
282 |
+
},
|
283 |
+
target=target,
|
284 |
+
partial_mask=mask,
|
285 |
+
)
|
286 |
+
timing_results[f"anchor_controlled_{key}"] = timing / sampling_size
|
287 |
+
save_result(cache_dir, "anchor", key, results["anchor_controlled"][key])
|
288 |
+
|
289 |
+
|
290 |
+
### Rerun Unconditional sampling
|
291 |
+
if results["unconditional"] is None:
|
292 |
+
print("Generating unconditional data...")
|
293 |
+
results["unconditional"], timing = measure_inference_time(
|
294 |
+
trainer.control_sample,
|
295 |
+
num=sampling_size,
|
296 |
+
size_every=batch_size,
|
297 |
+
shape=[seq_length, feature_dim],
|
298 |
+
model_kwargs={
|
299 |
+
"gradient_control_signal": {},
|
300 |
+
"coef": coef,
|
301 |
+
"learning_rate": stepsize,
|
302 |
+
},
|
303 |
+
)
|
304 |
+
timing_results["unconditional"] = timing / sampling_size
|
305 |
+
save_result(cache_dir, "unconditional", "", results["unconditional"])
|
306 |
+
|
307 |
+
# After all sampling is done, print timing results
|
308 |
+
print("\nAverage Inference Time per Sample:")
|
309 |
+
print("-" * 40)
|
310 |
+
for key, time_per_sample in timing_results.items():
|
311 |
+
print(f"{key}: {time_per_sample:.4f} seconds")
|
312 |
+
|
313 |
+
# return results, dataset_name, seq_length
|
314 |
+
|
315 |
+
|
316 |
+
if __name__ == "__main__":
|
317 |
+
args = parse_args()
|
318 |
+
run(args)
|
319 |
+
|
engine/logger.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
from utils.io_utils import write_args, save_config_to_yaml
|
9 |
+
|
10 |
+
|
11 |
+
class Logger(object):
|
12 |
+
def __init__(self, args):
|
13 |
+
self.args = args
|
14 |
+
self.save_dir = args.save_dir
|
15 |
+
|
16 |
+
os.makedirs(self.save_dir, exist_ok=True)
|
17 |
+
|
18 |
+
# save the args and config
|
19 |
+
self.config_dir = os.path.join(self.save_dir, "configs")
|
20 |
+
os.makedirs(self.config_dir, exist_ok=True)
|
21 |
+
file_name = os.path.join(self.config_dir, "args.txt")
|
22 |
+
write_args(args, file_name)
|
23 |
+
|
24 |
+
log_dir = os.path.join(self.save_dir, "logs")
|
25 |
+
if not os.path.exists(log_dir):
|
26 |
+
os.makedirs(log_dir, exist_ok=True)
|
27 |
+
self.text_writer = open(os.path.join(log_dir, "log.txt"), "a") # 'w')
|
28 |
+
if args.tensorboard:
|
29 |
+
self.log_info("using tensorboard")
|
30 |
+
self.tb_writer = torch.utils.tensorboard.SummaryWriter(
|
31 |
+
log_dir=log_dir
|
32 |
+
) # tensorboard.SummaryWriter(log_dir=log_dir)
|
33 |
+
else:
|
34 |
+
self.tb_writer = None
|
35 |
+
|
36 |
+
def save_config(self, config):
|
37 |
+
save_config_to_yaml(config, os.path.join(self.config_dir, "config.yaml"))
|
38 |
+
|
39 |
+
def log_info(self, info, check_primary=True):
|
40 |
+
print(info)
|
41 |
+
info = str(info)
|
42 |
+
time_str = time.strftime("%Y-%m-%d-%H-%M")
|
43 |
+
info = "{}: {}".format(time_str, info)
|
44 |
+
if not info.endswith("\n"):
|
45 |
+
info += "\n"
|
46 |
+
self.text_writer.write(info)
|
47 |
+
self.text_writer.flush()
|
48 |
+
|
49 |
+
def add_scalar(self, **kargs):
|
50 |
+
"""Log a scalar variable."""
|
51 |
+
if self.tb_writer is not None:
|
52 |
+
self.tb_writer.add_scalar(**kargs)
|
53 |
+
|
54 |
+
def add_scalars(self, **kargs):
|
55 |
+
"""Log a scalar variable."""
|
56 |
+
if self.tb_writer is not None:
|
57 |
+
self.tb_writer.add_scalars(**kargs)
|
58 |
+
|
59 |
+
def add_image(self, **kargs):
|
60 |
+
"""Log a scalar variable."""
|
61 |
+
if self.tb_writer is not None:
|
62 |
+
self.tb_writer.add_image(**kargs)
|
63 |
+
|
64 |
+
def add_images(self, **kargs):
|
65 |
+
"""Log a scalar variable."""
|
66 |
+
if self.tb_writer is not None:
|
67 |
+
self.tb_writer.add_images(**kargs)
|
68 |
+
|
69 |
+
def close(self):
|
70 |
+
self.text_writer.close()
|
71 |
+
self.tb_writer.close()
|