TSEditor / experiment.py
PeterYu's picture
update
2875fe6
raw
history blame
28.3 kB
import os
import torch
os.environ["WANDB_ENABLED"] = "false"
from engine.solver import Trainer
from data.build_dataloader import build_dataloader
from utils.metric_utils import visualization, save_pdf
from data.build_dataloader import build_dataloader_cond
# from utils.metric_utils import visualization
from utils.io_utils import load_yaml_config, instantiate_from_config
from models.model_utils import unnormalize_to_zero_to_one
from scipy.signal import find_peaks, peak_prominences
# disable user warnings
import warnings
warnings.simplefilter("ignore", UserWarning)
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
import pickle
from pathlib import Path
def load_cached_results(cache_dir):
results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}}
for cache_file in cache_dir.glob("*.pkl"):
with open(cache_file, "rb") as f:
key = cache_file.stem
# if key=="unconditional":
# continue
if key == "unconditional":
results["unconditional"] = pickle.load(f)
elif key.startswith("sum_"):
param = key[4:] # Remove 'sum_' prefix
results["sum_controlled"][param] = pickle.load(f)
elif key.startswith("anchor_"):
param = key[7:] # Remove 'anchor_' prefix
results["anchor_controlled"][param] = pickle.load(f)
return results
def save_result(cache_dir, key, subkey, data):
if subkey:
filename = f"{key}_{subkey}.pkl"
else:
filename = f"{key}.pkl"
with open(cache_dir / filename, "wb") as f:
pickle.dump(data, f)
class Arguments:
def __init__(self, config_path, gpu=0) -> None:
self.config_path = config_path
# self.config_path = "./config/control/revenue-baseline-sine.yaml"
self.save_dir = (
"../../../data/" + os.path.basename(self.config_path).split(".")[0]
)
self.gpu = gpu
os.makedirs(self.save_dir, exist_ok=True)
self.mode = "infill"
self.missing_ratio = 0.95
self.milestone = 10
def beautiful_text(key, highlight):
# print(key)
if "auc" in key:
auc = key.split("_")[1]
weight = key.split("_")[3]
if highlight is None:
return f"AUC: $\\mathbf{{{auc}}}$ Weight: {weight}"
else:
return f"AUC: {auc} Weight: $\\mathbf{{{weight}}}$"
if "anchor" in key:
anchor = key.split("_")[1]
weight = key.split("_")[3]
return f"anchor: {anchor} Weight: {weight}"
return key
def get_alpha(idx, n_plots):
"""Generate alpha value between 0.3-0.8 based on plot index"""
return 0.5 + (0.4 * idx / (n_plots - 1)) if n_plots > 1 else 0.8
def create_color_gradient(
sorting_value=None, start_color="#FFFF00", end_color="#00008B"
):
"""Create color gradient using matplotlib color interpolation."""
def color_fader(c1, c2, mix=0):
"""Fade from color c1 to c2 with mix ratio."""
c1 = np.array(mpl.colors.to_rgb(c1))
c2 = np.array(mpl.colors.to_rgb(c2))
return mpl.colors.to_hex((1 - mix) * c1 + mix * c2)
if sorting_value is not None:
# Normalize values between 0-1
values = np.array(list(sorting_value.values()))
normalized = (values - values.min()) / (values.max() - values.min())
# Create color mapping
return {
key: color_fader(start_color, end_color, mix=norm_val)
for key, norm_val in zip(sorting_value.keys(), normalized)
}
else:
# Return middle point color
return color_fader(start_color, end_color, mix=0.5)
def create_color_gradient(
sorting_value=None,
start_color="#FFFF00",
middle_color="#00FF00",
end_color="#00008B",
):
"""Create color gradient using matplotlib interpolation with middle color."""
def color_fader(c1, c2, mix=0):
"""Fade from color c1 to c2 with mix ratio."""
c1 = np.array(mpl.colors.to_rgb(c1))
c2 = np.array(mpl.colors.to_rgb(c2))
return mpl.colors.to_hex((1 - mix) * c1 + mix * c2)
if sorting_value is not None:
values = np.array(list(sorting_value.values()))
normalized = (values - values.min()) / (values.max() - values.min())
colors = {}
for key, norm_val in zip(sorting_value.keys(), normalized):
if norm_val <= 0.5:
# Interpolate between start and middle
mix = norm_val * 2 # Scale 0-0.5 to 0-1
colors[key] = color_fader(start_color, middle_color, mix)
else:
# Interpolate between middle and end
mix = (norm_val - 0.5) * 2 # Scale 0.5-1 to 0-1
colors[key] = color_fader(middle_color, end_color, mix)
return colors
else:
return middle_color # Return middle color directly
# for config_path in [
# # "./config/modified/sines.yaml",
# # "./config/modified/revenue-baseline-365.yaml",
# "./config/modified/energy.yaml",
# "./config/modified/fmri.yaml",
# ]:
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Controlled Sampling")
parser.add_argument(
"--config_path", type=str, default="./config/modified/energy.yaml"
)
parser.add_argument("--gpu", type=int, default=0)
return parser.parse_args()
def run(run_args):
args = Arguments(run_args.config_path, run_args.gpu)
configs = load_yaml_config(args.config_path)
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(args.gpu)
dl_info = build_dataloader(configs, args)
model = instantiate_from_config(configs["model"]).to(device)
trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info)
# args.milestone
trainer.load("10")
dataset = dl_info["dataset"]
test_dl_info = build_dataloader_cond(configs, args)
test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"]
coef = configs["dataloader"]["test_dataset"]["coefficient"]
stepsize = configs["dataloader"]["test_dataset"]["step_size"]
sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"]
seq_length, feature_dim = test_dataset.window, test_dataset.var_num
dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0]
mapper = {
"sines": "sines",
"revenue": "revenue",
"energy": "energy",
"fmri": "fMRI",
}
gap = seq_length // 5
if seq_length in [96, 192, 384]:
ori_data = np.load(
os.path.join(
"../../../data/train/",str(seq_length),
dataset_name,
"samples",
f'{mapper[dataset_name].replace("sines", "sine")}_norm_truth_{seq_length}_train.npy',
)
)
masks = np.load(
os.path.join(
"../../../data/train/",str(seq_length),
dataset_name,
"samples",
f'{mapper[dataset_name].replace("sines", "sine")}_masking_{seq_length}.npy',
)
)
else:
ori_data = np.load(
os.path.join(
"../../../data/train/",
dataset_name,
"samples",
f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy",
)
)
masks = np.load(
os.path.join(
"../../../data/train/",
dataset_name,
"samples",
f"{mapper[dataset_name]}_masking_{seq_length}.npy",
)
)
sample_num, _, _ = masks.shape
# observed = ori_data[:sample_num] * masks
ori_data = ori_data[:sample_num]
sampling_size = min(1000, len(test_dataset), sample_num)
batch_size = 500
print(f"Sampling size: {sampling_size}, Batch size: {batch_size}")
### Cache file path
cache_dir = Path(f"../../../data/cache/{dataset_name}_{seq_length}")
if "csdi" in args.config_path:
cache_dir = Path(f"../../../data/cache/csdi/{dataset_name}_{seq_length}")
cache_dir.mkdir(exist_ok=True)
results = load_cached_results(cache_dir)
### Unconditional sampling
if results["unconditional"] is None:
print("Generating unconditional data...")
results["unconditional"] = trainer.control_sample(
num=sampling_size,
size_every=batch_size,
shape=[seq_length, feature_dim],
model_kwargs={
"gradient_control_signal": {},
"coef": coef,
"learning_rate": stepsize,
},
)
save_result(cache_dir, "unconditional", "", results["unconditional"])
### Different AUC values
auc_weights = [10]
auc_values = [-100, 20, 50, 150] # -200, -150, -100, -50, 0, 20, 30, 50, 100, 150
for auc in auc_values:
for weight in auc_weights:
key = f"auc_{auc}_weight_{weight}"
if key not in results["sum_controlled"]:
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}")
results["sum_controlled"][key] = trainer.control_sample(
num=sampling_size,
size_every=batch_size,
shape=[seq_length, feature_dim],
model_kwargs={
"gradient_control_signal": {"auc": auc, "auc_weight": weight},
"coef": coef,
"learning_rate": stepsize,
},
)
save_result(cache_dir, "sum", key, results["sum_controlled"][key])
### Different AUC weights
auc_weights = [1, 10, 50, 100]
auc_values = [-100]
for auc in auc_values:
for weight in auc_weights:
key = f"auc_{auc}_weight_{weight}"
if key not in results["sum_controlled"]:
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}")
results["sum_controlled"][key] = trainer.control_sample(
num=sampling_size // 2,
size_every=batch_size,
shape=[seq_length, feature_dim],
model_kwargs={
"gradient_control_signal": {"auc": auc, "auc_weight": weight},
"coef": coef,
"learning_rate": stepsize,
},
)
save_result(cache_dir, "sum", key, results["sum_controlled"][key])
### Different AUC segments
auc_weights = [10]
auc_values = [150]
auc_average = 10
auc_segments = ((gap, 2 * gap), (2 * gap, 3 * gap), (3 * gap, 4 * gap))
# for auc in auc_values:
# for weight in auc_weights:
# for segment in auc_segments:
auc = auc_values[0]
weight = auc_weights[0]
# segment = auc_segments[0]
for segment in auc_segments:
key = f"auc_{auc}_weight_{weight}_segment_{segment[0]}_{segment[1]}"
if key not in results["sum_controlled"]:
print(
f"Generating sum controlled data - AUC: {auc}, Weight: {weight}, Segment: {segment}"
)
results["sum_controlled"][key] = trainer.control_sample(
num=sampling_size,
size_every=batch_size,
shape=[seq_length, feature_dim],
model_kwargs={
"gradient_control_signal": {
"auc": auc_average * (segment[1] - segment[0]), # / seq_length,
"auc_weight": weight,
"segment": [segment],
},
"coef": coef,
"learning_rate": stepsize,
},
)
save_result(cache_dir, "sum", key, results["sum_controlled"][key])
# Different anchors
anchor_values = [-0.8, 0.6, 1.0]
anchor_weights = [0.01, 0.01, 0.5, 1.0]
for peak in anchor_values:
for weight in anchor_weights:
key = f"peak_{peak}_weight_{weight}"
if key not in results["anchor_controlled"]:
mask = np.zeros((seq_length, feature_dim), dtype=np.float32)
mask[gap // 2 :: gap, 0] = weight
target = np.zeros((seq_length, feature_dim), dtype=np.float32)
target[gap // 2 :: gap, 0] = peak
print(f"Anchor controlled data - Peak: {peak}, Weight: {weight}")
results["anchor_controlled"][key] = trainer.control_sample(
num=sampling_size,
size_every=batch_size,
shape=[seq_length, feature_dim],
model_kwargs={
"gradient_control_signal": {}, # "auc": -50, "auc_weight": 10.0},
"coef": coef,
"learning_rate": stepsize,
},
target=target,
partial_mask=mask,
)
save_result(cache_dir, "anchor", key, results["anchor_controlled"][key])
# plot mask, target, and generated sequence
# plt.figure(figsize=(6, 3))
# plt.scatter(
# range(gap // 2, seq_length, gap), [weight] * 5, label="Mask"
# )
# plt.scatter(
# range(gap // 2, seq_length, gap), [peak] * 5, label="Target"
# )
# plt.plot(
# results["anchor_controlled"][key][0, :, 0],
# label="Generated Sequence",
# )
# plt.title(f"Anchor Controlled Data - Peak: {peak}, Weight: {weight}")
# plt.legend()
# plt.show()
if dataset.auto_norm:
for key, data in results.items():
if isinstance(data, dict):
for subkey, subdata in data.items():
results[key][subkey] = unnormalize_to_zero_to_one(subdata)
else:
results[key] = unnormalize_to_zero_to_one(data)
results["ori_data"] = ori_data
# results tructure to sampling_size
for key, data in results.items():
if isinstance(data, dict):
for subkey, subdata in data.items():
results[key][subkey] = subdata[:sampling_size]
else:
results[key] = data[:sampling_size]
return results, dataset_name, seq_length
def ploting(results, dataset_name, seq_length):
gap = seq_length // 5
ds_name_display = {
"sines": "Synthetic Sine Waves",
"revenue": "Revenue",
"energy": "ETTh",
"fmri": "fMRI",
}
# Unnormalize results if needed
ori_data = results["ori_data"]
# Store the results in variables for compatibility with existing code
unconditional_data = results["unconditional"]
sum_controled_data = results["sum_controlled"]
# ['auc_0_weight_10.0'] # default values
anchor_controled_data = results["anchor_controlled"]
# ['anchor_0.8_weight_0.1'] # default values
### Visualization
def kernel_subplots(
data, output_label="", highlight=None
):
# from scipy import integrate
# Calculate area under curve for each distribution
def get_auc(data_array):
return data_array.sum(-1).mean()
# Get AUC values
auc_orig = get_auc(data["ori_data"])
auc_uncond = get_auc(data["Unconditional"])
# Setup subplots
keys = [k for k in data.keys() if k not in ["ori_data", "Unconditional"]]
l = len(keys)
n_cols = min(4, len(keys))
n_rows = (len(keys) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows))
fig.set_dpi(300)
if n_rows == 1:
axes = axes.reshape(1, -1)
for idx, key in enumerate(keys):
row, col = idx // n_cols, idx % n_cols
ax = axes[row, col]
# Plot distributions
sns.distplot(
data["ori_data"],
hist=False,
kde=True,
kde_kws={"linewidth": 2, "alpha": 0.9 - get_alpha(idx, l) * 0.5},
color="red",
ax=ax,
label=f"Original\n$\overline{{Area}}={auc_orig:.3f}$",
)
sns.distplot(
data["Unconditional"],
hist=False,
kde=True,
kde_kws={
"linewidth": 2,
"linestyle": "--",
"alpha": 0.9 - get_alpha(idx, l) * 0.5,
},
color="#15B01A",
ax=ax, # FF4500 GREEN:15B01A
label=f"Unconditional\n$\overline{{Area}}= {auc_uncond:.3f}$",
)
auc_control = get_auc(data[key])
sns.distplot(
data[key],
hist=False,
kde=True,
kde_kws={"linewidth": 2, "alpha": get_alpha(idx, l), "linestyle": "--"},
color="#9A0EEA",
ax=ax,
label=f"{beautiful_text(key, highlight)}\n$\overline{{Area}}= {auc_control:.3f})$",
)
# ax.set_title(f'{beautiful_text(key)}')
ax.legend()
# Set labels only for first column and last row
if col == 0:
ax.set_ylabel("Density")
else:
ax.set_ylabel("")
if row == n_rows - 1:
ax.set_xlabel("Value")
else:
ax.set_xlabel("")
# fig.suptitle(f"Kernel Density Estimation of {output_label}", fontsize=16)#, fontweight='bold')
plt.tight_layout()
plt.show()
# save pdf
# plt.savefig(f"./figures/{output_label}_kde.pdf", bbox_inches='tight')
save_pdf(fig, f"./figures/{output_label}_kde.pdf")
plt.close()
# Sum control
samples = 1000
data = {
"ori_data": ori_data[:samples, :, :1],
"Unconditional": unconditional_data[:samples, :, :1],
}
for key in [
# "auc_-200_weight_10",
"auc_-100_weight_10",
# "auc_0_weight_10",
"auc_20_weight_10",
# "auc_30_weight_10",
"auc_50_weight_10",
# "auc_100_weight_10",
"auc_150_weight_10",
]:
data[key] = sum_controled_data[key][:samples, :, :1]
print(
key,
" ==> ",
sum_controled_data[key][:samples, :, :1].sum()
/ sum_controled_data[key][:samples, :, :1].shape[0],
)
# visualization_control(
# data=data,
# analysis="kernel",
# compare=ori_data.shape[0],
# output_label="revenue"
# )
# Updated
# kernel_subplots(
# data=data,
# output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control"
# )
data = {
"ori_data": ori_data[:samples, :, :1],
"Unconditional": unconditional_data[:samples, :, :1],
}
for key in [
"auc_-100_weight_1",
"auc_-100_weight_10",
"auc_-100_weight_50",
"auc_-100_weight_100",
]:
data[key] = sum_controled_data[key][:samples, :, :1]
# print sum
print(
key,
" ==> ",
sum_controled_data[key][:samples, :, :1].sum()
/ sum_controled_data[key][:samples, :, :1].shape[0],
)
kernel_subplots(
data=data,
analysis="kernel",
compare=ori_data.shape[0],
output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control",
highlight="weight",
)
# anchor control
data = {
"ori_data": ori_data[:samples, :, :1],
"Unconditional": unconditional_data[:samples, :, :1],
}
# anchor_values = [-0.8, 0.6, 1.0]
# anchor_weights = [0.01, 0.01, 0.5, 1.0]
for key in [
"anchor_-0.8_weight_0.01",
"anchor_-0.8_weight_0.1",
"anchor_-0.8_weight_0.5",
"anchor_-0.8_weight_1.0",
"anchor_0.6_weight_0.01",
"anchor_0.6_weight_0.1",
"anchor_0.6_weight_0.5",
"anchor_0.6_weight_1.0",
"anchor_1.0_weight_0.01",
"anchor_1.0_weight_0.1",
"anchor_1.0_weight_0.5",
"anchor_1.0_weight_1.0",
]:
data[key] = anchor_controled_data[key][:samples, :, :1]
# print anchor
# print(key, " ==> ", anchor_controled_data[key][:samples, :, :1].max())
def visualization_control_anchor_subplots(
data, seq_length, analysis="anchor", compare=100, output_label=""
):
# Extract unique anchors and weights
anchors = sorted(
set([float(k.split("_")[1]) for k in data.keys() if "anchor" in k])
)
weights = sorted(
set([float(k.split("_")[3]) for k in data.keys() if "weight" in k])
)
# Create subplot grid
n_rows = len(anchors)
n_cols = len(weights)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows))
fig.set_dpi(300)
gap = seq_length // 5
for i, anchor in enumerate(anchors):
for j, weight in enumerate(weights):
ax = axes[i][j]
key = f"anchor_{anchor}_weight_{weight}"
# Plot distributions
sns.distplot(
data["ori_data"],
hist=False,
kde=True,
kde_kws={"linewidth": 2},
color="red",
ax=ax,
label="Original",
)
sns.distplot(
data["Unconditional"],
hist=False,
kde=True,
kde_kws={"linewidth": 2, "linestyle": "--"},
color="#15B01A",
ax=ax,
label="Unconditional",
)
if key in data:
sns.distplot(
data[key],
hist=False,
kde=True,
kde_kws={"linewidth": 2, "linestyle": "--"},
color="#9A0EEA",
ax=ax,
label=f"Controlled\n$Target={anchor}, Conf={weight}$",
)
# anchor_point = int(anchor * seq_length)
anchor_points = np.arange(gap // 2, seq_length, gap)
for anchor_point in anchor_points:
ax.axvline(
x=anchor_point / seq_length,
color="black",
linestyle="--",
alpha=0.5,
)
# Labels and titles
if i == n_rows - 1:
ax.set_xlabel("Value")
if j == 0:
ax.set_ylabel("Density")
ax.set_title(f"anchor={anchor}, Weight={weight}")
ax.legend()
plt.tight_layout()
plt.show()
# save_pdf(fig, f"./figures/{output_label}_anchor_kde.pdf")
plt.close()
# Anchor Control Distribution
visualization_control_anchor_subplots(
data=data,
seq_length=seq_length,
analysis="anchor",
compare=ori_data.shape[0],
output_label=f"{ds_name_display[dataset_name]} Dataset with Anchor Control",
)
def evaluate_anchor_detection(
data, target_anchors, window_size=7, min_distance=5, prominence_threshold=0.1
):
"""
Evaluate anchor detection accuracy by comparing detected anchors with target anchors.
Parameters:
data: numpy array of shape (batch_size, seq_length, features)
The generated sequences to analyze
The indices where anchors should occur (e.g., every 7 steps for weekly anchors)
target_anchor: list
List of indices where anchors should occur
window_size: int
Size of window to consider a anchor match
"""
batch_size, seq_length, features = data.shape
detected_anchors = []
accuracy_metrics = {}
# Create figure for visualization
fig, axes = plt.subplots(2, 2, figsize=(10, 5))
axes = axes.flatten()
# Analyze first 8 batches and first feature (revenue)
overall_matched = 0
overall_targets = 0
for i in range(4):
sequence = data[i, :, 0] # batch i, all timepoints, revenue feature
# Find anchors using scipy
anchors, properties = find_peaks(
sequence, distance=min_distance, prominence=prominence_threshold
)
# Plot original sequence and detected anchors
axes[i].plot(sequence, label="Generated")
# Plot target anchor positions
target_positions = (
target_anchors # np.arange(0, seq_length, 7) # Weekly anchors
)
axes[i].plot(
target_positions,
sequence[target_positions],
"o",
label="Target" if i == 1 else "",
)
axes[i].plot(
anchors, sequence[anchors], "x", label="Detected" if i == 1 else ""
)
axes[i].set_title(f"Sequence {i+1}")
if i == 1:
axes[i].legend(bbox_to_anchor=(1.05, 1), loc="upper left")
axes[i].grid(True)
# Count matches within window for this sequence
matched_anchors = 0
for target in target_positions:
# Check if any detected anchor is within the window of the target
matches = np.any(
(anchors >= target - window_size // 2)
& (anchors <= target + window_size // 2)
)
if matches:
matched_anchors += 1
overall_matched += matched_anchors
overall_targets += len(target_positions)
for i in range(4, batch_size):
anchors, properties = find_peaks(
data[i, :, 0], distance=min_distance, prominence=prominence_threshold
)
matched_anchors = 0
for target in target_anchors:
matches = np.any(
(anchors >= target - window_size // 2)
& (anchors <= target + window_size // 2)
)
if matches:
matched_anchors += 1
overall_matched += matched_anchors
overall_targets += len(target_anchors)
# Calculate overall metrics
accuracy = overall_matched / overall_targets
precision = overall_matched / (len(anchors) * 8) if len(anchors) > 0 else 0
accuracy_metrics = {
"accuracy": accuracy,
"precision": precision,
"total_targets": overall_targets,
"detected_anchors": len(anchors) * 8,
"matched_anchors": overall_matched,
}
plt.tight_layout()
plt.show()
return accuracy_metrics, anchors
# Evaluate anchor detection for different control settings
anchor_accuracies = {}
for key, data in anchor_controled_data.items():
print(f"\nEvaluating {key}")
metrics, anchors = evaluate_anchor_detection(
data,
target_anchors=range(0, seq_length, gap),
window_size=max(1, gap // 2),
min_distance=max(1, gap - 1),
)
anchor_accuracies[key] = metrics
print(f"Accuracy: {metrics['accuracy']:.3f}")
print(f"Precision: {metrics['precision']:.3f}")
print(
f"Matched anchors: {metrics['matched_anchors']} / {metrics['total_targets']}"
)
print("=" * 50)
if __name__ == "__main__":
args = parse_args()
run(args)