Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
os.environ["WANDB_ENABLED"] = "false" | |
from engine.solver import Trainer | |
from data.build_dataloader import build_dataloader | |
from utils.metric_utils import visualization, save_pdf | |
# from utils.metric_utils import visualization | |
from utils.io_utils import load_yaml_config, instantiate_from_config | |
from models.model_utils import unnormalize_to_zero_to_one | |
from scipy.signal import find_peaks, peak_prominences | |
# disable user warnings | |
import warnings | |
warnings.simplefilter("ignore", UserWarning) | |
import scipy.stats | |
import numpy as np | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
from sklearn.decomposition import PCA | |
class Arguments: | |
def __init__(self, config_path) -> None: | |
self.config_path = config_path | |
# self.config_path = "./config/control/revenue-baseline-sine.yaml" | |
self.save_dir = ( | |
"../../../data/" + os.path.basename(self.config_path).split(".")[0] | |
) | |
self.gpu = 0 | |
os.makedirs(self.save_dir, exist_ok=True) | |
self.mode = "infill" | |
self.missing_ratio = 0.95 | |
self.milestone = 10 | |
import numpy as np | |
import matplotlib as mpl | |
def create_color_gradient(sorting_value=None, start_color='#FFFF00', end_color='#00008B'): | |
"""Create color gradient using matplotlib color interpolation.""" | |
def color_fader(c1, c2, mix=0): | |
"""Fade from color c1 to c2 with mix ratio.""" | |
c1 = np.array(mpl.colors.to_rgb(c1)) | |
c2 = np.array(mpl.colors.to_rgb(c2)) | |
return mpl.colors.to_hex((1-mix)*c1 + mix*c2) | |
if sorting_value is not None: | |
# Normalize values between 0-1 | |
values = np.array(list(sorting_value.values())) | |
normalized = (values - values.min()) / (values.max() - values.min()) | |
# Create color mapping | |
return { | |
key: color_fader(start_color, end_color, mix=norm_val) | |
for key, norm_val in zip(sorting_value.keys(), normalized) | |
} | |
else: | |
# Return middle point color | |
return color_fader(start_color, end_color, mix=0.5) | |
def create_color_gradient(sorting_value=None, start_color='#FFFF00', middle_color='#00FF00', end_color='#00008B'): | |
"""Create color gradient using matplotlib interpolation with middle color.""" | |
def color_fader(c1, c2, mix=0): | |
"""Fade from color c1 to c2 with mix ratio.""" | |
c1 = np.array(mpl.colors.to_rgb(c1)) | |
c2 = np.array(mpl.colors.to_rgb(c2)) | |
return mpl.colors.to_hex((1-mix)*c1 + mix*c2) | |
if sorting_value is not None: | |
values = np.array(list(sorting_value.values())) | |
normalized = (values - values.min()) / (values.max() - values.min()) | |
colors = {} | |
for key, norm_val in zip(sorting_value.keys(), normalized): | |
if norm_val <= 0.5: | |
# Interpolate between start and middle | |
mix = norm_val * 2 # Scale 0-0.5 to 0-1 | |
colors[key] = color_fader(start_color, middle_color, mix) | |
else: | |
# Interpolate between middle and end | |
mix = (norm_val - 0.5) * 2 # Scale 0.5-1 to 0-1 | |
colors[key] = color_fader(middle_color, end_color, mix) | |
return colors | |
else: | |
return middle_color # Return middle color directly | |
def evaluate_peak_detection(data, target_peaks, window_size=7, min_distance=5, prominence_threshold=0.1): | |
""" | |
Evaluate peak detection accuracy by comparing detected peaks with target peaks. | |
Parameters: | |
data: numpy array of shape (batch_size, seq_length, features) | |
The generated sequences to analyze | |
The indices where peaks should occur (e.g., every 7 steps for weekly peaks) | |
target_peak: list | |
List of indices where peaks should occur | |
window_size: int | |
Size of window to consider a peak match | |
""" | |
batch_size, seq_length, features = data.shape | |
detected_peaks = [] | |
accuracy_metrics = {} | |
# Create figure for visualization | |
fig, axes = plt.subplots(4, 2, figsize=(20, 12)) | |
axes = axes.flatten() | |
# Analyze first 8 batches and first feature (revenue) | |
overall_matched = 0 | |
overall_targets = 0 | |
for i in range(8): | |
sequence = data[i, :, 0] # batch i, all timepoints, revenue feature | |
# Find peaks using scipy | |
peaks, properties = find_peaks(sequence, | |
distance=min_distance, | |
prominence=prominence_threshold) | |
# Plot original sequence and detected peaks | |
axes[i].plot(sequence, label='Generated Sequence') | |
axes[i].plot(peaks, sequence[peaks], "x", label='Detected Peaks') | |
# Plot target peak positions | |
target_positions = target_peaks # np.arange(0, seq_length, 7) # Weekly peaks | |
axes[i].plot(target_positions, sequence[target_positions], "o", | |
label='Target Peak Positions') | |
axes[i].set_title(f'Sequence {i+1} Peak Detection Analysis') | |
axes[i].legend() | |
axes[i].grid(True) | |
# Count matches within window for this sequence | |
matched_peaks = 0 | |
for target in target_positions: | |
# Check if any detected peak is within the window of the target | |
matches = np.any((peaks >= target - window_size//2) & | |
(peaks <= target + window_size//2)) | |
if matches: | |
matched_peaks += 1 | |
overall_matched += matched_peaks | |
overall_targets += len(target_positions) | |
for i in range(8, batch_size): | |
peaks, properties = find_peaks(data[i, :, 0], distance=min_distance, prominence=prominence_threshold) | |
matched_peaks = 0 | |
for target in target_peaks: | |
matches = np.any((peaks >= target - window_size//2) & | |
(peaks <= target + window_size//2)) | |
if matches: | |
matched_peaks += 1 | |
overall_matched += matched_peaks | |
overall_targets += len(target_peaks) | |
# Calculate overall metrics | |
accuracy = overall_matched / overall_targets | |
precision = overall_matched / (len(peaks) * 8) if len(peaks) > 0 else 0 | |
accuracy_metrics = { | |
'accuracy': accuracy, | |
'precision': precision, | |
'total_targets': overall_targets, | |
'detected_peaks': len(peaks) * 8, | |
'matched_peaks': overall_matched | |
} | |
plt.tight_layout() | |
plt.show() | |
return accuracy_metrics, peaks | |
for config_path in [ | |
"./config/modified/sines.yaml", | |
"./config/modified/revenue-baseline-365.yaml", | |
"./config/modified/energy.yaml", | |
"./config/modified/fmri.yaml", | |
]: | |
args = Arguments(config_path) | |
configs = load_yaml_config(args.config_path) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
torch.cuda.set_device(args.gpu) | |
dl_info = build_dataloader(configs, args) | |
model = instantiate_from_config(configs["model"]).to(device) | |
trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info) | |
# trainer.load(args.milestone, from_folder="../../../data/ckpt_baseline_240") | |
# trainer.train() | |
from data.build_dataloader import build_dataloader_cond | |
# args.milestone | |
trainer.load("10") | |
test_dl_info = build_dataloader_cond(configs, args) | |
test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"] | |
coef = configs["dataloader"]["test_dataset"]["coefficient"] | |
stepsize = configs["dataloader"]["test_dataset"]["step_size"] | |
sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"] | |
seq_length, feature_dim = test_dataset.window, test_dataset.var_num | |
# samples, ori_data, masks = trainer.restore( | |
# test_dataloader, | |
# [seq_length, feature_dim], | |
# coef, | |
# stepsize, | |
# sampling_steps, | |
# control_signal={}, | |
# # test= | |
# ) | |
# if test_dataset.auto_norm: | |
# samples = unnormalize_to_zero_to_one(samples) | |
# ori_data = np.load(os.path.join(dataset.dir, f"sine_ground_truth_{seq_length}_test.npy")) | |
dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0] | |
mapper = { | |
"sines": "sines", | |
"revenue": "revenue", | |
"energy": "energy", | |
"fmri": "fMRI", | |
} | |
gap = seq_length // 5 | |
ori_data = np.load( | |
os.path.join("../../../data/train/", dataset_name, "samples", f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy") | |
) | |
masks = np.load(os.path.join("../../../data/train/", dataset_name, "samples", f"{mapper[dataset_name]}_masking_{seq_length}.npy")) | |
sample_num, seq_len, feat_dim = masks.shape | |
observed = ori_data[:sample_num] * masks | |
ori_data = ori_data[:sample_num] | |
import pickle | |
from pathlib import Path | |
# Cache file path | |
cache_dir = Path(f"../../../data/cache_{dataset_name}") | |
cache_dir.mkdir(exist_ok=True) | |
def load_cached_results(): | |
results = {'unconditional': None, 'sum_controlled': {}, 'anchor_controlled': {}} | |
for cache_file in cache_dir.glob('*.pkl'): | |
with open(cache_file, 'rb') as f: | |
key = cache_file.stem | |
if key == 'unconditional': | |
results['unconditional'] = pickle.load(f) | |
elif key.startswith('sum_'): | |
param = key[4:] # Remove 'sum_' prefix | |
results['sum_controlled'][param] = pickle.load(f) | |
elif key.startswith('anchor_'): | |
param = key[7:] # Remove 'anchor_' prefix | |
results['anchor_controlled'][param] = pickle.load(f) | |
return results | |
def save_result(key, subkey, data): | |
if subkey: | |
filename = f"{key}_{subkey}.pkl" | |
else: | |
filename = f"{key}.pkl" | |
with open(cache_dir / filename, 'wb') as f: | |
pickle.dump(data, f) | |
results = load_cached_results() | |
dataset = dl_info["dataset"] | |
seq_length, feature_dim = dataset.window, dataset.var_num | |
coef = configs["dataloader"]["test_dataset"]["coefficient"] | |
stepsize = configs["dataloader"]["test_dataset"]["step_size"] | |
# Unconditional sampling | |
if results['unconditional'] is None: | |
print("Generating unconditional data...") | |
results['unconditional'] = trainer.sample( | |
num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim] | |
) | |
save_result('unconditional', None, results['unconditional']) | |
# Different AUC weights | |
auc_weights = [10,] | |
auc_values = [-200, -150, -100, 0, 20, 30, 50, 100, 150] | |
for auc in auc_values: | |
for weight in auc_weights: | |
key = f"auc_{auc}_weight_{weight}" | |
if key not in results['sum_controlled']: | |
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
results['sum_controlled'][key] = trainer.control_sample( | |
num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
"coef": coef, | |
"learning_rate": stepsize | |
} | |
) | |
save_result('sum', key, results['sum_controlled'][key]) | |
auc_weights = [1, 10, 50, 100] | |
auc_values = [-200,] | |
for auc in auc_values: | |
for weight in auc_weights: | |
key = f"auc_{auc}_weight_{weight}" | |
if key not in results['sum_controlled']: | |
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
results['sum_controlled'][key] = trainer.control_sample( | |
num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
"coef": coef, | |
"learning_rate": stepsize | |
} | |
) | |
save_result('sum', key, results['sum_controlled'][key]) | |
# Different weekly peaks | |
peak_values = [0.8, 1.0] | |
peak_weights = [0.1, 0.5, 1.0] | |
# import matplotlib.pyplot as plt | |
# for peak in peak_values: | |
# for weight in peak_weights: | |
# key = f"peak_{peak}_weight_{weight}" | |
# if key not in results['anchor_controlled']: | |
# mask = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
# mask[::gap, 0] = weight | |
# target = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
# target[::gap, 0] = peak | |
# print(f"Generating anchor controlled data - Peak: {peak}, Weight: {weight}") | |
# results['anchor_controlled'][key] = trainer.control_sample( | |
# num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], | |
# model_kwargs={ | |
# "gradient_control_signal": {"auc": -50, "auc_weight": 10.0}, | |
# "coef": coef, | |
# "learning_rate": stepsize | |
# }, | |
# target=target, | |
# partial_mask=mask | |
# ) | |
# save_result('anchor', key, results['anchor_controlled'][key]) | |
# # plot mask, target, and generated sequence | |
# plt.figure(figsize=(12, 6)) | |
# plt.plot(mask[:, 0], label='Mask') | |
# plt.plot(target[:, 0], label='Target') | |
# plt.plot(results['anchor_controlled'][key][0, :, 0], label='Generated Sequence') | |
# plt.title(f"Anchor Controlled Data - Peak: {peak}, Weight: {weight}") | |
# plt.legend() | |
# plt.show() | |
# Unnormalize results if needed | |
if dataset.auto_norm: | |
for key, data in results.items(): | |
if isinstance(data, dict): | |
for subkey, subdata in data.items(): | |
results[key][subkey] = unnormalize_to_zero_to_one(subdata) | |
else: | |
results[key] = unnormalize_to_zero_to_one(data) | |
# Store the results in variables for compatibility with existing code | |
unconditional_data = results['unconditional'] | |
sum_controled_data = results['sum_controlled']# ['auc_0_weight_10.0'] # default values | |
anchor_controled_data = results['anchor_controlled'] # ['peak_0.8_weight_0.1'] # default values | |
# Sum control | |
samples = 1000 | |
data = { | |
"ori_data": ori_data[:samples, :, :1], | |
"Unconditional": unconditional_data[:samples, :, :1], | |
} | |
# for key, value in sum_controled_data.items(): | |
# if "weight_10" in key: | |
# data[key] = value | |
# print(key) | |
keys = [ | |
# "auc_-200_weight_10", | |
"auc_-100_weight_10", | |
# "auc_0_weight_10", | |
"auc_20_weight_10", | |
# "auc_30_weight_10", | |
"auc_50_weight_10", | |
# "auc_100_weight_10", | |
"auc_150_weight_10", | |
] | |
for key in keys: | |
data[key] = sum_controled_data[key][:samples, :, :1] | |
# print sum | |
print(key, " ==> ", sum_controled_data[key][:samples, :, :1].sum() / sum_controled_data[key][:samples, :, :1].shape[0]) | |
# visualization_control( | |
# data=data, | |
# analysis="kernel", | |
# compare=ori_data.shape[0], | |
# output_label="revenue" | |
# ) | |
def visualization_control_subplots(data, analysis="kernel", compare=100, output_label="", highlight=None): | |
# from scipy import integrate | |
# Calculate area under curve for each distribution | |
def get_auc(data_array): | |
return data_array.sum(-1).mean() | |
# Get AUC values | |
auc_orig = get_auc(data["ori_data"]) | |
auc_uncond = get_auc(data["Unconditional"]) | |
# Setup subplots | |
keys = [k for k in data.keys() if k not in ["ori_data", "Unconditional"]] | |
l = len(keys) | |
n_cols = min(4, len(keys)) | |
n_rows = (len(keys) + n_cols - 1) // n_cols | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 4*n_rows)) | |
fig.set_dpi(300) | |
if n_rows == 1: | |
axes = axes.reshape(1, -1) | |
def beautiful_text(key): | |
print(key) | |
if "auc" in key: | |
auc = key.split("_")[1] | |
weight = key.split("_")[3] | |
if highlight is None: | |
return f"AUC: $\\mathbf{{{auc}}}$ Weight: {weight}" | |
else: | |
return f"AUC: {auc} Weight: $\\mathbf{{{weight}}}$" | |
if "peak" in key: | |
peak = key.split("_")[1] | |
weight = key.split("_")[3] | |
return f"Peak: {peak} Weight: {weight}" | |
return key | |
# Plot distributions | |
# colors = create_color_gradient({key: get_auc(data[key]) for key in keys}, '#004225','#F02147', '#4B0082') | |
def get_alpha(idx, n_plots): | |
"""Generate alpha value between 0.3-0.8 based on plot index""" | |
return 0.5 + (0.4 * idx / (n_plots - 1)) if n_plots > 1 else 0.8 | |
for idx, key in enumerate(keys): | |
row, col = idx // n_cols, idx % n_cols | |
ax = axes[row, col] | |
# Plot distributions | |
sns.distplot(data["ori_data"], hist=False, kde=True, | |
kde_kws={"linewidth": 2, "alpha": 0.9 - get_alpha(idx, l) * 0.5}, color='red', | |
ax=ax, label=f'Original\n$\overline{{Area}}={auc_orig:.3f}$') | |
sns.distplot(data["Unconditional"], hist=False, kde=True, | |
kde_kws={"linewidth": 2, "linestyle":"--", "alpha": 0.9 - get_alpha(idx, l) * 0.5}, | |
color='#15B01A', ax=ax, #FF4500 GREEN:15B01A | |
label=f'Unconditional\n$\overline{{Area}}= {auc_uncond:.3f}$') | |
auc_control = get_auc(data[key]) | |
sns.distplot(data[key], hist=False, kde=True, | |
kde_kws={"linewidth": 2, "alpha": get_alpha(idx, l), "linestyle": "--"}, color="#9A0EEA", | |
ax=ax, label=f'{beautiful_text(key)}\n$\overline{{Area}}= {auc_control:.3f})$') | |
# ax.set_title(f'{beautiful_text(key)}') | |
ax.legend() | |
# Set labels only for first column and last row | |
if col == 0: ax.set_ylabel('Density') | |
else: ax.set_ylabel('') | |
if row == n_rows - 1: ax.set_xlabel('Value') | |
else: ax.set_xlabel('') | |
fig.suptitle(f"Kernel Density Estimation of {output_label}", fontsize=16)#, fontweight='bold') | |
plt.tight_layout() | |
plt.show() | |
# save pdf | |
# plt.savefig(f"./figures/{output_label}_kde.pdf", bbox_inches='tight') | |
save_pdf(fig, f"./figures/{output_label}_kde.pdf") | |
plt.close() | |
ds_name_display = { | |
"sines": "Synthetic Sine Waves", | |
"revenue": "Revenue", | |
"energy": "ETTh", | |
"fmri": "fMRI", | |
} | |
visualization_control_subplots( | |
data=data, | |
analysis="kernel", | |
compare=ori_data.shape[0], | |
output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control" | |
) | |
# peak control | |
# data = { | |
# "ori_data": ori_data[:samples, :, :1], | |
# "Unconditional": unconditional_data[:samples, :, :1], | |
# } | |
# keys = [ | |
# "peak_0.8_weight_0.1", | |
# "peak_0.8_weight_0.5", | |
# "peak_0.8_weight_1.0", | |
# "peak_1.0_weight_0.1", | |
# "peak_1.0_weight_0.5", | |
# "peak_1.0_weight_1.0", | |
# ] | |
# for key in keys: | |
# data[key] = anchor_controled_data[key][:samples, :, :1] | |
# # print peak | |
# print(key, " ==> ", anchor_controled_data[key][:samples, :, :1].max()) | |
# visualization_control( | |
# data=data, | |
# analysis="kernel", | |
# compare=ori_data.shape[0], | |
# output_label="revenue" | |
# ) | |
# # config_mapping = { | |
# # "sines": { | |
# # } | |
# # "revenue": "revenue", | |
# # "energy": "energy", | |
# # "fmri": "fMRI", | |
# # } | |
# # Evaluate peak detection for different control settings | |
# peak_accuracies = {} | |
# for key, data in anchor_controled_data.items(): | |
# print(f"\nEvaluating {key}") | |
# metrics, peaks = evaluate_peak_detection( | |
# data, | |
# target_peaks=range(0, seq_length, gap), | |
# window_size=max(1, gap//2), | |
# min_distance=max(1, gap - 1) | |
# ) | |
# peak_accuracies[key] = metrics | |
# print(f"Accuracy: {metrics['accuracy']:.3f}") | |
# print(f"Precision: {metrics['precision']:.3f}") | |
# print(f"Matched peaks: {metrics['matched_peaks']} / {metrics['total_targets']}") | |
print("="*50) | |