InPeerReview's picture
Upload 3 files
840ef2c verified
import torch
import numpy as np
def calcuate_confusion_matrix(num_class:int, gt:torch.tensor, pred:torch.tensor):
gt_vector = gt.flatten()
pred_vector = pred.flatten()
mask = (gt_vector >= 0) & (gt_vector < num_class)
cm = torch.bincount(num_class * gt_vector[mask].to(dtype=int) + pred_vector[mask], minlength=num_class ** 2).reshape(num_class, num_class)
return cm
class segmengtion_metric(object):
def __init__(self, num_class:int, device:str):
self.num_class = num_class
self.device = device
self.confusion_matrix = torch.zeros((self.num_class, self.num_class)).to(self.device)
def clear(self):
self.confusion_matrix = torch.zeros((self.num_class, self.num_class)).to(self.device)
def update_confusion_matrix(self, gt, pred):
cm = calcuate_confusion_matrix(self.num_class, gt, pred)
self.confusion_matrix += cm
def get_matrix_per_batch(self, gt, pred):
confusion_matrix = calcuate_confusion_matrix(self.num_class, gt, pred)
tp = torch.diag(confusion_matrix)
sum_a1 = torch.sum(confusion_matrix, dim=1)
sum_a0 = torch.sum(confusion_matrix, dim=0)
acc = tp.sum() / (confusion_matrix.sum() + torch.finfo(type=torch.float32).eps)
recall = tp / (sum_a1 + torch.finfo(type=torch.float32).eps)
precision = tp / (sum_a0 + torch.finfo(type=torch.float32).eps)
f1 = (2 * recall * precision) / (recall + precision + torch.finfo(type=torch.float32).eps)
iou = tp / (sum_a1 + sum_a0 - tp + torch.finfo(type=torch.float32).eps)
cls_precision = dict(zip(['pre_class[{}]'.format(i) for i in range(self.num_class)], precision))
cls_recall = dict(zip(['rec_class[{}]'.format(i) for i in range(self.num_class)], recall))
cls_f1 = dict(zip(['f1_class[{}]'.format(i) for i in range(self.num_class)], f1))
cls_iou = dict(zip(['iou_class[{}]'.format(i) for i in range(self.num_class)], iou))
mean_precision = precision[precision != 0].mean()
mean_recall = recall[recall != 0].mean()
mean_iou = iou[iou != 0].mean()
mean_f1 = f1[f1 != 0].mean()
score_dict_batch = {'acc': acc, 'mean_pre': mean_precision, 'mean_rec': mean_recall, 'mIoU': mean_iou, 'mF1': mean_f1}
score_dict_batch.update(cls_precision)
score_dict_batch.update(cls_recall)
score_dict_batch.update(cls_iou)
score_dict_batch.update(cls_f1)
return score_dict_batch
def get_metric_dict_per_epoch(self):
tp = torch.diag(self.confusion_matrix)
sum_a1 = torch.sum(self.confusion_matrix, dim=1)
sum_a0 = torch.sum(self.confusion_matrix, dim=0)
acc = tp.sum() / (self.confusion_matrix.sum() + torch.finfo(type=torch.float32).eps)
recall = tp / (sum_a1 + torch.finfo(type=torch.float32).eps)
precision = tp / (sum_a0 + torch.finfo(type=torch.float32).eps)
f1 = (2 * recall * precision) / (recall + precision + torch.finfo(type=torch.float32).eps)
iou = tp / (sum_a1 + sum_a0 - tp + torch.finfo(type=torch.float32).eps)
cls_precision = dict(zip(['Precision_Class[{}]'.format(i) for i in range(self.num_class)], precision))
cls_recall = dict(zip(['Recall_Class[{}]'.format(i) for i in range(self.num_class)], recall))
cls_iou = dict(zip(['IoU_Class[{}]'.format(i) for i in range(self.num_class)], iou))
cls_f1 = dict(zip(['F1_Class[{}]'.format(i) for i in range(self.num_class)], f1))
mean_precision = precision.mean()
mean_recall = recall.mean()
mean_iou = iou.mean()
mean_f1 = f1.mean()
score_dict_epoch = {'Accuracy': acc, 'mean_Precision': mean_precision, 'mean_Recall': mean_recall,
'mIoU': mean_iou, 'mF1': mean_f1}
score_dict_epoch.update(cls_precision)
score_dict_epoch.update(cls_recall)
score_dict_epoch.update(cls_iou)
score_dict_epoch.update(cls_f1)
return score_dict_epoch
if __name__=="__main__":
gt_label = torch.tensor([[0, 1, 2, 3, 1],
[1, 2, 2, 3, 4]])
pre_label = torch.tensor([[0, 1, 2, 3, 1],
[5, 1, 2, 1, 4]])
num_class = 6
metric = segmengtion_metric(6, 'cuda:0')
res = metric.get_matrix_per_batch(gt_label, pre_label)
res1 = metric.get_metric_dict_per_epoch()
print(res)
print(res1)