|
|
|
|
|
import torch |
|
|
|
|
|
class VadAccuracy(object): |
|
def __init__(self, threshold: float = 0.5) -> None: |
|
self.threshold = threshold |
|
|
|
self.correct_count = 0. |
|
self.total_count = 0. |
|
|
|
def __call__(self, |
|
predictions: torch.Tensor, |
|
gold_labels: torch.Tensor, |
|
): |
|
""" |
|
:param predictions: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation. |
|
:param gold_labels: torch.Tensor, shape: [b, t, 1]. |
|
:return: |
|
""" |
|
predictions = (predictions > self.threshold).float() |
|
correct = predictions.eq(gold_labels).float() |
|
self.correct_count += correct.sum() |
|
self.total_count += gold_labels.numel() |
|
|
|
def get_metric(self, reset: bool = False): |
|
""" |
|
Returns |
|
------- |
|
The accumulated accuracy. |
|
""" |
|
if self.total_count > 1e-12: |
|
accuracy = float(self.correct_count) / float(self.total_count) |
|
else: |
|
accuracy = 0.0 |
|
if reset: |
|
self.reset() |
|
return {'accuracy': accuracy} |
|
|
|
def reset(self): |
|
self.correct_count = 0.0 |
|
self.total_count = 0.0 |
|
|
|
|
|
def main(): |
|
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32) |
|
targets = torch.zeros(size=(1, 198, 1), dtype=torch.float32) |
|
|
|
metric_fn = VadAccuracy() |
|
|
|
metric_fn.__call__(inputs, targets) |
|
|
|
metrics = metric_fn.get_metric() |
|
print(metrics) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|