File size: 20,890 Bytes
5e0b9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# ------------------------------------------------------------------------
# HOTR official code : main.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import torch
import torch.nn.functional as F
import copy
import numpy as np
import itertools
from torch import nn

from hotr.util import box_ops
from hotr.util.misc import (accuracy, get_world_size, is_dist_avail_and_initialized)

class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_actions=None, HOI_losses=None, HOI_matcher=None, args=None):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.eos_coef=eos_coef

        self.HOI_losses = HOI_losses
        self.HOI_matcher = HOI_matcher
        self.use_consis=args.use_consis & len(args.augpath_name)>0
        self.num_path = 1+len(args.augpath_name)
        if args:
            self.HOI_eos_coef = args.hoi_eos_coef
            if args.dataset_file == 'vcoco':
                self.invalid_ids = args.invalid_ids
                self.valid_ids = np.concatenate((args.valid_ids,[-1]), axis=0) # no interaction
            elif args.dataset_file == 'hico-det':
                self.invalid_ids = []
                self.valid_ids = list(range(num_actions)) + [-1]

                # for targets
                self.num_tgt_classes = len(args.valid_obj_ids)
                tgt_empty_weight = torch.ones(self.num_tgt_classes + 1)
                tgt_empty_weight[-1] = self.HOI_eos_coef
                self.register_buffer('tgt_empty_weight', tgt_empty_weight)
        self.dataset_file = args.dataset_file
        
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = eos_coef
        self.register_buffer('empty_weight', empty_weight)

    #######################################################################################################################
    # * DETR Losses
    #######################################################################################################################
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes),
            box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses


    #######################################################################################################################
    # * HOTR Losses
    #######################################################################################################################
    # >>> HOI Losses 1 : HO Pointer
    def loss_pair_labels(self, outputs, targets, hoi_indices, num_boxes,use_consis, log=False):
        assert ('pred_hidx' in outputs and 'pred_oidx' in outputs)
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
        nu,q,hd=outputs['pred_hidx'].shape
        src_hidx = outputs['pred_hidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1)
        src_oidx = outputs['pred_oidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1)
        hoi_ind=list(itertools.chain.from_iterable(hoi_indices))

        idx = self._get_src_permutation_idx(hoi_ind)
        
        target_hidx_classes = torch.full(src_hidx.shape[:2], -1, dtype=torch.int64, device=src_hidx.device)
        target_oidx_classes = torch.full(src_oidx.shape[:2], -1, dtype=torch.int64, device=src_oidx.device)

        # H Pointer loss        
        target_classes_h = torch.cat([t["h_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
        target_hidx_classes[idx] = target_classes_h

        # O Pointer loss
        target_classes_o = torch.cat([t["o_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
        target_oidx_classes[idx] = target_classes_o

        loss_h = F.cross_entropy(src_hidx.transpose(1, 2), target_hidx_classes, ignore_index=-1)
        loss_o = F.cross_entropy(src_oidx.transpose(1, 2), target_oidx_classes, ignore_index=-1)
        
        #Consistency loss
        if use_consis:
            consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices ]
            src_hidx_inputs=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
            src_hidx_targets=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
            src_oidx_inputs=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
            src_oidx_targets=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
            
            loss_h_consistency=[0.5*(F.kl_div(src_hidx_input.log(),src_hidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_hidx_target.log(),src_hidx_input.clone().detach(),reduction='batchmean')) for src_hidx_input,src_hidx_target in zip(src_hidx_inputs,src_hidx_targets)]
            loss_o_consistency=[0.5*(F.kl_div(src_oidx_input.log(),src_oidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_oidx_target.log(),src_oidx_input.clone().detach(),reduction='batchmean')) for src_oidx_input,src_oidx_target in zip(src_oidx_inputs,src_oidx_targets)]
            
            loss_h_consistency=torch.mean(torch.stack(loss_h_consistency))
            loss_o_consistency=torch.mean(torch.stack(loss_o_consistency))
            
            losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o,'loss_h_consistency':loss_h_consistency,'loss_o_consistency':loss_o_consistency}
        else:
            losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o}

        return losses

    # >>> HOI Losses 2 : pair actions
    def loss_pair_actions(self, outputs, targets, hoi_indices, num_boxes,use_consis):
        assert 'pred_actions' in outputs
        src_actions = outputs['pred_actions'].flatten(end_dim=1)
        hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
        # idx = self._get_src_permutation_idx(hoi_indices)
        idx = self._get_src_permutation_idx(hoi_ind)
        
        # Construct Target --------------------------------------------------------------------------------------------------------------
        target_classes_o = torch.cat([t["pair_actions"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
        target_classes = torch.full(src_actions.shape, 0, dtype=torch.float32, device=src_actions.device)
        target_classes[..., -1] = 1 # the last index for no-interaction is '1' if a label exists

        pos_classes = torch.full(target_classes[idx].shape, 0, dtype=torch.float32, device=src_actions.device) # else, the last index for no-interaction is '0'
        pos_classes[:, :-1] = target_classes_o.float()
        target_classes[idx] = pos_classes
        # --------------------------------------------------------------------------------------------------------------------------------

        # BCE Loss -----------------------------------------------------------------------------------------------------------------------
        logits = src_actions.sigmoid()
        loss_bce = F.binary_cross_entropy(logits[..., self.valid_ids], target_classes[..., self.valid_ids], reduction='none')
        p_t = logits[..., self.valid_ids] * target_classes[..., self.valid_ids] + (1 - logits[..., self.valid_ids]) * (1 - target_classes[..., self.valid_ids])
        loss_bce = ((1-p_t)**2 * loss_bce)
        alpha_t = 0.25 * target_classes[..., self.valid_ids] + (1 - 0.25) * (1 - target_classes[..., self.valid_ids])
        loss_focal = alpha_t * loss_bce
        loss_act = loss_focal.sum() / max(target_classes[..., self.valid_ids[:-1]].sum(), 1)
        # --------------------------------------------------------------------------------------------------------------------------------

        #Consistency loss
        if use_consis:
            consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices]
            src_action_inputs=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[0]]) for i,consistency_idx in enumerate(consistency_idxs)]
            src_action_targets=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[1]]) for i,consistency_idx in enumerate(consistency_idxs)]

            loss_action_consistency=[F.mse_loss(src_action_input,src_action_target) for src_action_input,src_action_target in zip(src_action_inputs,src_action_targets)]
            loss_action_consistency=torch.mean(torch.stack(loss_action_consistency))
            # import pdb;pdb.set_trace()
            losses = {'loss_act': loss_act,'loss_act_consistency':loss_action_consistency}
        else:
            losses = {'loss_act': loss_act}
        return losses

    # HOI Losses 3 : action targets
    def loss_pair_targets(self, outputs, targets, hoi_indices, num_interactions,use_consis, log=True):
        assert 'pred_obj_logits' in outputs
        src_logits = outputs['pred_obj_logits']
        nu,q,hd=outputs['pred_obj_logits'].shape
        hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
        idx = self._get_src_permutation_idx(hoi_ind)

        target_classes_o = torch.cat([t['pair_targets'][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
        pad_tgt = -1 # src_logits.shape[2]-1
        target_classes = torch.full(src_logits.shape[:2], pad_tgt, dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o
        
        loss_obj_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.tgt_empty_weight, ignore_index=-1)
        
        #consistency
        if use_consis:    
            consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices]
            src_logits_inputs=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
            src_logits_targets=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
            loss_tgt_consistency=[0.5*(F.kl_div(src_logit_input.log(),src_logit_target.clone().detach(),reduction='batchmean')+F.kl_div(src_logit_target.log(),src_logit_input.clone().detach(),reduction='batchmean')) for src_logit_input,src_logit_target in zip(src_logits_inputs,src_logits_targets)]
            loss_tgt_consistency=torch.mean(torch.stack(loss_tgt_consistency))
            losses = {'loss_tgt': loss_obj_ce,"loss_tgt_label_consistency":loss_tgt_consistency}
        else:
            losses = {'loss_tgt': loss_obj_ce}
        if log:
            ignore_idx = (target_classes_o != -1)
            losses['obj_class_error'] = 100 - accuracy(src_logits[idx][ignore_idx, :-1], target_classes_o[ignore_idx])[0]
            # losses['obj_class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_consistency_src_permutation_idx(self, indices):
        all_tgt=torch.cat([j for(_,j) in indices]).unique()
        path_idxs=[torch.cat([torch.tensor([i]) for i,(_,t)in enumerate(indices) if (t==tgt).any()]) for tgt in all_tgt]
        q_idxs=[torch.cat([s[t==tgt] for (s,t)in indices]) for tgt in all_tgt]
        path_idxs=torch.cat([torch.combinations(path_idx) for path_idx in path_idxs if len(path_idx)>1])
        q_idxs=torch.cat([torch.combinations(q_idx) for q_idx in q_idxs if len(q_idx)>1])
        
        return (path_idxs[:,0],q_idxs[:,0]),(path_idxs[:,1],q_idxs[:,1])

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    # *****************************************************************************
    # >>> DETR Losses
    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

    # >>> HOTR Losses
    def get_HOI_loss(self, loss, outputs, targets, indices, num_boxes,use_consis, **kwargs):
        loss_map = {
            'pair_labels': self.loss_pair_labels,
            'pair_actions': self.loss_pair_actions
        }
        if self.dataset_file == 'hico-det': loss_map['pair_targets'] = self.loss_pair_targets
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes,use_consis, **kwargs)
    # *****************************************************************************

    def forward(self, outputs, targets, log=False):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if (k != 'aux_outputs' and k != 'hoi_aux_outputs')}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)
       
        if self.HOI_losses is not None:
            input_targets = [copy.deepcopy(target) for target in targets]
            hoi_indices, hoi_targets = self.HOI_matcher(outputs_without_aux, input_targets, indices, log)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss == 'masks':
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs = {'log': False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        # HOI detection losses
        if self.HOI_losses is not None:
            for loss in self.HOI_losses:
                losses.update(self.get_HOI_loss(loss, outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis))
            # if self.dataset_file == 'hico-det': losses['loss_oidx'] += losses['loss_tgt']

            if 'hoi_aux_outputs' in outputs:
                for i, aux_outputs in enumerate(outputs['hoi_aux_outputs']):
                    input_targets = [copy.deepcopy(target) for target in targets]
                    hoi_indices, targets_for_aux = self.HOI_matcher(aux_outputs, input_targets, indices, log)
                    for loss in self.HOI_losses:
                        kwargs = {}
                        if loss == 'pair_targets': kwargs = {'log': False} # Logging is enabled only for the last layer
                        l_dict = self.get_HOI_loss(loss, aux_outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis, **kwargs)
                        l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                        losses.update(l_dict)
                    # if self.dataset_file == 'hico-det': losses[f'loss_oidx_{i}'] += losses[f'loss_tgt_{i}']

        return losses