Spaces:
Running
on
Zero
Running
on
Zero
# mostly from: https://github.com/cvg/glue-factory/blob/main/gluefactory/geometry/epipolar.py | |
import numpy as np | |
import torch | |
def angle_error_mat(R1, R2): | |
cos = (torch.trace(torch.einsum("...ij, ...jk -> ...ik", R1.T, R2)) - 1) / 2 | |
cos = torch.clip(cos, -1.0, 1.0) # numerical errors can make it out of bounds | |
return torch.rad2deg(torch.abs(torch.arccos(cos))) | |
def angle_error_vec(v1, v2, eps=1e-10): | |
n = torch.clip(v1.norm(dim=-1) * v2.norm(dim=-1), min=eps) | |
v1v2 = (v1 * v2).sum(dim=-1) # dot product in the last dimension | |
return torch.rad2deg(torch.arccos(torch.clip(v1v2 / n, -1.0, 1.0))) | |
def relative_pose_error(R_gt, t_gt, R, t, ignore_gt_t_thr=0.0, eps=1e-10): | |
# angle error between 2 vectors | |
t_err = angle_error_vec(t, t_gt, eps) | |
t_err = torch.minimum(t_err, 180 - t_err) # handle E ambiguity | |
if t_gt.norm() < ignore_gt_t_thr: # pure rotation is challenging | |
t_err = torch.zeros_like(t_err) | |
# angle error between 2 rotation matrices | |
r_err = angle_error_mat(R, R_gt) | |
return t_err, r_err | |
def cal_error_auc(errors, thresholds): | |
sort_idx = np.argsort(errors) | |
errors = np.array(errors.copy())[sort_idx] | |
recall = (np.arange(len(errors)) + 1) / len(errors) | |
errors = np.r_[0.0, errors] | |
recall = np.r_[0.0, recall] | |
aucs = [] | |
for t in thresholds: | |
last_index = np.searchsorted(errors, t) | |
r = np.r_[recall[:last_index], recall[last_index - 1]] | |
e = np.r_[errors[:last_index], t] | |
aucs.append(np.round((np.trapz(r, x=e) / t), 4)) | |
return aucs | |
class AUCMetric: | |
def __init__(self, thresholds, elements=None): | |
self._elements = elements | |
self.thresholds = thresholds | |
if not isinstance(thresholds, list): | |
self.thresholds = [thresholds] | |
def update(self, tensor): | |
assert tensor.dim() == 1 | |
self._elements += tensor.cpu().numpy().tolist() | |
def compute(self): | |
if len(self._elements) == 0: | |
return np.nan | |
else: | |
return cal_error_auc(self._elements, self.thresholds) | |