#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import List, Tuple import torch import torch.nn as nn from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss class BCELoss(BaseVadLoss): """ Binary Cross-Entropy Loss, BCE Loss """ def __init__(self, reduction: str = "mean", ): super(BCELoss, self).__init__() self.reduction = reduction self.bce_loss_fn = nn.BCELoss(reduction=reduction) def forward(self, inputs: torch.Tensor, targets: torch.Tensor): """ :param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation. :param targets: shape as `inputs`. :return: """ loss = self.bce_loss_fn.forward(inputs, targets) return loss def main(): inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32) loss_fn = BCELoss() loss = loss_fn.forward(inputs, inputs) print(loss) return if __name__ == "__main__": main()