Skip to content

Commit 901d328

Browse files
committed
Begin type erasure and local caching optimizations
1 parent 2f2eb4b commit 901d328

2 files changed

Lines changed: 41 additions & 21 deletions

File tree

mdio/coordinate_selector.h

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ class CoordinateSelector {
130130
const size_t n = kept_runs_.size();
131131

132132
// 1) Fire off all reads in parallel and gather the key values
133-
std::vector<Future<VariableData<T>>> reads;
133+
std::vector<Future<VariableData<void>>> reads;
134134
reads.reserve(n);
135135
for (auto const& desc : kept_runs_) {
136136
MDIO_ASSIGN_OR_RETURN(auto ds, dataset_.isel(desc));
137-
MDIO_ASSIGN_OR_RETURN(auto var, ds.variables.get<T>(sort_key));
137+
MDIO_ASSIGN_OR_RETURN(auto var, ds.variables.at(sort_key));
138138
reads.push_back(var.Read());
139139
}
140140

@@ -149,9 +149,9 @@ class CoordinateSelector {
149149
// if (!f.status().ok()) return f.status();
150150
// auto data = f.value();
151151
// keys.push_back(data.get_data_accessor().data()[data.get_flattened_offset()]);
152-
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<T>(f));
152+
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<void>(f));
153153
auto data = std::get<0>(resolution);
154-
auto data_ptr = std::get<1>(resolution);
154+
auto data_ptr = static_cast<const T*>(std::get<1>(resolution));
155155
auto offset = std::get<2>(resolution);
156156
// auto n = std::get<3>(resolution); // Not required
157157
keys.push_back(data_ptr[offset]);
@@ -195,13 +195,13 @@ class CoordinateSelector {
195195
#ifdef MDIO_INTERNAL_PROFILING
196196
auto start = std::chrono::high_resolution_clock::now();
197197
#endif
198-
std::vector<Future<VariableData<T>>> reads;
198+
std::vector<Future<VariableData<void>>> reads;
199199
reads.reserve(kept_runs_.size());
200200
std::vector<T> ret;
201201

202202
for (const auto& desc : kept_runs_) {
203203
MDIO_ASSIGN_OR_RETURN(auto ds, dataset_.isel(desc));
204-
MDIO_ASSIGN_OR_RETURN(auto var, ds.variables.get<T>(output_variable));
204+
MDIO_ASSIGN_OR_RETURN(auto var, ds.variables.at(output_variable));
205205
auto fut = var.Read();
206206
reads.push_back(fut);
207207
if (var.rank() == 1) {
@@ -216,9 +216,9 @@ class CoordinateSelector {
216216
#endif
217217

218218
for (auto& f : reads) {
219-
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<T>(f));
219+
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<void>(f));
220220
auto data = std::get<0>(resolution);
221-
auto data_ptr = std::get<1>(resolution);
221+
auto data_ptr = static_cast<const T*>(std::get<1>(resolution));
222222
auto offset = std::get<2>(resolution);
223223
auto n = std::get<3>(resolution);
224224
std::vector<T> buffer(n);
@@ -237,6 +237,7 @@ class CoordinateSelector {
237237
Dataset& dataset_;
238238
tensorstore::IndexDomain<> base_domain_;
239239
std::vector<std::vector<mdio::RangeDescriptor<mdio::Index>>> kept_runs_;
240+
std::map<std::string_view, VariableData<void>> cached_variables_;
240241

241242
template <typename D>
242243
Future<void> _applyOp(D const& op) {
@@ -310,19 +311,35 @@ class CoordinateSelector {
310311

311312
template <typename T>
312313
Future<void> _init_runs(const ValueDescriptor<T>& descriptor) {
313-
using Interval = typename Variable<T>::Interval;
314+
using Interval = typename Variable<void>::Interval;
314315
#ifdef MDIO_INTERNAL_PROFILING
315316
auto start = std::chrono::high_resolution_clock::now();
316317
#endif
317-
MDIO_ASSIGN_OR_RETURN(auto var, dataset_.variables.get<T>(
318+
MDIO_ASSIGN_OR_RETURN(auto var, dataset_.variables.at(
318319
std::string(descriptor.label.label())));
319-
auto fut = var.Read();
320+
321+
const T* data_ptr;
322+
Index offset;
323+
Index n_samples;
320324
MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals());
321-
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<T>(fut));
322-
auto data = std::get<0>(resolution);
323-
auto data_ptr = std::get<1>(resolution);
324-
auto offset = std::get<2>(resolution);
325-
auto n_samples = std::get<3>(resolution);
325+
if (cached_variables_.find(descriptor.label.label()) == cached_variables_.end()) {
326+
// TODO(BrianMichell): Ensure that the domain has not changed.
327+
std::cout << "Reading VariableData" << std::endl;
328+
auto fut = var.Read();
329+
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<void>(fut));
330+
auto dataToCache = std::get<0>(resolution);
331+
cached_variables_.insert_or_assign(descriptor.label.label(), std::move(dataToCache));
332+
}
333+
auto it = cached_variables_.find(descriptor.label.label());
334+
if (it == cached_variables_.end()) {
335+
std::stringstream ss;
336+
ss << "Cached variable not found for coordinate '" << descriptor.label.label() << "'";
337+
return absl::NotFoundError(ss.str());
338+
}
339+
auto& data = it->second;
340+
data_ptr = static_cast<const T*>(data.get_data_accessor().data());
341+
offset = data.get_flattened_offset();
342+
n_samples = data.num_samples();
326343

327344
auto current_pos = intervals;
328345
bool isInRun = false;
@@ -343,7 +360,8 @@ class CoordinateSelector {
343360
// The start of a new run
344361
isInRun = true;
345362
for (auto i = run_idx; i < idx; ++i) {
346-
_current_position_increment<T>(current_pos, intervals);
363+
// _current_position_increment<T>(current_pos, intervals);
364+
_current_position_increment<void>(current_pos, intervals);
347365
}
348366
// _current_position_stride<T>(current_pos, intervals, idx - run_idx);
349367
run_idx = idx;
@@ -358,16 +376,15 @@ class CoordinateSelector {
358376
// Use 1 less than the current index to ensure we get the correct end
359377
// location.
360378
for (auto i = run_idx; i < idx - 1; ++i) {
361-
_current_position_increment<T>(current_pos, intervals);
379+
_current_position_increment<void>(current_pos, intervals);
362380
}
363-
// _current_position_stride<T>(current_pos, intervals, idx - run_idx);
364381
run_idx = idx;
365382
auto& last_run = local_runs.back();
366383
for (auto i = 0; i < current_pos.size(); ++i) {
367384
last_run[i].exclusive_max = current_pos[i].inclusive_min + 1;
368385
}
369386
// We need to advance to the actual current position
370-
_current_position_increment<T>(current_pos, intervals);
387+
_current_position_increment<void>(current_pos, intervals);
371388
} else if (!is_match && !isInRun) {
372389
// No run at all
373390
// do nothing TODO: Remove me
@@ -388,7 +405,7 @@ class CoordinateSelector {
388405
return absl::NotFoundError(ss.str());
389406
}
390407

391-
kept_runs_ = _from_intervals<T>(local_runs);
408+
kept_runs_ = _from_intervals<void>(local_runs);
392409
#ifdef MDIO_INTERNAL_PROFILING
393410
std::cout << "Finalize time... ";
394411
timer(start);

mdio/variable.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,9 @@ Future<Variable<T, R, M>> OpenVariable(const nlohmann::json& json_store,
576576
}
577577
// the negative of this is valid for tensorstore ...
578578

579+
// TODO(BrianMichell): Look into making the recheck_cached_data an open option.
580+
// store_spec["recheck_cached_data"] = false; // This could become problematic if we are doing read/write operations.
581+
579582
auto spec = tensorstore::MakeReadyFuture<::nlohmann::json>(store_spec);
580583

581584
// open a store:

0 commit comments

Comments
 (0)