File size: 2,600 Bytes
63deadc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pytest

from fsspec.callbacks import Callback, TqdmCallback


def test_callbacks():
    empty_callback = Callback()
    assert empty_callback.call("something", somearg=None) is None

    hooks = {"something": lambda *_, arg=None: arg + 2}
    simple_callback = Callback(hooks=hooks)
    assert simple_callback.call("something", arg=2) == 4

    hooks = {"something": lambda *_, arg1=None, arg2=None: arg1 + arg2}
    multi_arg_callback = Callback(hooks=hooks)
    assert multi_arg_callback.call("something", arg1=2, arg2=2) == 4


def test_callbacks_as_callback():
    empty_callback = Callback.as_callback(None)
    assert empty_callback.call("something", arg="somearg") is None
    assert Callback.as_callback(None) is Callback.as_callback(None)

    hooks = {"something": lambda *_, arg=None: arg + 2}
    real_callback = Callback.as_callback(Callback(hooks=hooks))
    assert real_callback.call("something", arg=2) == 4


def test_callbacks_as_context_manager(mocker):
    spy_close = mocker.spy(Callback, "close")

    with Callback() as cb:
        assert isinstance(cb, Callback)

    spy_close.assert_called_once()


def test_callbacks_branched():
    callback = Callback()

    branch = callback.branched("path_1", "path_2")

    assert branch is not callback
    assert isinstance(branch, Callback)


@pytest.mark.asyncio
async def test_callbacks_branch_coro(mocker):
    async_fn = mocker.AsyncMock(return_value=10)
    callback = Callback()
    wrapped_fn = callback.branch_coro(async_fn)
    spy = mocker.spy(callback, "branched")

    assert await wrapped_fn("path_1", "path_2", key="value") == 10

    spy.assert_called_once_with("path_1", "path_2", key="value")
    async_fn.assert_called_once_with(
        "path_1", "path_2", callback=spy.spy_return, key="value"
    )


def test_callbacks_wrap():
    events = []

    class TestCallback(Callback):
        def relative_update(self, inc=1):
            events.append(inc)

    callback = TestCallback()
    for _ in callback.wrap(range(10)):
        ...

    assert events == [1] * 10


@pytest.mark.parametrize("tqdm_kwargs", [{}, {"desc": "A custom desc"}])
def test_tqdm_callback(tqdm_kwargs, mocker):
    pytest.importorskip("tqdm")
    callback = TqdmCallback(tqdm_kwargs=tqdm_kwargs)
    mocker.patch.object(callback, "_tqdm_cls")
    callback.set_size(10)
    for _ in callback.wrap(range(10)):
        ...

    assert callback.tqdm.update.call_count == 11
    if not tqdm_kwargs:
        callback._tqdm_cls.assert_called_with(total=10)
    else:
        callback._tqdm_cls.assert_called_with(total=10, **tqdm_kwargs)