HoneyTian's picture
first commit
9829721
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch
class VadAccuracy(object):
def __init__(self, threshold: float = 0.5) -> None:
self.threshold = threshold
self.correct_count = 0.
self.total_count = 0.
def __call__(self,
predictions: torch.Tensor,
gold_labels: torch.Tensor,
):
"""
:param predictions: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
:param gold_labels: torch.Tensor, shape: [b, t, 1].
:return:
"""
predictions = (predictions > self.threshold).float()
correct = predictions.eq(gold_labels).float()
self.correct_count += correct.sum()
self.total_count += gold_labels.numel()
def get_metric(self, reset: bool = False):
"""
Returns
-------
The accumulated accuracy.
"""
if self.total_count > 1e-12:
accuracy = float(self.correct_count) / float(self.total_count)
else:
accuracy = 0.0
if reset:
self.reset()
return {'accuracy': accuracy}
def reset(self):
self.correct_count = 0.0
self.total_count = 0.0
def main():
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
targets = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
metric_fn = VadAccuracy()
metric_fn.__call__(inputs, targets)
metrics = metric_fn.get_metric()
print(metrics)
return
if __name__ == "__main__":
main()