TSEditor / run.py
PeterYu's picture
update
2875fe6
raw
history blame
6.57 kB
import os
import sys
import warnings
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
import torch
import tensorflow as tf
# Add parent directory to path
sys.path.append(os.path.join(os.path.dirname('__file__'), '../'))
# Local imports
from experiment import run
from utils.context_fid import Context_FID
from utils.cross_correlation import CrossCorrelLoss
from utils.metric_utils import display_scores
from utils.discriminative_metric import discriminative_score_metrics
from utils.predictive_metric import predictive_score_metrics
# Suppress warnings
# Configure GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
# Global settings
iterations = 5
enable_fid = 0
enable_corr = 0
enable_dis = 1
enable_pred = 1
# all_results = {}
# for config_path in [
# "./config/modified/sines.yaml",
# "./config/modified/revenue-baseline-365.yaml",
# "./config/modified/energy.yaml",
# "./config/modified/fmri.yaml",
# "./config/modified/96/energy.yaml",
# "./config/modified/192/energy.yaml",
# "./config/modified/384/energy.yaml",
# "./config/modified/96/fmri.yaml",
# "./config/modified/192/fmri.yaml",
# "./config/modified/384/fmri.yaml",
# "./config/modified/96/sines.yaml",
# "./config/modified/192/sines.yaml",
# "./config/modified/384/sines.yaml",
# "./config/modified/192/revenue.yaml",
# "./config/modified/96/revenue.yaml",
# "./config/modified/384/revenue.yaml",
# ]:
# class Args:
# config_path = config_path
# gpu = 0
# results, dataset_name, seq_length = run(Args())
# all_results[config_path] = (results, dataset_name, seq_length)
# python run.py ./config/modified/energy.yaml
# python run.py ./config/modified/fmri.yaml
# python run.py ./config/modified/sines.yaml
# python run.py ./config/modified/revenue-baseline-365.yaml
# python run.py ./config/modified/96/energy.yaml
# python run.py ./config/modified/192/energy.yaml
# python run.py ./config/modified/384/energy.yaml
# python run.py ./config/modified/96/fmri.yaml
# python run.py ./config/modified/192/fmri.yaml
# python run.py ./config/modified/384/fmri.yaml
# python run.py ./config/modified/96/sines.yaml
# python run.py ./config/modified/192/sines.yaml
# python run.py ./config/modified/384/sines.yaml
# python run.py ./config/modified/192/revenue.yaml
# python run.py ./config/modified/96/revenue.yaml
# python run.py ./config/modified/384/revenue.yaml
ds_name_display = {
"sines": "Sine",
"revenue": "Revenue",
"energy": "ETTh",
"fmri": "fMRI",
}
def random_choice(size, num_select=100):
select_idx = np.random.randint(low=0, high=size, size=(num_select,))
return select_idx
def compute_metrics(ori_data, fake_data, iterations=5, data_name='sines', data_len=24, key="unconditional"):
if enable_dis:
discriminative_score = []
for i in range(iterations):
temp_disc, fake_acc, real_acc = discriminative_score_metrics(ori_data[:], fake_data[:ori_data.shape[0]])
discriminative_score.append(temp_disc)
print(f'Iter {i}: ', temp_disc, ',', fake_acc, ',', real_acc, '\n')
mean, sigma = display_scores(discriminative_score)
content = f'disc {data_name} {key} {data_len} {mean} {sigma}'
with open(f'log {data_name}.txt', 'a+') as file:
file.write(content + '\n')
if enable_pred:
predictive_score = []
for i in range(iterations):
temp_pred = predictive_score_metrics(ori_data, fake_data[:ori_data.shape[0]])
predictive_score.append(temp_pred)
print(i, ' epoch: ', temp_pred, '\n')
mean, sigma = display_scores(predictive_score)
content = f'pred {data_name} {key} {data_len} {mean} {sigma}'
with open(f'log {data_name}.txt', 'a+') as file:
file.write(content + '\n')
if enable_fid:
context_fid_score = []
for i in range(iterations):
context_fid = Context_FID(ori_data[:], fake_data[:ori_data.shape[0]])
context_fid_score.append(context_fid)
print(f'Iter {i}: ', 'context-fid =', context_fid, '\n')
mean, sigma = display_scores(context_fid_score)
content = f'fid {data_name} {key} {data_len} {mean} {sigma}'
with open(f'log {data_name}.txt', 'a+') as file:
file.write(content + '\n')
if enable_corr:
x_real = torch.from_numpy(ori_data)
x_fake = torch.from_numpy(fake_data)
correlational_score = []
size = int(x_real.shape[0] / iterations)
for i in range(iterations):
real_idx = random_choice(x_real.shape[0], size)
fake_idx = random_choice(x_fake.shape[0], size)
corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss')
loss = corr.compute(x_fake[fake_idx, :, :])
correlational_score.append(loss.item())
print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\n')
mean, sigma = display_scores(correlational_score)
content = f'corr {data_name} {key} {data_len} {mean} {sigma}'
with open(f'log {data_name}.txt', 'a+') as file:
file.write(content + '\n')
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('config_path', type=str, default="./config/modified/sines.yaml")
parser.add_argument('--gpu', type=int, default=0)
# class Args:
# config_path = config_path
# gpu = 0
results, dataset_name, seq_length = run(parser.parse_args())
# config_path = parser.parse_args().config_path
# results, dataset_name, seq_length = all_results[config_path]
ori_data = results["ori_data"]
unconditional_data = results["unconditional"]
sum_controled_data = results["sum_controlled"]
anchor_controled_data = results["anchor_controlled"]
compute_metrics(ori_data, unconditional_data, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key="unconditional")
for key, value in sum_controled_data.items():
compute_metrics(ori_data, value, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key=key)
for key, value in anchor_controled_data.items():
compute_metrics(ori_data, value, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key=key)