-
Notifications
You must be signed in to change notification settings - Fork 295
Expand file tree
/
Copy pathtest_fetcher_ng.py
More file actions
284 lines (239 loc) · 10.2 KB
/
test_fetcher_ng.py
File metadata and controls
284 lines (239 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# Copyright 2021, New York University and the TUF contributors
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Unit test for Urllib3Fetcher."""
import io
import logging
import math
import os
import sys
import tempfile
import unittest
from typing import ClassVar
from unittest.mock import Mock, patch
import urllib3
from tests import utils
from tuf.api import exceptions
from tuf.ngclient import Urllib3Fetcher
logger = logging.getLogger(__name__)
class TestFetcher(unittest.TestCase):
"""Test Urllib3Fetcher class."""
server_process_handler: ClassVar[utils.TestServerProcess]
@classmethod
def setUpClass(cls) -> None:
"""
Create a temporary file and launch a simple server in the
current working directory.
"""
cls.server_process_handler = utils.TestServerProcess(log=logger)
cls.file_contents = b"junk data"
cls.file_length = len(cls.file_contents)
with tempfile.NamedTemporaryFile(
dir=os.getcwd(), delete=False
) as cls.target_file:
cls.target_file.write(cls.file_contents)
cls.url_prefix = (
f"http://{utils.TEST_HOST_ADDRESS}:"
f"{cls.server_process_handler.port!s}"
)
target_filename = os.path.basename(cls.target_file.name)
cls.url = f"{cls.url_prefix}/{target_filename}"
@classmethod
def tearDownClass(cls) -> None:
# Stop server process and perform clean up.
cls.server_process_handler.clean()
os.remove(cls.target_file.name)
def setUp(self) -> None:
# Instantiate a concrete instance of FetcherInterface
self.fetcher = Urllib3Fetcher()
# Simple fetch.
def test_fetch(self) -> None:
with tempfile.TemporaryFile() as temp_file:
for chunk in self.fetcher.fetch(self.url):
temp_file.write(chunk)
temp_file.seek(0)
self.assertEqual(self.file_contents, temp_file.read())
# URL data downloaded in more than one chunk
def test_fetch_in_chunks(self) -> None:
# Set a smaller chunk size to ensure that the file will be downloaded
# in more than one chunk
self.fetcher.chunk_size = 4
# expected_chunks_count: 3 (depends on length of self.file_length)
expected_chunks_count = math.ceil(
self.file_length / self.fetcher.chunk_size
)
self.assertEqual(expected_chunks_count, 3)
chunks_count = 0
with tempfile.TemporaryFile() as temp_file:
for chunk in self.fetcher.fetch(self.url):
temp_file.write(chunk)
chunks_count += 1
temp_file.seek(0)
self.assertEqual(self.file_contents, temp_file.read())
# Check that we calculate chunks as expected
self.assertEqual(chunks_count, expected_chunks_count)
# Incorrect URL parsing
def test_url_parsing(self) -> None:
with self.assertRaises(exceptions.DownloadError):
self.fetcher.fetch("http://invalid/")
# File not found error
def test_http_error(self) -> None:
with self.assertRaises(exceptions.DownloadHTTPError) as cm:
self.url = f"{self.url_prefix}/non-existing-path"
self.fetcher.fetch(self.url)
self.assertEqual(cm.exception.status_code, 404)
# Response read timeout error
@patch.object(urllib3.PoolManager, "request")
def test_response_read_timeout(self, mock_session_get: Mock) -> None:
mock_response = Mock()
mock_response.status = 200
attr = {
"stream.side_effect": urllib3.exceptions.MaxRetryError(
urllib3.connectionpool.ConnectionPool("localhost"),
"",
urllib3.exceptions.TimeoutError(),
)
}
mock_response.configure_mock(**attr)
mock_session_get.return_value = mock_response
with self.assertRaises(exceptions.SlowRetrievalError):
next(self.fetcher.fetch(self.url))
mock_response.stream.assert_called_once()
# Read/connect session timeout error
@patch.object(
urllib3.PoolManager,
"request",
side_effect=urllib3.exceptions.MaxRetryError(
urllib3.connectionpool.ConnectionPool("localhost"),
"",
urllib3.exceptions.TimeoutError(),
),
)
def test_session_get_timeout(self, mock_session_get: Mock) -> None:
with self.assertRaises(exceptions.SlowRetrievalError):
self.fetcher.fetch(self.url)
mock_session_get.assert_called_once()
# Test retry on ReadTimeoutError during streaming
@patch.object(urllib3.PoolManager, "request")
def test_download_bytes_retry_on_streaming_timeout(
self, mock_request: Mock
) -> None:
"""Test that download_bytes retries when ReadTimeoutError occurs during streaming."""
mock_response_fail = Mock()
mock_response_fail.status = 200
mock_response_fail.stream.side_effect = (
urllib3.exceptions.ReadTimeoutError(
urllib3.connectionpool.ConnectionPool("localhost"),
"",
"Read timed out",
)
)
mock_response_success = Mock()
mock_response_success.status = 200
mock_response_success.stream.return_value = iter(
[self.file_contents[:4], self.file_contents[4:]]
)
mock_request.side_effect = [
mock_response_fail,
mock_response_fail,
mock_response_success,
]
data = self.fetcher.download_bytes(self.url, self.file_length)
self.assertEqual(self.file_contents, data)
self.assertEqual(mock_request.call_count, 3)
# Test retry exhaustion
@patch.object(urllib3.PoolManager, "request")
def test_download_bytes_retry_exhaustion(self, mock_request: Mock) -> None:
"""Test that download_bytes fails after exhausting all retries."""
# All attempts fail
mock_response = Mock()
mock_response.status = 200
mock_response.stream.side_effect = urllib3.exceptions.ReadTimeoutError(
urllib3.connectionpool.ConnectionPool("localhost"),
"",
"Read timed out",
)
mock_request.return_value = mock_response
with self.assertRaises(exceptions.SlowRetrievalError):
self.fetcher.download_bytes(self.url, self.file_length)
# Should have been called 3 times (max_retries=3)
self.assertEqual(mock_request.call_count, 3)
# Test retry on ProtocolError during streaming
@patch.object(urllib3.PoolManager, "request")
def test_download_bytes_retry_on_protocol_error(
self, mock_request: Mock
) -> None:
"""Test that download_bytes retries when ProtocolError occurs during streaming."""
# First attempt fails with protocol error, second succeeds
mock_response_fail = Mock()
mock_response_fail.status = 200
mock_response_fail.stream.side_effect = (
urllib3.exceptions.ProtocolError("Connection broken")
)
mock_response_success = Mock()
mock_response_success.status = 200
mock_response_success.stream.return_value = iter(
[self.file_contents[:4], self.file_contents[4:]]
)
mock_request.side_effect = [
mock_response_fail,
mock_response_success,
]
data = self.fetcher.download_bytes(self.url, self.file_length)
self.assertEqual(self.file_contents, data)
self.assertEqual(mock_request.call_count, 2)
# Test that non-timeout errors are not retried
@patch.object(urllib3.PoolManager, "request")
def test_download_bytes_no_retry_on_http_error(
self, mock_request: Mock
) -> None:
"""Test that download_bytes does not retry on HTTP errors like 404."""
mock_response = Mock()
mock_response.status = 404
mock_request.return_value = mock_response
with self.assertRaises(exceptions.DownloadHTTPError):
self.fetcher.download_bytes(self.url, self.file_length)
# Should only be called once, no retries
mock_request.assert_called_once()
# Test that length mismatch errors are not retried
def test_download_bytes_no_retry_on_length_mismatch(self) -> None:
"""Test that download_bytes does not retry on length mismatch errors."""
# Try to download more data than the file contains
with self.assertRaises(exceptions.DownloadLengthMismatchError):
# File is self.file_length bytes, asking for less should fail
self.fetcher.download_bytes(self.url, self.file_length - 4)
# Simple bytes download
def test_download_bytes(self) -> None:
data = self.fetcher.download_bytes(self.url, self.file_length)
self.assertEqual(self.file_contents, data)
# Download file smaller than required max_length
def test_download_bytes_upper_length(self) -> None:
data = self.fetcher.download_bytes(self.url, self.file_length + 4)
self.assertEqual(self.file_contents, data)
# Download a file bigger than expected
def test_download_bytes_length_mismatch(self) -> None:
with self.assertRaises(exceptions.DownloadLengthMismatchError):
self.fetcher.download_bytes(self.url, self.file_length - 4)
# Simple file download
def test_download_file(self) -> None:
with self.fetcher.download_file(
self.url, self.file_length
) as temp_file:
temp_file.seek(0, io.SEEK_END)
self.assertEqual(self.file_length, temp_file.tell())
# Download file smaller than required max_length
def test_download_file_upper_length(self) -> None:
with self.fetcher.download_file(
self.url, self.file_length + 4
) as temp_file:
temp_file.seek(0, io.SEEK_END)
self.assertEqual(self.file_length, temp_file.tell())
# Download a file bigger than expected
def test_download_file_length_mismatch(self) -> None:
with self.assertRaises(
exceptions.DownloadLengthMismatchError
), self.fetcher.download_file(self.url, self.file_length - 4):
pass # we never get here as download_file() raises
# Run unit test.
if __name__ == "__main__":
utils.configure_test_logging(sys.argv)
unittest.main()