import os import sys import warnings warnings.filterwarnings("ignore") os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import numpy as np import torch import tensorflow as tf # Add parent directory to path sys.path.append(os.path.join(os.path.dirname('__file__'), '../')) # Local imports from experiment import run from utils.context_fid import Context_FID from utils.cross_correlation import CrossCorrelLoss from utils.metric_utils import display_scores from utils.discriminative_metric import discriminative_score_metrics from utils.predictive_metric import predictive_score_metrics # Suppress warnings # Configure GPU memory growth gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) # Global settings iterations = 5 enable_fid = 0 enable_corr = 0 enable_dis = 1 enable_pred = 1 # all_results = {} # for config_path in [ # "./config/modified/sines.yaml", # "./config/modified/revenue-baseline-365.yaml", # "./config/modified/energy.yaml", # "./config/modified/fmri.yaml", # "./config/modified/96/energy.yaml", # "./config/modified/192/energy.yaml", # "./config/modified/384/energy.yaml", # "./config/modified/96/fmri.yaml", # "./config/modified/192/fmri.yaml", # "./config/modified/384/fmri.yaml", # "./config/modified/96/sines.yaml", # "./config/modified/192/sines.yaml", # "./config/modified/384/sines.yaml", # "./config/modified/192/revenue.yaml", # "./config/modified/96/revenue.yaml", # "./config/modified/384/revenue.yaml", # ]: # class Args: # config_path = config_path # gpu = 0 # results, dataset_name, seq_length = run(Args()) # all_results[config_path] = (results, dataset_name, seq_length) # python run.py ./config/modified/energy.yaml # python run.py ./config/modified/fmri.yaml # python run.py ./config/modified/sines.yaml # python run.py ./config/modified/revenue-baseline-365.yaml # python run.py ./config/modified/96/energy.yaml # python run.py ./config/modified/192/energy.yaml # python run.py ./config/modified/384/energy.yaml # python run.py ./config/modified/96/fmri.yaml # python run.py ./config/modified/192/fmri.yaml # python run.py ./config/modified/384/fmri.yaml # python run.py ./config/modified/96/sines.yaml # python run.py ./config/modified/192/sines.yaml # python run.py ./config/modified/384/sines.yaml # python run.py ./config/modified/192/revenue.yaml # python run.py ./config/modified/96/revenue.yaml # python run.py ./config/modified/384/revenue.yaml ds_name_display = { "sines": "Sine", "revenue": "Revenue", "energy": "ETTh", "fmri": "fMRI", } def random_choice(size, num_select=100): select_idx = np.random.randint(low=0, high=size, size=(num_select,)) return select_idx def compute_metrics(ori_data, fake_data, iterations=5, data_name='sines', data_len=24, key="unconditional"): if enable_dis: discriminative_score = [] for i in range(iterations): temp_disc, fake_acc, real_acc = discriminative_score_metrics(ori_data[:], fake_data[:ori_data.shape[0]]) discriminative_score.append(temp_disc) print(f'Iter {i}: ', temp_disc, ',', fake_acc, ',', real_acc, '\n') mean, sigma = display_scores(discriminative_score) content = f'disc {data_name} {key} {data_len} {mean} {sigma}' with open(f'log {data_name}.txt', 'a+') as file: file.write(content + '\n') if enable_pred: predictive_score = [] for i in range(iterations): temp_pred = predictive_score_metrics(ori_data, fake_data[:ori_data.shape[0]]) predictive_score.append(temp_pred) print(i, ' epoch: ', temp_pred, '\n') mean, sigma = display_scores(predictive_score) content = f'pred {data_name} {key} {data_len} {mean} {sigma}' with open(f'log {data_name}.txt', 'a+') as file: file.write(content + '\n') if enable_fid: context_fid_score = [] for i in range(iterations): context_fid = Context_FID(ori_data[:], fake_data[:ori_data.shape[0]]) context_fid_score.append(context_fid) print(f'Iter {i}: ', 'context-fid =', context_fid, '\n') mean, sigma = display_scores(context_fid_score) content = f'fid {data_name} {key} {data_len} {mean} {sigma}' with open(f'log {data_name}.txt', 'a+') as file: file.write(content + '\n') if enable_corr: x_real = torch.from_numpy(ori_data) x_fake = torch.from_numpy(fake_data) correlational_score = [] size = int(x_real.shape[0] / iterations) for i in range(iterations): real_idx = random_choice(x_real.shape[0], size) fake_idx = random_choice(x_fake.shape[0], size) corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss') loss = corr.compute(x_fake[fake_idx, :, :]) correlational_score.append(loss.item()) print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\n') mean, sigma = display_scores(correlational_score) content = f'corr {data_name} {key} {data_len} {mean} {sigma}' with open(f'log {data_name}.txt', 'a+') as file: file.write(content + '\n') import argparse parser = argparse.ArgumentParser() parser.add_argument('config_path', type=str, default="./config/modified/sines.yaml") parser.add_argument('--gpu', type=int, default=0) # class Args: # config_path = config_path # gpu = 0 results, dataset_name, seq_length = run(parser.parse_args()) # config_path = parser.parse_args().config_path # results, dataset_name, seq_length = all_results[config_path] ori_data = results["ori_data"] unconditional_data = results["unconditional"] sum_controled_data = results["sum_controlled"] anchor_controled_data = results["anchor_controlled"] compute_metrics(ori_data, unconditional_data, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key="unconditional") for key, value in sum_controled_data.items(): compute_metrics(ori_data, value, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key=key) for key, value in anchor_controled_data.items(): compute_metrics(ori_data, value, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key=key)