diff --git a/tree/dataframe/inc/ROOT/RCsvDS.hxx b/tree/dataframe/inc/ROOT/RCsvDS.hxx index 89bf6188edacc..d29e7d44c345c 100644 --- a/tree/dataframe/inc/ROOT/RCsvDS.hxx +++ b/tree/dataframe/inc/ROOT/RCsvDS.hxx @@ -27,7 +27,8 @@ namespace ROOT::Internal::RDF { class R__CLING_PTRCHECK(off) RCsvDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase { void *fValuePtr; - void *GetImpl(Long64_t) final { return fValuePtr; } + void *GetImpl(std::size_t) final { return fValuePtr; } + void LoadImpl(Long64_t, bool) final {} public: RCsvDSColumnReader(void *valuePtr) : fValuePtr(valuePtr) {} diff --git a/tree/dataframe/inc/ROOT/RDF/RAction.hxx b/tree/dataframe/inc/ROOT/RDF/RAction.hxx index 2a9e512f65633..6fae6281fd8d0 100644 --- a/tree/dataframe/inc/ROOT/RDF/RAction.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RAction.hxx @@ -99,31 +99,33 @@ public: } template - auto GetValueChecked(unsigned int slot, std::size_t readerIdx, Long64_t entry) -> ColType & + auto GetValueChecked(unsigned int slot, std::size_t readerIdx, std::size_t idx) -> ColType & { - if (auto *val = fValues[slot][readerIdx]->template TryGet(entry)) + if (auto *val = fValues[slot][readerIdx]->template TryGet(idx)) return *val; throw std::out_of_range{"RDataFrame: Action (" + fHelper.GetActionName() + ") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " + - std::to_string(entry) + + std::to_string(idx) + ". You can use the DefaultValueFor operation to provide a default value, or " "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."}; } template - void CallExec(unsigned int slot, Long64_t entry, TypeList, std::index_sequence) + void CallExec(unsigned int slot, std::size_t idx, TypeList, std::index_sequence) { ROOT::Internal::RDF::CallGuaranteedOrder{[&](auto &&...args) { return fHelper.Exec(slot, args...); }, - GetValueChecked(slot, S, entry)...}; - (void)entry; // avoid unused parameter warning (gcc 12.1) + GetValueChecked(slot, S, idx)...}; + (void)idx; // avoid unused parameter warning (gcc 12.1) } void Run(unsigned int slot, Long64_t entry) final { - // check if entry passes all filters - if (fPrevNode.CheckFilters(slot, entry)) - CallExec(slot, entry, ColumnTypes_t{}, TypeInd_t{}); + const auto mask = fPrevNode.CheckFilters(slot, entry); + std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry, mask](auto *v) { v->Load(entry, mask); }); + + if (mask) + CallExec(slot, /*idx=*/0u, ColumnTypes_t{}, TypeInd_t{}); } void TriggerChildrenCount() final { fPrevNode.IncrChildrenCount(); } diff --git a/tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx b/tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx index 510b8b5ec273b..3e30013461e7f 100644 --- a/tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx @@ -166,27 +166,27 @@ public: fHelper.InitTask(r, slot); } - void *GetValue(unsigned int slot, std::size_t readerIdx, Long64_t entry) + void *GetValue(unsigned int slot, std::size_t readerIdx, std::size_t idx) { assert(slot < fValues.size()); assert(readerIdx < fValues[slot].size()); - if (auto *val = fValues[slot][readerIdx]->template TryGet(entry)) + if (auto *val = fValues[slot][readerIdx]->template TryGet(idx)) return val; throw std::out_of_range{"RDataFrame: Action (" + fHelper.GetActionName() + ") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " + - std::to_string(entry) + + std::to_string(idx) + ". You can use the DefaultValueFor operation to provide a default value, or " "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."}; } - void CallExec(unsigned int slot, Long64_t entry) + void CallExec(unsigned int slot, std::size_t idx) { std::vector untypedValues; auto nReaders = fValues[slot].size(); untypedValues.reserve(nReaders); for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++) - untypedValues.push_back(GetValue(slot, readerIdx, entry)); + untypedValues.push_back(GetValue(slot, readerIdx, idx)); fHelper.Exec(slot, untypedValues); } @@ -207,14 +207,17 @@ public: std::vector untypedValues; auto nReaders = fValues[slot].size(); untypedValues.reserve(nReaders); + std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry](auto *v) { v->Load(entry, true); }); for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++) - untypedValues.push_back(GetValue(slot, readerIdx, entry)); + untypedValues.push_back(GetValue(slot, readerIdx, /*idx=*/0u)); fHelper.Exec(slot, untypedValues, filterPassed); } } else { - if (fPrevNodes.front()->CheckFilters(slot, entry)) - CallExec(slot, entry); + const auto mask = fPrevNodes.front()->CheckFilters(slot, entry); + std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry, mask](auto *v) { v->Load(entry, mask); }); + if (mask) + CallExec(slot, /*idx=*/0u); } } diff --git a/tree/dataframe/inc/ROOT/RDF/RColumnReaderBase.hxx b/tree/dataframe/inc/ROOT/RDF/RColumnReaderBase.hxx index 29aeee4f306e7..a129a9f261e87 100644 --- a/tree/dataframe/inc/ROOT/RDF/RColumnReaderBase.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RColumnReaderBase.hxx @@ -26,9 +26,23 @@ This pure virtual class provides a common base class for the different column re RDSColumnReader. **/ class R__CLING_PTRCHECK(off) RColumnReaderBase { + Long64_t fLoadedEntry = -1; + public: virtual ~RColumnReaderBase() = default; + /// Load the column value for the given entry. + /// \param entry The entry number to load. + /// \param mask The entry mask. Values will be loaded only for entries for which the mask equals true. + void Load(Long64_t entry, bool mask) + { + // For now, as `mask` is just a single boolean, as an optimization we can return early here if `mask == false`. + if (mask) { + fLoadedEntry = entry; + this->LoadImpl(entry, mask); + } + } + /// Return the column value for the given entry. /// \tparam T The column type /// \param entry The entry number @@ -36,13 +50,14 @@ public: /// The caller is responsible for checking that the returned value actually /// exists. template - T *TryGet(Long64_t entry) + T *TryGet(std::size_t idx) { - return static_cast(GetImpl(entry)); + return static_cast(GetImpl(idx)); } private: - virtual void *GetImpl(Long64_t entry) = 0; + virtual void *GetImpl(std::size_t idx) = 0; + virtual void LoadImpl(Long64_t /*entry*/, bool /*mask*/) = 0; }; } // namespace RDF diff --git a/tree/dataframe/inc/ROOT/RDF/RDSColumnReader.hxx b/tree/dataframe/inc/ROOT/RDF/RDSColumnReader.hxx index 2e317573102b5..571154d80497b 100644 --- a/tree/dataframe/inc/ROOT/RDF/RDSColumnReader.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RDSColumnReader.hxx @@ -23,7 +23,8 @@ template class R__CLING_PTRCHECK(off) RDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase { T **fDSValuePtr = nullptr; - void *GetImpl(Long64_t) final { return *fDSValuePtr; } + void *GetImpl(std::size_t) final { return *fDSValuePtr; } + void LoadImpl(Long64_t, bool) final {} public: RDSColumnReader(void *DSValuePtr) : fDSValuePtr(static_cast(DSValuePtr)) {} diff --git a/tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx b/tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx index a03af81745a1b..864a38583ecce 100644 --- a/tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RDefaultValueFor.hxx @@ -57,9 +57,9 @@ class R__CLING_PTRCHECK(off) RDefaultValueFor final : public RDefineBase { /// The map key is the full variation name, e.g. "pt:up". std::unordered_map> fVariedDefines; - T &GetValueOrDefault(unsigned int slot, Long64_t entry) + T &GetValueOrDefault(unsigned int slot, std::size_t idx) { - if (auto *value = fValues[slot]->template TryGet(entry)) + if (auto *value = fValues[slot]->template TryGet(idx)) return *value; else return fDefaultValue; @@ -104,12 +104,15 @@ public: } /// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry - void Update(unsigned int slot, Long64_t entry) final + void Update(unsigned int slot, Long64_t entry, bool mask) final { if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep()]) { // evaluate this define expression, cache the result - fLastResults[slot * RDFInternal::CacheLineStep()] = GetValueOrDefault(slot, entry); - fLastCheckedEntry[slot * RDFInternal::CacheLineStep()] = entry; + fValues[slot]->Load(entry, mask); + if (mask) { + fLastResults[slot * RDFInternal::CacheLineStep()] = GetValueOrDefault(slot, /*idx=*/0u); + fLastCheckedEntry[slot * RDFInternal::CacheLineStep()] = entry; + } } } diff --git a/tree/dataframe/inc/ROOT/RDF/RDefine.hxx b/tree/dataframe/inc/ROOT/RDF/RDefine.hxx index 8c9a57351079a..1350e3989505c 100644 --- a/tree/dataframe/inc/ROOT/RDF/RDefine.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RDefine.hxx @@ -71,39 +71,43 @@ class R__CLING_PTRCHECK(off) RDefine final : public RDefineBase { std::unordered_map> fVariedDefines; template - auto GetValueChecked(unsigned int slot, std::size_t readerIdx, Long64_t entry) -> ColType & + auto GetValueChecked(unsigned int slot, std::size_t readerIdx, std::size_t idx) -> ColType & { - if (auto *val = fValues[slot][readerIdx]->template TryGet(entry)) + if (auto *val = fValues[slot][readerIdx]->template TryGet(idx)) return *val; throw std::out_of_range{"RDataFrame: Define could not retrieve value for column '" + fColumnNames[readerIdx] + - "' for entry " + std::to_string(entry) + + "' for entry " + std::to_string(idx) + ". You can use the DefaultValueFor operation to provide a default value, or " "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."}; } template - void UpdateHelper(unsigned int slot, Long64_t entry, TypeList, std::index_sequence, NoneTag) + void UpdateHelper(unsigned int slot, std::size_t idx, Long64_t /*entry*/, TypeList, + std::index_sequence, NoneTag) { fLastResults[slot * RDFInternal::CacheLineStep()] = - fExpression(GetValueChecked(slot, S, entry)...); - (void)entry; // avoid unused parameter warning (gcc 12.1) + fExpression(GetValueChecked(slot, S, idx)...); + (void)idx; // avoid unused parameter warning (gcc 12.1) } template - void UpdateHelper(unsigned int slot, Long64_t entry, TypeList, std::index_sequence, SlotTag) + void UpdateHelper(unsigned int slot, std::size_t idx, Long64_t /*entry*/, TypeList, + std::index_sequence, SlotTag) { fLastResults[slot * RDFInternal::CacheLineStep()] = - fExpression(slot, GetValueChecked(slot, S, entry)...); - (void)entry; // avoid unused parameter warning (gcc 12.1) + fExpression(slot, GetValueChecked(slot, S, idx)...); + (void)idx; // avoid unused parameter warning (gcc 12.1) } template - void - UpdateHelper(unsigned int slot, Long64_t entry, TypeList, std::index_sequence, SlotAndEntryTag) + void UpdateHelper(unsigned int slot, std::size_t idx, Long64_t batchFirstEntry, TypeList, + std::index_sequence, SlotAndEntryTag) { fLastResults[slot * RDFInternal::CacheLineStep()] = - fExpression(slot, entry, GetValueChecked(slot, S, entry)...); + fExpression(slot, batchFirstEntry + idx, GetValueChecked(slot, S, idx)...); + (void)idx; // avoid unused parameter warning (gcc 12.1) + (void)batchFirstEntry; // avoid unused parameter warning (gcc 12.1) } public: @@ -134,12 +138,14 @@ public: } /// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry - void Update(unsigned int slot, Long64_t entry) final + void Update(unsigned int slot, Long64_t entry, bool mask) final { if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep()]) { - // evaluate this define expression, cache the result - UpdateHelper(slot, entry, ColumnTypes_t{}, TypeInd_t{}, ExtraArgsTag{}); - fLastCheckedEntry[slot * RDFInternal::CacheLineStep()] = entry; + std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry, mask](auto *v) { v->Load(entry, mask); }); + if (mask) { + UpdateHelper(slot, /*idx=*/0u, entry, ColumnTypes_t{}, TypeInd_t{}, ExtraArgsTag{}); + fLastCheckedEntry[slot * RDFInternal::CacheLineStep()] = entry; + } } } diff --git a/tree/dataframe/inc/ROOT/RDF/RDefineBase.hxx b/tree/dataframe/inc/ROOT/RDF/RDefineBase.hxx index f20b77e0b3286..058db69fc1d2a 100644 --- a/tree/dataframe/inc/ROOT/RDF/RDefineBase.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RDefineBase.hxx @@ -63,7 +63,7 @@ public: std::string GetName() const; std::string GetTypeName() const; /// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry - virtual void Update(unsigned int slot, Long64_t entry) = 0; + virtual void Update(unsigned int slot, Long64_t entry, bool mask) = 0; /// Update function to be called once per sample, used if the derived type is a RDefinePerSample virtual void Update(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &/*id*/) {} /// Clean-up operations to be performed at the end of a task. diff --git a/tree/dataframe/inc/ROOT/RDF/RDefinePerSample.hxx b/tree/dataframe/inc/ROOT/RDF/RDefinePerSample.hxx index 0302113345774..6568f3b3d1ef2 100644 --- a/tree/dataframe/inc/ROOT/RDF/RDefinePerSample.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RDefinePerSample.hxx @@ -59,7 +59,7 @@ public: return static_cast(&fLastResults[slot * RDFInternal::CacheLineStep()]); } - void Update(unsigned int, Long64_t) final + void Update(unsigned int, Long64_t, bool) final { // no-op } diff --git a/tree/dataframe/inc/ROOT/RDF/RDefineReader.hxx b/tree/dataframe/inc/ROOT/RDF/RDefineReader.hxx index 754d4cdfcb5fe..6f708dabaabbc 100644 --- a/tree/dataframe/inc/ROOT/RDF/RDefineReader.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RDefineReader.hxx @@ -42,11 +42,9 @@ class R__CLING_PTRCHECK(off) RDefineReader final : public ROOT::Detail::RDF::RCo /// The slot this value belongs to. unsigned int fSlot = std::numeric_limits::max(); - void *GetImpl(Long64_t entry) final - { - fDefine.Update(fSlot, entry); - return fValuePtr; - } + void *GetImpl(std::size_t /*idx*/) final { return fValuePtr; } + + void LoadImpl(Long64_t entry, bool mask) final { fDefine.Update(fSlot, entry, mask); } public: RDefineReader(unsigned int slot, RDFDetail::RDefineBase &define) diff --git a/tree/dataframe/inc/ROOT/RDF/RFilter.hxx b/tree/dataframe/inc/ROOT/RDF/RFilter.hxx index a884fa6b631ba..c572406604300 100644 --- a/tree/dataframe/inc/ROOT/RDF/RFilter.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RFilter.hxx @@ -95,43 +95,47 @@ public: bool CheckFilters(unsigned int slot, Long64_t entry) final { - if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep()]) { - if (!fPrevNode.CheckFilters(slot, entry)) { - // a filter upstream returned false, cache the result - fLastResult[slot * RDFInternal::CacheLineStep()] = false; - } else { - // evaluate this filter, cache the result - auto passed = CheckFilterHelper(slot, entry, ColumnTypes_t{}, TypeInd_t{}); - passed ? ++fAccepted[slot * RDFInternal::CacheLineStep()] + auto &newMask = fLastResult[slot * RDFInternal::CacheLineStep()]; + auto &lastEntry = fLastCheckedEntry[slot * RDFInternal::CacheLineStep()]; + + if (entry != lastEntry) { + newMask = fPrevNode.CheckFilters(slot, entry); + + // evaluate this filter, cache the result + std::for_each(fValues[slot].begin(), fValues[slot].end(), + [entry, newMask](auto *v) { v->Load(entry, newMask); }); + CheckFilterHelper(slot, /*idx=*/0u, newMask, ColumnTypes_t{}, TypeInd_t{}); + + lastEntry = entry; + } + + return newMask; + } + + template + void CheckFilterHelper(unsigned int slot, std::size_t idx, int &entryMask, TypeList, + std::index_sequence) + { + if (entryMask) { + entryMask = fFilter(GetValueChecked(slot, S, idx)...); + entryMask ? ++fAccepted[slot * RDFInternal::CacheLineStep()] : ++fRejected[slot * RDFInternal::CacheLineStep()]; - fLastResult[slot * RDFInternal::CacheLineStep()] = passed; - } - fLastCheckedEntry[slot * RDFInternal::CacheLineStep()] = entry; } - return fLastResult[slot * RDFInternal::CacheLineStep()]; + (void)idx; // avoid unused parameter warning (gcc 12.1) } template - auto GetValueChecked(unsigned int slot, std::size_t readerIdx, Long64_t entry) -> ColType & + auto GetValueChecked(unsigned int slot, std::size_t readerIdx, std::size_t idx) -> ColType & { - if (auto *val = fValues[slot][readerIdx]->template TryGet(entry)) + if (auto *val = fValues[slot][readerIdx]->template TryGet(idx)) return *val; throw std::out_of_range{"RDataFrame: Filter could not retrieve value for column '" + fColumnNames[readerIdx] + - "' for entry " + std::to_string(entry) + + "' for entry " + std::to_string(idx) + ". You can use the DefaultValueFor operation to provide a default value, or " "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."}; } - template - bool CheckFilterHelper(unsigned int slot, Long64_t entry, TypeList, std::index_sequence) - { - return fFilter(GetValueChecked(slot, S, entry)...); - // avoid unused parameter warnings (gcc 12.1) - (void)slot; - (void)entry; - } - void InitSlot(TTreeReader *r, unsigned int slot) final { RDFInternal::RColumnReadersInfo info{fColumnNames, fColRegister, fIsDefine.data(), *fLoopManager}; diff --git a/tree/dataframe/inc/ROOT/RDF/RFilterWithMissingValues.hxx b/tree/dataframe/inc/ROOT/RDF/RFilterWithMissingValues.hxx index 614513a5add6c..d265439206864 100644 --- a/tree/dataframe/inc/ROOT/RDF/RFilterWithMissingValues.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RFilterWithMissingValues.hxx @@ -102,24 +102,30 @@ public: constexpr static auto cacheLineStepint = RDFInternal::CacheLineStep(); constexpr static auto cacheLineStepULong64_t = RDFInternal::CacheLineStep(); - if (entry != fLastCheckedEntry[slot * cacheLineStepLong64_t]) { - if (!fPrevNodePtr->CheckFilters(slot, entry)) { - // a filter upstream returned false, cache the result - fLastResult[slot * cacheLineStepint] = false; - } else { - // evaluate this filter, cache the result - const bool valueIsMissing = fValues[slot]->template TryGet(entry) == nullptr; - if (fDiscardEntryWithMissingValue) { - valueIsMissing ? ++fRejected[slot * cacheLineStepULong64_t] : ++fAccepted[slot * cacheLineStepULong64_t]; - fLastResult[slot * cacheLineStepint] = !valueIsMissing; - } else { - valueIsMissing ? ++fAccepted[slot * cacheLineStepULong64_t] : ++fRejected[slot * cacheLineStepULong64_t]; - fLastResult[slot * cacheLineStepint] = valueIsMissing; - } - } - fLastCheckedEntry[slot * cacheLineStepLong64_t] = entry; + auto &newMask = fLastResult[slot * cacheLineStepint]; + auto &lastEntry = fLastCheckedEntry[slot * cacheLineStepLong64_t]; + + if (entry == lastEntry) + return newMask; + + newMask = fPrevNodePtr->CheckFilters(slot, entry); + if (!newMask) + return false; + + fValues[slot]->Load(entry, newMask); + + const bool valueIsMissing = fValues[slot]->template TryGet(/*idx=*/0u) == nullptr; + if (fDiscardEntryWithMissingValue) { + valueIsMissing ? ++fRejected[slot * cacheLineStepULong64_t] : ++fAccepted[slot * cacheLineStepULong64_t]; + newMask = !valueIsMissing; + } else { + valueIsMissing ? ++fAccepted[slot * cacheLineStepULong64_t] : ++fRejected[slot * cacheLineStepULong64_t]; + newMask = valueIsMissing; } - return fLastResult[slot * cacheLineStepint]; + + lastEntry = entry; + + return newMask; } void InitSlot(TTreeReader *r, unsigned int slot) final diff --git a/tree/dataframe/inc/ROOT/RDF/RJittedDefine.hxx b/tree/dataframe/inc/ROOT/RDF/RJittedDefine.hxx index 2bc79170de8f3..20592787fd2cc 100644 --- a/tree/dataframe/inc/ROOT/RDF/RJittedDefine.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RJittedDefine.hxx @@ -58,7 +58,7 @@ public: void InitSlot(TTreeReader *r, unsigned int slot) final; void *GetValuePtr(unsigned int slot) final; const std::type_info &GetTypeId() const final; - void Update(unsigned int slot, Long64_t entry) final; + void Update(unsigned int slot, Long64_t entry, bool mask) final; void Update(unsigned int slot, const ROOT::RDF::RSampleInfo &id) final; void FinalizeSlot(unsigned int slot) final; void MakeVariations(const std::vector &variations) final; diff --git a/tree/dataframe/inc/ROOT/RDF/RJittedVariation.hxx b/tree/dataframe/inc/ROOT/RDF/RJittedVariation.hxx index 0abd61df5109d..000aab52eb1a5 100644 --- a/tree/dataframe/inc/ROOT/RDF/RJittedVariation.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RJittedVariation.hxx @@ -43,7 +43,7 @@ public: void InitSlot(TTreeReader *r, unsigned int slot) final; void *GetValuePtr(unsigned int slot, const std::string &column, const std::string &variation) final; const std::type_info &GetTypeId() const final; - void Update(unsigned int slot, Long64_t entry) final; + void Update(unsigned int slot, Long64_t entry, bool mask) final; void FinalizeSlot(unsigned int slot) final; }; diff --git a/tree/dataframe/inc/ROOT/RDF/RLazyDSImpl.hxx b/tree/dataframe/inc/ROOT/RDF/RLazyDSImpl.hxx index 9f31d9d4a831e..3fe651520abd4 100644 --- a/tree/dataframe/inc/ROOT/RDF/RLazyDSImpl.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RLazyDSImpl.hxx @@ -27,7 +27,8 @@ namespace ROOT::Internal::RDF { class R__CLING_PTRCHECK(off) RLazyDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase { ROOT::Internal::RDF::TPointerHolder *fPtr; - void *GetImpl(Long64_t) final { return fPtr->GetPointer(); } + void *GetImpl(std::size_t) final { return fPtr->GetPointer(); } + void LoadImpl(Long64_t, bool) final {} public: RLazyDSColumnReader(ROOT::Internal::RDF::TPointerHolder *ptr) : fPtr(ptr) {} diff --git a/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx b/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx index b2dc37c8e85a1..57ee799b1c83b 100644 --- a/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RTreeColumnReader.hxx @@ -38,7 +38,8 @@ namespace RDF { class R__CLING_PTRCHECK(off) RTreeOpaqueColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase { std::unique_ptr fTreeValue; - void *GetImpl(Long64_t) override; + void *GetImpl(std::size_t) override; + void LoadImpl(Long64_t, bool) override; public: /// Construct the RTreeColumnReader. Actual initialization is performed lazily by the Init method. @@ -57,7 +58,8 @@ public: class R__CLING_PTRCHECK(off) RTreeUntypedValueColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase { std::unique_ptr fTreeValue; - void *GetImpl(Long64_t) override; + void *GetImpl(std::size_t) override; + void LoadImpl(Long64_t, bool) override; public: RTreeUntypedValueColumnReader(TTreeReader &r, std::string_view colName, std::string_view typeName); @@ -111,11 +113,14 @@ private: /// Whether we already printed a warning about performing a copy of the TTreeReaderArray contents bool fCopyWarningPrinted = false; - void *GetImpl(Long64_t entry) override; + void *fValuePtr{nullptr}; - void *ReadStdArray(Long64_t entry); - void *ReadStdVector(Long64_t entry); - void *ReadRVec(Long64_t entry); + void *GetImpl(std::size_t) override; + void LoadImpl(Long64_t, bool) override; + + void *LoadStdArray(Long64_t entry); + void *LoadStdVector(Long64_t entry); + void *LoadRVec(Long64_t entry); }; class R__CLING_PTRCHECK(off) RMaskedColumnReader : public ROOT::Detail::RDF::RColumnReaderBase { @@ -123,7 +128,9 @@ class R__CLING_PTRCHECK(off) RMaskedColumnReader : public ROOT::Detail::RDF::RCo std::unique_ptr> fTreeValueMask; unsigned int fMaskIndex = 0; - void *GetImpl(Long64_t) override; + void *fValuePtr{nullptr}; + void *GetImpl(std::size_t) override; + void LoadImpl(Long64_t, bool) override; public: RMaskedColumnReader(TTreeReader &r, std::unique_ptr valueReader, diff --git a/tree/dataframe/inc/ROOT/RDF/RVariation.hxx b/tree/dataframe/inc/ROOT/RDF/RVariation.hxx index 0a3bc702d7e9e..ac7091f12aa7c 100644 --- a/tree/dataframe/inc/ROOT/RDF/RVariation.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RVariation.hxx @@ -160,23 +160,23 @@ class R__CLING_PTRCHECK(off) RVariation final : public RVariationBase { std::vector> fValues; template - auto GetValueChecked(unsigned int slot, std::size_t readerIdx, Long64_t entry) -> ColType & + auto GetValueChecked(unsigned int slot, std::size_t readerIdx, std::size_t idx) -> ColType & { - if (auto *val = fValues[slot][readerIdx]->template TryGet(entry)) + if (auto *val = fValues[slot][readerIdx]->template TryGet(idx)) return *val; throw std::out_of_range{"RDataFrame: Could not retrieve value for variation '" + fColNames[readerIdx] + - "' for entry " + std::to_string(entry) + + "' for entry " + std::to_string(idx) + ". You can use the DefaultValueFor operation to provide a default value, or " "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."}; } template - void UpdateHelper(unsigned int slot, Long64_t entry, TypeList, std::index_sequence) + void UpdateHelper(unsigned int slot, std::size_t idx, TypeList, std::index_sequence) { // fExpression must return an RVec - auto &&results = fExpression(GetValueChecked(slot, S, entry)...); - (void)entry; // avoid unused parameter warnings (gcc 12.1) + auto &&results = fExpression(GetValueChecked(slot, S, idx)...); + (void)idx; // avoid unused parameter warnings (gcc 12.1) if (!ResultsSizeEq(results, fVariationNames.size(), fColNames.size(), std::integral_constant{})) { @@ -230,12 +230,15 @@ public: } /// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry - void Update(unsigned int slot, Long64_t entry) final + void Update(unsigned int slot, Long64_t entry, bool mask) final { if (entry != fLastCheckedEntry[slot * CacheLineStep()]) { // evaluate this filter, cache the result - UpdateHelper(slot, entry, ColumnTypes_t{}, TypeInd_t{}); - fLastCheckedEntry[slot * CacheLineStep()] = entry; + std::for_each(fValues[slot].begin(), fValues[slot].end(), [entry, mask](auto *v) { v->Load(entry, mask); }); + if (mask) { + UpdateHelper(slot, /*idx=*/0u, ColumnTypes_t{}, TypeInd_t{}); + fLastCheckedEntry[slot * CacheLineStep()] = entry; + } } } diff --git a/tree/dataframe/inc/ROOT/RDF/RVariationBase.hxx b/tree/dataframe/inc/ROOT/RDF/RVariationBase.hxx index 34b54a9b0a245..46a9e4ce29e64 100644 --- a/tree/dataframe/inc/ROOT/RDF/RVariationBase.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RVariationBase.hxx @@ -70,7 +70,7 @@ public: const std::string &GetVariationName() const; std::string GetTypeName() const; /// Update the value at the address returned by GetValuePtr with the content corresponding to the given entry - virtual void Update(unsigned int slot, Long64_t entry) = 0; + virtual void Update(unsigned int slot, Long64_t entry, bool mask) = 0; /// Clean-up operations to be performed at the end of a task. virtual void FinalizeSlot(unsigned int slot) = 0; }; diff --git a/tree/dataframe/inc/ROOT/RDF/RVariationReader.hxx b/tree/dataframe/inc/ROOT/RDF/RVariationReader.hxx index a9e6dfacaafd6..a9ca2b9fcfdaf 100644 --- a/tree/dataframe/inc/ROOT/RDF/RVariationReader.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RVariationReader.hxx @@ -32,11 +32,9 @@ class R__CLING_PTRCHECK(off) RVariationReader final : public ROOT::Detail::RDF:: /// The slot this value belongs to. unsigned int fSlot = std::numeric_limits::max(); - void *GetImpl(Long64_t entry) final - { - fVariation->Update(fSlot, entry); - return fValuePtr; - } + void *GetImpl(std::size_t /*idx*/) final { return fValuePtr; } + + void LoadImpl(Long64_t entry, bool mask) final { fVariation->Update(fSlot, entry, mask); } public: RVariationReader(unsigned int slot, const std::string &colName, const std::string &variationName, diff --git a/tree/dataframe/inc/ROOT/RDF/RVariedAction.hxx b/tree/dataframe/inc/ROOT/RDF/RVariedAction.hxx index a14ec6df9ac3b..434f2558c78d2 100644 --- a/tree/dataframe/inc/ROOT/RDF/RVariedAction.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RVariedAction.hxx @@ -141,31 +141,35 @@ public: } template - auto GetValueChecked(unsigned int slot, unsigned int varIdx, std::size_t readerIdx, Long64_t entry) -> ColType & + auto GetValueChecked(unsigned int slot, unsigned int varIdx, std::size_t readerIdx, std::size_t idx) -> ColType & { - if (auto *val = fInputValues[slot][varIdx][readerIdx]->template TryGet(entry)) + if (auto *val = fInputValues[slot][varIdx][readerIdx]->template TryGet(idx)) return *val; throw std::out_of_range{"RDataFrame: Varied action (" + fHelpers[0].GetActionName() + ") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " + - std::to_string(entry) + + std::to_string(idx) + ". You can use the DefaultValueFor operation to provide a default value, or " "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."}; } template - void CallExec(unsigned int slot, unsigned int varIdx, Long64_t entry, TypeList, + void CallExec(unsigned int slot, unsigned int varIdx, std::size_t idx, TypeList, std::index_sequence) { - fHelpers[varIdx].Exec(slot, GetValueChecked(slot, varIdx, ReaderIdxs, entry)...); - (void)entry; + fHelpers[varIdx].Exec(slot, GetValueChecked(slot, varIdx, ReaderIdxs, idx)...); + (void)idx; } void Run(unsigned int slot, Long64_t entry) final { for (auto varIdx = 0u; varIdx < GetVariations().size(); ++varIdx) { - if (fPrevNodes[varIdx]->CheckFilters(slot, entry)) - CallExec(slot, varIdx, entry, ColumnTypes_t{}, TypeInd_t{}); + const auto mask = fPrevNodes[varIdx]->CheckFilters(slot, entry); + std::for_each(fInputValues[slot][varIdx].begin(), fInputValues[slot][varIdx].end(), + [entry, mask](auto *v) { v->Load(entry, mask); }); + + if (mask) + CallExec(slot, varIdx, /*idx=*/0u, ColumnTypes_t{}, TypeInd_t{}); } } diff --git a/tree/dataframe/inc/ROOT/RSqliteDS.hxx b/tree/dataframe/inc/ROOT/RSqliteDS.hxx index 5f46a5139dbda..a7a903eee15d0 100644 --- a/tree/dataframe/inc/ROOT/RSqliteDS.hxx +++ b/tree/dataframe/inc/ROOT/RSqliteDS.hxx @@ -22,7 +22,8 @@ namespace ROOT::Internal::RDF { class R__CLING_PTRCHECK(off) RSqliteDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase { void *fValuePtr; - void *GetImpl(Long64_t) final { return fValuePtr; } + void *GetImpl(std::size_t) final { return fValuePtr; } + void LoadImpl(Long64_t, bool) final {} public: RSqliteDSColumnReader(void *valuePtr) : fValuePtr(valuePtr) {} diff --git a/tree/dataframe/inc/ROOT/RTrivialDS.hxx b/tree/dataframe/inc/ROOT/RTrivialDS.hxx index d7b7de89da425..97f40281b0ead 100644 --- a/tree/dataframe/inc/ROOT/RTrivialDS.hxx +++ b/tree/dataframe/inc/ROOT/RTrivialDS.hxx @@ -17,7 +17,8 @@ namespace ROOT::Internal::RDF { class R__CLING_PTRCHECK(off) RTrivialDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase { ULong64_t *fValuePtr; - void *GetImpl(Long64_t) final { return fValuePtr; } + void *GetImpl(std::size_t) final { return fValuePtr; } + void LoadImpl(Long64_t, bool) final {} public: RTrivialDSColumnReader(ULong64_t *valuePtr) : fValuePtr(valuePtr) {} diff --git a/tree/dataframe/inc/ROOT/RVecDS.hxx b/tree/dataframe/inc/ROOT/RVecDS.hxx index 18f36f6942cd9..d3d27ba46a67c 100644 --- a/tree/dataframe/inc/ROOT/RVecDS.hxx +++ b/tree/dataframe/inc/ROOT/RVecDS.hxx @@ -30,7 +30,8 @@ namespace ROOT::Internal::RDF { class R__CLING_PTRCHECK(off) RVecDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase { TPointerHolder *fPtrHolder; - void *GetImpl(Long64_t) final { return fPtrHolder->GetPointer(); } + void *GetImpl(std::size_t) final { return fPtrHolder->GetPointer(); } + void LoadImpl(Long64_t, bool) final {} public: RVecDSColumnReader(TPointerHolder *ptrHolder) : fPtrHolder(ptrHolder) {} diff --git a/tree/dataframe/src/RJittedDefine.cxx b/tree/dataframe/src/RJittedDefine.cxx index 0ead1a3b0ce1d..b094c3080b30f 100644 --- a/tree/dataframe/src/RJittedDefine.cxx +++ b/tree/dataframe/src/RJittedDefine.cxx @@ -39,10 +39,10 @@ const std::type_info &RJittedDefine::GetTypeId() const "retrieved. This should never happen, please report this as a bug."); } -void RJittedDefine::Update(unsigned int slot, Long64_t entry) +void RJittedDefine::Update(unsigned int slot, Long64_t entry, bool mask) { assert(fConcreteDefine != nullptr); - fConcreteDefine->Update(slot, entry); + fConcreteDefine->Update(slot, entry, mask); } void RJittedDefine::Update(unsigned int slot, const ROOT::RDF::RSampleInfo &id) diff --git a/tree/dataframe/src/RJittedVariation.cxx b/tree/dataframe/src/RJittedVariation.cxx index 50f43c7db64ff..663a87fdfddaf 100644 --- a/tree/dataframe/src/RJittedVariation.cxx +++ b/tree/dataframe/src/RJittedVariation.cxx @@ -34,10 +34,10 @@ const std::type_info &RJittedVariation::GetTypeId() const return fConcreteVariation->GetTypeId(); } -void RJittedVariation::Update(unsigned int slot, Long64_t entry) +void RJittedVariation::Update(unsigned int slot, Long64_t entry, bool mask) { assert(fConcreteVariation != nullptr); - fConcreteVariation->Update(slot, entry); + fConcreteVariation->Update(slot, entry, mask); } void RJittedVariation::FinalizeSlot(unsigned int slot) diff --git a/tree/dataframe/src/RNTupleDS.cxx b/tree/dataframe/src/RNTupleDS.cxx index 82c72ee81102e..0e4ed6ef0e1df 100644 --- a/tree/dataframe/src/RNTupleDS.cxx +++ b/tree/dataframe/src/RNTupleDS.cxx @@ -270,13 +270,14 @@ class RNTupleColumnReader : public ROOT::Detail::RDF::RColumnReaderBase { fLastEntry = -1; } - void *GetImpl(Long64_t entry) final + void *GetImpl(std::size_t /*idx*/) final { return fValue->GetPtr().get(); } + + void LoadImpl(Long64_t entry, bool mask) final { - if (entry != fLastEntry) { + if (entry != fLastEntry && mask) { fValue->Read(entry - fEntryOffset); fLastEntry = entry; } - return fValue->GetPtr().get(); } }; } // namespace ROOT::Internal::RDF diff --git a/tree/dataframe/src/RTreeColumnReader.cxx b/tree/dataframe/src/RTreeColumnReader.cxx index 0786bb1a7c6af..37182a7c80cb9 100644 --- a/tree/dataframe/src/RTreeColumnReader.cxx +++ b/tree/dataframe/src/RTreeColumnReader.cxx @@ -5,11 +5,13 @@ #include -void *ROOT::Internal::RDF::RTreeOpaqueColumnReader::GetImpl(Long64_t) +void *ROOT::Internal::RDF::RTreeOpaqueColumnReader::GetImpl(std::size_t) { return fTreeValue->GetAddress(); } +void ROOT::Internal::RDF::RTreeOpaqueColumnReader::LoadImpl(Long64_t, bool) {} + ROOT::Internal::RDF::RTreeOpaqueColumnReader::RTreeOpaqueColumnReader(TTreeReader &r, std::string_view colName) : fTreeValue(std::make_unique(r, colName.data())) { @@ -17,11 +19,13 @@ ROOT::Internal::RDF::RTreeOpaqueColumnReader::RTreeOpaqueColumnReader(TTreeReade ROOT::Internal::RDF::RTreeOpaqueColumnReader::~RTreeOpaqueColumnReader() = default; -void *ROOT::Internal::RDF::RTreeUntypedValueColumnReader::GetImpl(Long64_t) +void *ROOT::Internal::RDF::RTreeUntypedValueColumnReader::GetImpl(std::size_t) { return fTreeValue->Get(); } +void ROOT::Internal::RDF::RTreeUntypedValueColumnReader::LoadImpl(Long64_t, bool) {} + ROOT::Internal::RDF::RTreeUntypedValueColumnReader::RTreeUntypedValueColumnReader(TTreeReader &r, std::string_view colName, std::string_view typeName) @@ -31,7 +35,7 @@ ROOT::Internal::RDF::RTreeUntypedValueColumnReader::RTreeUntypedValueColumnReade ROOT::Internal::RDF::RTreeUntypedValueColumnReader::~RTreeUntypedValueColumnReader() = default; -void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadStdArray(Long64_t entry) +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::LoadStdArray(Long64_t entry) { if (entry == fLastEntry) return fRVec.data(); // We return the RVec we already created @@ -61,7 +65,7 @@ void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadStdArray(Long64_t return fRVec.data(); } -void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadStdVector(Long64_t entry) +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::LoadStdVector(Long64_t entry) { if (entry == fLastEntry) return &fStdVector; // We return the std::vector we already created @@ -95,7 +99,7 @@ void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadStdVector(Long64_t return &fStdVector; } -void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadRVec(Long64_t entry) +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::LoadRVec(Long64_t entry) { if (entry == fLastEntry) return &fRVec; // We return the RVec we already created @@ -160,15 +164,21 @@ void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::ReadRVec(Long64_t entr return &fRVec; } -void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::GetImpl(Long64_t entry) +void ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::LoadImpl(Long64_t entry, bool mask) { - if (fCollectionType == ECollectionType::kStdArray) - return ReadStdArray(entry); - - if (fCollectionType == ECollectionType::kStdVector) - return ReadStdVector(entry); + if (entry != fLastEntry && mask) { + if (fCollectionType == ECollectionType::kStdArray) + fValuePtr = LoadStdArray(entry); + else if (fCollectionType == ECollectionType::kStdVector) + fValuePtr = LoadStdVector(entry); + else + fValuePtr = LoadRVec(entry); + } +} - return ReadRVec(entry); +void *ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::GetImpl(std::size_t /*idx*/) +{ + return fValuePtr; } ROOT::Internal::RDF::RTreeUntypedArrayColumnReader::RTreeUntypedArrayColumnReader(TTreeReader &r, @@ -193,11 +203,24 @@ ROOT::Internal::RDF::RMaskedColumnReader::RMaskedColumnReader( ROOT::Internal::RDF::RMaskedColumnReader::~RMaskedColumnReader() = default; -void *ROOT::Internal::RDF::RMaskedColumnReader::GetImpl(Long64_t event) +void *ROOT::Internal::RDF::RMaskedColumnReader::GetImpl(std::size_t) { - const std::bitset<64> mask{*fTreeValueMask->Get()}; - if (mask.test(fMaskIndex) == false) - return nullptr; + return fValuePtr; +} + +void ROOT::Internal::RDF::RMaskedColumnReader::LoadImpl(Long64_t entry, bool mask) +{ + if (!mask) { + fValuePtr = nullptr; + return; + } + + const std::bitset<64> treeMask{*fTreeValueMask->Get()}; + if (treeMask.test(fMaskIndex) == false) { + fValuePtr = nullptr; + return; + } - return fValueReader->TryGet(event); + fValueReader->Load(entry, mask); + fValuePtr = fValueReader->TryGet(/*idx=*/0u); } diff --git a/tree/dataframe/test/RArraysDS.hxx b/tree/dataframe/test/RArraysDS.hxx index 5e0e654c6f43d..c21c09a673509 100644 --- a/tree/dataframe/test/RArraysDS.hxx +++ b/tree/dataframe/test/RArraysDS.hxx @@ -12,7 +12,8 @@ class R__CLING_PTRCHECK(off) RArraysDSVarReader final : public ROOT::Detail::RDF::RColumnReaderBase { std::vector *fPtr = nullptr; - void *GetImpl(Long64_t) final { return fPtr; } + void *GetImpl(std::size_t) final { return fPtr; } + void LoadImpl(Long64_t, bool) final {} public: RArraysDSVarReader(std::vector &v) : fPtr(&v) {} @@ -21,11 +22,12 @@ public: class R__CLING_PTRCHECK(off) RArraysDSVarSizeReader final : public ROOT::Detail::RDF::RColumnReaderBase { std::vector *fPtr = nullptr; std::size_t fSize = 0; - void *GetImpl(Long64_t) final + void *GetImpl(std::size_t) final { fSize = fPtr->size(); return &fSize; } + void LoadImpl(Long64_t, bool) final {} public: RArraysDSVarSizeReader(std::vector &v) : fPtr(&v) {}