from typing import List from unittest import TestCase from src.task.task_factory import tasks_factory from src.backend.validation_tools import tasks_name class TasksFactoryTest(TestCase): def test_task_factory(self): a_prediction_list = [1, 1, 1, 1, 1] tasks_payload = [ {task: {"predictions": a_prediction_list}} for task in tasks_name ] a_submission_json = { "model_name": "a_model_name", "model_url": "a_model_url", "tasks": tasks_payload, } actual_tasks_list = tasks_factory(a_submission_json) self.assertIsInstance(actual_tasks_list, List) expected_len = len(tasks_name) actual_len = len(actual_tasks_list) self.assertEqual(expected_len, actual_len) def test_task_factory_raise_error_name_not_valid(self): a_prediction_list = [1, 1, 1, 1, 1] tasks_payload = [ {task: {"predictions": a_prediction_list}} for task in ["invalid name"] ] a_submission_json = { "model_name": "a_model_name", "model_url": "a_model_url", "tasks": tasks_payload, } self.assertRaises(ValueError, tasks_factory, a_submission_json)