Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import time | |
os.environ["WANDB_ENABLED"] = "false" | |
from engine.solver import Trainer | |
from data.build_dataloader import build_dataloader | |
from data.build_dataloader import build_dataloader_cond | |
from utils.io_utils import load_yaml_config, instantiate_from_config | |
import warnings | |
warnings.simplefilter("ignore", UserWarning) | |
import numpy as np | |
import pickle | |
from pathlib import Path | |
def load_cached_results(cache_dir): | |
results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}} | |
for cache_file in cache_dir.glob("*.pkl"): | |
with open(cache_file, "rb") as f: | |
key = cache_file.stem | |
# if key=="unconditional": | |
# continue | |
if key == "unconditional": | |
results["unconditional"] = pickle.load(f) | |
elif key.startswith("sum_"): | |
param = key[4:] # Remove 'sum_' prefix | |
results["sum_controlled"][param] = pickle.load(f) | |
elif key.startswith("anchor_"): | |
param = key[7:] # Remove 'anchor_' prefix | |
results["anchor_controlled"][param] = pickle.load(f) | |
return results | |
def save_result(cache_dir, key, subkey, data): | |
return | |
if subkey: | |
filename = f"{key}_{subkey}.pkl" | |
else: | |
filename = f"{key}.pkl" | |
with open(cache_dir / filename, "wb") as f: | |
pickle.dump(data, f) | |
class Arguments: | |
def __init__(self, config_path, gpu=0) -> None: | |
self.config_path = config_path | |
# self.config_path = "./config/control/revenue-baseline-sine.yaml" | |
self.save_dir = ( | |
"../../../data/" + os.path.basename(self.config_path).split(".")[0] | |
) | |
self.gpu = gpu | |
os.makedirs(self.save_dir, exist_ok=True) | |
self.mode = "infill" | |
self.missing_ratio = 0.95 | |
self.milestone = 10 | |
import argparse | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Controlled Sampling") | |
parser.add_argument( | |
"--config_path", type=str, default="./config/modified/energy.yaml" | |
) | |
parser.add_argument("--gpu", type=int, default=0) | |
return parser.parse_args() | |
def run(run_args): | |
args = Arguments(run_args.config_path, run_args.gpu) | |
configs = load_yaml_config(args.config_path) | |
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") | |
torch.cuda.set_device(args.gpu) | |
dl_info = build_dataloader(configs, args) | |
model = instantiate_from_config(configs["model"]).to(device) | |
trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info) | |
# args.milestone | |
trainer.load("10") | |
dataset = dl_info["dataset"] | |
test_dl_info = build_dataloader_cond(configs, args) | |
test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"] | |
coef = configs["dataloader"]["test_dataset"]["coefficient"] | |
stepsize = configs["dataloader"]["test_dataset"]["step_size"] | |
sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"] | |
seq_length, feature_dim = test_dataset.window, test_dataset.var_num | |
dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0] | |
mapper = { | |
"sines": "sines", | |
"revenue": "revenue", | |
"energy": "energy", | |
"fmri": "fMRI", | |
} | |
gap = seq_length // 5 | |
if seq_length in [96, 192, 384]: | |
ori_data = np.load( | |
os.path.join( | |
"../../../data/train/",str(seq_length), | |
dataset_name, | |
"samples", | |
f'{mapper[dataset_name].replace("sines", "sine")}_norm_truth_{seq_length}_train.npy', | |
) | |
) | |
masks = np.load( | |
os.path.join( | |
"../../../data/train/",str(seq_length), | |
dataset_name, | |
"samples", | |
f'{mapper[dataset_name].replace("sines", "sine")}_masking_{seq_length}.npy', | |
) | |
) | |
else: | |
ori_data = np.load( | |
os.path.join( | |
"../../../data/train/", | |
dataset_name, | |
"samples", | |
f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy", | |
) | |
) | |
masks = np.load( | |
os.path.join( | |
"../../../data/train/", | |
dataset_name, | |
"samples", | |
f"{mapper[dataset_name]}_masking_{seq_length}.npy", | |
) | |
) | |
sample_num, _, _ = masks.shape | |
# observed = ori_data[:sample_num] * masks | |
ori_data = ori_data[:sample_num] | |
sampling_size = min(1000, len(test_dataset), sample_num) | |
batch_size = 500 | |
print(f"Sampling size: {sampling_size}, Batch size: {batch_size}") | |
### Cache file path | |
cache_dir = Path(f"../../../data/cache/{dataset_name}_{seq_length}") | |
cache_dir.mkdir(exist_ok=True) | |
# results = load_cached_results(cache_dir) | |
results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}} | |
def measure_inference_time(func, *args, **kwargs): | |
start_time = time.time() | |
result = func(*args, **kwargs) | |
end_time = time.time() | |
return result, (end_time - start_time) | |
timing_results = {} | |
### Unconditional sampling | |
if results["unconditional"] is None: | |
print("Generating unconditional data...") | |
results["unconditional"], timing = measure_inference_time( | |
trainer.control_sample, | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
timing_results["unconditional"] = timing / sampling_size | |
save_result(cache_dir, "unconditional", "", results["unconditional"]) | |
### Different AUC values | |
auc_weights = [10] | |
auc_values = [-100, 20, 50, 150] # -200, -150, -100, -50, 0, 20, 30, 50, 100, 150 | |
for auc in auc_values: | |
for weight in auc_weights: | |
key = f"auc_{auc}_weight_{weight}" | |
if key not in results["sum_controlled"]: | |
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
results["sum_controlled"][key], timing = measure_inference_time( | |
trainer.control_sample, | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
timing_results[f"sum_controlled_{key}"] = timing / sampling_size | |
save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
### Different AUC weights | |
auc_weights = [1, 10, 50, 100] | |
auc_values = [-100] | |
for auc in auc_values: | |
for weight in auc_weights: | |
key = f"auc_{auc}_weight_{weight}" | |
if key not in results["sum_controlled"]: | |
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
results["sum_controlled"][key], timing = measure_inference_time( | |
trainer.control_sample, | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
timing_results[f"sum_controlled_{key}"] = timing / (sampling_size) | |
save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
### Different AUC segments | |
auc_weights = [10] | |
auc_values = [150] | |
auc_average = 10 | |
auc_segments = ((gap, 2 * gap), (2 * gap, 3 * gap), (3 * gap, 4 * gap)) | |
# for auc in auc_values: | |
# for weight in auc_weights: | |
# for segment in auc_segments: | |
auc = auc_values[0] | |
weight = auc_weights[0] | |
# segment = auc_segments[0] | |
for segment in auc_segments: | |
key = f"auc_{auc}_weight_{weight}_segment_{segment[0]}_{segment[1]}" | |
if key not in results["sum_controlled"]: | |
print( | |
f"Generating sum controlled data - AUC: {auc}, Weight: {weight}, Segment: {segment}" | |
) | |
results["sum_controlled"][key], timing = measure_inference_time( | |
trainer.control_sample, | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": { | |
"auc": auc_average * (segment[1] - segment[0]), # / seq_length, | |
"auc_weight": weight, | |
"segment": [segment], | |
}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
timing_results[f"sum_controlled_{key}"] = timing / sampling_size | |
save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
# Different anchors | |
anchor_values = [-0.8, 0.6, 1.0] | |
anchor_weights = [0.01, 0.01, 0.5, 1.0] | |
for peak in anchor_values: | |
for weight in anchor_weights: | |
key = f"peak_{peak}_weight_{weight}" | |
if key not in results["anchor_controlled"]: | |
mask = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
mask[gap // 2 :: gap, 0] = weight | |
target = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
target[gap // 2 :: gap, 0] = peak | |
print(f"Anchor controlled data - Peak: {peak}, Weight: {weight}") | |
results["anchor_controlled"][key], timing = measure_inference_time( | |
trainer.control_sample, | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {}, # "auc": -50, "auc_weight": 10.0}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
target=target, | |
partial_mask=mask, | |
) | |
timing_results[f"anchor_controlled_{key}"] = timing / sampling_size | |
save_result(cache_dir, "anchor", key, results["anchor_controlled"][key]) | |
### Rerun Unconditional sampling | |
if results["unconditional"] is None: | |
print("Generating unconditional data...") | |
results["unconditional"], timing = measure_inference_time( | |
trainer.control_sample, | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
timing_results["unconditional"] = timing / sampling_size | |
save_result(cache_dir, "unconditional", "", results["unconditional"]) | |
# After all sampling is done, print timing results | |
print("\nAverage Inference Time per Sample:") | |
print("-" * 40) | |
for key, time_per_sample in timing_results.items(): | |
print(f"{key}: {time_per_sample:.4f} seconds") | |
# return results, dataset_name, seq_length | |
if __name__ == "__main__": | |
args = parse_args() | |
run(args) | |