Skip to content

Commit b8daec0

Browse files
committed
zarr3: Add original_is_structured flag for void access
For void access, the codec handling differs between: - Non-structured types: codec prepared for [chunk_shape] with original dtype Need to decode/encode then reinterpret bytes. - Structured types: codec already prepared for [chunk_shape, bytes_per_elem] with byte dtype. Just decode/encode directly. Add original_is_structured parameter to cache constructors to properly distinguish these cases in DecodeChunk and EncodeChunk. This follows the pattern from zarr v2 (PR google#272) where CreateVoidMetadata() creates a modified metadata for void access.
1 parent 7065b42 commit b8daec0

3 files changed

Lines changed: 56 additions & 25 deletions

File tree

tensorstore/driver/zarr3/chunk_cache.cc

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,12 @@ ZarrChunkCache::~ZarrChunkCache() = default;
7676
ZarrLeafChunkCache::ZarrLeafChunkCache(
7777
kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state,
7878
ZarrDType dtype, internal::CachePool::WeakPtr /*data_cache_pool*/,
79-
bool open_as_void)
79+
bool open_as_void, bool original_is_structured)
8080
: Base(std::move(store)),
8181
codec_state_(std::move(codec_state)),
8282
dtype_(std::move(dtype)),
83-
open_as_void_(open_as_void) {}
83+
open_as_void_(open_as_void),
84+
original_is_structured_(original_is_structured) {}
8485

8586
void ZarrLeafChunkCache::Read(ZarrChunkCache::ReadRequest request,
8687
AnyFlowReceiver<absl::Status, internal::ReadChunk,
@@ -158,15 +159,27 @@ ZarrLeafChunkCache::DecodeChunk(span<const Index> chunk_indices,
158159
const size_t num_fields = dtype_.fields.size();
159160
absl::InlinedVector<SharedArray<const void>, 1> field_arrays(num_fields);
160161

161-
// Special case: void access - decode using original codec shape, then
162-
// reinterpret as bytes with extra dimension.
162+
// Special case: void access - decode and return as bytes.
163163
//
164-
// The codec was prepared for the original dtype and chunk_shape (without
165-
// bytes dimension). We decode to that shape, then view the raw bytes with
166-
// an extra dimension representing the bytes per element.
164+
// For non-structured types: codec was prepared for [chunk_shape] with
165+
// original dtype. We decode to that shape then reinterpret as bytes.
166+
//
167+
// For structured types: codec was already prepared for
168+
// [chunk_shape, bytes_per_elem] with byte dtype. Just decode directly.
167169
if (open_as_void_) {
168-
// The grid's chunk_shape for void has extra bytes dimension - strip it
169-
// to get the original codec shape.
170+
const auto& void_component_shape = grid().components[0].shape();
171+
172+
if (original_is_structured_) {
173+
// Structured types: codec already expects bytes with extra dimension.
174+
// Just decode directly to the void component shape.
175+
TENSORSTORE_ASSIGN_OR_RETURN(
176+
field_arrays[0],
177+
codec_state_->DecodeArray(void_component_shape, std::move(data)));
178+
return field_arrays;
179+
}
180+
181+
// Non-structured types: codec expects original dtype without extra
182+
// dimension. Decode, then reinterpret as bytes.
170183
const auto& void_chunk_shape = grid().chunk_shape;
171184
std::vector<Index> original_chunk_shape(
172185
void_chunk_shape.begin(),
@@ -178,8 +191,6 @@ ZarrLeafChunkCache::DecodeChunk(span<const Index> chunk_indices,
178191
codec_state_->DecodeArray(original_chunk_shape, std::move(data)));
179192

180193
// Reinterpret the decoded array's bytes as [chunk_shape..., bytes_per_elem]
181-
// This creates a view over the same memory but with byte dtype and extra dim
182-
const auto& void_component_shape = grid().components[0].shape();
183194
auto byte_array = AllocateArray(
184195
void_component_shape, c_order, default_init,
185196
dtype_v<tensorstore::dtypes::byte_t>);
@@ -242,12 +253,20 @@ Result<absl::Cord> ZarrLeafChunkCache::EncodeChunk(
242253
span<const SharedArray<const void>> component_arrays) {
243254
assert(component_arrays.size() == 1);
244255

245-
// Special case: void access - reinterpret byte array back to original
246-
// dtype shape before encoding.
256+
// Special case: void access - encode bytes back to original format.
257+
//
258+
// For structured types: codec already expects bytes with extra dimension.
259+
// Just encode directly.
247260
//
248-
// The input has shape [chunk_shape..., bytes_per_elem] of byte_t.
249-
// The codec expects [chunk_shape] of the original dtype.
261+
// For non-structured types: reinterpret byte array as original dtype
262+
// and shape before encoding.
250263
if (open_as_void_) {
264+
if (original_is_structured_) {
265+
// Structured types: codec already expects bytes with extra dimension.
266+
return codec_state_->EncodeArray(component_arrays[0]);
267+
}
268+
269+
// Non-structured types: reinterpret bytes as original dtype/shape.
251270
const auto& byte_array = component_arrays[0];
252271
const Index bytes_per_element = dtype_.bytes_per_outer_element;
253272

@@ -256,7 +275,6 @@ Result<absl::Cord> ZarrLeafChunkCache::EncodeChunk(
256275
std::vector<Index> original_shape(void_shape.begin(), void_shape.end() - 1);
257276

258277
// Create a view over the byte data with original layout
259-
// The codec expects the original dtype's element size for stride calculation
260278
auto encoded_array = SharedArray<const void>(
261279
byte_array.element_pointer(),
262280
StridedLayout<>(c_order, bytes_per_element, original_shape));
@@ -274,12 +292,13 @@ kvstore::Driver* ZarrLeafChunkCache::GetKvStoreDriver() {
274292
ZarrShardedChunkCache::ZarrShardedChunkCache(
275293
kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state,
276294
ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool,
277-
bool open_as_void)
295+
bool open_as_void, bool original_is_structured)
278296
: base_kvstore_(std::move(store)),
279297
codec_state_(std::move(codec_state)),
280298
dtype_(std::move(dtype)),
281-
data_cache_pool_(std::move(data_cache_pool)),
282-
open_as_void_(open_as_void) {}
299+
open_as_void_(open_as_void),
300+
original_is_structured_(original_is_structured),
301+
data_cache_pool_(std::move(data_cache_pool)) {}
283302

284303
Result<IndexTransform<>> TranslateCellToSourceTransformForShard(
285304
IndexTransform<> transform, span<const Index> grid_cell_indices,
@@ -588,7 +607,8 @@ void ZarrShardedChunkCache::Entry::DoInitialize() {
588607
*sharding_state.sub_chunk_codec_chain,
589608
std::move(sharding_kvstore), cache.executor(),
590609
ZarrShardingCodec::PreparedState::Ptr(&sharding_state),
591-
cache.dtype_, cache.data_cache_pool_, cache.open_as_void_);
610+
cache.dtype_, cache.data_cache_pool_, cache.open_as_void_,
611+
cache.original_is_structured_);
592612
zarr_chunk_cache = new_cache.release();
593613
return std::unique_ptr<internal::Cache>(&zarr_chunk_cache->cache());
594614
})

tensorstore/driver/zarr3/chunk_cache.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ class ZarrLeafChunkCache : public internal::KvsBackedChunkCache,
159159
ZarrCodecChain::PreparedState::Ptr codec_state,
160160
ZarrDType dtype,
161161
internal::CachePool::WeakPtr data_cache_pool,
162-
bool open_as_void);
162+
bool open_as_void,
163+
bool original_is_structured);
163164

164165
void Read(ZarrChunkCache::ReadRequest request,
165166
AnyFlowReceiver<absl::Status, internal::ReadChunk,
@@ -188,6 +189,7 @@ class ZarrLeafChunkCache : public internal::KvsBackedChunkCache,
188189
ZarrCodecChain::PreparedState::Ptr codec_state_;
189190
ZarrDType dtype_;
190191
bool open_as_void_;
192+
bool original_is_structured_;
191193
};
192194

193195
/// Chunk cache for a Zarr array where each chunk is a shard.
@@ -199,7 +201,8 @@ class ZarrShardedChunkCache : public internal::Cache, public ZarrChunkCache {
199201
ZarrCodecChain::PreparedState::Ptr codec_state,
200202
ZarrDType dtype,
201203
internal::CachePool::WeakPtr data_cache_pool,
202-
bool open_as_void);
204+
bool open_as_void,
205+
bool original_is_structured);
203206

204207
const ZarrShardingCodec::PreparedState& sharding_codec_state() const {
205208
return static_cast<const ZarrShardingCodec::PreparedState&>(
@@ -250,6 +253,7 @@ class ZarrShardedChunkCache : public internal::Cache, public ZarrChunkCache {
250253
ZarrCodecChain::PreparedState::Ptr codec_state_;
251254
ZarrDType dtype_;
252255
bool open_as_void_;
256+
bool original_is_structured_;
253257

254258
// Data cache pool, if it differs from `this->pool()` (which is equal to the
255259
// metadata cache pool).
@@ -265,12 +269,12 @@ class ZarrShardSubChunkCache : public ChunkCacheImpl {
265269
kvstore::DriverPtr store, Executor executor,
266270
ZarrShardingCodec::PreparedState::Ptr sharding_state,
267271
ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool,
268-
bool open_as_void)
272+
bool open_as_void, bool original_is_structured)
269273
: ChunkCacheImpl(std::move(store),
270274
ZarrCodecChain::PreparedState::Ptr(
271275
sharding_state->sub_chunk_codec_state),
272276
std::move(dtype), std::move(data_cache_pool),
273-
open_as_void),
277+
open_as_void, original_is_structured),
274278
sharding_state_(std::move(sharding_state)),
275279
executor_(std::move(executor)) {}
276280

tensorstore/driver/zarr3/driver.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,11 +913,18 @@ class ZarrDriver::OpenState : public ZarrDriver::OpenStateBase {
913913
/*.num_bytes=*/metadata.data_type.bytes_per_outer_element}},
914914
/*.bytes_per_outer_element=*/metadata.data_type.bytes_per_outer_element};
915915
}
916+
// Determine if original dtype is structured (multiple fields or field with
917+
// outer_shape). This affects how void access handles codec operations.
918+
const bool original_is_structured =
919+
metadata.data_type.fields.size() > 1 ||
920+
(metadata.data_type.fields.size() == 1 &&
921+
!metadata.data_type.fields[0].outer_shape.empty());
922+
916923
return internal_zarr3::MakeZarrChunkCache<DataCacheBase, ZarrDataCache>(
917924
*metadata.codecs, std::move(initializer), spec().store.path,
918925
metadata.codec_state, dtype,
919926
/*data_cache_pool=*/*cache_pool(),
920-
spec().open_as_void);
927+
spec().open_as_void, original_is_structured);
921928
}
922929

923930
Result<size_t> GetComponentIndex(const void* metadata_ptr,

0 commit comments

Comments
 (0)