File size: 1,556 Bytes
9829721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch
class VadAccuracy(object):
def __init__(self, threshold: float = 0.5) -> None:
self.threshold = threshold
self.correct_count = 0.
self.total_count = 0.
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
):
"""
:param predictions: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
:param gold_labels: torch.Tensor, shape: [b, t, 1].
:return:
"""
predictions = (predictions > self.threshold).float()
correct = predictions.eq(gold_labels).float()
self.correct_count += correct.sum()
self.total_count += gold_labels.numel()
def get_metric(self, reset: bool = False):
"""
Returns
-------
The accumulated accuracy.
"""
if self.total_count > 1e-12:
accuracy = float(self.correct_count) / float(self.total_count)
else:
accuracy = 0.0
if reset:
self.reset()
return {'accuracy': accuracy}
def reset(self):
self.correct_count = 0.0
self.total_count = 0.0
def main():
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
targets = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
metric_fn = VadAccuracy()
metric_fn.__call__(inputs, targets)
metrics = metric_fn.get_metric()
print(metrics)
return
if __name__ == "__main__":
main()
|