File size: 1,025 Bytes
5e1cd25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import List, Tuple
import torch
import torch.nn as nn
from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
class BCELoss(BaseVadLoss):
"""
Binary Cross-Entropy Loss, BCE Loss
"""
def __init__(self,
reduction: str = "mean",
):
super(BCELoss, self).__init__()
self.reduction = reduction
self.bce_loss_fn = nn.BCELoss(reduction=reduction)
def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
"""
:param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
:param targets: shape as `inputs`.
:return:
"""
loss = self.bce_loss_fn.forward(inputs, targets)
return loss
def main():
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
loss_fn = BCELoss()
loss = loss_fn.forward(inputs, inputs)
print(loss)
return
if __name__ == "__main__":
main()
|