22import lzma
33import sys
44import zlib
5- from dataclasses import dataclass
5+ from collections .abc import AsyncIterator
6+ from dataclasses import dataclass , field
67from enum import StrEnum
78from typing import Any , Callable , Mapping
89
@@ -22,6 +23,55 @@ class Compression(StrEnum):
2223 ZSTD = "zstd"
2324
2425
26+ @dataclass (frozen = True )
27+ class FileSignature :
28+ """File signature (magic bytes) for a compression format."""
29+
30+ signature : bytes
31+ compression : Compression
32+
33+
34+ # File signatures for compression format detection
35+ # Reference: https://file-extension.net/seeker/
36+ COMPRESSION_SIGNATURES : tuple [FileSignature , ...] = (
37+ FileSignature (b"\x1f \x8b \x08 " , Compression .GZIP ),
38+ FileSignature (b"\xfd \x37 \x7a \x58 \x5a \x00 " , Compression .XZ ),
39+ FileSignature (b"\x42 \x5a \x68 " , Compression .BZ2 ),
40+ FileSignature (b"\x28 \xb5 \x2f \xfd " , Compression .ZSTD ),
41+ )
42+
43+ # Standard buffer size for file signature detection (covers most formats)
44+ SIGNATURE_BUFFER_SIZE = 8
45+
46+
47+ def detect_compression_from_signature (data : bytes ) -> Compression | None :
48+ """Detect compression format from file signature bytes at the start of data.
49+
50+ Args:
51+ data: The first few bytes of the file/stream (at least SIGNATURE_BUFFER_SIZE bytes recommended)
52+
53+ Returns:
54+ The detected Compression type, or None if uncompressed/unknown
55+ """
56+ for sig in COMPRESSION_SIGNATURES :
57+ if data .startswith (sig .signature ):
58+ return sig .compression
59+ return None
60+
61+
62+ def create_decompressor (compression : Compression ) -> Any :
63+ """Create a decompressor object for the given compression type."""
64+ match compression :
65+ case Compression .GZIP :
66+ return zlib .decompressobj (wbits = 47 ) # Auto-detect gzip/zlib
67+ case Compression .XZ :
68+ return lzma .LZMADecompressor ()
69+ case Compression .BZ2 :
70+ return bz2 .BZ2Decompressor ()
71+ case Compression .ZSTD :
72+ return zstd .ZstdDecompressor ()
73+
74+
2575@dataclass (kw_only = True )
2676class CompressedStream (ObjectStream [bytes ]):
2777 stream : AnyByteStream
@@ -99,3 +149,68 @@ def compress_stream(stream: AnyByteStream, compression: Compression | None) -> A
99149 compressor = zstd .ZstdCompressor (),
100150 decompressor = zstd .ZstdDecompressor (),
101151 )
152+
153+
154+ @dataclass (kw_only = True )
155+ class AutoDecompressIterator (AsyncIterator [bytes ]):
156+ """An async iterator that auto-detects and decompresses compressed data.
157+
158+ This wraps an async iterator of bytes and transparently decompresses
159+ gzip, xz, bz2, or zstd compressed data based on file signature detection.
160+ Uncompressed data passes through unchanged.
161+ """
162+
163+ source : AsyncIterator [bytes ]
164+ _decompressor : Any = field (init = False , default = None )
165+ _detected : bool = field (init = False , default = False )
166+ _buffer : bytes = field (init = False , default = b"" )
167+ _exhausted : bool = field (init = False , default = False )
168+
169+ async def _detect_compression (self ) -> None :
170+ """Read enough bytes to detect compression format."""
171+ # Buffer data until we have enough for detection
172+ while len (self ._buffer ) < SIGNATURE_BUFFER_SIZE and not self ._exhausted :
173+ try :
174+ chunk = await self .source .__anext__ ()
175+ self ._buffer += chunk
176+ except StopAsyncIteration :
177+ self ._exhausted = True
178+ break
179+
180+ # Detect compression from buffered data
181+ compression = detect_compression_from_signature (self ._buffer )
182+ if compression is not None :
183+ self ._decompressor = create_decompressor (compression )
184+
185+ self ._detected = True
186+
187+ async def __anext__ (self ) -> bytes :
188+ # First call: detect compression format
189+ if not self ._detected :
190+ await self ._detect_compression ()
191+
192+ # Process buffered data first
193+ if self ._buffer :
194+ data = self ._buffer
195+ self ._buffer = b""
196+ if self ._decompressor is not None :
197+ return self ._decompressor .decompress (data )
198+ return data
199+
200+ # Stream exhausted
201+ if self ._exhausted :
202+ raise StopAsyncIteration
203+
204+ # Read and process next chunk
205+ try :
206+ chunk = await self .source .__anext__ ()
207+ except StopAsyncIteration :
208+ self ._exhausted = True
209+ raise
210+
211+ if self ._decompressor is not None :
212+ return self ._decompressor .decompress (chunk )
213+ return chunk
214+
215+ def __aiter__ (self ) -> AsyncIterator [bytes ]:
216+ return self
0 commit comments