import torch import numpy as np def sort_block(latent_index, block_size): device = latent_index.device latent_index_block = latent_index.cpu().numpy() latent_index_block[..., 1:] = latent_index_block[..., 1:] // block_size latent_index_inblock = latent_index.cpu().numpy() latent_index_inblock[..., 1:] = latent_index_inblock[..., 1:] % block_size sort_index = np.lexsort(( latent_index_inblock[..., 3], latent_index_inblock[..., 2], latent_index_inblock[..., 1], latent_index_block[..., 3], latent_index_block[..., 2], latent_index_block[..., 1]) ) sort_index = torch.from_numpy(sort_index).to(device) return latent_index[sort_index] def extract_tokens_and_coords(conditions, token_mask, num_cls=1, num_reg=4): device = conditions.device B = conditions.size(0) patch_size = token_mask.size(1) class_tokens = conditions[:, 0:num_cls, :] # [B, 1, 1024] register_tokens = conditions[:, num_cls:num_cls+num_reg, :] # [B, 4, 1024] patch_tokens = conditions[:, num_cls+num_reg:, :] # [B, 1369, 1024] selected_tokens_list = [] coords_list = [] for batch_idx in range(B): cls_tokens = class_tokens[batch_idx] # [1, 1024] reg_tokens = register_tokens[batch_idx] # [4, 1024] cls_reg_tokens = torch.cat([cls_tokens, reg_tokens], dim=0) # [5, 1024] cls_coord = torch.tensor([[batch_idx, 0, 0, 1]] * num_cls, device=device) reg_coords = torch.tensor([[batch_idx, 0, 0, 1]] * num_reg, device=device) cls_reg_coords = torch.cat([cls_coord, reg_coords], dim=0) mask = token_mask[batch_idx] pos = mask.nonzero(as_tuple=False) K = pos.size(0) if K > 0: h, w = pos[:, 0], pos[:, 1] indices = h * patch_size + w # patches = patch_tokens[batch_idx][indices] batch_ids = torch.full((K, 1), batch_idx, device=device) x = w.unsqueeze(1) y = h.unsqueeze(1) patch_coords = torch.cat([batch_ids, x, y, torch.zeros((K, 1), device=device)], dim=1) combined_tokens = torch.cat([cls_reg_tokens, patches], dim=0) combined_coords = torch.cat([cls_reg_coords, patch_coords], dim=0) else: combined_tokens = cls_reg_tokens combined_coords = cls_reg_coords selected_tokens_list.append(combined_tokens) coords_list.append(combined_coords) selected_tokens = torch.cat(selected_tokens_list, dim=0) coords = torch.cat(coords_list, dim=0) return selected_tokens, coords