File size: 4,986 Bytes
9d5b280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
from itertools import islice
import datasets
import pytest
import lm_eval.tasks as tasks
from lm_eval.api.task import ConfigurableTask
from lm_eval.evaluator_utils import get_task_list
from .utils import new_tasks
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
os.environ["TOKENIZERS_PARALLELISM"] = "false"
task_manager = tasks.TaskManager()
# Default Task
TASKS = ["arc_easy"]
def task_class():
global TASKS
# CI: new_tasks checks if any modifications have been made
task_classes = new_tasks()
# Check if task_classes is empty
task_classes = task_classes if task_classes else TASKS
res = tasks.get_task_dict(task_classes, task_manager)
res = [x.task for x in get_task_list(res)]
return res
@pytest.fixture()
def limit() -> int:
return 10
# Tests
@pytest.mark.parametrize("task_class", task_class(), ids=lambda x: f"{x.config.task}")
class TestNewTasks:
def test_download(self, task_class: ConfigurableTask):
task_class.download()
assert task_class.dataset is not None
def test_has_training_docs(self, task_class: ConfigurableTask):
assert task_class.has_training_docs() in [True, False]
def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs():
assert task_class._config["training_split"] is not None
def test_has_validation_docs(self, task_class):
assert task_class.has_validation_docs() in [True, False]
def test_check_validation_docs(self, task_class):
if task_class.has_validation_docs():
assert task_class._config["validation_split"] is not None
def test_has_test_docs(self, task_class):
assert task_class.has_test_docs() in [True, False]
def test_check_test_docs(self, task_class):
task = task_class
if task.has_test_docs():
assert task._config["test_split"] is not None
def test_should_decontaminate(self, task_class):
task = task_class
assert task.should_decontaminate() in [True, False]
if task.should_decontaminate():
assert task._config["doc_to_decontamination_query"] is not None
def test_doc_to_text(self, task_class, limit):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
# space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
target_delimiter: str = task.config.target_delimiter
if not task.multiple_input:
for x in _array:
assert isinstance(x, str)
assert (
(x[-1].isspace() is False if len(x) > 0 else True)
if target_delimiter.isspace()
else True
), (
"doc_to_text ends in a whitespace and target delimiter also a whitespace"
)
else:
pass
def test_create_choices(self, task_class, limit):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
if "multiple_choice" in task._config.output_type:
_array = [task.doc_to_choice(doc) for doc in arr]
assert all(isinstance(x, list) for x in _array)
assert all(isinstance(x[0], str) for x in _array)
def test_doc_to_target(self, task_class, limit):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array_target = [task.doc_to_target(doc) for doc in arr]
if task._config.output_type == "multiple_choice":
# TODO<baber>: label can be string or int; add better test conditions
assert all(
(isinstance(label, int) or isinstance(label, str))
for label in _array_target
)
def test_build_all_requests(self, task_class, limit):
task_class.build_all_requests(rank=1, limit=limit, world_size=1)
assert task_class.instances is not None
# ToDO: Add proper testing
def test_construct_requests(self, task_class, limit):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
# ctx is "" for multiple input tasks
requests = [
task.construct_requests(
doc=doc, ctx="" if task.multiple_input else task.doc_to_text(doc)
)
for doc in arr
]
assert len(requests) == limit if limit else True
|