Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
from PIL import Image | |
from dino_clip_featextract import img_transform_inv | |
from my_ipadapter_model import image_grid | |
from ncut_pytorch import NCUT, kway_ncut, rgb_from_tsne_3d, convert_to_lab_color | |
from ncut_pytorch.affinity_gamma import find_gamma_by_degree | |
from einops import rearrange | |
import torch | |
def ncut_tsne_multiple_images(image_embeds, n_eig=50, gamma=0.5, degree=0.5): | |
b, l, c = image_embeds.shape | |
inp = image_embeds.flatten(end_dim=-2) | |
if gamma is None: | |
gamma = find_gamma_by_degree(inp, degree, distance='rbf') | |
eigvec, eigval = NCUT(n_eig, affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp) | |
x3d, rgb = rgb_from_tsne_3d(eigvec, device='cuda', perplexity=50) | |
rgb = convert_to_lab_color(rgb) | |
rgb = rearrange(rgb, '(b l) c -> b l c', b=b) | |
eigvec = rearrange(eigvec, '(b l) c -> b l c', b=b) | |
return eigvec, rgb | |
def _kway_cluster_one_image(image_embeds, n_cluster, gamma=0.5, degree=0.5): | |
l, c = image_embeds.shape | |
inp = image_embeds.flatten(end_dim=-2) | |
if gamma is None: | |
gamma = find_gamma_by_degree(inp, degree, distance='rbf') | |
n_eig = n_cluster * 2 + 6 | |
n_eig = min(n_eig, inp.shape[0]//2-1) | |
num_samples = min(1000, inp.shape[0]//2) | |
eigvec, eigval = NCUT(n_eig, num_sample=num_samples, | |
affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp) | |
eigvec_continues = kway_ncut(eigvec[:, :n_cluster], return_continuous=True) | |
return eigvec_continues | |
def kway_cluster_per_image(image_embeds, n_cluster, gamma=0.5, degree=0.5): | |
eigvecs = [] | |
for i in range(image_embeds.shape[0]): | |
eigvec = _kway_cluster_one_image(image_embeds[i], n_cluster, gamma, degree) | |
eigvecs.append(eigvec) | |
eigvecs = torch.stack(eigvecs) | |
return eigvecs | |
def kway_cluster_multiple_images(image_embeds, n_cluster, gamma=0.5, degree=0.5): | |
b, l, c = image_embeds.shape | |
inp = image_embeds.flatten(end_dim=-2) | |
if gamma is None: | |
gamma = find_gamma_by_degree(inp, degree, distance='rbf') | |
n_eig = n_cluster * 2 + 6 | |
n_eig = min(n_eig, inp.shape[0]//2-1) | |
num_samples = min(1000, inp.shape[0]//2) | |
eigvec, eigval = NCUT(n_eig, num_sample=num_samples, | |
affinity_focal_gamma=gamma, distance='rbf', device='cuda').fit_transform(inp) | |
eigvec_continues = kway_ncut(eigvec[:, :n_cluster], return_continuous=True) | |
eigvec_continues = rearrange(eigvec_continues, '(b l) c -> b l c', b=b) | |
return eigvec_continues | |
def get_single_multi_discrete_rgbs(joint_rgbs, single_eigvecs): | |
n_cluster = single_eigvecs.shape[-1] | |
discrete_rgbs = np.zeros_like(joint_rgbs) | |
for i_img in range(joint_rgbs.shape[0]): | |
_rgb = joint_rgbs[i_img] | |
_eigvec = single_eigvecs[i_img].cpu().numpy() | |
_cluster_labels = _eigvec.argmax(-1) | |
_discrete_rgb = np.zeros_like(_rgb) | |
for i_cluster in range(n_cluster): | |
_discrete_rgb[_cluster_labels == i_cluster] = _rgb[_cluster_labels == i_cluster].mean(0) | |
discrete_rgbs[i_img] = _discrete_rgb | |
discrete_rgbs = discrete_rgbs * 255 | |
discrete_rgbs = discrete_rgbs.astype(np.uint8) | |
return discrete_rgbs | |
def get_center_features(image_embeds, cluster_labels, n_cluster): | |
center_features = torch.zeros((n_cluster, image_embeds.shape[-1])) | |
for i_cluster in range(n_cluster): | |
mask = cluster_labels == i_cluster | |
if mask.sum() > 0: | |
center_features[i_cluster] = image_embeds[mask].mean(0) | |
else: | |
# center_features[i_cluster] = torch.zeros_like(image_embeds[0]) | |
center_features[i_cluster] = torch.ones_like(image_embeds[0]) * 114514 | |
return center_features | |
def cosine_similarity(A, B): | |
_A = A / A.norm(dim=-1, keepdim=True) | |
_B = B / B.norm(dim=-1, keepdim=True) | |
return _A @ _B.T | |
from scipy.optimize import linear_sum_assignment | |
def hungarian_match_centers(center_features1, center_features2): | |
dist = torch.cdist(center_features1, center_features2) | |
dist = dist.cpu().detach().numpy() | |
row_ind, col_ind = linear_sum_assignment(dist) | |
return col_ind | |
def argmin_matching(center_features1, center_features2): | |
dist = torch.cdist(center_features1, center_features2) | |
dist = dist.cpu().detach().numpy() | |
return np.argmin(dist, axis=-1) | |
def match_centers(image_embed1, image_embed2, eigvec1, eigvec2, match_method='hungarian'): | |
cluster_label1 = eigvec1.argmax(-1).cpu().numpy() | |
cluster_label2 = eigvec2.argmax(-1).cpu().numpy() | |
n_cluster = eigvec1.shape[-1] | |
center_features1 = get_center_features(image_embed1, cluster_label1, n_cluster=n_cluster) | |
center_features2 = get_center_features(image_embed2, cluster_label2, n_cluster=n_cluster) | |
if match_method == 'hungarian': | |
one_to_one_mapping = hungarian_match_centers(center_features1, center_features2) | |
elif match_method == 'argmin': | |
one_to_one_mapping = argmin_matching(center_features1, center_features2) | |
return one_to_one_mapping | |
def match_centers_three_images(image_embeds, eigvecs, match_method='hungarian'): | |
# image_embeds: b, l, c; b = 3, A2, A1, B1 | |
# eigvecs: b, l | |
A2_to_A1 = match_centers(image_embeds[0], image_embeds[1], eigvecs[0], eigvecs[1], match_method=match_method) | |
A1_to_B1 = match_centers(image_embeds[1], image_embeds[2], eigvecs[1], eigvecs[2], match_method=match_method) | |
return A2_to_A1, A1_to_B1 | |
def match_centers_two_images(image_embed1, image_embed2, eigvec1, eigvec2, match_method='hungarian'): | |
one_to_one_mapping = match_centers(image_embed1, image_embed2, eigvec1, eigvec2, match_method=match_method) | |
return one_to_one_mapping | |
def plot_clusters(image, eigvec, cluster_order, hw=16): | |
cluster_images = [] | |
img = img_transform_inv(image).resize((128, 128), resample=Image.Resampling.NEAREST) | |
for idx_cluster in cluster_order: | |
mask = eigvec.argmax(-1) == idx_cluster | |
mask = mask.cpu().numpy()[1:].reshape(hw, hw) | |
mask = (mask * 255).astype(np.uint8) | |
mask = Image.fromarray(mask).resize((128, 128), resample=Image.Resampling.NEAREST) | |
# superimpose | |
mask = np.array(mask).astype(np.float32) / 255 | |
_img = np.array(img).astype(np.float32) / 255 | |
mask = np.stack([mask] * 3, axis=-1) | |
mask[mask == 0] = 0.1 | |
_img = _img * mask | |
_img = _img * 255 | |
_img = _img.astype(np.uint8) | |
cluster_images.append(Image.fromarray(_img)) | |
return cluster_images | |
def grid_one_image(image, eigvec, cluster_order, discrete_rgb, hw=16, n_cols=10): | |
cluster_images = plot_clusters(image, eigvec, cluster_order, hw) | |
img = img_transform_inv(image).resize((128, 128), resample=Image.Resampling.NEAREST) | |
ncut_image = discrete_rgb[1:].reshape(hw, hw, 3) | |
ncut_image = Image.fromarray(ncut_image).resize((128, 128), resample=Image.Resampling.NEAREST) | |
# extend cluster_images to n_cols | |
num_missing = n_cols - len(cluster_images) % n_cols | |
num_missing = 0 if num_missing == n_cols else num_missing | |
_img_append = Image.fromarray(np.zeros((128, 128, 3), dtype=np.uint8)) | |
cluster_images.extend([_img_append] * num_missing) | |
# add img and ncut_image before each row | |
prepend_images = [img, ncut_image] | |
n_rows = len(cluster_images) // n_cols | |
new_cluster_images = [] | |
for i_row in range(n_rows): | |
image_list = prepend_images + cluster_images[i_row * n_cols:(i_row + 1) * n_cols] | |
new_cluster_images.append(image_list) | |
return new_cluster_images | |
def grid_multiple_images(images, eigvecs, cluster_orders, discrete_rgbs, hw=16, n_cols=10): | |
grid_images = [] | |
for image, eigvec, cluster_order, discrete_rgb in zip(images, eigvecs, cluster_orders, discrete_rgbs): | |
grid_images.append(grid_one_image(image, eigvec, cluster_order, discrete_rgb, hw, n_cols)) | |
interleave_images = [] | |
for i_row in range(len(grid_images[0])): | |
for i_img in range(len(grid_images)): | |
interleave_images.append(grid_images[i_img][i_row]) | |
return interleave_images | |
def get_correspondence_plot(images, eigvecs, cluster_orders, discrete_rgbs, hw=16, n_cols=10): | |
n_cluster = eigvecs.shape[-1] | |
n_cols = min(n_cols, n_cluster) | |
interleave_images = grid_multiple_images(images, eigvecs, cluster_orders, discrete_rgbs, hw, n_cols) | |
n_row = len(interleave_images) | |
n_cols = len(interleave_images[0]) | |
grid = image_grid(sum(interleave_images, []), n_row, n_cols) | |
return grid |