import os import torch import numpy as np os.environ["WANDB_ENABLED"] = "false" from engine.solver import Trainer from data.build_dataloader import build_dataloader from utils.metric_utils import visualization, save_pdf # from utils.metric_utils import visualization from utils.io_utils import load_yaml_config, instantiate_from_config from models.model_utils import unnormalize_to_zero_to_one from scipy.signal import find_peaks, peak_prominences # disable user warnings import warnings warnings.simplefilter("ignore", UserWarning) import scipy.stats import numpy as np import seaborn as sns import matplotlib.pyplot as plt from sklearn.manifold import TSNE from sklearn.decomposition import PCA class Arguments: def __init__(self, config_path) -> None: self.config_path = config_path # self.config_path = "./config/control/revenue-baseline-sine.yaml" self.save_dir = ( "../../../data/" + os.path.basename(self.config_path).split(".")[0] ) self.gpu = 0 os.makedirs(self.save_dir, exist_ok=True) self.mode = "infill" self.missing_ratio = 0.95 self.milestone = 10 import numpy as np import matplotlib as mpl def create_color_gradient(sorting_value=None, start_color='#FFFF00', end_color='#00008B'): """Create color gradient using matplotlib color interpolation.""" def color_fader(c1, c2, mix=0): """Fade from color c1 to c2 with mix ratio.""" c1 = np.array(mpl.colors.to_rgb(c1)) c2 = np.array(mpl.colors.to_rgb(c2)) return mpl.colors.to_hex((1-mix)*c1 + mix*c2) if sorting_value is not None: # Normalize values between 0-1 values = np.array(list(sorting_value.values())) normalized = (values - values.min()) / (values.max() - values.min()) # Create color mapping return { key: color_fader(start_color, end_color, mix=norm_val) for key, norm_val in zip(sorting_value.keys(), normalized) } else: # Return middle point color return color_fader(start_color, end_color, mix=0.5) def create_color_gradient(sorting_value=None, start_color='#FFFF00', middle_color='#00FF00', end_color='#00008B'): """Create color gradient using matplotlib interpolation with middle color.""" def color_fader(c1, c2, mix=0): """Fade from color c1 to c2 with mix ratio.""" c1 = np.array(mpl.colors.to_rgb(c1)) c2 = np.array(mpl.colors.to_rgb(c2)) return mpl.colors.to_hex((1-mix)*c1 + mix*c2) if sorting_value is not None: values = np.array(list(sorting_value.values())) normalized = (values - values.min()) / (values.max() - values.min()) colors = {} for key, norm_val in zip(sorting_value.keys(), normalized): if norm_val <= 0.5: # Interpolate between start and middle mix = norm_val * 2 # Scale 0-0.5 to 0-1 colors[key] = color_fader(start_color, middle_color, mix) else: # Interpolate between middle and end mix = (norm_val - 0.5) * 2 # Scale 0.5-1 to 0-1 colors[key] = color_fader(middle_color, end_color, mix) return colors else: return middle_color # Return middle color directly def evaluate_peak_detection(data, target_peaks, window_size=7, min_distance=5, prominence_threshold=0.1): """ Evaluate peak detection accuracy by comparing detected peaks with target peaks. Parameters: data: numpy array of shape (batch_size, seq_length, features) The generated sequences to analyze The indices where peaks should occur (e.g., every 7 steps for weekly peaks) target_peak: list List of indices where peaks should occur window_size: int Size of window to consider a peak match """ batch_size, seq_length, features = data.shape detected_peaks = [] accuracy_metrics = {} # Create figure for visualization fig, axes = plt.subplots(4, 2, figsize=(20, 12)) axes = axes.flatten() # Analyze first 8 batches and first feature (revenue) overall_matched = 0 overall_targets = 0 for i in range(8): sequence = data[i, :, 0] # batch i, all timepoints, revenue feature # Find peaks using scipy peaks, properties = find_peaks(sequence, distance=min_distance, prominence=prominence_threshold) # Plot original sequence and detected peaks axes[i].plot(sequence, label='Generated Sequence') axes[i].plot(peaks, sequence[peaks], "x", label='Detected Peaks') # Plot target peak positions target_positions = target_peaks # np.arange(0, seq_length, 7) # Weekly peaks axes[i].plot(target_positions, sequence[target_positions], "o", label='Target Peak Positions') axes[i].set_title(f'Sequence {i+1} Peak Detection Analysis') axes[i].legend() axes[i].grid(True) # Count matches within window for this sequence matched_peaks = 0 for target in target_positions: # Check if any detected peak is within the window of the target matches = np.any((peaks >= target - window_size//2) & (peaks <= target + window_size//2)) if matches: matched_peaks += 1 overall_matched += matched_peaks overall_targets += len(target_positions) for i in range(8, batch_size): peaks, properties = find_peaks(data[i, :, 0], distance=min_distance, prominence=prominence_threshold) matched_peaks = 0 for target in target_peaks: matches = np.any((peaks >= target - window_size//2) & (peaks <= target + window_size//2)) if matches: matched_peaks += 1 overall_matched += matched_peaks overall_targets += len(target_peaks) # Calculate overall metrics accuracy = overall_matched / overall_targets precision = overall_matched / (len(peaks) * 8) if len(peaks) > 0 else 0 accuracy_metrics = { 'accuracy': accuracy, 'precision': precision, 'total_targets': overall_targets, 'detected_peaks': len(peaks) * 8, 'matched_peaks': overall_matched } plt.tight_layout() plt.show() return accuracy_metrics, peaks for config_path in [ "./config/modified/sines.yaml", "./config/modified/revenue-baseline-365.yaml", "./config/modified/energy.yaml", "./config/modified/fmri.yaml", ]: args = Arguments(config_path) configs = load_yaml_config(args.config_path) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(args.gpu) dl_info = build_dataloader(configs, args) model = instantiate_from_config(configs["model"]).to(device) trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info) # trainer.load(args.milestone, from_folder="../../../data/ckpt_baseline_240") # trainer.train() from data.build_dataloader import build_dataloader_cond # args.milestone trainer.load("10") test_dl_info = build_dataloader_cond(configs, args) test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"] coef = configs["dataloader"]["test_dataset"]["coefficient"] stepsize = configs["dataloader"]["test_dataset"]["step_size"] sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"] seq_length, feature_dim = test_dataset.window, test_dataset.var_num # samples, ori_data, masks = trainer.restore( # test_dataloader, # [seq_length, feature_dim], # coef, # stepsize, # sampling_steps, # control_signal={}, # # test= # ) # if test_dataset.auto_norm: # samples = unnormalize_to_zero_to_one(samples) # ori_data = np.load(os.path.join(dataset.dir, f"sine_ground_truth_{seq_length}_test.npy")) dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0] mapper = { "sines": "sines", "revenue": "revenue", "energy": "energy", "fmri": "fMRI", } gap = seq_length // 5 ori_data = np.load( os.path.join("../../../data/train/", dataset_name, "samples", f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy") ) masks = np.load(os.path.join("../../../data/train/", dataset_name, "samples", f"{mapper[dataset_name]}_masking_{seq_length}.npy")) sample_num, seq_len, feat_dim = masks.shape observed = ori_data[:sample_num] * masks ori_data = ori_data[:sample_num] import pickle from pathlib import Path # Cache file path cache_dir = Path(f"../../../data/cache_{dataset_name}") cache_dir.mkdir(exist_ok=True) def load_cached_results(): results = {'unconditional': None, 'sum_controlled': {}, 'anchor_controlled': {}} for cache_file in cache_dir.glob('*.pkl'): with open(cache_file, 'rb') as f: key = cache_file.stem if key == 'unconditional': results['unconditional'] = pickle.load(f) elif key.startswith('sum_'): param = key[4:] # Remove 'sum_' prefix results['sum_controlled'][param] = pickle.load(f) elif key.startswith('anchor_'): param = key[7:] # Remove 'anchor_' prefix results['anchor_controlled'][param] = pickle.load(f) return results def save_result(key, subkey, data): if subkey: filename = f"{key}_{subkey}.pkl" else: filename = f"{key}.pkl" with open(cache_dir / filename, 'wb') as f: pickle.dump(data, f) results = load_cached_results() dataset = dl_info["dataset"] seq_length, feature_dim = dataset.window, dataset.var_num coef = configs["dataloader"]["test_dataset"]["coefficient"] stepsize = configs["dataloader"]["test_dataset"]["step_size"] # Unconditional sampling if results['unconditional'] is None: print("Generating unconditional data...") results['unconditional'] = trainer.sample( num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim] ) save_result('unconditional', None, results['unconditional']) # Different AUC weights auc_weights = [10,] auc_values = [-200, -150, -100, 0, 20, 30, 50, 100, 150] for auc in auc_values: for weight in auc_weights: key = f"auc_{auc}_weight_{weight}" if key not in results['sum_controlled']: print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") results['sum_controlled'][key] = trainer.control_sample( num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], model_kwargs={ "gradient_control_signal": {"auc": auc, "auc_weight": weight}, "coef": coef, "learning_rate": stepsize } ) save_result('sum', key, results['sum_controlled'][key]) auc_weights = [1, 10, 50, 100] auc_values = [-200,] for auc in auc_values: for weight in auc_weights: key = f"auc_{auc}_weight_{weight}" if key not in results['sum_controlled']: print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") results['sum_controlled'][key] = trainer.control_sample( num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], model_kwargs={ "gradient_control_signal": {"auc": auc, "auc_weight": weight}, "coef": coef, "learning_rate": stepsize } ) save_result('sum', key, results['sum_controlled'][key]) # Different weekly peaks peak_values = [0.8, 1.0] peak_weights = [0.1, 0.5, 1.0] # import matplotlib.pyplot as plt # for peak in peak_values: # for weight in peak_weights: # key = f"peak_{peak}_weight_{weight}" # if key not in results['anchor_controlled']: # mask = np.zeros((seq_length, feature_dim), dtype=np.float32) # mask[::gap, 0] = weight # target = np.zeros((seq_length, feature_dim), dtype=np.float32) # target[::gap, 0] = peak # print(f"Generating anchor controlled data - Peak: {peak}, Weight: {weight}") # results['anchor_controlled'][key] = trainer.control_sample( # num=min(1000, len(dataset)), size_every=500, shape=[seq_length, feature_dim], # model_kwargs={ # "gradient_control_signal": {"auc": -50, "auc_weight": 10.0}, # "coef": coef, # "learning_rate": stepsize # }, # target=target, # partial_mask=mask # ) # save_result('anchor', key, results['anchor_controlled'][key]) # # plot mask, target, and generated sequence # plt.figure(figsize=(12, 6)) # plt.plot(mask[:, 0], label='Mask') # plt.plot(target[:, 0], label='Target') # plt.plot(results['anchor_controlled'][key][0, :, 0], label='Generated Sequence') # plt.title(f"Anchor Controlled Data - Peak: {peak}, Weight: {weight}") # plt.legend() # plt.show() # Unnormalize results if needed if dataset.auto_norm: for key, data in results.items(): if isinstance(data, dict): for subkey, subdata in data.items(): results[key][subkey] = unnormalize_to_zero_to_one(subdata) else: results[key] = unnormalize_to_zero_to_one(data) # Store the results in variables for compatibility with existing code unconditional_data = results['unconditional'] sum_controled_data = results['sum_controlled']# ['auc_0_weight_10.0'] # default values anchor_controled_data = results['anchor_controlled'] # ['peak_0.8_weight_0.1'] # default values # Sum control samples = 1000 data = { "ori_data": ori_data[:samples, :, :1], "Unconditional": unconditional_data[:samples, :, :1], } # for key, value in sum_controled_data.items(): # if "weight_10" in key: # data[key] = value # print(key) keys = [ # "auc_-200_weight_10", "auc_-100_weight_10", # "auc_0_weight_10", "auc_20_weight_10", # "auc_30_weight_10", "auc_50_weight_10", # "auc_100_weight_10", "auc_150_weight_10", ] for key in keys: data[key] = sum_controled_data[key][:samples, :, :1] # print sum print(key, " ==> ", sum_controled_data[key][:samples, :, :1].sum() / sum_controled_data[key][:samples, :, :1].shape[0]) # visualization_control( # data=data, # analysis="kernel", # compare=ori_data.shape[0], # output_label="revenue" # ) def visualization_control_subplots(data, analysis="kernel", compare=100, output_label="", highlight=None): # from scipy import integrate # Calculate area under curve for each distribution def get_auc(data_array): return data_array.sum(-1).mean() # Get AUC values auc_orig = get_auc(data["ori_data"]) auc_uncond = get_auc(data["Unconditional"]) # Setup subplots keys = [k for k in data.keys() if k not in ["ori_data", "Unconditional"]] l = len(keys) n_cols = min(4, len(keys)) n_rows = (len(keys) + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 4*n_rows)) fig.set_dpi(300) if n_rows == 1: axes = axes.reshape(1, -1) def beautiful_text(key): print(key) if "auc" in key: auc = key.split("_")[1] weight = key.split("_")[3] if highlight is None: return f"AUC: $\\mathbf{{{auc}}}$ Weight: {weight}" else: return f"AUC: {auc} Weight: $\\mathbf{{{weight}}}$" if "peak" in key: peak = key.split("_")[1] weight = key.split("_")[3] return f"Peak: {peak} Weight: {weight}" return key # Plot distributions # colors = create_color_gradient({key: get_auc(data[key]) for key in keys}, '#004225','#F02147', '#4B0082') def get_alpha(idx, n_plots): """Generate alpha value between 0.3-0.8 based on plot index""" return 0.5 + (0.4 * idx / (n_plots - 1)) if n_plots > 1 else 0.8 for idx, key in enumerate(keys): row, col = idx // n_cols, idx % n_cols ax = axes[row, col] # Plot distributions sns.distplot(data["ori_data"], hist=False, kde=True, kde_kws={"linewidth": 2, "alpha": 0.9 - get_alpha(idx, l) * 0.5}, color='red', ax=ax, label=f'Original\n$\overline{{Area}}={auc_orig:.3f}$') sns.distplot(data["Unconditional"], hist=False, kde=True, kde_kws={"linewidth": 2, "linestyle":"--", "alpha": 0.9 - get_alpha(idx, l) * 0.5}, color='#15B01A', ax=ax, #FF4500 GREEN:15B01A label=f'Unconditional\n$\overline{{Area}}= {auc_uncond:.3f}$') auc_control = get_auc(data[key]) sns.distplot(data[key], hist=False, kde=True, kde_kws={"linewidth": 2, "alpha": get_alpha(idx, l), "linestyle": "--"}, color="#9A0EEA", ax=ax, label=f'{beautiful_text(key)}\n$\overline{{Area}}= {auc_control:.3f})$') # ax.set_title(f'{beautiful_text(key)}') ax.legend() # Set labels only for first column and last row if col == 0: ax.set_ylabel('Density') else: ax.set_ylabel('') if row == n_rows - 1: ax.set_xlabel('Value') else: ax.set_xlabel('') fig.suptitle(f"Kernel Density Estimation of {output_label}", fontsize=16)#, fontweight='bold') plt.tight_layout() plt.show() # save pdf # plt.savefig(f"./figures/{output_label}_kde.pdf", bbox_inches='tight') save_pdf(fig, f"./figures/{output_label}_kde.pdf") plt.close() ds_name_display = { "sines": "Synthetic Sine Waves", "revenue": "Revenue", "energy": "ETTh", "fmri": "fMRI", } visualization_control_subplots( data=data, analysis="kernel", compare=ori_data.shape[0], output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control" ) # peak control # data = { # "ori_data": ori_data[:samples, :, :1], # "Unconditional": unconditional_data[:samples, :, :1], # } # keys = [ # "peak_0.8_weight_0.1", # "peak_0.8_weight_0.5", # "peak_0.8_weight_1.0", # "peak_1.0_weight_0.1", # "peak_1.0_weight_0.5", # "peak_1.0_weight_1.0", # ] # for key in keys: # data[key] = anchor_controled_data[key][:samples, :, :1] # # print peak # print(key, " ==> ", anchor_controled_data[key][:samples, :, :1].max()) # visualization_control( # data=data, # analysis="kernel", # compare=ori_data.shape[0], # output_label="revenue" # ) # # config_mapping = { # # "sines": { # # } # # "revenue": "revenue", # # "energy": "energy", # # "fmri": "fMRI", # # } # # Evaluate peak detection for different control settings # peak_accuracies = {} # for key, data in anchor_controled_data.items(): # print(f"\nEvaluating {key}") # metrics, peaks = evaluate_peak_detection( # data, # target_peaks=range(0, seq_length, gap), # window_size=max(1, gap//2), # min_distance=max(1, gap - 1) # ) # peak_accuracies[key] = metrics # print(f"Accuracy: {metrics['accuracy']:.3f}") # print(f"Precision: {metrics['precision']:.3f}") # print(f"Matched peaks: {metrics['matched_peaks']} / {metrics['total_targets']}") print("="*50)