Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,798 Bytes
499e141 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from loguru import logger
import torch
import torch.nn as nn
from kornia.geometry.conversions import convert_points_to_homogeneous
from kornia.geometry.epipolar import numeric
class XoFTRLoss(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config # config under the global namespace
self.loss_config = config['xoftr']['loss']
self.pos_w = self.loss_config['pos_weight']
self.neg_w = self.loss_config['neg_weight']
def compute_fine_matching_loss(self, data):
""" Point-wise Focal Loss with 0 / 1 confidence as gt.
Args:
data (dict): {
conf_matrix_fine (torch.Tensor): (N, W_f^2, W_f^2)
conf_matrix_f_gt (torch.Tensor): (N, W_f^2, W_f^2)
}
"""
conf_matrix_fine = data['conf_matrix_fine']
conf_matrix_f_gt = data['conf_matrix_f_gt']
pos_mask, neg_mask = conf_matrix_f_gt > 0, conf_matrix_f_gt == 0
pos_w, neg_w = self.pos_w, self.neg_w
if not pos_mask.any(): # assign a wrong gt
pos_mask[0, 0, 0] = True
pos_w = 0.
if not neg_mask.any():
neg_mask[0, 0, 0] = True
neg_w = 0.
conf_matrix_fine = torch.clamp(conf_matrix_fine, 1e-6, 1-1e-6)
alpha = self.loss_config['focal_alpha']
gamma = self.loss_config['focal_gamma']
loss_pos = - alpha * torch.pow(1 - conf_matrix_fine[pos_mask], gamma) * (conf_matrix_fine[pos_mask]).log()
# loss_pos *= conf_matrix_f_gt[pos_mask]
loss_neg = - alpha * torch.pow(conf_matrix_fine[neg_mask], gamma) * (1 - conf_matrix_fine[neg_mask]).log()
return pos_w * loss_pos.mean() + neg_w * loss_neg.mean()
def _symmetric_epipolar_distance(self, pts0, pts1, E, K0, K1):
"""Squared symmetric epipolar distance.
This can be seen as a biased estimation of the reprojection error.
Args:
pts0 (torch.Tensor): [N, 2]
E (torch.Tensor): [3, 3]
"""
pts0 = (pts0 - K0[:, [0, 1], [2, 2]]) / K0[:, [0, 1], [0, 1]]
pts1 = (pts1 - K1[:, [0, 1], [2, 2]]) / K1[:, [0, 1], [0, 1]]
pts0 = convert_points_to_homogeneous(pts0)
pts1 = convert_points_to_homogeneous(pts1)
Ep0 = (pts0[:,None,:] @ E.transpose(-2,-1)).squeeze(1) # [N, 3]
p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
Etp1 = (pts1[:,None,:] @ E).squeeze(1) # [N, 3]
d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2 + 1e-9) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2 + 1e-9)) # N
return d
def compute_sub_pixel_loss(self, data):
""" symmetric epipolar distance loss.
Args:
data (dict): {
m_bids (torch.Tensor): (N)
T_0to1 (torch.Tensor): (B, 4, 4)
mkpts0_f_train (torch.Tensor): (N, 2)
mkpts1_f_train (torch.Tensor): (N, 2)
}
"""
Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
E_mat = Tx @ data['T_0to1'][:, :3, :3]
m_bids = data['m_bids']
pts0 = data['mkpts0_f_train']
pts1 = data['mkpts1_f_train']
sym_dist = self._symmetric_epipolar_distance(pts0, pts1, E_mat[m_bids], data['K0'][m_bids], data['K1'][m_bids])
# filter matches with high epipolar error (only train approximately correct fine-level matches)
loss = sym_dist[sym_dist<1e-4]
if len(loss) == 0:
return torch.zeros(1, device=loss.device, requires_grad=False)[0]
return loss.mean()
def compute_coarse_loss(self, data, weight=None):
""" Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
Args:
data (dict): {
conf_matrix_0_to_1 (torch.Tensor): (N, HW0, HW1)
conf_matrix_1_to_0 (torch.Tensor): (N, HW0, HW1)
conf_gt (torch.Tensor): (N, HW0, HW1)
}
weight (torch.Tensor): (N, HW0, HW1)
"""
conf_matrix_0_to_1 = data["conf_matrix_0_to_1"]
conf_matrix_1_to_0 = data["conf_matrix_1_to_0"]
conf_gt = data["conf_matrix_gt"]
pos_mask = conf_gt == 1
c_pos_w = self.pos_w
# corner case: no gt coarse-level match at all
if not pos_mask.any(): # assign a wrong gt
pos_mask[0, 0, 0] = True
if weight is not None:
weight[0, 0, 0] = 0.
c_pos_w = 0.
conf_matrix_0_to_1 = torch.clamp(conf_matrix_0_to_1, 1e-6, 1-1e-6)
conf_matrix_1_to_0 = torch.clamp(conf_matrix_1_to_0, 1e-6, 1-1e-6)
alpha = self.loss_config['focal_alpha']
gamma = self.loss_config['focal_gamma']
loss_pos = - alpha * torch.pow(1 - conf_matrix_0_to_1[pos_mask], gamma) * (conf_matrix_0_to_1[pos_mask]).log()
loss_pos += - alpha * torch.pow(1 - conf_matrix_1_to_0[pos_mask], gamma) * (conf_matrix_1_to_0[pos_mask]).log()
if weight is not None:
loss_pos = loss_pos * weight[pos_mask]
loss_c = (c_pos_w * loss_pos.mean())
return loss_c
@torch.no_grad()
def compute_c_weight(self, data):
""" compute element-wise weights for computing coarse-level loss. """
if 'mask0' in data:
c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
else:
c_weight = None
return c_weight
def forward(self, data):
"""
Update:
data (dict): update{
'loss': [1] the reduced loss across a batch,
'loss_scalars' (dict): loss scalars for tensorboard_record
}
"""
loss_scalars = {}
# 0. compute element-wise loss weight
c_weight = self.compute_c_weight(data)
# 1. coarse-level loss
loss_c = self.compute_coarse_loss(data, weight=c_weight)
loss_c *= self.loss_config['coarse_weight']
loss = loss_c
loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
# 2. fine-level matching loss for windows
loss_f_match = self.compute_fine_matching_loss(data)
loss_f_match *= self.loss_config['fine_weight']
loss = loss + loss_f_match
loss_scalars.update({"loss_f": loss_f_match.clone().detach().cpu()})
# 3. sub-pixel refinement loss
loss_sub = self.compute_sub_pixel_loss(data)
loss_sub *= self.loss_config['sub_weight']
loss = loss + loss_sub
loss_scalars.update({"loss_sub": loss_sub.clone().detach().cpu()})
loss_scalars.update({'loss': loss.clone().detach().cpu()})
data.update({"loss": loss, "loss_scalars": loss_scalars})
|