Skip to content

Commit 166b2b7

Browse files
committed
feat: add zstandard compression support
1 parent 4fdb86c commit 166b2b7

6 files changed

Lines changed: 353 additions & 82 deletions

File tree

poetry.lock

Lines changed: 150 additions & 43 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ python = ">=3.9,<4.0"
3030
aiohttp = ">=3.10.0"
3131
python-dateutil = "^2.8.2"
3232
aiofiles = "^24.1.0"
33+
zstandard = ">=0.19.0"
3334

3435
[tool.poetry.group.dev.dependencies]
3536
pytest = "^8.3.5"

tardis_dev/_http.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
logger = logging.getLogger(__name__)
1515

1616

17-
async def create_session(api_key: str, timeout: int) -> aiohttp.ClientSession:
17+
async def create_session(api_key: str, timeout: int, accept_encoding: str = "gzip") -> aiohttp.ClientSession:
1818
from tardis_dev import __version__
1919

2020
headers = {
21-
"Accept-Encoding": "gzip",
21+
"Accept-Encoding": accept_encoding,
2222
"User-Agent": f"tardis-dev/{__version__} (+https://github.com/tardis-dev/tardis-python)",
2323
}
2424
if api_key:
@@ -38,14 +38,20 @@ async def reliable_download(
3838
dest_path: str,
3939
http_proxy: Optional[str] = None,
4040
max_attempts: int = 30,
41-
) -> None:
41+
append_content_encoding_extension: bool = False,
42+
) -> str:
4243
attempts = 0
4344

4445
while True:
4546
attempts += 1
4647
try:
47-
await _download(session, _get_retry_url(url, attempts), dest_path, http_proxy)
48-
return
48+
return await _download(
49+
session,
50+
_get_retry_url(url, attempts),
51+
dest_path,
52+
http_proxy,
53+
append_content_encoding_extension=append_content_encoding_extension,
54+
)
4955
except asyncio.CancelledError:
5056
raise
5157
except Exception as exc:
@@ -99,12 +105,30 @@ async def _download(
99105
url: str,
100106
dest_path: str,
101107
http_proxy: Optional[str],
102-
) -> None:
108+
*,
109+
append_content_encoding_extension: bool,
110+
) -> str:
103111
async with session.get(url, proxy=http_proxy) as response:
104112
if response.status != 200:
105113
error_text = await response.text()
106114
raise urllib.error.HTTPError(url, code=response.status, msg=error_text, hdrs=None, fp=None)
107115

116+
final_path = dest_path
117+
if append_content_encoding_extension:
118+
content_encoding = response.headers.get("Content-Encoding")
119+
if content_encoding == "zstd":
120+
final_path = f"{dest_path}.zst"
121+
elif content_encoding == "gzip":
122+
final_path = f"{dest_path}.gz"
123+
else:
124+
raise urllib.error.HTTPError(
125+
url,
126+
code=400,
127+
msg=f"Unsupported data feed content encoding: {content_encoding}",
128+
hdrs=None,
129+
fp=None,
130+
)
131+
108132
pathlib.Path(dest_path).parent.mkdir(parents=True, exist_ok=True)
109133
temp_path = f"{dest_path}{secrets.token_hex(8)}.unconfirmed"
110134

@@ -113,7 +137,8 @@ async def _download(
113137
async for chunk in response.content.iter_any():
114138
await temp_file.write(chunk)
115139

116-
os.replace(temp_path, dest_path)
140+
os.replace(temp_path, final_path)
141+
return final_path
117142
finally:
118143
if os.path.exists(temp_path):
119144
os.remove(temp_path)

tardis_dev/replay.py

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import gzip
33
import hashlib
4+
import io
45
import json as json_module
56
import logging
67
import os
@@ -11,6 +12,7 @@
1112
from typing import Any, AsyncIterator, Dict, List, NamedTuple, Optional, Sequence, Union
1213

1314
import dateutil.parser
15+
import zstandard
1416

1517
from tardis_dev._http import create_session, reliable_download
1618
from tardis_dev._options import DEFAULT_CACHE_DIR, DEFAULT_ENDPOINT
@@ -84,7 +86,15 @@ async def replay(
8486

8587
while current_slice_path is None:
8688
await asyncio.sleep(0)
87-
path_to_check = _get_slice_cache_path(
89+
zstd_path = _get_slice_cache_path(
90+
cache_dir,
91+
exchange,
92+
current_slice_date,
93+
normalized_filters,
94+
filters_hash=filters_hash,
95+
content_encoding="zstd",
96+
)
97+
gzip_path = _get_slice_cache_path(
8898
cache_dir,
8999
exchange,
90100
current_slice_date,
@@ -95,26 +105,27 @@ async def replay(
95105
if fetch_data_task.done() and fetch_data_task.exception():
96106
raise fetch_data_task.exception()
97107

98-
if os.path.isfile(path_to_check):
99-
current_slice_path = path_to_check
108+
if os.path.isfile(zstd_path):
109+
current_slice_path = zstd_path
110+
elif os.path.isfile(gzip_path):
111+
current_slice_path = gzip_path
100112
else:
101113
await asyncio.sleep(0.1)
102114

103-
with gzip.open(current_slice_path, "rb") as file:
104-
for line in file:
105-
if len(line) <= 1:
106-
if with_disconnects and not last_message_was_disconnect:
107-
last_message_was_disconnect = True
108-
yield None
109-
continue
115+
for line in _iterate_slice_lines(current_slice_path):
116+
if len(line) <= 1:
117+
if with_disconnects and not last_message_was_disconnect:
118+
last_message_was_disconnect = True
119+
yield None
120+
continue
110121

111-
last_message_was_disconnect = False
122+
last_message_was_disconnect = False
112123

113-
if decode_response:
114-
timestamp = datetime.fromisoformat(line[0 : DATE_MESSAGE_SPLIT_INDEX - 2].decode("utf-8"))
115-
yield Response(timestamp, json.loads(line[DATE_MESSAGE_SPLIT_INDEX + 1 :]))
116-
else:
117-
yield Response(line[0:DATE_MESSAGE_SPLIT_INDEX], line[DATE_MESSAGE_SPLIT_INDEX + 1 :])
124+
if decode_response:
125+
timestamp = datetime.fromisoformat(line[0 : DATE_MESSAGE_SPLIT_INDEX - 2].decode("utf-8"))
126+
yield Response(timestamp, json.loads(line[DATE_MESSAGE_SPLIT_INDEX + 1 :]))
127+
else:
128+
yield Response(line[0:DATE_MESSAGE_SPLIT_INDEX], line[DATE_MESSAGE_SPLIT_INDEX + 1 :])
118129

119130
if auto_cleanup:
120131
_remove_processed_slice(current_slice_path)
@@ -167,7 +178,7 @@ async def _fetch_data_to_replay(
167178
if minutes_diff <= 0:
168179
return
169180

170-
async with await create_session(api_key, timeout) as session:
181+
async with await create_session(api_key, timeout, "zstd, gzip") as session:
171182
fetch_data_tasks = set()
172183
try:
173184
prefetch_offsets = [minutes_diff - 1]
@@ -231,9 +242,23 @@ async def _fetch_slice_if_not_cached(
231242
filters_hash: str,
232243
) -> None:
233244
slice_date = from_date + timedelta(minutes=offset)
234-
cache_path = _get_slice_cache_path(cache_dir, exchange, slice_date, filters, filters_hash=filters_hash)
245+
cache_zstd_path = _get_slice_cache_path(
246+
cache_dir,
247+
exchange,
248+
slice_date,
249+
filters,
250+
filters_hash=filters_hash,
251+
content_encoding="zstd",
252+
)
253+
cache_gzip_path = _get_slice_cache_path(
254+
cache_dir,
255+
exchange,
256+
slice_date,
257+
filters,
258+
filters_hash=filters_hash,
259+
)
235260

236-
if os.path.isfile(cache_path):
261+
if os.path.isfile(cache_zstd_path) or os.path.isfile(cache_gzip_path):
237262
return
238263

239264
fetch_url = f"{endpoint}/data-feeds/{exchange}?from={_format_replay_query_date(from_date)}&offset={offset}"
@@ -242,11 +267,14 @@ async def _fetch_slice_if_not_cached(
242267
filters_url_encoded = urllib.parse.quote(filters_serialized, safe="~()*!.'")
243268
fetch_url += f"&filters={filters_url_encoded}"
244269

270+
cache_base_path = cache_gzip_path.removesuffix(".gz")
271+
245272
await reliable_download(
246273
session=session,
247274
url=fetch_url,
248-
dest_path=cache_path,
275+
dest_path=cache_base_path,
249276
http_proxy=http_proxy,
277+
append_content_encoding_extension=True,
250278
)
251279

252280

@@ -341,16 +369,14 @@ def _get_slice_cache_path(
341369
filters: Optional[Sequence[Channel]],
342370
*,
343371
filters_hash: Optional[str] = None,
372+
content_encoding: Optional[str] = None,
344373
) -> str:
345-
return (
346-
os.path.join(
347-
cache_dir,
348-
"feeds",
349-
exchange,
350-
filters_hash if filters_hash is not None else _get_filters_hash(filters),
351-
_format_date_to_path(date),
352-
)
353-
+ ".json.gz"
374+
return os.path.join(
375+
cache_dir,
376+
"feeds",
377+
exchange,
378+
filters_hash if filters_hash is not None else _get_filters_hash(filters),
379+
f"{_format_date_to_path(date)}.json{'.zst' if content_encoding == 'zstd' else '.gz'}",
354380
)
355381

356382

@@ -405,6 +431,18 @@ def _remove_processed_slice(path: str) -> None:
405431
os.remove(path)
406432

407433

434+
def _iterate_slice_lines(path: str):
435+
if path.endswith(".zst"):
436+
with open(path, "rb") as compressed_file:
437+
with zstandard.ZstdDecompressor().stream_reader(compressed_file) as file:
438+
with io.BufferedReader(file) as buffered_file:
439+
yield from buffered_file
440+
return
441+
442+
with gzip.open(path, "rb") as file:
443+
yield from file
444+
445+
408446
def _clear_replay_cache_range(
409447
*,
410448
cache_dir: str,

tests/test_http.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,26 @@
88

99

1010
@pytest.mark.asyncio
11-
async def test_create_session_omits_authorization_header_when_api_key_missing():
12-
async with await create_session("", 5) as session:
11+
async def test_create_session_uses_requested_accept_encoding_and_omits_authorization_header_when_api_key_missing():
12+
async with await create_session("", 5, "zstd, gzip") as session:
1313
assert "Authorization" not in session.headers
14+
assert session.headers["Accept-Encoding"] == "zstd, gzip"
15+
16+
17+
@pytest.mark.asyncio
18+
async def test_reliable_download_appends_zstd_extension_for_replay_cache(tmp_path: Path):
19+
destination = tmp_path / "slice.json"
20+
url = "https://example.com/data"
21+
22+
with aioresponses() as mocked:
23+
mocked.get(url, body=b"payload", headers={"Content-Encoding": "zstd"})
24+
25+
async with await create_session("", 5) as session:
26+
final_path = await reliable_download(session, url, str(destination), append_content_encoding_extension=True)
27+
28+
assert final_path.endswith(".zst")
29+
assert Path(final_path).read_bytes() == b"payload"
30+
assert not destination.exists()
1431

1532

1633
@pytest.mark.asyncio

tests/test_replay.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
import pytest
7+
import zstandard
78

89
from tardis_dev import Channel, replay
910
from tardis_dev.replay import (
@@ -113,6 +114,48 @@ def test_replay_cache_path_uses_normalized_filter_hash():
113114
]
114115

115116

117+
@pytest.mark.asyncio
118+
async def test_replay_prefers_zstd_cache_path_when_available(monkeypatch, tmp_path: Path):
119+
filters = _live_replay_filters()
120+
gzip_path = Path(_get_slice_cache_path(str(tmp_path), LIVE_REPLAY_EXCHANGE, datetime(2019, 5, 1, 0, 0), filters))
121+
zstd_path = Path(
122+
_get_slice_cache_path(
123+
str(tmp_path),
124+
LIVE_REPLAY_EXCHANGE,
125+
datetime(2019, 5, 1, 0, 0),
126+
filters,
127+
content_encoding="zstd",
128+
)
129+
)
130+
gzip_path.parent.mkdir(parents=True, exist_ok=True)
131+
with gzip.open(gzip_path, "wb") as file:
132+
file.write(b'2019-05-01T00:00:00.0000000Z {"table":"trade","source":"gzip"}\n')
133+
zstd_path.write_bytes(
134+
zstandard.ZstdCompressor().compress(
135+
b'2019-05-01T00:00:00.0000000Z {"table":"trade","source":"zstd"}\n'
136+
)
137+
)
138+
139+
async def fake_fetch_data_to_replay(**kwargs):
140+
return None
141+
142+
monkeypatch.setattr(replay_module, "_fetch_data_to_replay", fake_fetch_data_to_replay)
143+
144+
results = []
145+
async for item in replay(
146+
exchange=LIVE_REPLAY_EXCHANGE,
147+
from_date=LIVE_REPLAY_FROM,
148+
to_date=LIVE_REPLAY_TO,
149+
filters=filters,
150+
cache_dir=str(tmp_path),
151+
):
152+
results.append(item)
153+
154+
assert len(results) == 1
155+
assert results[0] is not None
156+
assert results[0].message["source"] == "zstd"
157+
158+
116159
def test_replay_rejects_invalid_filter_items():
117160
async def collect():
118161
async for _ in replay(exchange="bitmex", from_date="2019-06-01", to_date="2019-06-02", filters=["bad"]):
@@ -245,11 +288,51 @@ async def fake_fetch_data_to_replay(**kwargs):
245288
assert results[0].message == b'{"table":"trade","action":"partial","data":[{"symbol":"BTCUSD"}]}\n'
246289

247290

291+
@pytest.mark.asyncio
292+
async def test_replay_reads_zstd_cached_slice(monkeypatch, tmp_path: Path):
293+
cache_dir = tmp_path / "cache"
294+
filters = _live_replay_filters()
295+
slice_path = Path(
296+
_get_slice_cache_path(
297+
str(cache_dir),
298+
LIVE_REPLAY_EXCHANGE,
299+
datetime(2019, 5, 1, 0, 0),
300+
filters,
301+
content_encoding="zstd",
302+
)
303+
)
304+
slice_path.parent.mkdir(parents=True, exist_ok=True)
305+
slice_path.write_bytes(
306+
zstandard.ZstdCompressor().compress(
307+
b'2019-05-01T00:00:00.0000000Z {"table":"trade","action":"partial","data":[{"symbol":"BTCUSD"}]}\n'
308+
)
309+
)
310+
311+
async def fake_fetch_data_to_replay(**kwargs):
312+
return None
313+
314+
monkeypatch.setattr(replay_module, "_fetch_data_to_replay", fake_fetch_data_to_replay)
315+
316+
results = []
317+
async for item in replay(
318+
exchange=LIVE_REPLAY_EXCHANGE,
319+
from_date=LIVE_REPLAY_FROM,
320+
to_date=LIVE_REPLAY_TO,
321+
filters=filters,
322+
cache_dir=str(cache_dir),
323+
):
324+
results.append(item)
325+
326+
assert len(results) == 1
327+
assert results[0] is not None
328+
assert results[0].message["table"] == "trade"
329+
330+
248331
@pytest.mark.asyncio
249332
async def test_fetch_data_to_replay_prefetches_last_then_first(monkeypatch):
250333
offsets = []
251334

252-
async def fake_create_session(api_key: str, timeout: int):
335+
async def fake_create_session(api_key: str, timeout: int, accept_encoding: str):
253336
return _FakeSession()
254337

255338
async def fake_fetch_slice_if_not_cached(**kwargs):

0 commit comments

Comments
 (0)