HoneyTian's picture
update
5e1cd25
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import List, Tuple
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
def __init__(self,
reduction: str = "mean",
eps: float = 1e-6,
):
super(DiceLoss, self).__init__()
self.reduction = reduction
self.eps = eps
if reduction not in ("sum", "mean"):
raise AssertionError(f"param reduction must be sum or mean.")
def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
"""
:param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
:param targets: shape as `inputs`.
:return:
"""
inputs_ = torch.squeeze(inputs, dim=-1)
targets_ = torch.squeeze(targets, dim=-1)
# shape: [b, t]
intersection = (inputs_ * targets_).sum(dim=-1)
union = (inputs_ + targets_).sum(dim=-1)
# shape: [b,]
dice = (2. * intersection + self.eps) / (union + self.eps)
# shape: [b,]
loss = 1. - dice
# shape: [b,]
if self.reduction == "mean":
loss = torch.mean(loss)
elif self.reduction == "sum":
loss = torch.sum(loss)
else:
raise AssertionError
return loss
def main():
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
loss_fn = DiceLoss()
loss = loss_fn.forward(inputs, inputs)
print(loss)
return
if __name__ == "__main__":
main()