Spaces:
Sleeping
Sleeping
import scipy | |
import numpy as np | |
from models.ts2vec.ts2vec import TS2Vec | |
def calculate_fid(act1, act2): | |
# calculate mean and covariance statistics | |
mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False) | |
mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False) | |
# calculate sum squared difference between means | |
ssdiff = np.sum((mu1 - mu2) ** 2.0) | |
# calculate sqrt of product between cov | |
covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2)) | |
# check and correct imaginary numbers from sqrt | |
if np.iscomplexobj(covmean): | |
covmean = covmean.real | |
# calculate score | |
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) | |
return fid | |
def Context_FID(ori_data, generated_data): | |
model = TS2Vec( | |
input_dims=ori_data.shape[-1], | |
device=0, | |
batch_size=8, | |
lr=0.001, | |
output_dims=320, | |
max_train_length=3000, | |
) | |
model.fit(ori_data, verbose=False) | |
ori_represenation = model.encode(ori_data, encoding_window="full_series") | |
gen_represenation = model.encode(generated_data, encoding_window="full_series") | |
idx = np.random.permutation(ori_data.shape[0]) | |
ori_represenation = ori_represenation[idx] | |
gen_represenation = gen_represenation[idx] | |
results = calculate_fid(ori_represenation, gen_represenation) | |
return results | |