Skip to content

Commit e197bc6

Browse files
authored
Merge pull request #7 from BrianMichell/v3_open_as_void_validation
Apply changes based on feedback from google#272
2 parents aad0ee0 + 1298bcb commit e197bc6

6 files changed

Lines changed: 1303 additions & 38 deletions

File tree

tensorstore/driver/zarr3/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ tensorstore_cc_library(
221221
srcs = ["chunk_cache.cc"],
222222
hdrs = ["chunk_cache.h"],
223223
deps = [
224+
":metadata",
224225
"//tensorstore:array",
225226
"//tensorstore:array_storage_statistics",
226227
"//tensorstore:batch",

tensorstore/driver/zarr3/chunk_cache.cc

Lines changed: 128 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,13 @@ ZarrChunkCache::~ZarrChunkCache() = default;
7676
ZarrLeafChunkCache::ZarrLeafChunkCache(
7777
kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state,
7878
ZarrDType dtype, internal::CachePool::WeakPtr /*data_cache_pool*/,
79-
bool open_as_void)
79+
bool open_as_void, bool original_is_structured, DataType original_dtype)
8080
: Base(std::move(store)),
8181
codec_state_(std::move(codec_state)),
8282
dtype_(std::move(dtype)),
83-
open_as_void_(open_as_void) {}
83+
open_as_void_(open_as_void),
84+
original_is_structured_(original_is_structured),
85+
original_dtype_(original_dtype) {}
8486

8587
void ZarrLeafChunkCache::Read(ZarrChunkCache::ReadRequest request,
8688
AnyFlowReceiver<absl::Status, internal::ReadChunk,
@@ -158,11 +160,49 @@ ZarrLeafChunkCache::DecodeChunk(span<const Index> chunk_indices,
158160
const size_t num_fields = dtype_.fields.size();
159161
absl::InlinedVector<SharedArray<const void>, 1> field_arrays(num_fields);
160162

161-
// Special case: void access - return raw bytes directly
163+
// Special case: void access - decode and return as bytes.
164+
//
165+
// For non-structured types: codec was prepared for [chunk_shape] with
166+
// original dtype. We decode to that shape then reinterpret as bytes.
167+
//
168+
// For structured types: codec was already prepared for
169+
// [chunk_shape, bytes_per_elem] with byte dtype. Just decode directly.
162170
if (open_as_void_) {
171+
assert(num_fields == 1); // Void access uses a single synthesized field
172+
const auto& void_component_shape = grid().components[0].shape();
173+
174+
if (original_is_structured_) {
175+
// Structured types: codec already expects bytes with extra dimension.
176+
// Just decode directly to the void component shape.
177+
TENSORSTORE_ASSIGN_OR_RETURN(
178+
field_arrays[0],
179+
codec_state_->DecodeArray(void_component_shape, std::move(data)));
180+
return field_arrays;
181+
}
182+
183+
// Non-structured types: codec expects original dtype without extra
184+
// dimension. Decode, then reinterpret as bytes.
185+
const auto& void_chunk_shape = grid().chunk_shape;
186+
std::vector<Index> original_chunk_shape(
187+
void_chunk_shape.begin(),
188+
void_chunk_shape.end() - 1); // Strip bytes dimension
189+
190+
// Decode using original codec shape
163191
TENSORSTORE_ASSIGN_OR_RETURN(
164-
field_arrays[0], codec_state_->DecodeArray(grid().components[0].shape(),
165-
std::move(data)));
192+
auto decoded_array,
193+
codec_state_->DecodeArray(original_chunk_shape, std::move(data)));
194+
195+
// Reinterpret the decoded array's bytes as [chunk_shape..., bytes_per_elem]
196+
auto byte_array = AllocateArray(
197+
void_component_shape, c_order, default_init,
198+
dtype_v<tensorstore::dtypes::byte_t>);
199+
200+
// Copy decoded data to byte array (handles potential layout differences)
201+
std::memcpy(byte_array.data(), decoded_array.data(),
202+
decoded_array.num_elements() *
203+
decoded_array.dtype().size());
204+
205+
field_arrays[0] = std::move(byte_array);
166206
return field_arrays;
167207
}
168208

@@ -213,8 +253,82 @@ ZarrLeafChunkCache::DecodeChunk(span<const Index> chunk_indices,
213253
Result<absl::Cord> ZarrLeafChunkCache::EncodeChunk(
214254
span<const Index> chunk_indices,
215255
span<const SharedArray<const void>> component_arrays) {
216-
assert(component_arrays.size() == 1);
217-
return codec_state_->EncodeArray(component_arrays[0]);
256+
const size_t num_fields = dtype_.fields.size();
257+
258+
// Special case: void access - encode bytes back to original format.
259+
if (open_as_void_) {
260+
assert(component_arrays.size() == 1);
261+
262+
if (original_is_structured_) {
263+
// Structured types: codec already expects bytes with extra dimension.
264+
return codec_state_->EncodeArray(component_arrays[0]);
265+
}
266+
267+
// Non-structured types: reinterpret bytes as original dtype/shape.
268+
const auto& byte_array = component_arrays[0];
269+
const Index bytes_per_element = dtype_.bytes_per_outer_element;
270+
271+
// Build original chunk shape by stripping the bytes dimension
272+
const auto& void_shape = byte_array.shape();
273+
std::vector<Index> original_shape(void_shape.begin(), void_shape.end() - 1);
274+
275+
// Use the original dtype (stored during cache creation) for encoding.
276+
// Create a view over the byte data with original dtype and layout.
277+
// Use the aliasing constructor to share ownership with byte_array but
278+
// interpret the data with the original dtype.
279+
SharedArray<const void> encoded_array;
280+
auto aliased_ptr = std::shared_ptr<const void>(
281+
byte_array.pointer(), // Share ownership with byte_array
282+
byte_array.data()); // But point to the raw data
283+
encoded_array.element_pointer() = SharedElementPointer<const void>(
284+
std::move(aliased_ptr), original_dtype_);
285+
encoded_array.layout() = StridedLayout<>(c_order, bytes_per_element,
286+
original_shape);
287+
288+
return codec_state_->EncodeArray(encoded_array);
289+
}
290+
291+
// For single non-structured field, encode directly
292+
if (num_fields == 1 && dtype_.fields[0].outer_shape.empty()) {
293+
assert(component_arrays.size() == 1);
294+
return codec_state_->EncodeArray(component_arrays[0]);
295+
}
296+
297+
// For structured types, combine multiple field arrays into a single byte array
298+
assert(component_arrays.size() == num_fields);
299+
300+
// Build encode shape: [chunk_dims..., bytes_per_outer_element]
301+
const auto& chunk_shape = grid().chunk_shape;
302+
std::vector<Index> encode_shape(chunk_shape.begin(), chunk_shape.end());
303+
encode_shape.push_back(dtype_.bytes_per_outer_element);
304+
305+
// Calculate number of outer elements
306+
Index num_elements = 1;
307+
for (size_t i = 0; i < chunk_shape.size(); ++i) {
308+
num_elements *= chunk_shape[i];
309+
}
310+
311+
// Allocate byte array for combined fields
312+
auto byte_array = AllocateArray<std::byte>(encode_shape, c_order, value_init);
313+
auto* dst_bytes = byte_array.data();
314+
315+
// Copy each field's data into the byte array at their respective offsets
316+
for (size_t field_i = 0; field_i < num_fields; ++field_i) {
317+
const auto& field = dtype_.fields[field_i];
318+
const auto& field_array = component_arrays[field_i];
319+
const auto* src = static_cast<const std::byte*>(field_array.data());
320+
const Index field_size = field.dtype->size;
321+
322+
// Copy field data to each struct element
323+
for (Index i = 0; i < num_elements; ++i) {
324+
std::memcpy(dst_bytes + i * dtype_.bytes_per_outer_element +
325+
field.byte_offset,
326+
src + i * field_size,
327+
field_size);
328+
}
329+
}
330+
331+
return codec_state_->EncodeArray(byte_array);
218332
}
219333

220334
kvstore::Driver* ZarrLeafChunkCache::GetKvStoreDriver() {
@@ -224,12 +338,14 @@ kvstore::Driver* ZarrLeafChunkCache::GetKvStoreDriver() {
224338
ZarrShardedChunkCache::ZarrShardedChunkCache(
225339
kvstore::DriverPtr store, ZarrCodecChain::PreparedState::Ptr codec_state,
226340
ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool,
227-
bool open_as_void)
341+
bool open_as_void, bool original_is_structured, DataType original_dtype)
228342
: base_kvstore_(std::move(store)),
229343
codec_state_(std::move(codec_state)),
230344
dtype_(std::move(dtype)),
231-
data_cache_pool_(std::move(data_cache_pool)),
232-
open_as_void_(open_as_void) {}
345+
open_as_void_(open_as_void),
346+
original_is_structured_(original_is_structured),
347+
original_dtype_(original_dtype),
348+
data_cache_pool_(std::move(data_cache_pool)) {}
233349

234350
Result<IndexTransform<>> TranslateCellToSourceTransformForShard(
235351
IndexTransform<> transform, span<const Index> grid_cell_indices,
@@ -538,7 +654,8 @@ void ZarrShardedChunkCache::Entry::DoInitialize() {
538654
*sharding_state.sub_chunk_codec_chain,
539655
std::move(sharding_kvstore), cache.executor(),
540656
ZarrShardingCodec::PreparedState::Ptr(&sharding_state),
541-
cache.dtype_, cache.data_cache_pool_, cache.open_as_void_);
657+
cache.dtype_, cache.data_cache_pool_, cache.open_as_void_,
658+
cache.original_is_structured_, cache.original_dtype_);
542659
zarr_chunk_cache = new_cache.release();
543660
return std::unique_ptr<internal::Cache>(&zarr_chunk_cache->cache());
544661
})

tensorstore/driver/zarr3/chunk_cache.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ class ZarrLeafChunkCache : public internal::KvsBackedChunkCache,
159159
ZarrCodecChain::PreparedState::Ptr codec_state,
160160
ZarrDType dtype,
161161
internal::CachePool::WeakPtr data_cache_pool,
162-
bool open_as_void);
162+
bool open_as_void,
163+
bool original_is_structured,
164+
DataType original_dtype);
163165

164166
void Read(ZarrChunkCache::ReadRequest request,
165167
AnyFlowReceiver<absl::Status, internal::ReadChunk,
@@ -188,6 +190,8 @@ class ZarrLeafChunkCache : public internal::KvsBackedChunkCache,
188190
ZarrCodecChain::PreparedState::Ptr codec_state_;
189191
ZarrDType dtype_;
190192
bool open_as_void_;
193+
bool original_is_structured_;
194+
DataType original_dtype_; // Original dtype for void access encoding
191195
};
192196

193197
/// Chunk cache for a Zarr array where each chunk is a shard.
@@ -199,7 +203,9 @@ class ZarrShardedChunkCache : public internal::Cache, public ZarrChunkCache {
199203
ZarrCodecChain::PreparedState::Ptr codec_state,
200204
ZarrDType dtype,
201205
internal::CachePool::WeakPtr data_cache_pool,
202-
bool open_as_void);
206+
bool open_as_void,
207+
bool original_is_structured,
208+
DataType original_dtype);
203209

204210
const ZarrShardingCodec::PreparedState& sharding_codec_state() const {
205211
return static_cast<const ZarrShardingCodec::PreparedState&>(
@@ -250,6 +256,8 @@ class ZarrShardedChunkCache : public internal::Cache, public ZarrChunkCache {
250256
ZarrCodecChain::PreparedState::Ptr codec_state_;
251257
ZarrDType dtype_;
252258
bool open_as_void_;
259+
bool original_is_structured_;
260+
DataType original_dtype_; // Original dtype for void access encoding
253261

254262
// Data cache pool, if it differs from `this->pool()` (which is equal to the
255263
// metadata cache pool).
@@ -265,12 +273,12 @@ class ZarrShardSubChunkCache : public ChunkCacheImpl {
265273
kvstore::DriverPtr store, Executor executor,
266274
ZarrShardingCodec::PreparedState::Ptr sharding_state,
267275
ZarrDType dtype, internal::CachePool::WeakPtr data_cache_pool,
268-
bool open_as_void)
276+
bool open_as_void, bool original_is_structured, DataType original_dtype)
269277
: ChunkCacheImpl(std::move(store),
270278
ZarrCodecChain::PreparedState::Ptr(
271279
sharding_state->sub_chunk_codec_state),
272280
std::move(dtype), std::move(data_cache_pool),
273-
open_as_void),
281+
open_as_void, original_is_structured, original_dtype),
274282
sharding_state_(std::move(sharding_state)),
275283
executor_(std::move(executor)) {}
276284

0 commit comments

Comments
 (0)