File size: 2,634 Bytes
bcb05d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import numpy as np

def sort_block(latent_index, block_size):
    device = latent_index.device
    latent_index_block = latent_index.cpu().numpy()
    latent_index_block[..., 1:] = latent_index_block[..., 1:] // block_size
    latent_index_inblock = latent_index.cpu().numpy()
    latent_index_inblock[..., 1:] = latent_index_inblock[..., 1:] % block_size
    sort_index = np.lexsort((
        latent_index_inblock[..., 3],
        latent_index_inblock[..., 2],
        latent_index_inblock[..., 1],
        latent_index_block[..., 3],
        latent_index_block[..., 2],
        latent_index_block[..., 1])
    )
    sort_index = torch.from_numpy(sort_index).to(device)
    return latent_index[sort_index]

def extract_tokens_and_coords(conditions, token_mask, num_cls=1, num_reg=4):
    device = conditions.device
    B = conditions.size(0)
    patch_size = token_mask.size(1) 
    
    class_tokens = conditions[:, 0:num_cls, :]  # [B, 1, 1024]
    register_tokens = conditions[:, num_cls:num_cls+num_reg, :]  # [B, 4, 1024]
    patch_tokens = conditions[:, num_cls+num_reg:, :]  # [B, 1369, 1024]

    selected_tokens_list = []
    coords_list = []

    for batch_idx in range(B):
        cls_tokens = class_tokens[batch_idx]  # [1, 1024]
        reg_tokens = register_tokens[batch_idx]  # [4, 1024]
        cls_reg_tokens = torch.cat([cls_tokens, reg_tokens], dim=0)  # [5, 1024]

        cls_coord = torch.tensor([[batch_idx, 0, 0, 1]] * num_cls, device=device)
        reg_coords = torch.tensor([[batch_idx, 0, 0, 1]] * num_reg, device=device)
        cls_reg_coords = torch.cat([cls_coord, reg_coords], dim=0)  

        mask = token_mask[batch_idx] 
        pos = mask.nonzero(as_tuple=False) 
        K = pos.size(0)

        if K > 0:
            h, w = pos[:, 0], pos[:, 1]
            indices = h * patch_size + w  #
            patches = patch_tokens[batch_idx][indices] 

            batch_ids = torch.full((K, 1), batch_idx, device=device)
            x = w.unsqueeze(1)  
            y = h.unsqueeze(1)  
            patch_coords = torch.cat([batch_ids, x, y, torch.zeros((K, 1), device=device)], dim=1)

            combined_tokens = torch.cat([cls_reg_tokens, patches], dim=0)
            combined_coords = torch.cat([cls_reg_coords, patch_coords], dim=0)
        else:
            combined_tokens = cls_reg_tokens
            combined_coords = cls_reg_coords

        selected_tokens_list.append(combined_tokens)
        coords_list.append(combined_coords)

    selected_tokens = torch.cat(selected_tokens_list, dim=0)
    coords = torch.cat(coords_list, dim=0)
    
    return selected_tokens, coords