|
import os |
|
from itertools import islice |
|
|
|
import datasets |
|
import pytest |
|
|
|
import lm_eval.tasks as tasks |
|
from lm_eval.api.task import ConfigurableTask |
|
from lm_eval.evaluator_utils import get_task_list |
|
|
|
from .utils import new_tasks |
|
|
|
|
|
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
task_manager = tasks.TaskManager() |
|
|
|
TASKS = ["arc_easy"] |
|
|
|
|
|
def task_class(): |
|
global TASKS |
|
|
|
task_classes = new_tasks() |
|
|
|
task_classes = task_classes if task_classes else TASKS |
|
res = tasks.get_task_dict(task_classes, task_manager) |
|
res = [x.task for x in get_task_list(res)] |
|
|
|
return res |
|
|
|
|
|
@pytest.fixture() |
|
def limit() -> int: |
|
return 10 |
|
|
|
|
|
|
|
@pytest.mark.parametrize("task_class", task_class(), ids=lambda x: f"{x.config.task}") |
|
class TestNewTasks: |
|
def test_download(self, task_class: ConfigurableTask): |
|
task_class.download() |
|
assert task_class.dataset is not None |
|
|
|
def test_has_training_docs(self, task_class: ConfigurableTask): |
|
assert task_class.has_training_docs() in [True, False] |
|
|
|
def test_check_training_docs(self, task_class: ConfigurableTask): |
|
if task_class.has_training_docs(): |
|
assert task_class._config["training_split"] is not None |
|
|
|
def test_has_validation_docs(self, task_class): |
|
assert task_class.has_validation_docs() in [True, False] |
|
|
|
def test_check_validation_docs(self, task_class): |
|
if task_class.has_validation_docs(): |
|
assert task_class._config["validation_split"] is not None |
|
|
|
def test_has_test_docs(self, task_class): |
|
assert task_class.has_test_docs() in [True, False] |
|
|
|
def test_check_test_docs(self, task_class): |
|
task = task_class |
|
if task.has_test_docs(): |
|
assert task._config["test_split"] is not None |
|
|
|
def test_should_decontaminate(self, task_class): |
|
task = task_class |
|
assert task.should_decontaminate() in [True, False] |
|
if task.should_decontaminate(): |
|
assert task._config["doc_to_decontamination_query"] is not None |
|
|
|
def test_doc_to_text(self, task_class, limit): |
|
task = task_class |
|
arr = ( |
|
list(islice(task.test_docs(), limit)) |
|
if task.has_test_docs() |
|
else list(islice(task.validation_docs(), limit)) |
|
) |
|
_array = [task.doc_to_text(doc) for doc in arr] |
|
|
|
target_delimiter: str = task.config.target_delimiter |
|
if not task.multiple_input: |
|
for x in _array: |
|
assert isinstance(x, str) |
|
assert ( |
|
(x[-1].isspace() is False if len(x) > 0 else True) |
|
if target_delimiter.isspace() |
|
else True |
|
), ( |
|
"doc_to_text ends in a whitespace and target delimiter also a whitespace" |
|
) |
|
else: |
|
pass |
|
|
|
def test_create_choices(self, task_class, limit): |
|
task = task_class |
|
arr = ( |
|
list(islice(task.test_docs(), limit)) |
|
if task.has_test_docs() |
|
else list(islice(task.validation_docs(), limit)) |
|
) |
|
if "multiple_choice" in task._config.output_type: |
|
_array = [task.doc_to_choice(doc) for doc in arr] |
|
assert all(isinstance(x, list) for x in _array) |
|
assert all(isinstance(x[0], str) for x in _array) |
|
|
|
def test_doc_to_target(self, task_class, limit): |
|
task = task_class |
|
arr = ( |
|
list(islice(task.test_docs(), limit)) |
|
if task.has_test_docs() |
|
else list(islice(task.validation_docs(), limit)) |
|
) |
|
_array_target = [task.doc_to_target(doc) for doc in arr] |
|
if task._config.output_type == "multiple_choice": |
|
|
|
assert all( |
|
(isinstance(label, int) or isinstance(label, str)) |
|
for label in _array_target |
|
) |
|
|
|
def test_build_all_requests(self, task_class, limit): |
|
task_class.build_all_requests(rank=1, limit=limit, world_size=1) |
|
assert task_class.instances is not None |
|
|
|
|
|
def test_construct_requests(self, task_class, limit): |
|
task = task_class |
|
arr = ( |
|
list(islice(task.test_docs(), limit)) |
|
if task.has_test_docs() |
|
else list(islice(task.validation_docs(), limit)) |
|
) |
|
|
|
requests = [ |
|
task.construct_requests( |
|
doc=doc, ctx="" if task.multiple_input else task.doc_to_text(doc) |
|
) |
|
for doc in arr |
|
] |
|
assert len(requests) == limit if limit else True |
|
|