Spaces:
Running
on
L4
Running
on
L4
File size: 1,438 Bytes
c1ce505 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import numpy as np
from .utils import *
def chamfer_loss(x, y):
d = torch.cdist(x, y)
return d.min(dim=0).values.mean() + d.min(dim=1).values.mean()
def continuity_loss(x):
d = (x[1:] - x[:-1]).norm(dim=-1, p=2)
return d.mean()
def svg_length_loss(p_pred, p_target):
pred_length, target_length = get_length(p_pred), get_length(p_target)
return (target_length - pred_length).abs() / target_length
def svg_emd_loss(p_pred, p_target,
first_point_weight=False, return_matched_indices=False):
n, m = len(p_pred), len(p_target)
if n == 0:
return 0.
# Make target point lists clockwise
p_target = make_clockwise(p_target)
# Compute length distribution
distr_pred = torch.linspace(0., 1., n).to(p_pred.device)
distr_target = get_length_distribution(p_target, normalize=True)
d = torch.cdist(distr_pred.unsqueeze(-1), distr_target.unsqueeze(-1))
matching = d.argmin(dim=-1)
p_target_sub = p_target[matching]
# EMD
i = np.argmin([torch.norm(p_pred - reorder(p_target_sub, i), dim=-1).mean() for i in range(n)])
losses = torch.norm(p_pred - reorder(p_target_sub, i), dim=-1)
if first_point_weight:
weights = torch.ones_like(losses)
weights[0] = 10.
losses = losses * weights
if return_matched_indices:
return losses.mean(), (p_pred, p_target, reorder(matching, i))
return losses.mean()
|