File size: 1,288 Bytes
75ec748 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from typing import List
from unittest import TestCase
from src.task.task_factory import tasks_factory
from src.backend.validation_tools import tasks_name
class TasksFactoryTest(TestCase):
def test_task_factory(self):
a_prediction_list = [1, 1, 1, 1, 1]
tasks_payload = [
{task: {"predictions": a_prediction_list}} for task in tasks_name
]
a_submission_json = {
"model_name": "a_model_name",
"model_url": "a_model_url",
"tasks": tasks_payload,
}
actual_tasks_list = tasks_factory(a_submission_json)
self.assertIsInstance(actual_tasks_list, List)
expected_len = len(tasks_name)
actual_len = len(actual_tasks_list)
self.assertEqual(expected_len, actual_len)
def test_task_factory_raise_error_name_not_valid(self):
a_prediction_list = [1, 1, 1, 1, 1]
tasks_payload = [
{task: {"predictions": a_prediction_list}} for task in ["invalid name"]
]
a_submission_json = {
"model_name": "a_model_name",
"model_url": "a_model_url",
"tasks": tasks_payload,
}
self.assertRaises(ValueError, tasks_factory, a_submission_json)
|