Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import skdim | |
from ncut_pytorch.ncut_pytorch import farthest_point_sampling | |
import logging | |
def get_intrinsic_dim(feats, max_sample=2000): | |
if isinstance(feats, torch.Tensor): | |
feats = feats.cpu().detach().numpy() | |
feats = torch.tensor(feats) | |
feats = feats.reshape(-1, feats.shape[-1]) | |
if feats.shape[0] > max_sample: | |
sample_idx = farthest_point_sampling(feats, max_sample) | |
feats = feats[sample_idx] | |
data = feats.cpu().numpy() | |
id_est = skdim.id.MLE().fit(data) | |
dim = id_est.dimension_ | |
if dim == 0: | |
dim = np.mean(id_est.dimension_pw_) | |
logging.warning(f"failed to estimate global intrinsic dimension, using average of local intrinsic dimension {dim}") | |
return dim |