OmniSVG's picture
Upload 80 files
c1ce505 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from deepsvg.difflib.tensor import SVGTensor
from .utils import _get_padding_mask, _get_visibility_mask
from .config import _DefaultConfig
class SVGLoss(nn.Module):
def __init__(self, cfg: _DefaultConfig):
super().__init__()
self.cfg = cfg
self.args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1
self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)
def forward(self, output, labels, weights):
loss = 0.
res = {}
# VAE
if self.cfg.use_vae:
mu, logsigma = output["mu"], output["logsigma"]
loss_kl = -0.5 * torch.mean(1 + logsigma - mu.pow(2) - torch.exp(logsigma))
loss_kl = loss_kl.clamp(min=weights["kl_tolerance"])
loss += weights["loss_kl_weight"] * loss_kl
res["loss_kl"] = loss_kl
# remove commitment loss
# if self.cfg.use_vqvae:
# vqvae_loss = output["vqvae_loss"].mean()
# loss += vqvae_loss
# res["vqvae_loss"] = vqvae_loss
# Target & predictions
# tgt_commands.shape [batch_size, max_num_groups, max_seq_len + 2]
# tgt_args.shape [batch_size, max_num_groups, max_seq_len + 2, n_args]
tgt_commands, tgt_args = output["tgt_commands"], output["tgt_args"]
visibility_mask = _get_visibility_mask(tgt_commands, seq_dim=-1)
padding_mask = _get_padding_mask(tgt_commands, seq_dim=-1, extended=True) * visibility_mask.unsqueeze(-1)
command_logits, args_logits = output["command_logits"], output["args_logits"]
# 2-stage visibility
if self.cfg.decode_stages == 2:
visibility_logits = output["visibility_logits"]
loss_visibility = F.cross_entropy(visibility_logits.reshape(-1, 2), visibility_mask.reshape(-1).long())
loss += weights["loss_visibility_weight"] * loss_visibility
res["loss_visibility"] = loss_visibility
# Commands & args
if self.cfg.bin_targets: # 当使用 bin_targets 时,每个坐标是由 8 bit 代表的,所以会多一维
tgt_args = tgt_args[..., 1:, :, :]
else:
tgt_args = tgt_args[..., 1:, :]
tgt_commands, padding_mask = tgt_commands[..., 1:], padding_mask[..., 1:]
# mask.shape [batch_size, 8, 31, 11]
# 对于预测正确的 command, mask 会乘上 True, cmd_args_mask 向量不会发生改变
# 对于预测错误的 command, mask 会乘上 False, 相当于把 cmd_args_mask 置为 0, 即不统计对应的 args
# pred_cmd = torch.argmax(command_logits, dim = -1)
# mask = self.cmd_args_mask[tgt_commands.long()] * (pred_cmd == tgt_commands).unsqueeze(-1)
mask = self.cmd_args_mask[tgt_commands.long()]
# padding_mask.shape [batch_size, num_path, num_commands + 1]
# command_logits.shape [batch_size, num_path, num_commands + 1, n_commands]
# command_logits[padding_mask.bool()].shape [-1, n_commands]
# 目的是把 PAD 的位置筛掉
loss_cmd = F.cross_entropy(command_logits[padding_mask.bool()].reshape(-1, self.cfg.n_commands), tgt_commands[padding_mask.bool()].reshape(-1).long())
if self.cfg.abs_targets:
# l2 loss performs better than l1 loss
loss_args = nn.MSELoss()(
args_logits[mask.bool()].reshape(-1),
tgt_args[mask.bool()].reshape(-1).float()
)
elif self.cfg.bin_targets:
loss_args = nn.MSELoss()(
args_logits[mask.bool()].reshape(-1),
tgt_args[mask.bool()].reshape(-1).float()
)
else:
loss_args = F.cross_entropy(
args_logits[mask.bool()].reshape(-1, self.args_dim),
tgt_args[mask.bool()].reshape(-1).long() + 1
) # shift due to -1 PAD_VAL
loss += weights["loss_cmd_weight"] * loss_cmd \
+ weights["loss_args_weight"] * loss_args
res.update({
"loss": loss,
"loss_cmd": loss_cmd,
"loss_args": loss_args
})
return res