TSEditor / utils /context_fid.py
PeterYu's picture
update
2875fe6
raw
history blame
1.38 kB
import scipy
import numpy as np
from models.ts2vec.ts2vec import TS2Vec
def calculate_fid(act1, act2):
# calculate mean and covariance statistics
mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
# calculate sum squared difference between means
ssdiff = np.sum((mu1 - mu2) ** 2.0)
# calculate sqrt of product between cov
covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2))
# check and correct imaginary numbers from sqrt
if np.iscomplexobj(covmean):
covmean = covmean.real
# calculate score
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid
def Context_FID(ori_data, generated_data):
model = TS2Vec(
input_dims=ori_data.shape[-1],
device=0,
batch_size=8,
lr=0.001,
output_dims=320,
max_train_length=3000,
)
model.fit(ori_data, verbose=False)
ori_represenation = model.encode(ori_data, encoding_window="full_series")
gen_represenation = model.encode(generated_data, encoding_window="full_series")
idx = np.random.permutation(ori_data.shape[0])
ori_represenation = ori_represenation[idx]
gen_represenation = gen_represenation[idx]
results = calculate_fid(ori_represenation, gen_represenation)
return results