Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import warnings | |
warnings.filterwarnings("ignore") | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import numpy as np | |
import torch | |
import tensorflow as tf | |
# Add parent directory to path | |
sys.path.append(os.path.join(os.path.dirname('__file__'), '../')) | |
# Local imports | |
from experiment import run | |
from utils.context_fid import Context_FID | |
from utils.cross_correlation import CrossCorrelLoss | |
from utils.metric_utils import display_scores | |
from utils.discriminative_metric import discriminative_score_metrics | |
from utils.predictive_metric import predictive_score_metrics | |
# Suppress warnings | |
# Configure GPU memory growth | |
gpus = tf.config.experimental.list_physical_devices('GPU') | |
if gpus: | |
try: | |
for gpu in gpus: | |
tf.config.experimental.set_memory_growth(gpu, True) | |
except RuntimeError as e: | |
print(e) | |
# Global settings | |
iterations = 5 | |
enable_fid = 0 | |
enable_corr = 0 | |
enable_dis = 1 | |
enable_pred = 1 | |
# all_results = {} | |
# for config_path in [ | |
# "./config/modified/sines.yaml", | |
# "./config/modified/revenue-baseline-365.yaml", | |
# "./config/modified/energy.yaml", | |
# "./config/modified/fmri.yaml", | |
# "./config/modified/96/energy.yaml", | |
# "./config/modified/192/energy.yaml", | |
# "./config/modified/384/energy.yaml", | |
# "./config/modified/96/fmri.yaml", | |
# "./config/modified/192/fmri.yaml", | |
# "./config/modified/384/fmri.yaml", | |
# "./config/modified/96/sines.yaml", | |
# "./config/modified/192/sines.yaml", | |
# "./config/modified/384/sines.yaml", | |
# "./config/modified/192/revenue.yaml", | |
# "./config/modified/96/revenue.yaml", | |
# "./config/modified/384/revenue.yaml", | |
# ]: | |
# class Args: | |
# config_path = config_path | |
# gpu = 0 | |
# results, dataset_name, seq_length = run(Args()) | |
# all_results[config_path] = (results, dataset_name, seq_length) | |
# python run.py ./config/modified/energy.yaml | |
# python run.py ./config/modified/fmri.yaml | |
# python run.py ./config/modified/sines.yaml | |
# python run.py ./config/modified/revenue-baseline-365.yaml | |
# python run.py ./config/modified/96/energy.yaml | |
# python run.py ./config/modified/192/energy.yaml | |
# python run.py ./config/modified/384/energy.yaml | |
# python run.py ./config/modified/96/fmri.yaml | |
# python run.py ./config/modified/192/fmri.yaml | |
# python run.py ./config/modified/384/fmri.yaml | |
# python run.py ./config/modified/96/sines.yaml | |
# python run.py ./config/modified/192/sines.yaml | |
# python run.py ./config/modified/384/sines.yaml | |
# python run.py ./config/modified/192/revenue.yaml | |
# python run.py ./config/modified/96/revenue.yaml | |
# python run.py ./config/modified/384/revenue.yaml | |
ds_name_display = { | |
"sines": "Sine", | |
"revenue": "Revenue", | |
"energy": "ETTh", | |
"fmri": "fMRI", | |
} | |
def random_choice(size, num_select=100): | |
select_idx = np.random.randint(low=0, high=size, size=(num_select,)) | |
return select_idx | |
def compute_metrics(ori_data, fake_data, iterations=5, data_name='sines', data_len=24, key="unconditional"): | |
if enable_dis: | |
discriminative_score = [] | |
for i in range(iterations): | |
temp_disc, fake_acc, real_acc = discriminative_score_metrics(ori_data[:], fake_data[:ori_data.shape[0]]) | |
discriminative_score.append(temp_disc) | |
print(f'Iter {i}: ', temp_disc, ',', fake_acc, ',', real_acc, '\n') | |
mean, sigma = display_scores(discriminative_score) | |
content = f'disc {data_name} {key} {data_len} {mean} {sigma}' | |
with open(f'log {data_name}.txt', 'a+') as file: | |
file.write(content + '\n') | |
if enable_pred: | |
predictive_score = [] | |
for i in range(iterations): | |
temp_pred = predictive_score_metrics(ori_data, fake_data[:ori_data.shape[0]]) | |
predictive_score.append(temp_pred) | |
print(i, ' epoch: ', temp_pred, '\n') | |
mean, sigma = display_scores(predictive_score) | |
content = f'pred {data_name} {key} {data_len} {mean} {sigma}' | |
with open(f'log {data_name}.txt', 'a+') as file: | |
file.write(content + '\n') | |
if enable_fid: | |
context_fid_score = [] | |
for i in range(iterations): | |
context_fid = Context_FID(ori_data[:], fake_data[:ori_data.shape[0]]) | |
context_fid_score.append(context_fid) | |
print(f'Iter {i}: ', 'context-fid =', context_fid, '\n') | |
mean, sigma = display_scores(context_fid_score) | |
content = f'fid {data_name} {key} {data_len} {mean} {sigma}' | |
with open(f'log {data_name}.txt', 'a+') as file: | |
file.write(content + '\n') | |
if enable_corr: | |
x_real = torch.from_numpy(ori_data) | |
x_fake = torch.from_numpy(fake_data) | |
correlational_score = [] | |
size = int(x_real.shape[0] / iterations) | |
for i in range(iterations): | |
real_idx = random_choice(x_real.shape[0], size) | |
fake_idx = random_choice(x_fake.shape[0], size) | |
corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss') | |
loss = corr.compute(x_fake[fake_idx, :, :]) | |
correlational_score.append(loss.item()) | |
print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\n') | |
mean, sigma = display_scores(correlational_score) | |
content = f'corr {data_name} {key} {data_len} {mean} {sigma}' | |
with open(f'log {data_name}.txt', 'a+') as file: | |
file.write(content + '\n') | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('config_path', type=str, default="./config/modified/sines.yaml") | |
parser.add_argument('--gpu', type=int, default=0) | |
# class Args: | |
# config_path = config_path | |
# gpu = 0 | |
results, dataset_name, seq_length = run(parser.parse_args()) | |
# config_path = parser.parse_args().config_path | |
# results, dataset_name, seq_length = all_results[config_path] | |
ori_data = results["ori_data"] | |
unconditional_data = results["unconditional"] | |
sum_controled_data = results["sum_controlled"] | |
anchor_controled_data = results["anchor_controlled"] | |
compute_metrics(ori_data, unconditional_data, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key="unconditional") | |
for key, value in sum_controled_data.items(): | |
compute_metrics(ori_data, value, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key=key) | |
for key, value in anchor_controled_data.items(): | |
compute_metrics(ori_data, value, iterations=iterations, data_name=ds_name_display[dataset_name], data_len=seq_length, key=key) | |