import numpy as np from PIL import Image from dino_clip_featextract import img_transform_inv from my_ipadapter_model import image_grid from ncut_pytorch import NCUT, kway_ncut, rgb_from_tsne_3d, convert_to_lab_color from ncut_pytorch.affinity_gamma import find_gamma_by_degree from einops import rearrange import torch def ncut_tsne_multiple_images(image_embeds, n_eig=50, gamma=0.5, degree=0.5): b, l, c = image_embeds.shape inp = image_embeds.flatten(end_dim=-2) if gamma is None: gamma = find_gamma_by_degree(inp, degree, distance='rbf') eigvec, eigval = NCUT(n_eig, affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp) x3d, rgb = rgb_from_tsne_3d(eigvec, device='cuda', perplexity=50) rgb = convert_to_lab_color(rgb) rgb = rearrange(rgb, '(b l) c -> b l c', b=b) eigvec = rearrange(eigvec, '(b l) c -> b l c', b=b) return eigvec, rgb def _kway_cluster_one_image(image_embeds, n_cluster, gamma=0.5, degree=0.5): l, c = image_embeds.shape inp = image_embeds.flatten(end_dim=-2) if gamma is None: gamma = find_gamma_by_degree(inp, degree, distance='rbf') n_eig = n_cluster * 2 + 6 n_eig = min(n_eig, inp.shape[0]//2-1) num_samples = min(1000, inp.shape[0]//2) eigvec, eigval = NCUT(n_eig, num_sample=num_samples, affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp) eigvec_continues = kway_ncut(eigvec[:, :n_cluster], return_continuous=True) return eigvec_continues def kway_cluster_per_image(image_embeds, n_cluster, gamma=0.5, degree=0.5): eigvecs = [] for i in range(image_embeds.shape[0]): eigvec = _kway_cluster_one_image(image_embeds[i], n_cluster, gamma, degree) eigvecs.append(eigvec) eigvecs = torch.stack(eigvecs) return eigvecs def kway_cluster_multiple_images(image_embeds, n_cluster, gamma=0.5, degree=0.5): b, l, c = image_embeds.shape inp = image_embeds.flatten(end_dim=-2) if gamma is None: gamma = find_gamma_by_degree(inp, degree, distance='rbf') n_eig = n_cluster * 2 + 6 n_eig = min(n_eig, inp.shape[0]//2-1) num_samples = min(1000, inp.shape[0]//2) eigvec, eigval = NCUT(n_eig, num_sample=num_samples, affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp) eigvec_continues = kway_ncut(eigvec[:, :n_cluster], return_continuous=True) eigvec_continues = rearrange(eigvec_continues, '(b l) c -> b l c', b=b) return eigvec_continues def get_single_multi_discrete_rgbs(joint_rgbs, single_eigvecs): n_cluster = single_eigvecs.shape[-1] discrete_rgbs = np.zeros_like(joint_rgbs) for i_img in range(joint_rgbs.shape[0]): _rgb = joint_rgbs[i_img] _eigvec = single_eigvecs[i_img].cpu().numpy() _cluster_labels = _eigvec.argmax(-1) _discrete_rgb = np.zeros_like(_rgb) for i_cluster in range(n_cluster): _discrete_rgb[_cluster_labels == i_cluster] = _rgb[_cluster_labels == i_cluster].mean(0) discrete_rgbs[i_img] = _discrete_rgb discrete_rgbs = discrete_rgbs * 255 discrete_rgbs = discrete_rgbs.astype(np.uint8) return discrete_rgbs def get_center_features(image_embeds, cluster_labels, n_cluster): center_features = torch.zeros((n_cluster, image_embeds.shape[-1])) for i_cluster in range(n_cluster): mask = cluster_labels == i_cluster if mask.sum() > 0: center_features[i_cluster] = image_embeds[mask].mean(0) else: # center_features[i_cluster] = torch.zeros_like(image_embeds[0]) center_features[i_cluster] = torch.ones_like(image_embeds[0]) * 114514 return center_features def cosine_similarity(A, B): _A = A / A.norm(dim=-1, keepdim=True) _B = B / B.norm(dim=-1, keepdim=True) return _A @ _B.T from scipy.optimize import linear_sum_assignment def hungarian_match_centers(center_features1, center_features2): dist = torch.cdist(center_features1, center_features2) dist = dist.cpu().detach().numpy() row_ind, col_ind = linear_sum_assignment(dist) return col_ind def argmin_matching(center_features1, center_features2): dist = torch.cdist(center_features1, center_features2) dist = dist.cpu().detach().numpy() return np.argmin(dist, axis=-1) def match_centers(image_embed1, image_embed2, eigvec1, eigvec2, match_method='hungarian'): cluster_label1 = eigvec1.argmax(-1).cpu().numpy() cluster_label2 = eigvec2.argmax(-1).cpu().numpy() n_cluster = eigvec1.shape[-1] center_features1 = get_center_features(image_embed1, cluster_label1, n_cluster=n_cluster) center_features2 = get_center_features(image_embed2, cluster_label2, n_cluster=n_cluster) if match_method == 'hungarian': one_to_one_mapping = hungarian_match_centers(center_features1, center_features2) elif match_method == 'argmin': one_to_one_mapping = argmin_matching(center_features1, center_features2) return one_to_one_mapping def match_centers_three_images(image_embeds, eigvecs, match_method='hungarian'): # image_embeds: b, l, c; b = 3, A2, A1, B1 # eigvecs: b, l A2_to_A1 = match_centers(image_embeds[0], image_embeds[1], eigvecs[0], eigvecs[1], match_method=match_method) A1_to_B1 = match_centers(image_embeds[1], image_embeds[2], eigvecs[1], eigvecs[2], match_method=match_method) return A2_to_A1, A1_to_B1 def match_centers_two_images(image_embed1, image_embed2, eigvec1, eigvec2, match_method='hungarian'): one_to_one_mapping = match_centers(image_embed1, image_embed2, eigvec1, eigvec2, match_method=match_method) return one_to_one_mapping def plot_clusters(image, eigvec, cluster_order, hw=16): cluster_images = [] img = img_transform_inv(image).resize((128, 128), resample=Image.Resampling.NEAREST) for idx_cluster in cluster_order: mask = eigvec.argmax(-1) == idx_cluster mask = mask.cpu().numpy()[1:].reshape(hw, hw) mask = (mask * 255).astype(np.uint8) mask = Image.fromarray(mask).resize((128, 128), resample=Image.Resampling.NEAREST) # superimpose mask = np.array(mask).astype(np.float32) / 255 _img = np.array(img).astype(np.float32) / 255 mask = np.stack([mask] * 3, axis=-1) mask[mask == 0] = 0.1 _img = _img * mask _img = _img * 255 _img = _img.astype(np.uint8) cluster_images.append(Image.fromarray(_img)) return cluster_images def grid_one_image(image, eigvec, cluster_order, discrete_rgb, hw=16, n_cols=10): cluster_images = plot_clusters(image, eigvec, cluster_order, hw) img = img_transform_inv(image).resize((128, 128), resample=Image.Resampling.NEAREST) ncut_image = discrete_rgb[1:].reshape(hw, hw, 3) ncut_image = Image.fromarray(ncut_image).resize((128, 128), resample=Image.Resampling.NEAREST) # extend cluster_images to n_cols num_missing = n_cols - len(cluster_images) % n_cols num_missing = 0 if num_missing == n_cols else num_missing _img_append = Image.fromarray(np.zeros((128, 128, 3), dtype=np.uint8)) cluster_images.extend([_img_append] * num_missing) # add img and ncut_image before each row prepend_images = [img, ncut_image] n_rows = len(cluster_images) // n_cols new_cluster_images = [] for i_row in range(n_rows): image_list = prepend_images + cluster_images[i_row * n_cols:(i_row + 1) * n_cols] new_cluster_images.append(image_list) return new_cluster_images def grid_multiple_images(images, eigvecs, cluster_orders, discrete_rgbs, hw=16, n_cols=10): grid_images = [] for image, eigvec, cluster_order, discrete_rgb in zip(images, eigvecs, cluster_orders, discrete_rgbs): grid_images.append(grid_one_image(image, eigvec, cluster_order, discrete_rgb, hw, n_cols)) interleave_images = [] for i_row in range(len(grid_images[0])): for i_img in range(len(grid_images)): interleave_images.append(grid_images[i_img][i_row]) return interleave_images def get_correspondence_plot(images, eigvecs, cluster_orders, discrete_rgbs, hw=16, n_cols=10): n_cluster = eigvecs.shape[-1] n_cols = min(n_cols, n_cluster) interleave_images = grid_multiple_images(images, eigvecs, cluster_orders, discrete_rgbs, hw, n_cols) n_row = len(interleave_images) n_cols = len(interleave_images[0]) grid = image_grid(sum(interleave_images, []), n_row, n_cols) return grid