File size: 2,041 Bytes
9d5b280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import tempfile
from pathlib import Path
import pytest
from lm_eval.tasks import TaskManager
@pytest.fixture(scope="module")
def custom_task_name():
return "zzz_my_python_task"
@pytest.fixture(scope="module")
def custom_task_tag():
return "zzz-tag"
@pytest.fixture(scope="module")
def task_yaml(pytestconfig, custom_task_name, custom_task_tag):
yield f"""include: {pytestconfig.rootpath}/lm_eval/tasks/hellaswag/hellaswag.yaml
task: {custom_task_name}
class: !function {custom_task_name}.MockPythonTask
tag:
- {custom_task_tag}
"""
@pytest.fixture(scope="module")
def task_code():
return """
from lm_eval.tasks import ConfigurableTask
class MockPythonTask(ConfigurableTask):
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
config=None,
) -> None:
config.pop("class")
super().__init__(data_dir, cache_dir, download_mode, config)
"""
@pytest.fixture(scope="module")
def custom_task_files_dir(task_yaml, task_code, custom_task_name):
with tempfile.TemporaryDirectory() as temp_dir:
yaml_path = Path(temp_dir) / f"{custom_task_name}.yaml"
with open(yaml_path, "w") as f:
f.write(task_yaml)
pysource_path = Path(temp_dir) / f"{custom_task_name}.py"
with open(pysource_path, "w") as f:
f.write(task_code)
yield temp_dir
def test_python_task_inclusion(
custom_task_files_dir: Path, custom_task_name: str, custom_task_tag: str
):
task_manager = TaskManager(
verbosity="INFO", include_path=str(custom_task_files_dir)
)
# check if python tasks enters the global task_index
assert custom_task_name in task_manager.task_index
# check if subtask is present
assert custom_task_name in task_manager.all_subtasks
# check if tag is present
assert custom_task_tag in task_manager.all_tags
# check if it can be loaded by tag (custom_task_tag)
assert custom_task_name in task_manager.load_task_or_group(custom_task_tag)
|