Spaces:
Sleeping
Sleeping
update
Browse files- app.py +164 -172
- models/CSDI/tiffusion.py +1 -59
app.py
CHANGED
@@ -43,8 +43,11 @@ class TimeSeriesEditor:
|
|
43 |
# Add frequency band multipliers
|
44 |
self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0
|
45 |
self.function_parser = FunctionParser()
|
46 |
-
self.trending_controls = [
|
47 |
-
|
|
|
|
|
|
|
48 |
def format_value(self, value: float, feature_idx: int) -> str:
|
49 |
"""Format value with appropriate units and notation"""
|
50 |
if self.show_normalized:
|
@@ -377,7 +380,7 @@ class TimeSeriesEditor:
|
|
377 |
peak_alpha: float,
|
378 |
auc_weight: float,
|
379 |
peak_weight: float,
|
380 |
-
enable_trending: bool =
|
381 |
enable_trending_with_diff: bool = False,
|
382 |
trending_params: str = ""
|
383 |
) -> Tuple[List[go.Figure], str, str, Dict]:
|
@@ -436,15 +439,16 @@ class TimeSeriesEditor:
|
|
436 |
# model_control_signal["selected_areas"] = areas
|
437 |
|
438 |
# Run prediction
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
|
|
448 |
|
449 |
# Store latest results
|
450 |
self.latest_sample = sample
|
@@ -610,10 +614,10 @@ class TimeSeriesEditor:
|
|
610 |
def create_gradio_interface(editor: TimeSeriesEditor):
|
611 |
with gr.Blocks() as app:
|
612 |
gr.Markdown("# Time Series Editor")
|
613 |
-
gr.Markdown("## Instruction: Scroll Down + Click
|
614 |
|
615 |
metrics_display = gr.JSON(label="Metrics", value={})
|
616 |
-
|
617 |
with gr.Row():
|
618 |
with gr.Column(scale=1):
|
619 |
# with Tab():
|
@@ -642,103 +646,102 @@ def create_gradio_interface(editor: TimeSeriesEditor):
|
|
642 |
|
643 |
# TS Section
|
644 |
gr.Markdown("## Time Series Control Panel")
|
645 |
-
with gr.Accordion("Open for More Detail"):
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
732 |
|
733 |
-
|
734 |
-
with gr.Group(visible=False):
|
735 |
-
gr.Markdown("### Peak Control")
|
736 |
-
enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False)
|
737 |
-
peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200")
|
738 |
-
peak_alpha_input = gr.Number(label="Peak Alpha", value=10)
|
739 |
-
peak_weight_input = gr.Number(label="Peak Weight", value=1.0)
|
740 |
-
|
741 |
-
update_model_btn = gr.Button("Update Figure")
|
742 |
|
743 |
gr.Markdown("## Extend Edit", visible=False)
|
744 |
with gr.Tab("Range Shift", visible=False):
|
@@ -905,26 +908,26 @@ def create_gradio_interface(editor: TimeSeriesEditor):
|
|
905 |
outputs=[*plots, metrics_display]
|
906 |
)
|
907 |
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
|
929 |
return app
|
930 |
|
@@ -1059,63 +1062,52 @@ class FunctionParser:
|
|
1059 |
|
1060 |
except Exception as e:
|
1061 |
print(f"Error: {str(e)}")
|
1062 |
-
|
1063 |
# Example usage:
|
1064 |
if __name__ == "__main__":
|
1065 |
-
# Initialize with example data points
|
1066 |
-
# example_data_points = "0,0,0.04;2,0,0.58;6,0,0.27;58,0,1.0;-1,0,0.05"
|
1067 |
-
|
1068 |
import os
|
1069 |
import torch
|
1070 |
import numpy as np
|
1071 |
-
from engine.solver import Trainer
|
1072 |
-
from utils.io_utils import load_yaml_config, instantiate_from_config
|
1073 |
|
1074 |
# assert torch.cuda.is_available(), "CUDA must be available"
|
1075 |
-
class Parameters:
|
1076 |
-
def __init__(self) -> None:
|
1077 |
-
self.gpu = 0
|
1078 |
-
self.config_path = "./config/modified/revenue-baseline-365.yaml"
|
1079 |
-
# self.config_path = "config/modified/96/fmri.yaml"
|
1080 |
-
# self.config_path = "./config/control/revenue-baseline-sine.yaml"
|
1081 |
-
# self.save_dir = (
|
1082 |
-
# "../../../data/" + os.path.basename(self.config_path).split(".")[0]
|
1083 |
-
# )
|
1084 |
-
self.mode = "infill"
|
1085 |
-
self.missing_ratio = 0.95
|
1086 |
-
self.milestone = "10"
|
1087 |
-
# os.makedirs(self.save_dir, exist_ok=True)
|
1088 |
-
|
1089 |
os.environ["WANDB_ENABLED"] = "false"
|
1090 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
1091 |
-
# print working directory
|
1092 |
print(os.getcwd())
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1110 |
feature_dim = 3
|
1111 |
print(f"seq_length: {seq_length}, feature_dim: {feature_dim}")
|
1112 |
|
1113 |
-
|
1114 |
-
editor = TimeSeriesEditor(seq_length, feature_dim, trainer)
|
1115 |
editor.coef = coef
|
1116 |
editor.stepsize = stepsize
|
1117 |
editor.sampling_steps = sampling_steps
|
1118 |
|
1119 |
app = create_gradio_interface(editor)
|
1120 |
-
# app.launch(server_name="0.0.0.0", server_port=8888, share=True)
|
1121 |
app.launch(show_api=False)
|
|
|
43 |
# Add frequency band multipliers
|
44 |
self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0
|
45 |
self.function_parser = FunctionParser()
|
46 |
+
self.trending_controls = [
|
47 |
+
(200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
|
48 |
+
# 200,250,0,sin(2*pi*x),0.05
|
49 |
+
]
|
50 |
+
|
51 |
def format_value(self, value: float, feature_idx: int) -> str:
|
52 |
"""Format value with appropriate units and notation"""
|
53 |
if self.show_normalized:
|
|
|
380 |
peak_alpha: float,
|
381 |
auc_weight: float,
|
382 |
peak_weight: float,
|
383 |
+
enable_trending: bool = True,
|
384 |
enable_trending_with_diff: bool = False,
|
385 |
trending_params: str = ""
|
386 |
) -> Tuple[List[go.Figure], str, str, Dict]:
|
|
|
439 |
# model_control_signal["selected_areas"] = areas
|
440 |
|
441 |
# Run prediction
|
442 |
+
with torch.no_grad():
|
443 |
+
sample = self.trainer.predict_weighted_points(
|
444 |
+
observed_points, # (seq_length, feature_dim)
|
445 |
+
observed_mask, # (seq_length, feature_dim)
|
446 |
+
self.coef, # fixed
|
447 |
+
self.stepsize, # fixed
|
448 |
+
self.sampling_steps, # fixed
|
449 |
+
# model_control_signal=model_control_signal,
|
450 |
+
gradient_control_signal=gradient_control_signal
|
451 |
+
)
|
452 |
|
453 |
# Store latest results
|
454 |
self.latest_sample = sample
|
|
|
614 |
def create_gradio_interface(editor: TimeSeriesEditor):
|
615 |
with gr.Blocks() as app:
|
616 |
gr.Markdown("# Time Series Editor")
|
617 |
+
gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s-30s] [Running on CPU...]")
|
618 |
|
619 |
metrics_display = gr.JSON(label="Metrics", value={})
|
620 |
+
|
621 |
with gr.Row():
|
622 |
with gr.Column(scale=1):
|
623 |
# with Tab():
|
|
|
646 |
|
647 |
# TS Section
|
648 |
gr.Markdown("## Time Series Control Panel")
|
649 |
+
# with gr.Accordion("Open for More Detail"):
|
650 |
+
with gr.Group():
|
651 |
+
gr.Markdown("### Fixed Point Control")
|
652 |
+
data_points_df = gr.Dataframe(
|
653 |
+
headers=["time", "feature", "value"],
|
654 |
+
datatype=["number", "number", "number"],
|
655 |
+
# label="Anchor Point Control",
|
656 |
+
value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 1.0], [60, 0, 0.5]],
|
657 |
+
col_count=(3, "fixed"), # Fix number of columns
|
658 |
+
interactive=True
|
659 |
+
)
|
660 |
+
add_data_point_btn = gr.Button("Add Data Point")
|
661 |
+
|
662 |
+
def add_data_point(df):
|
663 |
+
new_row = pd.DataFrame([[None, 0, None]],
|
664 |
+
columns=["time", "feature", "value"])
|
665 |
+
return pd.concat([df, new_row], ignore_index=True)
|
666 |
+
|
667 |
+
add_data_point_btn.click(
|
668 |
+
fn=add_data_point,
|
669 |
+
inputs=[data_points_df],
|
670 |
+
outputs=[data_points_df]
|
671 |
+
)
|
672 |
+
|
673 |
+
with gr.Group():
|
674 |
+
gr.Markdown("### Group of Anchor Point Control with Confidence")
|
675 |
+
point_groups_df = gr.Dataframe(
|
676 |
+
headers=["start", "end", "interval", "feature", "value", "weight"],
|
677 |
+
datatype=["number", "number", "number", "number", "number", "number"],
|
678 |
+
# label="Group of Anchor Point Control",
|
679 |
+
value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]],
|
680 |
+
col_count=(6, "fixed"), # Fix number of columns
|
681 |
+
interactive=True
|
682 |
+
)
|
683 |
+
add_point_group_btn = gr.Button("Add Point Group")
|
684 |
+
|
685 |
+
def add_point_group(df):
|
686 |
+
new_row = pd.DataFrame([[None, None, None, 0, None, None]],
|
687 |
+
columns=["start", "end", "interval", "feature", "value", "weight"])
|
688 |
+
return pd.concat([df, new_row], ignore_index=True)
|
689 |
+
|
690 |
+
add_point_group_btn.click(
|
691 |
+
fn=add_point_group,
|
692 |
+
inputs=[point_groups_df],
|
693 |
+
outputs=[point_groups_df]
|
694 |
+
)
|
695 |
+
|
696 |
+
with gr.Group():
|
697 |
+
# with gr.Tab("Trending Control"):
|
698 |
+
gr.Markdown("### Trending Control")
|
699 |
+
gr.Markdown("""
|
700 |
+
Enter trending control parameters in the format:
|
701 |
+
```
|
702 |
+
start_time,end_time,feature,function,confidence
|
703 |
+
```
|
704 |
+
Examples:
|
705 |
+
- Linear trend: `0,100,0,x`
|
706 |
+
- Sine wave: `0,100,0,sin(2*pi*x)`
|
707 |
+
- Exponential: `0,100,0,exp(-x)`
|
708 |
+
|
709 |
+
Separate multiple trends with semicolons.
|
710 |
+
""")
|
711 |
+
enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=True)
|
712 |
+
enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False)
|
713 |
+
trending_control = gr.Textbox(
|
714 |
+
label="Trending Control Parameters",
|
715 |
+
lines=2,
|
716 |
+
placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons",
|
717 |
+
value="200,250,0,sin(2*pi*x),0.05"
|
718 |
+
)
|
719 |
+
|
720 |
+
# Area Control Parameters
|
721 |
+
with gr.Group(visible=False):
|
722 |
+
gr.Markdown("### Area Control")
|
723 |
+
enable_area_control = gr.Checkbox(label="Enable Area Control", value=False)
|
724 |
+
area_selections = gr.Textbox(
|
725 |
+
label="Area Selections (format: start_time,end_time,feature,target_value)",
|
726 |
+
lines=2,
|
727 |
+
placeholder="Enter areas: start,end,feature,target; separated by semicolons",
|
728 |
+
)
|
729 |
+
|
730 |
+
# AUC Parameters
|
731 |
+
gr.Markdown("### Statistics Control")
|
732 |
+
enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True)
|
733 |
+
auc_input = gr.Number(label="Target Sum Value", value=-150)
|
734 |
+
auc_weight_input = gr.Number(label="Sum Weight", value=10.0)
|
735 |
+
|
736 |
+
# Peak Parameters
|
737 |
+
with gr.Group(visible=False):
|
738 |
+
gr.Markdown("### Peak Control")
|
739 |
+
enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False)
|
740 |
+
peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200")
|
741 |
+
peak_alpha_input = gr.Number(label="Peak Alpha", value=10)
|
742 |
+
peak_weight_input = gr.Number(label="Peak Weight", value=1.0)
|
743 |
|
744 |
+
update_model_btn = gr.Button("Update Figure")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
745 |
|
746 |
gr.Markdown("## Extend Edit", visible=False)
|
747 |
with gr.Tab("Range Shift", visible=False):
|
|
|
908 |
outputs=[*plots, metrics_display]
|
909 |
)
|
910 |
|
911 |
+
app.load(
|
912 |
+
fn=update_model_callback,
|
913 |
+
inputs=[
|
914 |
+
data_points_df,
|
915 |
+
point_groups_df,
|
916 |
+
enable_area_control,
|
917 |
+
area_selections,
|
918 |
+
enable_auc,
|
919 |
+
auc_input,
|
920 |
+
auc_weight_input,
|
921 |
+
enable_peaks,
|
922 |
+
peak_points_input,
|
923 |
+
peak_alpha_input,
|
924 |
+
peak_weight_input,
|
925 |
+
enable_trending_control,
|
926 |
+
enable_trending_control_with_diff,
|
927 |
+
trending_control
|
928 |
+
],
|
929 |
+
outputs=[*plots, metrics_display]
|
930 |
+
)
|
931 |
|
932 |
return app
|
933 |
|
|
|
1062 |
|
1063 |
except Exception as e:
|
1064 |
print(f"Error: {str(e)}")
|
|
|
1065 |
# Example usage:
|
1066 |
if __name__ == "__main__":
|
|
|
|
|
|
|
1067 |
import os
|
1068 |
import torch
|
1069 |
import numpy as np
|
|
|
|
|
1070 |
|
1071 |
# assert torch.cuda.is_available(), "CUDA must be available"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1072 |
os.environ["WANDB_ENABLED"] = "false"
|
|
|
|
|
1073 |
print(os.getcwd())
|
1074 |
+
|
1075 |
+
device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
|
1076 |
+
|
1077 |
+
from models.Tiffusion import tiffusion
|
1078 |
+
|
1079 |
+
model = tiffusion.Tiffusion(
|
1080 |
+
seq_length=365,
|
1081 |
+
feature_size=3,
|
1082 |
+
n_layer_enc=6,
|
1083 |
+
n_layer_dec=4,
|
1084 |
+
d_model=128,
|
1085 |
+
timesteps=500,
|
1086 |
+
sampling_timesteps=200,
|
1087 |
+
loss_type='l1',
|
1088 |
+
beta_schedule='cosine',
|
1089 |
+
n_heads=8,
|
1090 |
+
mlp_hidden_times=4,
|
1091 |
+
attn_pd=0.0,
|
1092 |
+
resid_pd=0.0,
|
1093 |
+
kernel_size=1,
|
1094 |
+
padding_size=0,
|
1095 |
+
control_signal=[]
|
1096 |
+
).to(device)
|
1097 |
+
|
1098 |
+
model.load_state_dict(torch.load("./weight/checkpoint-10.pt", map_location=device, weights_only=True)["model"])
|
1099 |
+
|
1100 |
+
coef = 1.0e-2
|
1101 |
+
stepsize = 5.0e-2
|
1102 |
+
sampling_steps = 100 # Adjustable between 100-500 for speed/accuracy tradeoff
|
1103 |
+
seq_length = 365
|
1104 |
feature_dim = 3
|
1105 |
print(f"seq_length: {seq_length}, feature_dim: {feature_dim}")
|
1106 |
|
1107 |
+
editor = TimeSeriesEditor(seq_length, feature_dim, model)
|
|
|
1108 |
editor.coef = coef
|
1109 |
editor.stepsize = stepsize
|
1110 |
editor.sampling_steps = sampling_steps
|
1111 |
|
1112 |
app = create_gradio_interface(editor)
|
|
|
1113 |
app.launch(show_api=False)
|
models/CSDI/tiffusion.py
CHANGED
@@ -33,7 +33,7 @@ def cosine_beta_schedule(timesteps, s=0.008):
|
|
33 |
return torch.clip(betas, 0, 0.999)
|
34 |
|
35 |
|
36 |
-
class Tiffusion(
|
37 |
def __init__(
|
38 |
self,
|
39 |
seq_length,
|
@@ -111,12 +111,9 @@ class Tiffusion(CSDI_base):
|
|
111 |
config_diff["beta_start"], config_diff["beta_end"], self.num_steps
|
112 |
)
|
113 |
|
114 |
-
|
115 |
self.alpha_hat = 1 - self.beta
|
116 |
self.alpha = np.cumprod(self.alpha_hat)
|
117 |
self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
|
118 |
-
# self.beta = torch.from_numpy(self.beta).float().to(self.device)
|
119 |
-
# self.alpha = torch.from_numpy(self.alpha).float().to(self.device)
|
120 |
|
121 |
self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
|
122 |
if self.is_unconditional == False:
|
@@ -127,63 +124,8 @@ class Tiffusion(CSDI_base):
|
|
127 |
num_embeddings=self.target_dim
|
128 |
, embedding_dim=self.emb_feature_dim
|
129 |
)
|
130 |
-
# self.model: Transformer = Transformer(
|
131 |
-
# n_feat=feature_size,
|
132 |
-
# n_channel=seq_length,
|
133 |
-
# n_layer_enc=n_layer_enc,
|
134 |
-
# n_layer_dec=n_layer_dec,
|
135 |
-
# n_heads=n_heads,
|
136 |
-
# attn_pdrop=attn_pd,
|
137 |
-
# resid_pdrop=resid_pd,
|
138 |
-
# mlp_hidden_times=mlp_hidden_times,
|
139 |
-
# max_len=seq_length,
|
140 |
-
# n_embd=d_model,
|
141 |
-
# conv_params=[kernel_size, padding_size],
|
142 |
-
# **kwargs,
|
143 |
-
# )
|
144 |
-
class Config:
|
145 |
-
def __init__(self, **kwargs):
|
146 |
-
self.__dict__.update(kwargs)
|
147 |
|
148 |
-
# type: CSDI
|
149 |
-
# layers: 3
|
150 |
-
# channels: 64
|
151 |
-
# nheads: 8
|
152 |
-
# diffusion_embedding_dim: 128
|
153 |
-
# is_linear: False # linear transformer
|
154 |
-
|
155 |
-
# beta_start: 0.0001
|
156 |
-
# beta_end: 0.5
|
157 |
-
# schedule: "quad"
|
158 |
-
|
159 |
-
# num_steps: 50
|
160 |
-
|
161 |
-
# # edit
|
162 |
-
# edit_steps: 50 # the number of steps to perform editing
|
163 |
-
# bootstrap_ratio: 0.5 # [0,1]
|
164 |
-
|
165 |
-
# is_attr_proj: False
|
166 |
-
|
167 |
-
# side:
|
168 |
-
# num_var: 1
|
169 |
-
# var_emb: 16
|
170 |
-
# time_emb: 128
|
171 |
-
|
172 |
-
# attrs:
|
173 |
-
# attr_emb: 64
|
174 |
-
# config_diff["side_dim"] = self.emb_total_dim
|
175 |
self.diffmodel = diff_CSDI(
|
176 |
-
# config=Config(
|
177 |
-
# layers=3,
|
178 |
-
# channels=64,
|
179 |
-
# nheads=8,
|
180 |
-
# diffusion_embedding_dim=128,
|
181 |
-
# is_linear=False,
|
182 |
-
# beta_start=0.0001,
|
183 |
-
# beta_end=0.5,
|
184 |
-
# schedule="quad",
|
185 |
-
# num_steps=50,
|
186 |
-
# )
|
187 |
{
|
188 |
"layers": 3,
|
189 |
"channels": 64,
|
|
|
33 |
return torch.clip(betas, 0, 0.999)
|
34 |
|
35 |
|
36 |
+
class Tiffusion(nn.Module):
|
37 |
def __init__(
|
38 |
self,
|
39 |
seq_length,
|
|
|
111 |
config_diff["beta_start"], config_diff["beta_end"], self.num_steps
|
112 |
)
|
113 |
|
|
|
114 |
self.alpha_hat = 1 - self.beta
|
115 |
self.alpha = np.cumprod(self.alpha_hat)
|
116 |
self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
|
|
|
|
|
117 |
|
118 |
self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
|
119 |
if self.is_unconditional == False:
|
|
|
124 |
num_embeddings=self.target_dim
|
125 |
, embedding_dim=self.emb_feature_dim
|
126 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
self.diffmodel = diff_CSDI(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
{
|
130 |
"layers": 3,
|
131 |
"channels": 64,
|