File size: 12,557 Bytes
1b1c183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import pytest
import aiohttp
from aiohttp import ClientResponse
import itertools
import os 
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename

class AsyncIteratorMock:
    """
    A mock class that simulates an asynchronous iterator.
    This is used to mimic the behavior of aiohttp's content iterator.
    """
    def __init__(self, seq):
        # Convert the input sequence into an iterator
        self.iter = iter(seq)

    def __aiter__(self):
        # This method is called when 'async for' is used
        return self

    async def __anext__(self):
        # This method is called for each iteration in an 'async for' loop
        try:
            return next(self.iter)
        except StopIteration:
            # This is the asynchronous equivalent of StopIteration
            raise StopAsyncIteration

class ContentMock:
    """
    A mock class that simulates the content attribute of an aiohttp ClientResponse.
    This class provides the iter_chunked method which returns an async iterator of chunks.
    """
    def __init__(self, chunks):
        # Store the chunks that will be returned by the iterator
        self.chunks = chunks

    def iter_chunked(self, chunk_size):
        # This method mimics aiohttp's content.iter_chunked()
        # For simplicity in testing, we ignore chunk_size and just return our predefined chunks
        return AsyncIteratorMock(self.chunks)

@pytest.mark.asyncio
async def test_download_model_success():
    mock_response = AsyncMock(spec=aiohttp.ClientResponse)
    mock_response.status = 200
    mock_response.headers = {'Content-Length': '1000'}
    # Create a mock for content that returns an async iterator directly
    chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
    mock_response.content = ContentMock(chunks)

    mock_make_request = AsyncMock(return_value=mock_response)
    mock_progress_callback = AsyncMock()

    # Mock file operations
    mock_open = MagicMock()
    mock_file = MagicMock()
    mock_open.return_value.__enter__.return_value = mock_file
    time_values = itertools.count(0, 0.1)

    with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
         patch('model_filemanager.check_file_exists', return_value=None), \
         patch('builtins.open', mock_open), \
         patch('time.time', side_effect=time_values):  # Simulate time passing

        result = await download_model(
            mock_make_request,
            'model.sft',
            'http://example.com/model.sft',
            'checkpoints',
            mock_progress_callback
        )

    # Assert the result
    assert isinstance(result, DownloadModelStatus)
    assert result.message == 'Successfully downloaded model.sft'
    assert result.status == 'completed'
    assert result.already_existed is False

    # Check progress callback calls
    assert mock_progress_callback.call_count >= 3  # At least start, one progress update, and completion
    
    # Check initial call
    mock_progress_callback.assert_any_call(
        'checkpoints/model.sft',
        DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
    )

    # Check final call
    mock_progress_callback.assert_any_call(
        'checkpoints/model.sft',
        DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
    )

    # Verify file writing
    mock_file.write.assert_any_call(b'a' * 500)
    mock_file.write.assert_any_call(b'b' * 300)
    mock_file.write.assert_any_call(b'c' * 200)

    # Verify request was made
    mock_make_request.assert_called_once_with('http://example.com/model.sft')

@pytest.mark.asyncio
async def test_download_model_url_request_failure():
    # Mock dependencies
    mock_response = AsyncMock(spec=ClientResponse)
    mock_response.status = 404  # Simulate a "Not Found" error
    mock_get = AsyncMock(return_value=mock_response)
    mock_progress_callback = AsyncMock()

    # Mock the create_model_path function
    with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
        # Mock the check_file_exists function to return None (file doesn't exist)
        with patch('model_filemanager.check_file_exists', return_value=None):
            # Call the function
            result = await download_model(
                mock_get,
                'model.safetensors',
                'http://example.com/model.safetensors',
                'mock_directory',
                mock_progress_callback
            )

    # Assert the expected behavior
    assert isinstance(result, DownloadModelStatus)
    assert result.status == 'error'
    assert result.message == 'Failed to download model.safetensors. Status code: 404'
    assert result.already_existed is False

    # Check that progress_callback was called with the correct arguments
    mock_progress_callback.assert_any_call(
        'mock_directory/model.safetensors',
        DownloadModelStatus(
            status=DownloadStatusType.PENDING,
            progress_percentage=0,
            message='Starting download of model.safetensors',
            already_existed=False
        )
    )
    mock_progress_callback.assert_called_with(
        'mock_directory/model.safetensors',
        DownloadModelStatus(
            status=DownloadStatusType.ERROR,
            progress_percentage=0,
            message='Failed to download model.safetensors. Status code: 404',
            already_existed=False
        )
    )

    # Verify that the get method was called with the correct URL
    mock_get.assert_called_once_with('http://example.com/model.safetensors')

@pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory():
    
    mock_make_request = AsyncMock()
    mock_progress_callback = AsyncMock()

    
    result = await download_model(
        mock_make_request,
        'model.sft',
        'http://example.com/model.sft',
        '../bad_path',
        mock_progress_callback
    )

    # Assert the result
    assert isinstance(result, DownloadModelStatus)
    assert result.message == 'Invalid model subdirectory'
    assert result.status == 'error'
    assert result.already_existed is False


# For create_model_path function
def test_create_model_path(tmp_path, monkeypatch):
    mock_models_dir = tmp_path / "models"
    monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
    
    model_name = "test_model.sft"
    model_directory = "test_dir"
    
    file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
    
    assert file_path == str(mock_models_dir / model_directory / model_name)
    assert relative_path == f"{model_directory}/{model_name}"
    assert os.path.exists(os.path.dirname(file_path))


@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
    file_path = tmp_path / "existing_model.sft"
    file_path.touch()  # Create an empty file
    
    mock_callback = AsyncMock()
    
    result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
    
    assert result is not None
    assert result.status == "completed"
    assert result.message == "existing_model.sft already exists"
    assert result.already_existed is True
    
    mock_callback.assert_called_once_with(
        "test/existing_model.sft",
        DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
    )

@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
    file_path = tmp_path / "non_existing_model.sft"
    
    mock_callback = AsyncMock()
    
    result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
    
    assert result is None
    mock_callback.assert_not_called()

@pytest.mark.asyncio
async def test_track_download_progress_no_content_length():
    mock_response = AsyncMock(spec=aiohttp.ClientResponse)
    mock_response.headers = {}  # No Content-Length header
    mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])

    mock_callback = AsyncMock()
    mock_open = MagicMock(return_value=MagicMock())

    with patch('builtins.open', mock_open):
        result = await track_download_progress(
            mock_response, '/mock/path/model.sft', 'model.sft',
            mock_callback, 'models/model.sft', interval=0.1
        )

    assert result.status == "completed"
    # Check that progress was reported even without knowing the total size
    mock_callback.assert_any_call(
        'models/model.sft',
        DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
    )

@pytest.mark.asyncio
async def test_track_download_progress_interval():
    mock_response = AsyncMock(spec=aiohttp.ClientResponse)
    mock_response.headers = {'Content-Length': '1000'}
    mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)

    mock_callback = AsyncMock()
    mock_open = MagicMock(return_value=MagicMock())

    # Create a mock time function that returns incremental float values
    mock_time = MagicMock()
    mock_time.side_effect = [i * 0.5 for i in range(30)]  # This should be enough for 10 chunks

    with patch('builtins.open', mock_open), \
         patch('time.time', mock_time):
        await track_download_progress(
            mock_response, '/mock/path/model.sft', 'model.sft',
            mock_callback, 'models/model.sft', interval=1.0
        )

    # Print out the actual call count and the arguments of each call for debugging
    print(f"mock_callback was called {mock_callback.call_count} times")
    for i, call in enumerate(mock_callback.call_args_list):
        args, kwargs = call
        print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")

    # Assert that progress was updated at least 3 times (start, at least one interval, and end)
    assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"

    # Verify the first and last calls
    first_call = mock_callback.call_args_list[0]
    assert first_call[0][1].status == "in_progress"
    # Allow for some initial progress, but it should be less than 50%
    assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"

    last_call = mock_callback.call_args_list[-1]
    assert last_call[0][1].status == "completed"
    assert last_call[0][1].progress_percentage == 100

def test_valid_subdirectory():
    assert validate_model_subdirectory("valid-model123") is True

def test_subdirectory_too_long():
    assert validate_model_subdirectory("a" * 51) is False

def test_subdirectory_with_double_dots():
    assert validate_model_subdirectory("model/../unsafe") is False

def test_subdirectory_with_slash():
    assert validate_model_subdirectory("model/unsafe") is False

def test_subdirectory_with_special_characters():
    assert validate_model_subdirectory("model@unsafe") is False

def test_subdirectory_with_underscore_and_dash():
    assert validate_model_subdirectory("valid_model-name") is True

def test_empty_subdirectory():
    assert validate_model_subdirectory("") is False

@pytest.mark.parametrize("filename, expected", [
    ("valid_model.safetensors", True),
    ("valid_model.sft", True),
    ("valid model.safetensors", True), # Test with space
    ("UPPERCASE_MODEL.SAFETENSORS", True),
    ("model_with.multiple.dots.pt", False),
    ("", False),  # Empty string
    ("../../../etc/passwd", False),  # Path traversal attempt
    ("/etc/passwd", False),  # Absolute path
    ("\\windows\\system32\\config\\sam", False),  # Windows path
    (".hidden_file.pt", False),  # Hidden file
    ("invalid<char>.ckpt", False),  # Invalid character
    ("invalid?.ckpt", False),  # Another invalid character
    ("very" * 100 + ".safetensors", False),  # Too long filename
    ("\nmodel_with_newline.pt", False),  # Newline character
    ("model_with_emoji😊.pt", False),  # Emoji in filename
])
def test_validate_filename(filename, expected):
    assert validate_filename(filename) == expected