Spaces:
No application file
No application file
| from typing import List, Union | |
| import torch | |
| import torch.nn as nn | |
| class MultiLayerLoss(nn.Module): | |
| def __init__(self, loss_fn: nn.Module, weights: List[float]=None) -> None: | |
| super().__init__() | |
| self.weights = weights | |
| self.loss_fn = loss_fn | |
| def forward(self, output: Union[torch.Tensor, List[torch.tensor]], target: Union[torch.Tensor, List[torch.tensor]]) -> torch.Tensor: | |
| """_summary_ | |
| Args: | |
| output (torch.Tensor): b * c * h * w | |
| Returns: | |
| torch.Tensor: _description_ | |
| """ | |
| if not isinstance(output, List): | |
| output = [output] | |
| if not isinstance(target, list): | |
| target = [target] | |
| assert len(output) == len(target), f"length of x({len(output)}) must be equal to target({len(target)})" | |
| if self.weights is not None: | |
| assert len(output) == len(self.weights), f"weights should be None or length of x({len(output)}) must be equal to weights({len(self.weights)})" | |
| total_loss = 0 | |
| for i in range(len(output)): | |
| x = output[i] | |
| y = target[i] | |
| x = self._get_feature(x) | |
| y = self._get_feature(y) | |
| loss = self.loss_fn(x, y) | |
| if self.weights is not None: | |
| loss *= self.weights[i] | |
| total_loss += loss | |
| return total_loss | |
| def cal_single_layer_loss(self, x, y): | |
| raise NotImplementedError | |
| def _get_feature(self, x): | |
| raise NotImplementedError |