Skip to content

Commit 5fd5b6e

Browse files
committed
refactor: SupportsSyncCodec is generic, like BaseCodec
1 parent 9d01432 commit 5fd5b6e

2 files changed

Lines changed: 45 additions & 38 deletions

File tree

src/zarr/abc/codec.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,19 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]:
6767

6868

6969
@runtime_checkable
70-
class SupportsSyncCodec(Protocol):
70+
class SupportsSyncCodec[CI: CodecInput, CO: CodecOutput](Protocol):
7171
"""Protocol for codecs that support synchronous encode/decode.
7272
73-
Codecs implementing this protocol provide ``_decode_sync`` and ``_encode_sync``
73+
Codecs implementing this protocol provide `_decode_sync` and `_encode_sync`
7474
methods that perform encoding/decoding without requiring an async event loop.
75+
76+
The type parameters mirror `BaseCodec`: `CI` is the decoded type and `CO` is
77+
the encoded type.
7578
"""
7679

77-
def _decode_sync(
78-
self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec
79-
) -> NDBuffer | Buffer: ...
80+
def _decode_sync(self, chunk_data: CO, chunk_spec: ArraySpec) -> CI: ...
8081

81-
def _encode_sync(
82-
self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec
83-
) -> NDBuffer | Buffer | None: ...
82+
def _encode_sync(self, chunk_data: CI, chunk_spec: ArraySpec) -> CO | None: ...
8483

8584

8685
class BaseCodec[CI: CodecInput, CO: CodecOutput](Metadata):

src/zarr/core/codec_pipeline.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,29 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
7171
class ChunkTransform:
7272
"""A synchronous codec chain bound to an ArraySpec.
7373
74-
Provides ``encode_chunk`` and ``decode_chunk`` for pure-compute
75-
codec operations (no IO, no threading, no batching).
74+
Provides `encode` and `decode` for pure-compute codec operations
75+
(no IO, no threading, no batching).
7676
77-
``shape`` and ``dtype`` reflect the representation **after** all
78-
ArrayArrayCodec transforms i.e. the spec that feeds the
77+
`shape` and `dtype` reflect the representation **after** all
78+
ArrayArrayCodec transforms -- i.e. the spec that feeds the
7979
ArrayBytesCodec.
8080
81-
All codecs must implement ``SupportsSyncCodec``. Construction will
82-
raise ``TypeError`` if any codec does not.
81+
All codecs must implement `SupportsSyncCodec`. Construction will
82+
raise `TypeError` if any codec does not.
8383
"""
8484

8585
codecs: tuple[Codec, ...]
8686
array_spec: ArraySpec
8787

88-
# (ArrayArrayCodec, input_spec) pairs in pipeline order.
89-
_aa_codecs: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = field(
88+
# (sync codec, input_spec) pairs in pipeline order.
89+
_aa_codecs: tuple[tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ArraySpec], ...] = field(
9090
init=False, repr=False, compare=False
9191
)
92-
_ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False)
92+
_ab_codec: SupportsSyncCodec[NDBuffer, Buffer] = field(init=False, repr=False, compare=False)
9393
_ab_spec: ArraySpec = field(init=False, repr=False, compare=False)
94-
_bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False)
94+
_bb_codecs: tuple[SupportsSyncCodec[Buffer, Buffer], ...] = field(
95+
init=False, repr=False, compare=False
96+
)
9597

9698
def __post_init__(self) -> None:
9799
non_sync = [c for c in self.codecs if not isinstance(c, SupportsSyncCodec)]
@@ -103,16 +105,22 @@ def __post_init__(self) -> None:
103105

104106
aa, ab, bb = codecs_from_list(list(self.codecs))
105107

106-
aa_codecs: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = ()
108+
aa_codecs: list[tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ArraySpec]] = []
107109
spec = self.array_spec
108110
for aa_codec in aa:
109-
aa_codecs = (*aa_codecs, (aa_codec, spec))
111+
assert isinstance(aa_codec, SupportsSyncCodec)
112+
aa_codecs.append((aa_codec, spec))
110113
spec = aa_codec.resolve_metadata(spec)
111114

112-
self._aa_codecs = aa_codecs
115+
self._aa_codecs = tuple(aa_codecs)
116+
assert isinstance(ab, SupportsSyncCodec)
113117
self._ab_codec = ab
114118
self._ab_spec = spec
115-
self._bb_codecs = bb
119+
bb_sync: list[SupportsSyncCodec[Buffer, Buffer]] = []
120+
for bb_codec in bb:
121+
assert isinstance(bb_codec, SupportsSyncCodec)
122+
bb_sync.append(bb_codec)
123+
self._bb_codecs = tuple(bb_sync)
116124

117125
@property
118126
def shape(self) -> tuple[int, ...]:
@@ -132,18 +140,16 @@ def decode(
132140
133141
Pure compute -- no IO.
134142
"""
135-
# All codecs are verified to implement SupportsSyncCodec in __post_init__,
136-
# but the stored types (ArrayArrayCodec, etc.) don't reflect this statically.
137-
bb_out: Any = chunk_bytes
143+
data: Buffer = chunk_bytes
138144
for bb_codec in reversed(self._bb_codecs):
139-
bb_out = bb_codec._decode_sync(bb_out, self._ab_spec) # type: ignore[attr-defined]
145+
data = bb_codec._decode_sync(data, self._ab_spec)
140146

141-
ab_out: Any = self._ab_codec._decode_sync(bb_out, self._ab_spec) # type: ignore[attr-defined]
147+
chunk_array: NDBuffer = self._ab_codec._decode_sync(data, self._ab_spec)
142148

143149
for aa_codec, spec in reversed(self._aa_codecs):
144-
ab_out = aa_codec._decode_sync(ab_out, spec) # type: ignore[attr-defined]
150+
chunk_array = aa_codec._decode_sync(chunk_array, spec)
145151

146-
return ab_out # type: ignore[no-any-return]
152+
return chunk_array
147153

148154
def encode(
149155
self,
@@ -153,23 +159,25 @@ def encode(
153159
154160
Pure compute -- no IO.
155161
"""
156-
aa_out: Any = chunk_array
157-
162+
aa_data: NDBuffer = chunk_array
158163
for aa_codec, spec in self._aa_codecs:
159-
if aa_out is None:
164+
aa_result = aa_codec._encode_sync(aa_data, spec)
165+
if aa_result is None:
160166
return None
161-
aa_out = aa_codec._encode_sync(aa_out, spec) # type: ignore[attr-defined]
167+
aa_data = aa_result
162168

163-
if aa_out is None:
169+
ab_result = self._ab_codec._encode_sync(aa_data, self._ab_spec)
170+
if ab_result is None:
164171
return None
165-
bb_out: Any = self._ab_codec._encode_sync(aa_out, self._ab_spec) # type: ignore[attr-defined]
166172

173+
bb_data: Buffer = ab_result
167174
for bb_codec in self._bb_codecs:
168-
if bb_out is None:
175+
bb_result = bb_codec._encode_sync(bb_data, self._ab_spec)
176+
if bb_result is None:
169177
return None
170-
bb_out = bb_codec._encode_sync(bb_out, self._ab_spec) # type: ignore[attr-defined]
178+
bb_data = bb_result
171179

172-
return bb_out # type: ignore[no-any-return]
180+
return bb_data
173181

174182
def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
175183
for codec in self.codecs:

0 commit comments

Comments
 (0)