def mean_3class_f1(predictions, references): # This is a passthrough function | |
string_label = ["entailment", "contradiction", "neutral"] | |
predictions = ( | |
string_label.index(predictions[0]) if predictions[0] in string_label else 0 | |
) | |
references = string_label.index(references[0]) | |
return (predictions, references) | |
def agg_mean_3class_f1(items): | |
predictions, references = zip(*items) | |
"""Computes the unweighted average of the F1 per class.""" | |
metric_str = "fbeta_score" | |
metric_fn_kwargs = { | |
"beta": 1, | |
"labels": range(3), | |
"average": "macro", | |
} | |
def _fn(predictions, references): | |
import sklearn.metrics | |
metric_fn = getattr(sklearn.metrics, metric_str) | |
metric_val = metric_fn(references, predictions, **metric_fn_kwargs) | |
return metric_val | |
return _fn(predictions, references) | |