PeterYu commited on
Commit
a785d5a
·
1 Parent(s): 21db335
Files changed (2) hide show
  1. app.py +164 -172
  2. models/CSDI/tiffusion.py +1 -59
app.py CHANGED
@@ -43,8 +43,11 @@ class TimeSeriesEditor:
43
  # Add frequency band multipliers
44
  self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0
45
  self.function_parser = FunctionParser()
46
- self.trending_controls = [] # Store trending controls
47
-
 
 
 
48
  def format_value(self, value: float, feature_idx: int) -> str:
49
  """Format value with appropriate units and notation"""
50
  if self.show_normalized:
@@ -377,7 +380,7 @@ class TimeSeriesEditor:
377
  peak_alpha: float,
378
  auc_weight: float,
379
  peak_weight: float,
380
- enable_trending: bool = False,
381
  enable_trending_with_diff: bool = False,
382
  trending_params: str = ""
383
  ) -> Tuple[List[go.Figure], str, str, Dict]:
@@ -436,15 +439,16 @@ class TimeSeriesEditor:
436
  # model_control_signal["selected_areas"] = areas
437
 
438
  # Run prediction
439
- sample = self.trainer.predict_weighted_points(
440
- observed_points, # (seq_length, feature_dim)
441
- observed_mask, # (seq_length, feature_dim)
442
- self.coef, # fixed
443
- self.stepsize, # fixed
444
- self.sampling_steps, # fixed
445
- # model_control_signal=model_control_signal,
446
- gradient_control_signal=gradient_control_signal
447
- )
 
448
 
449
  # Store latest results
450
  self.latest_sample = sample
@@ -610,10 +614,10 @@ class TimeSeriesEditor:
610
  def create_gradio_interface(editor: TimeSeriesEditor):
611
  with gr.Blocks() as app:
612
  gr.Markdown("# Time Series Editor")
613
- gr.Markdown("## Instruction: Scroll Down + Click `Update Figure` [~20s]")
614
 
615
  metrics_display = gr.JSON(label="Metrics", value={})
616
-
617
  with gr.Row():
618
  with gr.Column(scale=1):
619
  # with Tab():
@@ -642,103 +646,102 @@ def create_gradio_interface(editor: TimeSeriesEditor):
642
 
643
  # TS Section
644
  gr.Markdown("## Time Series Control Panel")
645
- with gr.Accordion("Open for More Detail"):
646
- with gr.Group():
647
- gr.Markdown("### Anchor Point Control")
648
- data_points_df = gr.Dataframe(
649
- headers=["time", "feature", "value"],
650
- datatype=["number", "number", "number"],
651
- # label="Anchor Point Control",
652
- value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 1.0], [60, 0, 0.5]],
653
- col_count=(3, "fixed"), # Fix number of columns
654
- interactive=True
655
- )
656
- add_data_point_btn = gr.Button("Add Data Point")
657
-
658
- def add_data_point(df):
659
- new_row = pd.DataFrame([[None, 0, None]],
660
- columns=["time", "feature", "value"])
661
- return pd.concat([df, new_row], ignore_index=True)
662
-
663
- add_data_point_btn.click(
664
- fn=add_data_point,
665
- inputs=[data_points_df],
666
- outputs=[data_points_df]
667
- )
668
-
669
- with gr.Group():
670
- gr.Markdown("### Group of Anchor Point Control")
671
- point_groups_df = gr.Dataframe(
672
- headers=["start", "end", "interval", "feature", "value", "weight"],
673
- datatype=["number", "number", "number", "number", "number", "number"],
674
- # label="Group of Anchor Point Control",
675
- value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]],
676
- col_count=(6, "fixed"), # Fix number of columns
677
- interactive=True
678
- )
679
- add_point_group_btn = gr.Button("Add Point Group")
680
-
681
- def add_point_group(df):
682
- new_row = pd.DataFrame([[None, None, None, 0, None, None]],
683
- columns=["start", "end", "interval", "feature", "value", "weight"])
684
- return pd.concat([df, new_row], ignore_index=True)
685
-
686
- add_point_group_btn.click(
687
- fn=add_point_group,
688
- inputs=[point_groups_df],
689
- outputs=[point_groups_df]
690
- )
691
-
692
- with gr.Group():
693
- # with gr.Tab("Trending Control"):
694
- gr.Markdown("### Trending Control")
695
- gr.Markdown("""
696
- Enter trending control parameters in the format:
697
- ```
698
- start_time,end_time,feature,function,confidence
699
- ```
700
- Examples:
701
- - Linear trend: `0,100,0,x`
702
- - Sine wave: `0,100,0,sin(2*pi*x)`
703
- - Exponential: `0,100,0,exp(-x)`
704
-
705
- Separate multiple trends with semicolons.
706
- """)
707
- enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=True)
708
- enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False)
709
- trending_control = gr.Textbox(
710
- label="Trending Control Parameters",
711
- lines=2,
712
- placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons",
713
- value="200,250,0,sin(2*pi*x),0.2"
714
- )
715
-
716
- # Area Control Parameters
717
- with gr.Group(visible=False):
718
- gr.Markdown("### Area Control")
719
- enable_area_control = gr.Checkbox(label="Enable Area Control", value=False)
720
- area_selections = gr.Textbox(
721
- label="Area Selections (format: start_time,end_time,feature,target_value)",
722
- lines=2,
723
- placeholder="Enter areas: start,end,feature,target; separated by semicolons",
724
-
725
- )
726
-
727
- # AUC Parameters
728
- gr.Markdown("### Statistics Control")
729
- enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True)
730
- auc_input = gr.Number(label="Target Sum Value", value=-150)
731
- auc_weight_input = gr.Number(label="Sum Weight", value=10.0)
 
 
 
 
 
 
 
732
 
733
- # Peak Parameters
734
- with gr.Group(visible=False):
735
- gr.Markdown("### Peak Control")
736
- enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False)
737
- peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200")
738
- peak_alpha_input = gr.Number(label="Peak Alpha", value=10)
739
- peak_weight_input = gr.Number(label="Peak Weight", value=1.0)
740
-
741
- update_model_btn = gr.Button("Update Figure")
742
 
743
  gr.Markdown("## Extend Edit", visible=False)
744
  with gr.Tab("Range Shift", visible=False):
@@ -905,26 +908,26 @@ def create_gradio_interface(editor: TimeSeriesEditor):
905
  outputs=[*plots, metrics_display]
906
  )
907
 
908
- # app.load(
909
- # fn=update_model_callback,
910
- # inputs=[
911
- # data_points_df,
912
- # point_groups_df,
913
- # enable_area_control,
914
- # area_selections,
915
- # enable_auc,
916
- # auc_input,
917
- # auc_weight_input,
918
- # enable_peaks,
919
- # peak_points_input,
920
- # peak_alpha_input,
921
- # peak_weight_input,
922
- # enable_trending_control,
923
- # enable_trending_control_with_diff,
924
- # trending_control
925
- # ],
926
- # outputs=[*plots, metrics_display]
927
- # )
928
 
929
  return app
930
 
@@ -1059,63 +1062,52 @@ class FunctionParser:
1059
 
1060
  except Exception as e:
1061
  print(f"Error: {str(e)}")
1062
-
1063
  # Example usage:
1064
  if __name__ == "__main__":
1065
- # Initialize with example data points
1066
- # example_data_points = "0,0,0.04;2,0,0.58;6,0,0.27;58,0,1.0;-1,0,0.05"
1067
-
1068
  import os
1069
  import torch
1070
  import numpy as np
1071
- from engine.solver import Trainer
1072
- from utils.io_utils import load_yaml_config, instantiate_from_config
1073
 
1074
  # assert torch.cuda.is_available(), "CUDA must be available"
1075
- class Parameters:
1076
- def __init__(self) -> None:
1077
- self.gpu = 0
1078
- self.config_path = "./config/modified/revenue-baseline-365.yaml"
1079
- # self.config_path = "config/modified/96/fmri.yaml"
1080
- # self.config_path = "./config/control/revenue-baseline-sine.yaml"
1081
- # self.save_dir = (
1082
- # "../../../data/" + os.path.basename(self.config_path).split(".")[0]
1083
- # )
1084
- self.mode = "infill"
1085
- self.missing_ratio = 0.95
1086
- self.milestone = "10"
1087
- # os.makedirs(self.save_dir, exist_ok=True)
1088
-
1089
  os.environ["WANDB_ENABLED"] = "false"
1090
- # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
1091
- # print working directory
1092
  print(os.getcwd())
1093
- args = Parameters()
1094
- configs = load_yaml_config(args.config_path)
1095
- # device = torch.device('cpu')
1096
- device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
1097
-
1098
- # dl_info = build_dataloader_cond(configs, args)
1099
- model = instantiate_from_config(configs["model"]).to(device)
1100
- trainer = Trainer(config=configs, args=args, model=model, dataloader={
1101
- "dataloader": []
1102
- })
1103
-
1104
- trainer.load(args.milestone, from_folder="./weight") #, from_folder="../../../data/ckpt_baseline_sine_240"), from_folder="./data/weight_365"
1105
- # dataloader, dataset = dl_info["dataloader"], dl_info["dataset"]
1106
- coef = configs["dataloader"]["test_dataset"]["coefficient"]
1107
- stepsize = configs["dataloader"]["test_dataset"]["step_size"]
1108
- sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"]
1109
- seq_length = configs["dataloader"]["test_dataset"]["params"]["window"]
 
 
 
 
 
 
 
 
 
 
 
 
 
1110
  feature_dim = 3
1111
  print(f"seq_length: {seq_length}, feature_dim: {feature_dim}")
1112
 
1113
- # Initialize your trainer, configs, and dataset here
1114
- editor = TimeSeriesEditor(seq_length, feature_dim, trainer)
1115
  editor.coef = coef
1116
  editor.stepsize = stepsize
1117
  editor.sampling_steps = sampling_steps
1118
 
1119
  app = create_gradio_interface(editor)
1120
- # app.launch(server_name="0.0.0.0", server_port=8888, share=True)
1121
  app.launch(show_api=False)
 
43
  # Add frequency band multipliers
44
  self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0
45
  self.function_parser = FunctionParser()
46
+ self.trending_controls = [
47
+ (200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
48
+ # 200,250,0,sin(2*pi*x),0.05
49
+ ]
50
+
51
  def format_value(self, value: float, feature_idx: int) -> str:
52
  """Format value with appropriate units and notation"""
53
  if self.show_normalized:
 
380
  peak_alpha: float,
381
  auc_weight: float,
382
  peak_weight: float,
383
+ enable_trending: bool = True,
384
  enable_trending_with_diff: bool = False,
385
  trending_params: str = ""
386
  ) -> Tuple[List[go.Figure], str, str, Dict]:
 
439
  # model_control_signal["selected_areas"] = areas
440
 
441
  # Run prediction
442
+ with torch.no_grad():
443
+ sample = self.trainer.predict_weighted_points(
444
+ observed_points, # (seq_length, feature_dim)
445
+ observed_mask, # (seq_length, feature_dim)
446
+ self.coef, # fixed
447
+ self.stepsize, # fixed
448
+ self.sampling_steps, # fixed
449
+ # model_control_signal=model_control_signal,
450
+ gradient_control_signal=gradient_control_signal
451
+ )
452
 
453
  # Store latest results
454
  self.latest_sample = sample
 
614
  def create_gradio_interface(editor: TimeSeriesEditor):
615
  with gr.Blocks() as app:
616
  gr.Markdown("# Time Series Editor")
617
+ gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s-30s] [Running on CPU...]")
618
 
619
  metrics_display = gr.JSON(label="Metrics", value={})
620
+
621
  with gr.Row():
622
  with gr.Column(scale=1):
623
  # with Tab():
 
646
 
647
  # TS Section
648
  gr.Markdown("## Time Series Control Panel")
649
+ # with gr.Accordion("Open for More Detail"):
650
+ with gr.Group():
651
+ gr.Markdown("### Fixed Point Control")
652
+ data_points_df = gr.Dataframe(
653
+ headers=["time", "feature", "value"],
654
+ datatype=["number", "number", "number"],
655
+ # label="Anchor Point Control",
656
+ value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 1.0], [60, 0, 0.5]],
657
+ col_count=(3, "fixed"), # Fix number of columns
658
+ interactive=True
659
+ )
660
+ add_data_point_btn = gr.Button("Add Data Point")
661
+
662
+ def add_data_point(df):
663
+ new_row = pd.DataFrame([[None, 0, None]],
664
+ columns=["time", "feature", "value"])
665
+ return pd.concat([df, new_row], ignore_index=True)
666
+
667
+ add_data_point_btn.click(
668
+ fn=add_data_point,
669
+ inputs=[data_points_df],
670
+ outputs=[data_points_df]
671
+ )
672
+
673
+ with gr.Group():
674
+ gr.Markdown("### Group of Anchor Point Control with Confidence")
675
+ point_groups_df = gr.Dataframe(
676
+ headers=["start", "end", "interval", "feature", "value", "weight"],
677
+ datatype=["number", "number", "number", "number", "number", "number"],
678
+ # label="Group of Anchor Point Control",
679
+ value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]],
680
+ col_count=(6, "fixed"), # Fix number of columns
681
+ interactive=True
682
+ )
683
+ add_point_group_btn = gr.Button("Add Point Group")
684
+
685
+ def add_point_group(df):
686
+ new_row = pd.DataFrame([[None, None, None, 0, None, None]],
687
+ columns=["start", "end", "interval", "feature", "value", "weight"])
688
+ return pd.concat([df, new_row], ignore_index=True)
689
+
690
+ add_point_group_btn.click(
691
+ fn=add_point_group,
692
+ inputs=[point_groups_df],
693
+ outputs=[point_groups_df]
694
+ )
695
+
696
+ with gr.Group():
697
+ # with gr.Tab("Trending Control"):
698
+ gr.Markdown("### Trending Control")
699
+ gr.Markdown("""
700
+ Enter trending control parameters in the format:
701
+ ```
702
+ start_time,end_time,feature,function,confidence
703
+ ```
704
+ Examples:
705
+ - Linear trend: `0,100,0,x`
706
+ - Sine wave: `0,100,0,sin(2*pi*x)`
707
+ - Exponential: `0,100,0,exp(-x)`
708
+
709
+ Separate multiple trends with semicolons.
710
+ """)
711
+ enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=True)
712
+ enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False)
713
+ trending_control = gr.Textbox(
714
+ label="Trending Control Parameters",
715
+ lines=2,
716
+ placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons",
717
+ value="200,250,0,sin(2*pi*x),0.05"
718
+ )
719
+
720
+ # Area Control Parameters
721
+ with gr.Group(visible=False):
722
+ gr.Markdown("### Area Control")
723
+ enable_area_control = gr.Checkbox(label="Enable Area Control", value=False)
724
+ area_selections = gr.Textbox(
725
+ label="Area Selections (format: start_time,end_time,feature,target_value)",
726
+ lines=2,
727
+ placeholder="Enter areas: start,end,feature,target; separated by semicolons",
728
+ )
729
+
730
+ # AUC Parameters
731
+ gr.Markdown("### Statistics Control")
732
+ enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True)
733
+ auc_input = gr.Number(label="Target Sum Value", value=-150)
734
+ auc_weight_input = gr.Number(label="Sum Weight", value=10.0)
735
+
736
+ # Peak Parameters
737
+ with gr.Group(visible=False):
738
+ gr.Markdown("### Peak Control")
739
+ enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False)
740
+ peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200")
741
+ peak_alpha_input = gr.Number(label="Peak Alpha", value=10)
742
+ peak_weight_input = gr.Number(label="Peak Weight", value=1.0)
743
 
744
+ update_model_btn = gr.Button("Update Figure")
 
 
 
 
 
 
 
 
745
 
746
  gr.Markdown("## Extend Edit", visible=False)
747
  with gr.Tab("Range Shift", visible=False):
 
908
  outputs=[*plots, metrics_display]
909
  )
910
 
911
+ app.load(
912
+ fn=update_model_callback,
913
+ inputs=[
914
+ data_points_df,
915
+ point_groups_df,
916
+ enable_area_control,
917
+ area_selections,
918
+ enable_auc,
919
+ auc_input,
920
+ auc_weight_input,
921
+ enable_peaks,
922
+ peak_points_input,
923
+ peak_alpha_input,
924
+ peak_weight_input,
925
+ enable_trending_control,
926
+ enable_trending_control_with_diff,
927
+ trending_control
928
+ ],
929
+ outputs=[*plots, metrics_display]
930
+ )
931
 
932
  return app
933
 
 
1062
 
1063
  except Exception as e:
1064
  print(f"Error: {str(e)}")
 
1065
  # Example usage:
1066
  if __name__ == "__main__":
 
 
 
1067
  import os
1068
  import torch
1069
  import numpy as np
 
 
1070
 
1071
  # assert torch.cuda.is_available(), "CUDA must be available"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1072
  os.environ["WANDB_ENABLED"] = "false"
 
 
1073
  print(os.getcwd())
1074
+
1075
+ device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
1076
+
1077
+ from models.Tiffusion import tiffusion
1078
+
1079
+ model = tiffusion.Tiffusion(
1080
+ seq_length=365,
1081
+ feature_size=3,
1082
+ n_layer_enc=6,
1083
+ n_layer_dec=4,
1084
+ d_model=128,
1085
+ timesteps=500,
1086
+ sampling_timesteps=200,
1087
+ loss_type='l1',
1088
+ beta_schedule='cosine',
1089
+ n_heads=8,
1090
+ mlp_hidden_times=4,
1091
+ attn_pd=0.0,
1092
+ resid_pd=0.0,
1093
+ kernel_size=1,
1094
+ padding_size=0,
1095
+ control_signal=[]
1096
+ ).to(device)
1097
+
1098
+ model.load_state_dict(torch.load("./weight/checkpoint-10.pt", map_location=device, weights_only=True)["model"])
1099
+
1100
+ coef = 1.0e-2
1101
+ stepsize = 5.0e-2
1102
+ sampling_steps = 100 # Adjustable between 100-500 for speed/accuracy tradeoff
1103
+ seq_length = 365
1104
  feature_dim = 3
1105
  print(f"seq_length: {seq_length}, feature_dim: {feature_dim}")
1106
 
1107
+ editor = TimeSeriesEditor(seq_length, feature_dim, model)
 
1108
  editor.coef = coef
1109
  editor.stepsize = stepsize
1110
  editor.sampling_steps = sampling_steps
1111
 
1112
  app = create_gradio_interface(editor)
 
1113
  app.launch(show_api=False)
models/CSDI/tiffusion.py CHANGED
@@ -33,7 +33,7 @@ def cosine_beta_schedule(timesteps, s=0.008):
33
  return torch.clip(betas, 0, 0.999)
34
 
35
 
36
- class Tiffusion(CSDI_base):
37
  def __init__(
38
  self,
39
  seq_length,
@@ -111,12 +111,9 @@ class Tiffusion(CSDI_base):
111
  config_diff["beta_start"], config_diff["beta_end"], self.num_steps
112
  )
113
 
114
-
115
  self.alpha_hat = 1 - self.beta
116
  self.alpha = np.cumprod(self.alpha_hat)
117
  self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
118
- # self.beta = torch.from_numpy(self.beta).float().to(self.device)
119
- # self.alpha = torch.from_numpy(self.alpha).float().to(self.device)
120
 
121
  self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
122
  if self.is_unconditional == False:
@@ -127,63 +124,8 @@ class Tiffusion(CSDI_base):
127
  num_embeddings=self.target_dim
128
  , embedding_dim=self.emb_feature_dim
129
  )
130
- # self.model: Transformer = Transformer(
131
- # n_feat=feature_size,
132
- # n_channel=seq_length,
133
- # n_layer_enc=n_layer_enc,
134
- # n_layer_dec=n_layer_dec,
135
- # n_heads=n_heads,
136
- # attn_pdrop=attn_pd,
137
- # resid_pdrop=resid_pd,
138
- # mlp_hidden_times=mlp_hidden_times,
139
- # max_len=seq_length,
140
- # n_embd=d_model,
141
- # conv_params=[kernel_size, padding_size],
142
- # **kwargs,
143
- # )
144
- class Config:
145
- def __init__(self, **kwargs):
146
- self.__dict__.update(kwargs)
147
 
148
- # type: CSDI
149
- # layers: 3
150
- # channels: 64
151
- # nheads: 8
152
- # diffusion_embedding_dim: 128
153
- # is_linear: False # linear transformer
154
-
155
- # beta_start: 0.0001
156
- # beta_end: 0.5
157
- # schedule: "quad"
158
-
159
- # num_steps: 50
160
-
161
- # # edit
162
- # edit_steps: 50 # the number of steps to perform editing
163
- # bootstrap_ratio: 0.5 # [0,1]
164
-
165
- # is_attr_proj: False
166
-
167
- # side:
168
- # num_var: 1
169
- # var_emb: 16
170
- # time_emb: 128
171
-
172
- # attrs:
173
- # attr_emb: 64
174
- # config_diff["side_dim"] = self.emb_total_dim
175
  self.diffmodel = diff_CSDI(
176
- # config=Config(
177
- # layers=3,
178
- # channels=64,
179
- # nheads=8,
180
- # diffusion_embedding_dim=128,
181
- # is_linear=False,
182
- # beta_start=0.0001,
183
- # beta_end=0.5,
184
- # schedule="quad",
185
- # num_steps=50,
186
- # )
187
  {
188
  "layers": 3,
189
  "channels": 64,
 
33
  return torch.clip(betas, 0, 0.999)
34
 
35
 
36
+ class Tiffusion(nn.Module):
37
  def __init__(
38
  self,
39
  seq_length,
 
111
  config_diff["beta_start"], config_diff["beta_end"], self.num_steps
112
  )
113
 
 
114
  self.alpha_hat = 1 - self.beta
115
  self.alpha = np.cumprod(self.alpha_hat)
116
  self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
 
 
117
 
118
  self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
119
  if self.is_unconditional == False:
 
124
  num_embeddings=self.target_dim
125
  , embedding_dim=self.emb_feature_dim
126
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  self.diffmodel = diff_CSDI(
 
 
 
 
 
 
 
 
 
 
 
129
  {
130
  "layers": 3,
131
  "channels": 64,