File size: 2,162 Bytes
75ec748 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from src.task.task_factory import Task
from tests.tasks.evaluation.task_test_case import TaskTest
class TaskFQUADTest(TaskTest):
def setUp(self) -> None:
self.dataset_size = 400
def test_given_a_prediction_smaller_than_corpus_when_compute_then_return_expected_result_and_warning(
self,
):
a_predictions = [
"par un mauvais état de santé",
"par un mauvais état de santé",
"par un mauvais état de santé",
"par un mauvais état de santé",
"par un mauvais état de santé",
]
task = Task(
task_name="fquad",
metric="fquad",
)
expected_results = {"exact_match": 20.0, "f1": 25.333333}
expected_warning = (
f"Your prediction size is of '{len(a_predictions)}', while the ground truths size is "
f"of '{self.dataset_size}'. We computed the metric over the first {len(a_predictions)}"
f" elements."
)
actual_result, actual_warning = task.compute(predictions=a_predictions)
self.assertEvalDictEqual(expected_results, actual_result)
self.assertEqual(expected_warning, actual_warning)
def test_given_a_prediction_when_compute_then_return_expected_result_no_warnings(
self,
):
a_predictions = ["par un mauvais état de santé"] * self.dataset_size
task = Task(
task_name="fquad",
metric="fquad",
)
expected_results = {"exact_match": 0.25, "f1": 5.9855542}
expected_warning = None
actual_result, actual_warning = task.compute(predictions=a_predictions)
self.assertEvalDictEqual(expected_results, actual_result)
self.assertEqual(expected_warning, actual_warning)
def test_given_a_prediction_larger_than_ground_truth_raise_error(self):
a_predictions = [1] * (self.dataset_size + 1)
task = Task(
task_name="fquad",
metric="fquad",
)
self.assertRaises(ValueError, task.compute, predictions=a_predictions)
|