Spaces:
Sleeping
Sleeping
import os | |
import torch | |
os.environ["WANDB_ENABLED"] = "false" | |
from engine.solver import Trainer | |
from data.build_dataloader import build_dataloader | |
from utils.metric_utils import visualization, save_pdf | |
from data.build_dataloader import build_dataloader_cond | |
# from utils.metric_utils import visualization | |
from utils.io_utils import load_yaml_config, instantiate_from_config | |
from models.model_utils import unnormalize_to_zero_to_one | |
from scipy.signal import find_peaks, peak_prominences | |
# disable user warnings | |
import warnings | |
warnings.simplefilter("ignore", UserWarning) | |
import numpy as np | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
import pickle | |
from pathlib import Path | |
def load_cached_results(cache_dir): | |
results = {"unconditional": None, "sum_controlled": {}, "anchor_controlled": {}} | |
for cache_file in cache_dir.glob("*.pkl"): | |
with open(cache_file, "rb") as f: | |
key = cache_file.stem | |
# if key=="unconditional": | |
# continue | |
if key == "unconditional": | |
results["unconditional"] = pickle.load(f) | |
elif key.startswith("sum_"): | |
param = key[4:] # Remove 'sum_' prefix | |
results["sum_controlled"][param] = pickle.load(f) | |
elif key.startswith("anchor_"): | |
param = key[7:] # Remove 'anchor_' prefix | |
results["anchor_controlled"][param] = pickle.load(f) | |
return results | |
def save_result(cache_dir, key, subkey, data): | |
if subkey: | |
filename = f"{key}_{subkey}.pkl" | |
else: | |
filename = f"{key}.pkl" | |
with open(cache_dir / filename, "wb") as f: | |
pickle.dump(data, f) | |
class Arguments: | |
def __init__(self, config_path, gpu=0) -> None: | |
self.config_path = config_path | |
# self.config_path = "./config/control/revenue-baseline-sine.yaml" | |
self.save_dir = ( | |
"../../../data/" + os.path.basename(self.config_path).split(".")[0] | |
) | |
self.gpu = gpu | |
os.makedirs(self.save_dir, exist_ok=True) | |
self.mode = "infill" | |
self.missing_ratio = 0.95 | |
self.milestone = 10 | |
def beautiful_text(key, highlight): | |
# print(key) | |
if "auc" in key: | |
auc = key.split("_")[1] | |
weight = key.split("_")[3] | |
if highlight is None: | |
return f"AUC: $\\mathbf{{{auc}}}$ Weight: {weight}" | |
else: | |
return f"AUC: {auc} Weight: $\\mathbf{{{weight}}}$" | |
if "anchor" in key: | |
anchor = key.split("_")[1] | |
weight = key.split("_")[3] | |
return f"anchor: {anchor} Weight: {weight}" | |
return key | |
def get_alpha(idx, n_plots): | |
"""Generate alpha value between 0.3-0.8 based on plot index""" | |
return 0.5 + (0.4 * idx / (n_plots - 1)) if n_plots > 1 else 0.8 | |
def create_color_gradient( | |
sorting_value=None, start_color="#FFFF00", end_color="#00008B" | |
): | |
"""Create color gradient using matplotlib color interpolation.""" | |
def color_fader(c1, c2, mix=0): | |
"""Fade from color c1 to c2 with mix ratio.""" | |
c1 = np.array(mpl.colors.to_rgb(c1)) | |
c2 = np.array(mpl.colors.to_rgb(c2)) | |
return mpl.colors.to_hex((1 - mix) * c1 + mix * c2) | |
if sorting_value is not None: | |
# Normalize values between 0-1 | |
values = np.array(list(sorting_value.values())) | |
normalized = (values - values.min()) / (values.max() - values.min()) | |
# Create color mapping | |
return { | |
key: color_fader(start_color, end_color, mix=norm_val) | |
for key, norm_val in zip(sorting_value.keys(), normalized) | |
} | |
else: | |
# Return middle point color | |
return color_fader(start_color, end_color, mix=0.5) | |
def create_color_gradient( | |
sorting_value=None, | |
start_color="#FFFF00", | |
middle_color="#00FF00", | |
end_color="#00008B", | |
): | |
"""Create color gradient using matplotlib interpolation with middle color.""" | |
def color_fader(c1, c2, mix=0): | |
"""Fade from color c1 to c2 with mix ratio.""" | |
c1 = np.array(mpl.colors.to_rgb(c1)) | |
c2 = np.array(mpl.colors.to_rgb(c2)) | |
return mpl.colors.to_hex((1 - mix) * c1 + mix * c2) | |
if sorting_value is not None: | |
values = np.array(list(sorting_value.values())) | |
normalized = (values - values.min()) / (values.max() - values.min()) | |
colors = {} | |
for key, norm_val in zip(sorting_value.keys(), normalized): | |
if norm_val <= 0.5: | |
# Interpolate between start and middle | |
mix = norm_val * 2 # Scale 0-0.5 to 0-1 | |
colors[key] = color_fader(start_color, middle_color, mix) | |
else: | |
# Interpolate between middle and end | |
mix = (norm_val - 0.5) * 2 # Scale 0.5-1 to 0-1 | |
colors[key] = color_fader(middle_color, end_color, mix) | |
return colors | |
else: | |
return middle_color # Return middle color directly | |
# for config_path in [ | |
# # "./config/modified/sines.yaml", | |
# # "./config/modified/revenue-baseline-365.yaml", | |
# "./config/modified/energy.yaml", | |
# "./config/modified/fmri.yaml", | |
# ]: | |
import argparse | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Controlled Sampling") | |
parser.add_argument( | |
"--config_path", type=str, default="./config/modified/energy.yaml" | |
) | |
parser.add_argument("--gpu", type=int, default=0) | |
return parser.parse_args() | |
def run(run_args): | |
args = Arguments(run_args.config_path, run_args.gpu) | |
configs = load_yaml_config(args.config_path) | |
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") | |
torch.cuda.set_device(args.gpu) | |
dl_info = build_dataloader(configs, args) | |
model = instantiate_from_config(configs["model"]).to(device) | |
trainer = Trainer(config=configs, args=args, model=model, dataloader=dl_info) | |
# args.milestone | |
trainer.load("10") | |
dataset = dl_info["dataset"] | |
test_dl_info = build_dataloader_cond(configs, args) | |
test_dataloader, test_dataset = test_dl_info["dataloader"], test_dl_info["dataset"] | |
coef = configs["dataloader"]["test_dataset"]["coefficient"] | |
stepsize = configs["dataloader"]["test_dataset"]["step_size"] | |
sampling_steps = configs["dataloader"]["test_dataset"]["sampling_steps"] | |
seq_length, feature_dim = test_dataset.window, test_dataset.var_num | |
dataset_name = os.path.basename(args.config_path).split(".")[0].split("-")[0] | |
mapper = { | |
"sines": "sines", | |
"revenue": "revenue", | |
"energy": "energy", | |
"fmri": "fMRI", | |
} | |
gap = seq_length // 5 | |
if seq_length in [96, 192, 384]: | |
ori_data = np.load( | |
os.path.join( | |
"../../../data/train/",str(seq_length), | |
dataset_name, | |
"samples", | |
f'{mapper[dataset_name].replace("sines", "sine")}_norm_truth_{seq_length}_train.npy', | |
) | |
) | |
masks = np.load( | |
os.path.join( | |
"../../../data/train/",str(seq_length), | |
dataset_name, | |
"samples", | |
f'{mapper[dataset_name].replace("sines", "sine")}_masking_{seq_length}.npy', | |
) | |
) | |
else: | |
ori_data = np.load( | |
os.path.join( | |
"../../../data/train/", | |
dataset_name, | |
"samples", | |
f"{mapper[dataset_name]}_norm_truth_{seq_length}_train.npy", | |
) | |
) | |
masks = np.load( | |
os.path.join( | |
"../../../data/train/", | |
dataset_name, | |
"samples", | |
f"{mapper[dataset_name]}_masking_{seq_length}.npy", | |
) | |
) | |
sample_num, _, _ = masks.shape | |
# observed = ori_data[:sample_num] * masks | |
ori_data = ori_data[:sample_num] | |
sampling_size = min(1000, len(test_dataset), sample_num) | |
batch_size = 500 | |
print(f"Sampling size: {sampling_size}, Batch size: {batch_size}") | |
### Cache file path | |
cache_dir = Path(f"../../../data/cache/{dataset_name}_{seq_length}") | |
if "csdi" in args.config_path: | |
cache_dir = Path(f"../../../data/cache/csdi/{dataset_name}_{seq_length}") | |
cache_dir.mkdir(exist_ok=True) | |
results = load_cached_results(cache_dir) | |
### Unconditional sampling | |
if results["unconditional"] is None: | |
print("Generating unconditional data...") | |
results["unconditional"] = trainer.control_sample( | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
save_result(cache_dir, "unconditional", "", results["unconditional"]) | |
### Different AUC values | |
auc_weights = [10] | |
auc_values = [-100, 20, 50, 150] # -200, -150, -100, -50, 0, 20, 30, 50, 100, 150 | |
for auc in auc_values: | |
for weight in auc_weights: | |
key = f"auc_{auc}_weight_{weight}" | |
if key not in results["sum_controlled"]: | |
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
results["sum_controlled"][key] = trainer.control_sample( | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
### Different AUC weights | |
auc_weights = [1, 10, 50, 100] | |
auc_values = [-100] | |
for auc in auc_values: | |
for weight in auc_weights: | |
key = f"auc_{auc}_weight_{weight}" | |
if key not in results["sum_controlled"]: | |
print(f"Generating sum controlled data - AUC: {auc}, Weight: {weight}") | |
results["sum_controlled"][key] = trainer.control_sample( | |
num=sampling_size // 2, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {"auc": auc, "auc_weight": weight}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
### Different AUC segments | |
auc_weights = [10] | |
auc_values = [150] | |
auc_average = 10 | |
auc_segments = ((gap, 2 * gap), (2 * gap, 3 * gap), (3 * gap, 4 * gap)) | |
# for auc in auc_values: | |
# for weight in auc_weights: | |
# for segment in auc_segments: | |
auc = auc_values[0] | |
weight = auc_weights[0] | |
# segment = auc_segments[0] | |
for segment in auc_segments: | |
key = f"auc_{auc}_weight_{weight}_segment_{segment[0]}_{segment[1]}" | |
if key not in results["sum_controlled"]: | |
print( | |
f"Generating sum controlled data - AUC: {auc}, Weight: {weight}, Segment: {segment}" | |
) | |
results["sum_controlled"][key] = trainer.control_sample( | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": { | |
"auc": auc_average * (segment[1] - segment[0]), # / seq_length, | |
"auc_weight": weight, | |
"segment": [segment], | |
}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
) | |
save_result(cache_dir, "sum", key, results["sum_controlled"][key]) | |
# Different anchors | |
anchor_values = [-0.8, 0.6, 1.0] | |
anchor_weights = [0.01, 0.01, 0.5, 1.0] | |
for peak in anchor_values: | |
for weight in anchor_weights: | |
key = f"peak_{peak}_weight_{weight}" | |
if key not in results["anchor_controlled"]: | |
mask = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
mask[gap // 2 :: gap, 0] = weight | |
target = np.zeros((seq_length, feature_dim), dtype=np.float32) | |
target[gap // 2 :: gap, 0] = peak | |
print(f"Anchor controlled data - Peak: {peak}, Weight: {weight}") | |
results["anchor_controlled"][key] = trainer.control_sample( | |
num=sampling_size, | |
size_every=batch_size, | |
shape=[seq_length, feature_dim], | |
model_kwargs={ | |
"gradient_control_signal": {}, # "auc": -50, "auc_weight": 10.0}, | |
"coef": coef, | |
"learning_rate": stepsize, | |
}, | |
target=target, | |
partial_mask=mask, | |
) | |
save_result(cache_dir, "anchor", key, results["anchor_controlled"][key]) | |
# plot mask, target, and generated sequence | |
# plt.figure(figsize=(6, 3)) | |
# plt.scatter( | |
# range(gap // 2, seq_length, gap), [weight] * 5, label="Mask" | |
# ) | |
# plt.scatter( | |
# range(gap // 2, seq_length, gap), [peak] * 5, label="Target" | |
# ) | |
# plt.plot( | |
# results["anchor_controlled"][key][0, :, 0], | |
# label="Generated Sequence", | |
# ) | |
# plt.title(f"Anchor Controlled Data - Peak: {peak}, Weight: {weight}") | |
# plt.legend() | |
# plt.show() | |
if dataset.auto_norm: | |
for key, data in results.items(): | |
if isinstance(data, dict): | |
for subkey, subdata in data.items(): | |
results[key][subkey] = unnormalize_to_zero_to_one(subdata) | |
else: | |
results[key] = unnormalize_to_zero_to_one(data) | |
results["ori_data"] = ori_data | |
# results tructure to sampling_size | |
for key, data in results.items(): | |
if isinstance(data, dict): | |
for subkey, subdata in data.items(): | |
results[key][subkey] = subdata[:sampling_size] | |
else: | |
results[key] = data[:sampling_size] | |
return results, dataset_name, seq_length | |
def ploting(results, dataset_name, seq_length): | |
gap = seq_length // 5 | |
ds_name_display = { | |
"sines": "Synthetic Sine Waves", | |
"revenue": "Revenue", | |
"energy": "ETTh", | |
"fmri": "fMRI", | |
} | |
# Unnormalize results if needed | |
ori_data = results["ori_data"] | |
# Store the results in variables for compatibility with existing code | |
unconditional_data = results["unconditional"] | |
sum_controled_data = results["sum_controlled"] | |
# ['auc_0_weight_10.0'] # default values | |
anchor_controled_data = results["anchor_controlled"] | |
# ['anchor_0.8_weight_0.1'] # default values | |
### Visualization | |
def kernel_subplots( | |
data, output_label="", highlight=None | |
): | |
# from scipy import integrate | |
# Calculate area under curve for each distribution | |
def get_auc(data_array): | |
return data_array.sum(-1).mean() | |
# Get AUC values | |
auc_orig = get_auc(data["ori_data"]) | |
auc_uncond = get_auc(data["Unconditional"]) | |
# Setup subplots | |
keys = [k for k in data.keys() if k not in ["ori_data", "Unconditional"]] | |
l = len(keys) | |
n_cols = min(4, len(keys)) | |
n_rows = (len(keys) + n_cols - 1) // n_cols | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows)) | |
fig.set_dpi(300) | |
if n_rows == 1: | |
axes = axes.reshape(1, -1) | |
for idx, key in enumerate(keys): | |
row, col = idx // n_cols, idx % n_cols | |
ax = axes[row, col] | |
# Plot distributions | |
sns.distplot( | |
data["ori_data"], | |
hist=False, | |
kde=True, | |
kde_kws={"linewidth": 2, "alpha": 0.9 - get_alpha(idx, l) * 0.5}, | |
color="red", | |
ax=ax, | |
label=f"Original\n$\overline{{Area}}={auc_orig:.3f}$", | |
) | |
sns.distplot( | |
data["Unconditional"], | |
hist=False, | |
kde=True, | |
kde_kws={ | |
"linewidth": 2, | |
"linestyle": "--", | |
"alpha": 0.9 - get_alpha(idx, l) * 0.5, | |
}, | |
color="#15B01A", | |
ax=ax, # FF4500 GREEN:15B01A | |
label=f"Unconditional\n$\overline{{Area}}= {auc_uncond:.3f}$", | |
) | |
auc_control = get_auc(data[key]) | |
sns.distplot( | |
data[key], | |
hist=False, | |
kde=True, | |
kde_kws={"linewidth": 2, "alpha": get_alpha(idx, l), "linestyle": "--"}, | |
color="#9A0EEA", | |
ax=ax, | |
label=f"{beautiful_text(key, highlight)}\n$\overline{{Area}}= {auc_control:.3f})$", | |
) | |
# ax.set_title(f'{beautiful_text(key)}') | |
ax.legend() | |
# Set labels only for first column and last row | |
if col == 0: | |
ax.set_ylabel("Density") | |
else: | |
ax.set_ylabel("") | |
if row == n_rows - 1: | |
ax.set_xlabel("Value") | |
else: | |
ax.set_xlabel("") | |
# fig.suptitle(f"Kernel Density Estimation of {output_label}", fontsize=16)#, fontweight='bold') | |
plt.tight_layout() | |
plt.show() | |
# save pdf | |
# plt.savefig(f"./figures/{output_label}_kde.pdf", bbox_inches='tight') | |
save_pdf(fig, f"./figures/{output_label}_kde.pdf") | |
plt.close() | |
# Sum control | |
samples = 1000 | |
data = { | |
"ori_data": ori_data[:samples, :, :1], | |
"Unconditional": unconditional_data[:samples, :, :1], | |
} | |
for key in [ | |
# "auc_-200_weight_10", | |
"auc_-100_weight_10", | |
# "auc_0_weight_10", | |
"auc_20_weight_10", | |
# "auc_30_weight_10", | |
"auc_50_weight_10", | |
# "auc_100_weight_10", | |
"auc_150_weight_10", | |
]: | |
data[key] = sum_controled_data[key][:samples, :, :1] | |
print( | |
key, | |
" ==> ", | |
sum_controled_data[key][:samples, :, :1].sum() | |
/ sum_controled_data[key][:samples, :, :1].shape[0], | |
) | |
# visualization_control( | |
# data=data, | |
# analysis="kernel", | |
# compare=ori_data.shape[0], | |
# output_label="revenue" | |
# ) | |
# Updated | |
# kernel_subplots( | |
# data=data, | |
# output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control" | |
# ) | |
data = { | |
"ori_data": ori_data[:samples, :, :1], | |
"Unconditional": unconditional_data[:samples, :, :1], | |
} | |
for key in [ | |
"auc_-100_weight_1", | |
"auc_-100_weight_10", | |
"auc_-100_weight_50", | |
"auc_-100_weight_100", | |
]: | |
data[key] = sum_controled_data[key][:samples, :, :1] | |
# print sum | |
print( | |
key, | |
" ==> ", | |
sum_controled_data[key][:samples, :, :1].sum() | |
/ sum_controled_data[key][:samples, :, :1].shape[0], | |
) | |
kernel_subplots( | |
data=data, | |
analysis="kernel", | |
compare=ori_data.shape[0], | |
output_label=f"{ds_name_display[dataset_name]} Dataset with Summation Control", | |
highlight="weight", | |
) | |
# anchor control | |
data = { | |
"ori_data": ori_data[:samples, :, :1], | |
"Unconditional": unconditional_data[:samples, :, :1], | |
} | |
# anchor_values = [-0.8, 0.6, 1.0] | |
# anchor_weights = [0.01, 0.01, 0.5, 1.0] | |
for key in [ | |
"anchor_-0.8_weight_0.01", | |
"anchor_-0.8_weight_0.1", | |
"anchor_-0.8_weight_0.5", | |
"anchor_-0.8_weight_1.0", | |
"anchor_0.6_weight_0.01", | |
"anchor_0.6_weight_0.1", | |
"anchor_0.6_weight_0.5", | |
"anchor_0.6_weight_1.0", | |
"anchor_1.0_weight_0.01", | |
"anchor_1.0_weight_0.1", | |
"anchor_1.0_weight_0.5", | |
"anchor_1.0_weight_1.0", | |
]: | |
data[key] = anchor_controled_data[key][:samples, :, :1] | |
# print anchor | |
# print(key, " ==> ", anchor_controled_data[key][:samples, :, :1].max()) | |
def visualization_control_anchor_subplots( | |
data, seq_length, analysis="anchor", compare=100, output_label="" | |
): | |
# Extract unique anchors and weights | |
anchors = sorted( | |
set([float(k.split("_")[1]) for k in data.keys() if "anchor" in k]) | |
) | |
weights = sorted( | |
set([float(k.split("_")[3]) for k in data.keys() if "weight" in k]) | |
) | |
# Create subplot grid | |
n_rows = len(anchors) | |
n_cols = len(weights) | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows)) | |
fig.set_dpi(300) | |
gap = seq_length // 5 | |
for i, anchor in enumerate(anchors): | |
for j, weight in enumerate(weights): | |
ax = axes[i][j] | |
key = f"anchor_{anchor}_weight_{weight}" | |
# Plot distributions | |
sns.distplot( | |
data["ori_data"], | |
hist=False, | |
kde=True, | |
kde_kws={"linewidth": 2}, | |
color="red", | |
ax=ax, | |
label="Original", | |
) | |
sns.distplot( | |
data["Unconditional"], | |
hist=False, | |
kde=True, | |
kde_kws={"linewidth": 2, "linestyle": "--"}, | |
color="#15B01A", | |
ax=ax, | |
label="Unconditional", | |
) | |
if key in data: | |
sns.distplot( | |
data[key], | |
hist=False, | |
kde=True, | |
kde_kws={"linewidth": 2, "linestyle": "--"}, | |
color="#9A0EEA", | |
ax=ax, | |
label=f"Controlled\n$Target={anchor}, Conf={weight}$", | |
) | |
# anchor_point = int(anchor * seq_length) | |
anchor_points = np.arange(gap // 2, seq_length, gap) | |
for anchor_point in anchor_points: | |
ax.axvline( | |
x=anchor_point / seq_length, | |
color="black", | |
linestyle="--", | |
alpha=0.5, | |
) | |
# Labels and titles | |
if i == n_rows - 1: | |
ax.set_xlabel("Value") | |
if j == 0: | |
ax.set_ylabel("Density") | |
ax.set_title(f"anchor={anchor}, Weight={weight}") | |
ax.legend() | |
plt.tight_layout() | |
plt.show() | |
# save_pdf(fig, f"./figures/{output_label}_anchor_kde.pdf") | |
plt.close() | |
# Anchor Control Distribution | |
visualization_control_anchor_subplots( | |
data=data, | |
seq_length=seq_length, | |
analysis="anchor", | |
compare=ori_data.shape[0], | |
output_label=f"{ds_name_display[dataset_name]} Dataset with Anchor Control", | |
) | |
def evaluate_anchor_detection( | |
data, target_anchors, window_size=7, min_distance=5, prominence_threshold=0.1 | |
): | |
""" | |
Evaluate anchor detection accuracy by comparing detected anchors with target anchors. | |
Parameters: | |
data: numpy array of shape (batch_size, seq_length, features) | |
The generated sequences to analyze | |
The indices where anchors should occur (e.g., every 7 steps for weekly anchors) | |
target_anchor: list | |
List of indices where anchors should occur | |
window_size: int | |
Size of window to consider a anchor match | |
""" | |
batch_size, seq_length, features = data.shape | |
detected_anchors = [] | |
accuracy_metrics = {} | |
# Create figure for visualization | |
fig, axes = plt.subplots(2, 2, figsize=(10, 5)) | |
axes = axes.flatten() | |
# Analyze first 8 batches and first feature (revenue) | |
overall_matched = 0 | |
overall_targets = 0 | |
for i in range(4): | |
sequence = data[i, :, 0] # batch i, all timepoints, revenue feature | |
# Find anchors using scipy | |
anchors, properties = find_peaks( | |
sequence, distance=min_distance, prominence=prominence_threshold | |
) | |
# Plot original sequence and detected anchors | |
axes[i].plot(sequence, label="Generated") | |
# Plot target anchor positions | |
target_positions = ( | |
target_anchors # np.arange(0, seq_length, 7) # Weekly anchors | |
) | |
axes[i].plot( | |
target_positions, | |
sequence[target_positions], | |
"o", | |
label="Target" if i == 1 else "", | |
) | |
axes[i].plot( | |
anchors, sequence[anchors], "x", label="Detected" if i == 1 else "" | |
) | |
axes[i].set_title(f"Sequence {i+1}") | |
if i == 1: | |
axes[i].legend(bbox_to_anchor=(1.05, 1), loc="upper left") | |
axes[i].grid(True) | |
# Count matches within window for this sequence | |
matched_anchors = 0 | |
for target in target_positions: | |
# Check if any detected anchor is within the window of the target | |
matches = np.any( | |
(anchors >= target - window_size // 2) | |
& (anchors <= target + window_size // 2) | |
) | |
if matches: | |
matched_anchors += 1 | |
overall_matched += matched_anchors | |
overall_targets += len(target_positions) | |
for i in range(4, batch_size): | |
anchors, properties = find_peaks( | |
data[i, :, 0], distance=min_distance, prominence=prominence_threshold | |
) | |
matched_anchors = 0 | |
for target in target_anchors: | |
matches = np.any( | |
(anchors >= target - window_size // 2) | |
& (anchors <= target + window_size // 2) | |
) | |
if matches: | |
matched_anchors += 1 | |
overall_matched += matched_anchors | |
overall_targets += len(target_anchors) | |
# Calculate overall metrics | |
accuracy = overall_matched / overall_targets | |
precision = overall_matched / (len(anchors) * 8) if len(anchors) > 0 else 0 | |
accuracy_metrics = { | |
"accuracy": accuracy, | |
"precision": precision, | |
"total_targets": overall_targets, | |
"detected_anchors": len(anchors) * 8, | |
"matched_anchors": overall_matched, | |
} | |
plt.tight_layout() | |
plt.show() | |
return accuracy_metrics, anchors | |
# Evaluate anchor detection for different control settings | |
anchor_accuracies = {} | |
for key, data in anchor_controled_data.items(): | |
print(f"\nEvaluating {key}") | |
metrics, anchors = evaluate_anchor_detection( | |
data, | |
target_anchors=range(0, seq_length, gap), | |
window_size=max(1, gap // 2), | |
min_distance=max(1, gap - 1), | |
) | |
anchor_accuracies[key] = metrics | |
print(f"Accuracy: {metrics['accuracy']:.3f}") | |
print(f"Precision: {metrics['precision']:.3f}") | |
print( | |
f"Matched anchors: {metrics['matched_anchors']} / {metrics['total_targets']}" | |
) | |
print("=" * 50) | |
if __name__ == "__main__": | |
args = parse_args() | |
run(args) | |