File size: 3,517 Bytes
9d5b280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import importlib
import os
import sys
from datetime import datetime
from typing import List, Optional, Tuple

import pytest
import torch

from lm_eval.caching.cache import PATH


MODULE_DIR = os.path.dirname(os.path.realpath(__file__))

# NOTE the script this loads uses simple evaluate
# TODO potentially test both the helper script and the normal script
sys.path.append(f"{MODULE_DIR}/../scripts")
model_loader = importlib.import_module("requests_caching")
run_model_for_task_caching = model_loader.run_model_for_task_caching

os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
DEFAULT_TASKS = ["lambada_openai", "sciq"]


@pytest.fixture(autouse=True)
def setup_and_teardown():
    # Setup
    torch.use_deterministic_algorithms(False)
    clear_cache()
    # Yields control back to the test function
    yield
    # Cleanup here


def clear_cache():
    if os.path.exists(PATH):
        cache_files = os.listdir(PATH)
        for file in cache_files:
            file_path = f"{PATH}/{file}"
            os.unlink(file_path)


# leaving tasks here to allow for the option to select specific task files
def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]:
    cache_files = os.listdir(PATH)

    file_task_names = []

    for file in cache_files:
        file_without_prefix = file.split("-")[1]
        file_without_prefix_and_suffix = file_without_prefix.split(".")[0]
        file_task_names.extend([file_without_prefix_and_suffix])

    return cache_files, file_task_names


def assert_created(tasks: List[str], file_task_names: List[str]):
    tasks.sort()
    file_task_names.sort()

    assert tasks == file_task_names


@pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_true(tasks: List[str]):
    run_model_for_task_caching(tasks=tasks, cache_requests="true")

    cache_files, file_task_names = get_cache_files()
    print(file_task_names)
    assert_created(tasks=tasks, file_task_names=file_task_names)


@pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_refresh(tasks: List[str]):
    run_model_for_task_caching(tasks=tasks, cache_requests="true")

    timestamp_before_test = datetime.now().timestamp()

    run_model_for_task_caching(tasks=tasks, cache_requests="refresh")

    cache_files, file_task_names = get_cache_files()

    for file in cache_files:
        modification_time = os.path.getmtime(f"{PATH}/{file}")
        assert modification_time > timestamp_before_test

    tasks.sort()
    file_task_names.sort()

    assert tasks == file_task_names


@pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_delete(tasks: List[str]):
    # populate the data first, rerun this test within this test for additional confidence
    # test_requests_caching_true(tasks=tasks)

    run_model_for_task_caching(tasks=tasks, cache_requests="delete")

    cache_files, file_task_names = get_cache_files()

    assert len(cache_files) == 0


# useful for locally running tests through the debugger
if __name__ == "__main__":

    def run_tests():
        tests = [
            # test_requests_caching_true,
            # test_requests_caching_refresh,
            # test_requests_caching_delete,
        ]
        # Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
        default_tasks = DEFAULT_TASKS
        for test_func in tests:
            clear_cache()
            test_func(tasks=default_tasks)

        print("Tests pass")

    run_tests()