#!/usr/bin/python3 # -*- coding: utf-8 -*- import torch class VadAccuracy(object): def __init__(self, threshold: float = 0.5) -> None: self.threshold = threshold self.correct_count = 0. self.total_count = 0. def __call__(self, predictions: torch.Tensor, gold_labels: torch.Tensor, ): """ :param predictions: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation. :param gold_labels: torch.Tensor, shape: [b, t, 1]. :return: """ predictions = (predictions > self.threshold).float() correct = predictions.eq(gold_labels).float() self.correct_count += correct.sum() self.total_count += gold_labels.numel() def get_metric(self, reset: bool = False): """ Returns ------- The accumulated accuracy. """ if self.total_count > 1e-12: accuracy = float(self.correct_count) / float(self.total_count) else: accuracy = 0.0 if reset: self.reset() return {'accuracy': accuracy} def reset(self): self.correct_count = 0.0 self.total_count = 0.0 def main(): inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32) targets = torch.zeros(size=(1, 198, 1), dtype=torch.float32) metric_fn = VadAccuracy() metric_fn.__call__(inputs, targets) metrics = metric_fn.get_metric() print(metrics) return if __name__ == "__main__": main()