Spaces:
Runtime error
Runtime error
File size: 11,871 Bytes
5e0b9df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# ------------------------------------------------------------------------
# HOTR official code : hotr/models/hotr_matcher.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn
from hotr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
import hotr.util.misc as utils
import wandb
class HungarianPairMatcher(nn.Module):
def __init__(self, args):
"""Creates the matcher
Params:
cost_action: This is the relative weight of the multi-label action classification error in the matching cost
cost_hbox: This is the relative weight of the classification error for human idx in the matching cost
cost_obox: This is the relative weight of the classification error for object idx in the matching cost
"""
super().__init__()
self.cost_action = args.set_cost_act
self.cost_hbox = self.cost_obox = args.set_cost_idx
self.cost_target = args.set_cost_tgt
self.log_printer = args.wandb
self.is_vcoco = (args.dataset_file == 'vcoco')
self.is_hico = (args.dataset_file == 'hico-det')
if self.is_vcoco:
self.valid_ids = args.valid_ids
self.invalid_ids = args.invalid_ids
assert self.cost_action != 0 or self.cost_hbox != 0 or self.cost_obox != 0, "all costs cant be 0"
def reduce_redundant_gt_box(self, tgt_bbox, indices):
"""Filters redundant Ground-Truth Bounding Boxes
Due to random crop augmentation, there exists cases where there exists
multiple redundant labels for the exact same bounding box and object class.
This function deals with the redundant labels for smoother HOTR training.
"""
tgt_bbox_unique, map_idx, idx_cnt = torch.unique(tgt_bbox, dim=0, return_inverse=True, return_counts=True)
k_idx, bbox_idx = indices
triggered = False
if (len(tgt_bbox) != len(tgt_bbox_unique)):
map_dict = {k: v for k, v in enumerate(map_idx)}
map_bbox2kidx = {int(bbox_id): k_id for bbox_id, k_id in zip(bbox_idx, k_idx)}
bbox_lst, k_lst = [], []
for bbox_id in bbox_idx:
if map_dict[int(bbox_id)] not in bbox_lst:
bbox_lst.append(map_dict[int(bbox_id)])
k_lst.append(map_bbox2kidx[int(bbox_id)])
bbox_idx = torch.tensor(bbox_lst)
k_idx = torch.tensor(k_lst)
tgt_bbox_res = tgt_bbox_unique
else:
tgt_bbox_res = tgt_bbox
bbox_idx = bbox_idx.to(tgt_bbox.device)
return tgt_bbox_res, k_idx, bbox_idx
@torch.no_grad()
def forward(self, outputs, targets, indices, log=False):
assert "pred_actions" in outputs, "There is no action output for pair matching"
num_obj_queries = outputs["pred_boxes"].shape[1]
bs,num_path, num_queries = outputs["pred_actions"].shape[:3]
detr_query_num = outputs["pred_logits"].shape[1] \
if (outputs["pred_oidx"].shape[-1] == (outputs["pred_logits"].shape[1] + 1)) else -1
return_list = []
if self.log_printer and log:
log_dict = {'h_cost': [], 'o_cost': [], 'act_cost': []}
if self.is_hico: log_dict['tgt_cost'] = []
for batch_idx in range(bs):
tgt_bbox = targets[batch_idx]["boxes"] # (num_boxes, 4)
tgt_cls = targets[batch_idx]["labels"] # (num_boxes)
if self.is_vcoco:
targets[batch_idx]["pair_actions"][:, self.invalid_ids] = 0
keep_idx = (targets[batch_idx]["pair_actions"].sum(dim=-1) != 0)
targets[batch_idx]["pair_boxes"] = targets[batch_idx]["pair_boxes"][keep_idx]
targets[batch_idx]["pair_actions"] = targets[batch_idx]["pair_actions"][keep_idx]
targets[batch_idx]["pair_targets"] = targets[batch_idx]["pair_targets"][keep_idx]
tgt_pbox = targets[batch_idx]["pair_boxes"] # (num_pair_boxes, 8)
tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 29)
tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)
tgt_hbox = tgt_pbox[:, :4] # (num_pair_boxes, 4)
tgt_obox = tgt_pbox[:, 4:] # (num_pair_boxes, 4)
elif self.is_hico:
tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 117)
tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)
tgt_hbox = targets[batch_idx]["sub_boxes"] # (num_pair_boxes, 4)
tgt_obox = targets[batch_idx]["obj_boxes"] # (num_pair_boxes, 4)
# find which gt boxes match the h, o boxes in the pair
if self.is_vcoco:
hbox_with_cls = torch.cat([tgt_hbox, torch.ones((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
elif self.is_hico:
hbox_with_cls = torch.cat([tgt_hbox, torch.zeros((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
obox_with_cls = torch.cat([tgt_obox, tgt_tgt.unsqueeze(-1)], dim=1)
obox_with_cls[obox_with_cls[:, :4].sum(dim=1) == -4, -1] = -1 # turn the class of occluded objects to -1
bbox_with_cls = torch.cat([tgt_bbox, tgt_cls.unsqueeze(-1)], dim=1)
bbox_with_cls, k_idx, bbox_idx = self.reduce_redundant_gt_box(bbox_with_cls, indices[batch_idx])
bbox_with_cls = torch.cat((bbox_with_cls, torch.as_tensor([-1.]*5).unsqueeze(0).to(tgt_cls.device)), dim=0)
cost_hbox = torch.cdist(hbox_with_cls, bbox_with_cls, p=1)
cost_obox = torch.cdist(obox_with_cls, bbox_with_cls, p=1)
# find which gt boxes matches which prediction in K
h_match_indices = torch.nonzero(cost_hbox == 0, as_tuple=False) # (num_hbox, num_boxes)
o_match_indices = torch.nonzero(cost_obox == 0, as_tuple=False) # (num_obox, num_boxes)
tgt_hids, tgt_oids = [], []
# obtain ground truth indices for h
if len(h_match_indices) != len(o_match_indices):
import pdb; pdb.set_trace()
for h_match_idx, o_match_idx in zip(h_match_indices, o_match_indices):
hbox_idx, H_bbox_idx = h_match_idx
obox_idx, O_bbox_idx = o_match_idx
if O_bbox_idx == (len(bbox_with_cls)-1): # if the object class is -1
O_bbox_idx = H_bbox_idx # happens in V-COCO, the target object may not appear
GT_idx_for_H = (bbox_idx == H_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
query_idx_for_H = k_idx[GT_idx_for_H]
tgt_hids.append(query_idx_for_H)
GT_idx_for_O = (bbox_idx == O_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
query_idx_for_O = k_idx[GT_idx_for_O]
tgt_oids.append(query_idx_for_O)
# check if empty
if len(tgt_hids) == 0: tgt_hids.append(torch.as_tensor([-1])) # we later ignore the label -1
if len(tgt_oids) == 0: tgt_oids.append(torch.as_tensor([-1])) # we later ignore the label -1
tgt_sum = (tgt_act.sum(dim=-1)).unsqueeze(0)
flag = False
if tgt_act.shape[0] == 0:
tgt_act = torch.zeros((1, tgt_act.shape[1])).to(targets[batch_idx]["pair_actions"].device)
targets[batch_idx]["pair_actions"] = torch.zeros((1, targets[batch_idx]["pair_actions"].shape[1])).to(targets[batch_idx]["pair_actions"].device)
if self.is_hico:
pad_tgt = -1 # outputs["pred_obj_logits"].shape[-1]-1
tgt_tgt = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"])
targets[batch_idx]["pair_targets"] = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"].device)
tgt_sum = (tgt_act.sum(dim=-1) + 1).unsqueeze(0)
# Concat target label
tgt_hids = torch.cat(tgt_hids).repeat(num_path)
tgt_oids = torch.cat(tgt_oids).repeat(num_path)
# import pdb;pdb.set_trace()
outputs_hidx=outputs["pred_hidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2)
outputs_oidx=outputs["pred_oidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2)
outputs_action=outputs["pred_actions"].view(bs,num_path*num_queries,-1)
out_hprob = outputs_hidx[batch_idx].softmax(-1)
out_oprob = outputs_oidx[batch_idx].softmax(-1)
out_act = outputs_action[batch_idx].clone()
if self.is_vcoco: out_act[..., self.invalid_ids] = 0
if self.is_hico:
outputs_obj_logits=outputs["pred_obj_logits"].view(bs,num_path,num_queries,-1).view(bs,num_path*num_queries,-1)
out_tgt = outputs_obj_logits[batch_idx].softmax(-1)
out_tgt[..., -1] = 0 # don't get cost for no-object
tgt_act = torch.cat([tgt_act, torch.zeros(tgt_act.shape[0]).unsqueeze(-1).to(tgt_act.device)], dim=-1).repeat(num_path,1)
cost_hclass = -out_hprob[:, tgt_hids] # [batch_size * num_queries, detr.num_queries+1]
cost_oclass = -out_oprob[:, tgt_oids] # [batch_size * num_queries, detr.num_queries+1]
# import pdb;pdb.set_trace()
cost_pos_act = (-torch.matmul(out_act, tgt_act.t().float())) / tgt_sum.repeat(1,num_path)
cost_neg_act = (torch.matmul(out_act, (~tgt_act.bool()).type(torch.int64).t().float())) / (~tgt_act.bool()).type(torch.int64).sum(dim=-1).unsqueeze(0)
cost_action = cost_pos_act + cost_neg_act
h_cost = self.cost_hbox * cost_hclass
o_cost = self.cost_obox * cost_oclass
act_cost = self.cost_action * cost_action
C = h_cost + o_cost + act_cost
if self.is_hico:
cost_target = -out_tgt[:, tgt_tgt.repeat(num_path)]
tgt_cost = self.cost_target * cost_target
C += tgt_cost
C = C.view(num_path,num_queries, -1).cpu()
sizes = [len(tgt_hids)//num_path]*num_path
hoi_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
return_list.append([(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in hoi_indices])
# import pdb;pdb.set_trace()
targets[batch_idx]["h_labels"] = tgt_hids.to(tgt_hbox.device)
targets[batch_idx]["o_labels"] = tgt_oids.to(tgt_obox.device)
log_act_cost = torch.zeros([1]).to(tgt_act.device) if tgt_act.shape[0] == 0 else act_cost.min(dim=0)[0].mean()
if self.log_printer and log:
log_dict['h_cost'].append(h_cost[:num_queries].min(dim=0)[0].mean())
log_dict['o_cost'].append(o_cost[:num_queries].min(dim=0)[0].mean())
log_dict['act_cost'].append(act_cost[:num_queries].min(dim=0)[0].mean())
if self.is_hico: log_dict['tgt_cost'].append(tgt_cost[:num_queries].min(dim=0)[0].mean())
if self.log_printer and log:
log_dict['h_cost'] = torch.stack(log_dict['h_cost']).mean()
log_dict['o_cost'] = torch.stack(log_dict['o_cost']).mean()
log_dict['act_cost'] = torch.stack(log_dict['act_cost']).mean()
if self.is_hico: log_dict['tgt_cost'] = torch.stack(log_dict['tgt_cost']).mean()
if utils.get_rank() == 0: wandb.log(log_dict)
return return_list, targets
def build_hoi_matcher(args):
return HungarianPairMatcher(args)
|