import torch import numpy as np def calcuate_confusion_matrix(num_class:int, gt:torch.tensor, pred:torch.tensor): gt_vector = gt.flatten() pred_vector = pred.flatten() mask = (gt_vector >= 0) & (gt_vector < num_class) cm = torch.bincount(num_class * gt_vector[mask].to(dtype=int) + pred_vector[mask], minlength=num_class ** 2).reshape(num_class, num_class) return cm class segmengtion_metric(object): def __init__(self, num_class:int, device:str): self.num_class = num_class self.device = device self.confusion_matrix = torch.zeros((self.num_class, self.num_class)).to(self.device) def clear(self): self.confusion_matrix = torch.zeros((self.num_class, self.num_class)).to(self.device) def update_confusion_matrix(self, gt, pred): cm = calcuate_confusion_matrix(self.num_class, gt, pred) self.confusion_matrix += cm def get_matrix_per_batch(self, gt, pred): confusion_matrix = calcuate_confusion_matrix(self.num_class, gt, pred) tp = torch.diag(confusion_matrix) sum_a1 = torch.sum(confusion_matrix, dim=1) sum_a0 = torch.sum(confusion_matrix, dim=0) acc = tp.sum() / (confusion_matrix.sum() + torch.finfo(type=torch.float32).eps) recall = tp / (sum_a1 + torch.finfo(type=torch.float32).eps) precision = tp / (sum_a0 + torch.finfo(type=torch.float32).eps) f1 = (2 * recall * precision) / (recall + precision + torch.finfo(type=torch.float32).eps) iou = tp / (sum_a1 + sum_a0 - tp + torch.finfo(type=torch.float32).eps) cls_precision = dict(zip(['pre_class[{}]'.format(i) for i in range(self.num_class)], precision)) cls_recall = dict(zip(['rec_class[{}]'.format(i) for i in range(self.num_class)], recall)) cls_f1 = dict(zip(['f1_class[{}]'.format(i) for i in range(self.num_class)], f1)) cls_iou = dict(zip(['iou_class[{}]'.format(i) for i in range(self.num_class)], iou)) mean_precision = precision[precision != 0].mean() mean_recall = recall[recall != 0].mean() mean_iou = iou[iou != 0].mean() mean_f1 = f1[f1 != 0].mean() score_dict_batch = {'acc': acc, 'mean_pre': mean_precision, 'mean_rec': mean_recall, 'mIoU': mean_iou, 'mF1': mean_f1} score_dict_batch.update(cls_precision) score_dict_batch.update(cls_recall) score_dict_batch.update(cls_iou) score_dict_batch.update(cls_f1) return score_dict_batch def get_metric_dict_per_epoch(self): tp = torch.diag(self.confusion_matrix) sum_a1 = torch.sum(self.confusion_matrix, dim=1) sum_a0 = torch.sum(self.confusion_matrix, dim=0) acc = tp.sum() / (self.confusion_matrix.sum() + torch.finfo(type=torch.float32).eps) recall = tp / (sum_a1 + torch.finfo(type=torch.float32).eps) precision = tp / (sum_a0 + torch.finfo(type=torch.float32).eps) f1 = (2 * recall * precision) / (recall + precision + torch.finfo(type=torch.float32).eps) iou = tp / (sum_a1 + sum_a0 - tp + torch.finfo(type=torch.float32).eps) cls_precision = dict(zip(['Precision_Class[{}]'.format(i) for i in range(self.num_class)], precision)) cls_recall = dict(zip(['Recall_Class[{}]'.format(i) for i in range(self.num_class)], recall)) cls_iou = dict(zip(['IoU_Class[{}]'.format(i) for i in range(self.num_class)], iou)) cls_f1 = dict(zip(['F1_Class[{}]'.format(i) for i in range(self.num_class)], f1)) mean_precision = precision.mean() mean_recall = recall.mean() mean_iou = iou.mean() mean_f1 = f1.mean() score_dict_epoch = {'Accuracy': acc, 'mean_Precision': mean_precision, 'mean_Recall': mean_recall, 'mIoU': mean_iou, 'mF1': mean_f1} score_dict_epoch.update(cls_precision) score_dict_epoch.update(cls_recall) score_dict_epoch.update(cls_iou) score_dict_epoch.update(cls_f1) return score_dict_epoch if __name__=="__main__": gt_label = torch.tensor([[0, 1, 2, 3, 1], [1, 2, 2, 3, 4]]) pre_label = torch.tensor([[0, 1, 2, 3, 1], [5, 1, 2, 1, 4]]) num_class = 6 metric = segmengtion_metric(6, 'cuda:0') res = metric.get_matrix_per_batch(gt_label, pre_label) res1 = metric.get_metric_dict_per_epoch() print(res) print(res1)