MoodSpace / my_intrinsic_dim.py
huzey's picture
commit
456aee9
raw
history blame contribute delete
805 Bytes
import numpy as np
import torch
import skdim
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
import logging
def get_intrinsic_dim(feats, max_sample=2000):
if isinstance(feats, torch.Tensor):
feats = feats.cpu().detach().numpy()
feats = torch.tensor(feats)
feats = feats.reshape(-1, feats.shape[-1])
if feats.shape[0] > max_sample:
sample_idx = farthest_point_sampling(feats, max_sample)
feats = feats[sample_idx]
data = feats.cpu().numpy()
id_est = skdim.id.MLE().fit(data)
dim = id_est.dimension_
if dim == 0:
dim = np.mean(id_est.dimension_pw_)
logging.warning(f"failed to estimate global intrinsic dimension, using average of local intrinsic dimension {dim}")
return dim