COLE / tests /tasks /evaluation /test_evaluation_piaf.py
Yurhu's picture
Initial snapshot upload
75ec748 verified
from src.task.task_factory import Task
from tests.tasks.evaluation.task_test_case import TaskTest
class TaskPIAFTest(TaskTest):
# We need to have two response otherwise correlation fails (nan).
def setUp(self) -> None:
self.dataset_size = 384
def test_given_a_prediction_smaller_than_corpus_when_compute_then_return_expected_result_and_warning(
self,
):
a_predictions = [
"ce partage ne fait pas disparaître l'idée d'un ensemble uni, le Regnum Francorum (Royaume des Francs).",
"ce partage ne fait pas disparaître l'idée d'un ensemble uni, le Regnum Francorum (Royaume des Francs).",
"ce partage ne fait pas disparaître l'idée d'un ensemble uni, le Regnum Francorum (Royaume des Francs).",
"ce partage ne fait pas disparaître l'idée d'un ensemble uni, le Regnum Francorum (Royaume des Francs).",
"ce partage ne fait pas disparaître l'idée d'un ensemble uni, le Regnum Francorum (Royaume des Francs).",
]
task = Task(
task_name="piaf",
metric="fquad",
)
expected_results = {"exact_match": 20.0, "f1": 20.0}
expected_warning = (
f"Your prediction size is of '{len(a_predictions)}', while the ground truths size is "
f"of '{self.dataset_size}'. We computed the metric over the first {len(a_predictions)}"
f" elements."
)
actual_result, actual_warning = task.compute(predictions=a_predictions)
self.assertEvalDictEqual(expected_results, actual_result)
self.assertEqual(expected_warning, actual_warning)
def test_given_a_prediction_when_compute_then_return_expected_result_no_warnings(
self,
):
a_predictions = [
"ce partage ne fait pas disparaître l'idée d'un ensemble uni, le Regnum Francorum (Royaume des Francs)."
] * self.dataset_size
task = Task(
task_name="piaf",
metric="fquad",
)
expected_results = {"exact_match": 0.25, "f1": 0.49026015}
expected_warning = None
actual_result, actual_warning = task.compute(predictions=a_predictions)
self.assertEvalDictEqual(expected_results, actual_result)
self.assertEqual(expected_warning, actual_warning)
def test_given_a_prediction_larger_than_ground_truth_raise_error(self):
a_predictions = [
"ce partage ne fait pas disparaître l'idée d'un ensemble uni, le Regnum Francorum (Royaume des Francs)."
] * (self.dataset_size + 1)
task = Task(
task_name="piaf",
metric="fquad",
)
self.assertRaises(ValueError, task.compute, predictions=a_predictions)