File size: 11,871 Bytes
5e0b9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# ------------------------------------------------------------------------
# HOTR official code : hotr/models/hotr_matcher.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn

from hotr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou

import hotr.util.misc as utils
import wandb

class HungarianPairMatcher(nn.Module):
    def __init__(self, args):
        """Creates the matcher
        Params:
            cost_action: This is the relative weight of the multi-label action classification error in the matching cost
            cost_hbox: This is the relative weight of the classification error for human idx in the matching cost
            cost_obox: This is the relative weight of the classification error for object idx in the matching cost
        """
        super().__init__()
        self.cost_action = args.set_cost_act
        self.cost_hbox = self.cost_obox = args.set_cost_idx
        self.cost_target = args.set_cost_tgt
        self.log_printer = args.wandb
        self.is_vcoco = (args.dataset_file == 'vcoco')
        self.is_hico = (args.dataset_file == 'hico-det')
        if self.is_vcoco:
            self.valid_ids = args.valid_ids
            self.invalid_ids = args.invalid_ids
        assert self.cost_action != 0 or self.cost_hbox != 0 or self.cost_obox != 0, "all costs cant be 0"

    def reduce_redundant_gt_box(self, tgt_bbox, indices):
        """Filters redundant Ground-Truth Bounding Boxes
        Due to random crop augmentation, there exists cases where there exists
        multiple redundant labels for the exact same bounding box and object class.
        This function deals with the redundant labels for smoother HOTR training.
        """
        tgt_bbox_unique, map_idx, idx_cnt = torch.unique(tgt_bbox, dim=0, return_inverse=True, return_counts=True)
       
        k_idx, bbox_idx = indices
       
        triggered = False
        if (len(tgt_bbox) != len(tgt_bbox_unique)):
            map_dict = {k: v for k, v in enumerate(map_idx)}
            map_bbox2kidx = {int(bbox_id): k_id for bbox_id, k_id in zip(bbox_idx, k_idx)}

            bbox_lst, k_lst = [], []
            for bbox_id in bbox_idx:
                if map_dict[int(bbox_id)] not in bbox_lst:
                    bbox_lst.append(map_dict[int(bbox_id)])
                    k_lst.append(map_bbox2kidx[int(bbox_id)])
            bbox_idx = torch.tensor(bbox_lst)
            k_idx = torch.tensor(k_lst)
            tgt_bbox_res = tgt_bbox_unique
        else:
            tgt_bbox_res = tgt_bbox
       
        bbox_idx = bbox_idx.to(tgt_bbox.device)

        return tgt_bbox_res, k_idx, bbox_idx

    @torch.no_grad()
    def forward(self, outputs, targets, indices, log=False):
        assert "pred_actions" in outputs, "There is no action output for pair matching"
        num_obj_queries = outputs["pred_boxes"].shape[1]
        bs,num_path, num_queries = outputs["pred_actions"].shape[:3]
        detr_query_num = outputs["pred_logits"].shape[1] \
            if (outputs["pred_oidx"].shape[-1] == (outputs["pred_logits"].shape[1] + 1)) else -1

        return_list = []
        if self.log_printer and log:
            log_dict = {'h_cost': [], 'o_cost': [], 'act_cost': []}
            if self.is_hico: log_dict['tgt_cost'] = []

        for batch_idx in range(bs):
            tgt_bbox = targets[batch_idx]["boxes"] # (num_boxes, 4)
            tgt_cls = targets[batch_idx]["labels"] # (num_boxes)

            if self.is_vcoco:
                targets[batch_idx]["pair_actions"][:, self.invalid_ids] = 0
                keep_idx = (targets[batch_idx]["pair_actions"].sum(dim=-1) != 0)
                targets[batch_idx]["pair_boxes"] = targets[batch_idx]["pair_boxes"][keep_idx]
                targets[batch_idx]["pair_actions"] = targets[batch_idx]["pair_actions"][keep_idx]
                targets[batch_idx]["pair_targets"] = targets[batch_idx]["pair_targets"][keep_idx]

                tgt_pbox = targets[batch_idx]["pair_boxes"] # (num_pair_boxes, 8)
                tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 29)
                tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)

                tgt_hbox = tgt_pbox[:, :4] # (num_pair_boxes, 4)
                tgt_obox = tgt_pbox[:, 4:] # (num_pair_boxes, 4)
            elif self.is_hico:
                tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 117)
                tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)

                tgt_hbox = targets[batch_idx]["sub_boxes"] # (num_pair_boxes, 4)
                tgt_obox = targets[batch_idx]["obj_boxes"] # (num_pair_boxes, 4)
            
            # find which gt boxes match the h, o boxes in the pair
            if self.is_vcoco:
                hbox_with_cls = torch.cat([tgt_hbox, torch.ones((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
            elif self.is_hico:
                hbox_with_cls = torch.cat([tgt_hbox, torch.zeros((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
            obox_with_cls = torch.cat([tgt_obox, tgt_tgt.unsqueeze(-1)], dim=1)
            obox_with_cls[obox_with_cls[:, :4].sum(dim=1) == -4, -1] = -1 # turn the class of occluded objects to -1

            bbox_with_cls = torch.cat([tgt_bbox, tgt_cls.unsqueeze(-1)], dim=1)
            bbox_with_cls, k_idx, bbox_idx = self.reduce_redundant_gt_box(bbox_with_cls, indices[batch_idx])
            bbox_with_cls = torch.cat((bbox_with_cls, torch.as_tensor([-1.]*5).unsqueeze(0).to(tgt_cls.device)), dim=0)

            cost_hbox = torch.cdist(hbox_with_cls, bbox_with_cls, p=1)
            cost_obox = torch.cdist(obox_with_cls, bbox_with_cls, p=1)

            # find which gt boxes matches which prediction in K
            h_match_indices = torch.nonzero(cost_hbox == 0, as_tuple=False) # (num_hbox, num_boxes)
            o_match_indices = torch.nonzero(cost_obox == 0, as_tuple=False) # (num_obox, num_boxes)

            tgt_hids, tgt_oids = [], []

            # obtain ground truth indices for h
            if len(h_match_indices) != len(o_match_indices):
                import pdb; pdb.set_trace()

            for h_match_idx, o_match_idx in zip(h_match_indices, o_match_indices):
                hbox_idx, H_bbox_idx = h_match_idx
                obox_idx, O_bbox_idx = o_match_idx
                if O_bbox_idx == (len(bbox_with_cls)-1): # if the object class is -1
                    O_bbox_idx = H_bbox_idx # happens in V-COCO, the target object may not appear

                GT_idx_for_H = (bbox_idx == H_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
                query_idx_for_H = k_idx[GT_idx_for_H]
                tgt_hids.append(query_idx_for_H)

                GT_idx_for_O = (bbox_idx == O_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
                query_idx_for_O = k_idx[GT_idx_for_O]
                tgt_oids.append(query_idx_for_O)

            # check if empty
            if len(tgt_hids) == 0: tgt_hids.append(torch.as_tensor([-1])) # we later ignore the label -1
            if len(tgt_oids) == 0: tgt_oids.append(torch.as_tensor([-1])) # we later ignore the label -1

            tgt_sum = (tgt_act.sum(dim=-1)).unsqueeze(0)
            flag = False
            if tgt_act.shape[0] == 0:
                tgt_act = torch.zeros((1, tgt_act.shape[1])).to(targets[batch_idx]["pair_actions"].device)
                targets[batch_idx]["pair_actions"] = torch.zeros((1, targets[batch_idx]["pair_actions"].shape[1])).to(targets[batch_idx]["pair_actions"].device)
                if self.is_hico:
                    pad_tgt = -1 # outputs["pred_obj_logits"].shape[-1]-1
                    tgt_tgt = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"])
                    targets[batch_idx]["pair_targets"] = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"].device)
                tgt_sum = (tgt_act.sum(dim=-1) + 1).unsqueeze(0)

            # Concat target label
            tgt_hids = torch.cat(tgt_hids).repeat(num_path)
            tgt_oids = torch.cat(tgt_oids).repeat(num_path)
            # import pdb;pdb.set_trace()
            outputs_hidx=outputs["pred_hidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2)
            outputs_oidx=outputs["pred_oidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2)
            
            outputs_action=outputs["pred_actions"].view(bs,num_path*num_queries,-1)
            out_hprob = outputs_hidx[batch_idx].softmax(-1)
            out_oprob = outputs_oidx[batch_idx].softmax(-1)
            out_act = outputs_action[batch_idx].clone()
            if self.is_vcoco: out_act[..., self.invalid_ids] = 0
            if self.is_hico:
                outputs_obj_logits=outputs["pred_obj_logits"].view(bs,num_path,num_queries,-1).view(bs,num_path*num_queries,-1)
                out_tgt = outputs_obj_logits[batch_idx].softmax(-1)
                out_tgt[..., -1] = 0 # don't get cost for no-object
            
            tgt_act = torch.cat([tgt_act, torch.zeros(tgt_act.shape[0]).unsqueeze(-1).to(tgt_act.device)], dim=-1).repeat(num_path,1)
            
            cost_hclass = -out_hprob[:, tgt_hids] # [batch_size * num_queries, detr.num_queries+1]
            cost_oclass = -out_oprob[:, tgt_oids] # [batch_size * num_queries, detr.num_queries+1]
            # import pdb;pdb.set_trace()
            cost_pos_act = (-torch.matmul(out_act, tgt_act.t().float())) / tgt_sum.repeat(1,num_path)
            cost_neg_act = (torch.matmul(out_act, (~tgt_act.bool()).type(torch.int64).t().float())) / (~tgt_act.bool()).type(torch.int64).sum(dim=-1).unsqueeze(0)
            cost_action = cost_pos_act + cost_neg_act

            h_cost = self.cost_hbox * cost_hclass
            o_cost = self.cost_obox * cost_oclass
            
            act_cost = self.cost_action * cost_action
            
            C = h_cost + o_cost + act_cost
            
            if self.is_hico:
                cost_target = -out_tgt[:, tgt_tgt.repeat(num_path)]
                tgt_cost = self.cost_target * cost_target
                C += tgt_cost
            C = C.view(num_path,num_queries, -1).cpu()
            
            sizes = [len(tgt_hids)//num_path]*num_path
            hoi_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
            return_list.append([(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in hoi_indices])
            # import pdb;pdb.set_trace()
            targets[batch_idx]["h_labels"] = tgt_hids.to(tgt_hbox.device)
            targets[batch_idx]["o_labels"] = tgt_oids.to(tgt_obox.device)
            log_act_cost = torch.zeros([1]).to(tgt_act.device) if tgt_act.shape[0] == 0 else act_cost.min(dim=0)[0].mean()

            if self.log_printer and log:
                log_dict['h_cost'].append(h_cost[:num_queries].min(dim=0)[0].mean())
                log_dict['o_cost'].append(o_cost[:num_queries].min(dim=0)[0].mean())
                log_dict['act_cost'].append(act_cost[:num_queries].min(dim=0)[0].mean())
                if self.is_hico: log_dict['tgt_cost'].append(tgt_cost[:num_queries].min(dim=0)[0].mean())
        if self.log_printer and log:
            log_dict['h_cost'] = torch.stack(log_dict['h_cost']).mean()
            log_dict['o_cost'] = torch.stack(log_dict['o_cost']).mean()
            log_dict['act_cost'] = torch.stack(log_dict['act_cost']).mean()
            if self.is_hico: log_dict['tgt_cost'] = torch.stack(log_dict['tgt_cost']).mean()
            if utils.get_rank() == 0: wandb.log(log_dict)
        return return_list, targets

def build_hoi_matcher(args):
    return HungarianPairMatcher(args)