HoneyTian's picture
update
5e1cd25
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import List, Tuple
import torch
import torch.nn as nn
from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
class BCELoss(BaseVadLoss):
"""
Binary Cross-Entropy Loss, BCE Loss
"""
def __init__(self,
reduction: str = "mean",
):
super(BCELoss, self).__init__()
self.reduction = reduction
self.bce_loss_fn = nn.BCELoss(reduction=reduction)
def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
"""
:param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
:param targets: shape as `inputs`.
:return:
"""
loss = self.bce_loss_fn.forward(inputs, targets)
return loss
def main():
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
loss_fn = BCELoss()
loss = loss_fn.forward(inputs, inputs)
print(loss)
return
if __name__ == "__main__":
main()