|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch import Tensor |
|
|
|
|
|
class Metric(ABC): |
|
""" Metric class with synchronization capabilities similar to TorchMetrics """ |
|
|
|
def __init__(self): |
|
self.states = {} |
|
|
|
def add_state(self, name: str, default: Tensor): |
|
assert name not in self.states |
|
self.states[name] = default.clone() |
|
setattr(self, name, default) |
|
|
|
def synchronize(self): |
|
if dist.is_initialized(): |
|
for state in self.states: |
|
dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD) |
|
|
|
def __call__(self, *args, **kwargs): |
|
self.update(*args, **kwargs) |
|
|
|
def reset(self): |
|
for name, default in self.states.items(): |
|
setattr(self, name, default.clone()) |
|
|
|
def compute(self): |
|
self.synchronize() |
|
value = self._compute().item() |
|
self.reset() |
|
return value |
|
|
|
@abstractmethod |
|
def _compute(self): |
|
pass |
|
|
|
@abstractmethod |
|
def update(self, preds: Tensor, targets: Tensor): |
|
pass |
|
|
|
|
|
class MeanAbsoluteError(Metric): |
|
def __init__(self): |
|
super().__init__() |
|
self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda')) |
|
self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda')) |
|
|
|
def update(self, preds: Tensor, targets: Tensor): |
|
preds = preds.detach() |
|
n = preds.shape[0] |
|
error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum() |
|
self.total += n |
|
self.error += error |
|
|
|
def _compute(self): |
|
return self.error / self.total |
|
|