Skip to content

Commit b0443cd

Browse files
committed
CallbackRegistrar: add correlation hook with master record
Adds `register_correlation_hook_with_master_record` hook - identical to existing `register_correlation_hook`, but with an additional master record parameter sent to the callback. Some modules need access to the master record and since it's already loaded during correlation process (snapshot creation), providing it doesn't incur any additional overhead. One such example is the new TS last activity module in the NERD2 system. For backwards compatibility, the new hook type was created instead of modifying the `register_correlation_hook`. Internally, they both share the same implementation.
1 parent 8462863 commit b0443cd

4 files changed

Lines changed: 74 additions & 14 deletions

File tree

dp3/common/callback_registrar.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dp3.common.state import SharedFlag
1414
from dp3.common.task import DataPointTask
1515
from dp3.common.types import ParsedTimedelta
16+
from dp3.common.utils import get_func_name
1617
from dp3.core.updater import Updater
1718
from dp3.snapshots.snapshooter import SnapShooter
1819
from dp3.task_processing.task_executor import TaskExecutor
@@ -365,6 +366,43 @@ def register_correlation_hook(
365366
may_change: each item should specify an attribute that `hook` may change.
366367
specification format is identical to `depends_on`.
367368
369+
Raises:
370+
ValueError: On failure of specification validation.
371+
"""
372+
# Ignore master record for this variant of the hook
373+
hook_name = get_func_name(hook)
374+
self._snap_shooter.register_correlation_hook(
375+
lambda e, s, _m: hook(e, s), entity_type, depends_on, may_change, hook_name=hook_name
376+
)
377+
378+
def register_correlation_hook_with_master_record(
379+
self,
380+
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
381+
entity_type: str,
382+
depends_on: list[list[str]],
383+
may_change: list[list[str]],
384+
):
385+
"""
386+
Registers passed hook to be called during snapshot creation.
387+
388+
Identical to `register_correlation_hook`, but the hook also receives the master record.
389+
390+
Binds hook to specified entity_type (though same hook can be bound multiple times).
391+
392+
`entity_type` and attribute specifications are validated, `ValueError` is raised on failure.
393+
394+
Args:
395+
hook: `hook` callable should expect entity type as str;
396+
its current values, including linked entities, as dict;
397+
and its master record as dict.
398+
Can optionally return a list of DataPointTask objects to perform.
399+
entity_type: specifies entity type
400+
depends_on: each item should specify an attribute that is depended on
401+
in the form of a path from the specified entity_type to individual attributes
402+
(even on linked entities).
403+
may_change: each item should specify an attribute that `hook` may change.
404+
specification format is identical to `depends_on`.
405+
368406
Raises:
369407
ValueError: On failure of specification validation.
370408
"""

dp3/snapshots/snapshooter.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,33 +185,41 @@ def register_timeseries_hook(
185185

186186
def register_correlation_hook(
187187
self,
188-
hook: Callable[[str, dict], Union[None, list[DataPointTask]]],
188+
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
189189
entity_type: str,
190190
depends_on: list[list[str]],
191191
may_change: list[list[str]],
192+
hook_name: Optional[str] = None,
192193
):
193194
"""
194195
Registers passed hook to be called during snapshot creation.
195196
197+
Common implementation for hooks with and without master record.
198+
196199
Binds hook to specified entity_type (though same hook can be bound multiple times).
197200
198201
`entity_type` and attribute specifications are validated, `ValueError` is raised on failure.
199202
200203
Args:
201204
hook: `hook` callable should expect entity type as str
202-
and its current values, including linked entities, as dict
205+
and its current values, including linked entities, as dict;
206+
and its master record as dict.
203207
Can optionally return a list of DataPointTask objects to perform.
204208
entity_type: specifies entity type
205209
depends_on: each item should specify an attribute that is depended on
206210
in the form of a path from the specified entity_type to individual attributes
207211
(even on linked entities).
208212
may_change: each item should specify an attribute that `hook` may change.
209213
specification format is identical to `depends_on`.
214+
hook_name: Optional custom name for the hook, used for logging purposes. If not
215+
provided, the function name of `hook` will be used.
210216
211217
Raises:
212218
ValueError: On failure of specification validation.
213219
"""
214-
self._correlation_hooks.register(hook, entity_type, depends_on, may_change)
220+
self._correlation_hooks.register(
221+
hook, entity_type, depends_on, may_change, hook_name=hook_name
222+
)
215223

216224
def register_run_init_hook(self, hook: Callable[[], list[DataPointTask]]):
217225
"""
@@ -457,9 +465,12 @@ def make_linkless_snapshot(self, entity_type: str, master_record: dict, time: da
457465
self.run_timeseries_processing(entity_type, master_record)
458466
values = self.get_values_at_time(entity_type, master_record, time)
459467
self.add_mirrored_links(entity_type, values)
460-
entity_values = {(entity_type, master_record["_id"]): values}
468+
entity_id = master_record["_id"]
469+
entity_values = {(entity_type, entity_id): values}
461470

462-
tasks = self._correlation_hooks.run(entity_values)
471+
tasks = self._correlation_hooks.run(
472+
entity_values, {(entity_type, entity_id): master_record}
473+
)
463474
for task in tasks:
464475
self.task_queue_writer.put_task(task)
465476

@@ -499,6 +510,7 @@ def make_snapshot(self, task: Snapshot):
499510
The resulting snapshots are saved into DB.
500511
"""
501512
entity_values = {}
513+
entity_master_records = {}
502514
for entity_type, entity_id in task.entities:
503515
record = self.db.get_master_record(entity_type, entity_id) or {"_id": entity_id}
504516
if not self.config.keep_empty and len(record) == 1:
@@ -508,9 +520,10 @@ def make_snapshot(self, task: Snapshot):
508520
values = self.get_values_at_time(entity_type, record, task.time)
509521
self.add_mirrored_links(entity_type, values)
510522
entity_values[entity_type, entity_id] = values
523+
entity_master_records[entity_type, entity_id] = record
511524

512525
self.link_loaded_entities(entity_values)
513-
created_tasks = self._correlation_hooks.run(entity_values)
526+
created_tasks = self._correlation_hooks.run(entity_values, entity_master_records)
514527
for created_task in created_tasks:
515528
self.task_queue_writer.put_task(created_task)
516529

dp3/snapshots/snapshot_hooks.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ def __init__(self, log: logging.Logger, model_spec: ModelSpec, elog: EventGroupT
8484

8585
def register(
8686
self,
87-
hook: Callable[[str, dict], Union[None, list[DataPointTask]]],
87+
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
8888
entity_type: str,
8989
depends_on: list[list[str]],
9090
may_change: list[list[str]],
91+
hook_name: Union[None, str] = None,
9192
) -> str:
9293
"""
9394
Registers passed hook to be called during snapshot creation.
@@ -97,15 +98,18 @@ def register(
9798
If entity_type and attribute specifications are validated
9899
and ValueError is raised on failure.
99100
Args:
100-
hook: `hook` callable should expect entity type as str
101-
and its current values, including linked entities, as dict.
101+
hook: `hook` callable should expect entity type as str;
102+
its current values, including linked entities, as dict;
103+
and its master record as dict.
102104
Can optionally return a list of DataPointTask objects to perform.
103105
entity_type: specifies entity type
104106
depends_on: each item should specify an attribute that is depended on
105107
in the form of a path from the specified entity_type to individual attributes
106108
(even on linked entities).
107109
may_change: each item should specify an attribute that `hook` may change.
108110
specification format is identical to `depends_on`.
111+
hook_name: Optional custom name for the hook, used for logging purposes. If not
112+
provided, the function name of `hook` will be used.
109113
Returns:
110114
Generated hook id.
111115
"""
@@ -120,7 +124,8 @@ def register(
120124
may_change = self._get_attr_path_destinations(entity_type, may_change)
121125

122126
hook_args = f"({entity_type}, [{','.join(depends_on)}], [{','.join(may_change)}])"
123-
hook_id = f"{get_func_name(hook)}{hook_args}"
127+
hook_name = hook_name if hook_name is not None else get_func_name(hook)
128+
hook_id = f"{hook_name}{hook_args}"
124129
self._short_hook_ids[hook_id] = hook_args
125130
self._dependency_graph.add_hook_dependency(hook_id, depends_on, may_change)
126131

@@ -191,7 +196,7 @@ def _resolve_entities_in_path(self, base_entity: str, path: list[str]) -> list[t
191196
position = entity_attributes[position.relation_to]
192197
return resolved_path
193198

194-
def run(self, entities: dict) -> list[DataPointTask]:
199+
def run(self, entities: dict, entity_master_records: dict) -> list[DataPointTask]:
195200
"""Runs registered hooks."""
196201
entity_types = {etype for etype, _ in entities}
197202
hook_subset = [
@@ -200,18 +205,22 @@ def run(self, entities: dict) -> list[DataPointTask]:
200205
topological_order = self._dependency_graph.topological_order
201206
hook_subset.sort(key=lambda x: topological_order.index(x[0]))
202207
entities_by_etype = defaultdict(dict)
208+
entity_master_records_by_etype = defaultdict(dict)
203209
for (etype, eid), values in entities.items():
204210
entities_by_etype[etype][eid] = values
211+
for (etype, eid), mr in entity_master_records.items():
212+
entity_master_records_by_etype[etype][eid] = mr
205213

206214
created_tasks = []
207215

208216
with task_context(self.model_spec):
209217
for hook_id, hook, etype in hook_subset:
210218
short_id = hook_id if len(hook_id) < 160 else self._short_hook_ids[hook_id]
211219
for eid, entity_values in entities_by_etype[etype].items():
220+
entity_master_record = entity_master_records_by_etype[etype].get(eid, {})
212221
self.log.debug("Running hook %s on entity %s", short_id, eid)
213222
try:
214-
tasks = hook(etype, entity_values)
223+
tasks = hook(etype, entity_values, entity_master_record)
215224
if tasks is not None and tasks:
216225
created_tasks.extend(tasks)
217226
except Exception as e:

tests/test_common/test_snapshots.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dp3.snapshots.snapshot_hooks import SnapshotCorrelationHookContainer
1515

1616

17-
def modify_value(_: str, record: dict, attr: str, value):
17+
def modify_value(_: str, record: dict, _master_record: dict, attr: str, value):
1818
record[attr] = value
1919

2020

@@ -37,7 +37,7 @@ def test_basic_function(self):
3737
hook=dummy_hook_abc, entity_type="A", depends_on=[["data1"]], may_change=[["data2"]]
3838
)
3939
values = {}
40-
self.container.run({("A", "a1"): values})
40+
self.container.run({("A", "a1"): values}, {})
4141
self.assertEqual(values["data2"], "abc")
4242

4343
def test_circular_dependency_error(self):

0 commit comments

Comments
 (0)