Skip to content

Commit 33bb0d3

Browse files
committed
CallbackRegistrar: add correlation hook with master record
Adds `register_correlation_hook_with_master_record` hook - identical to existing `register_correlation_hook`, but with an additional master record parameter sent to the callback. Some modules need access to the master record and since it's already loaded during correlation process (snapshot creation), providing it doesn't incur any additional overhead. One such example is the new TS last activity module in the NERD2 system. For backwards compatibility, the new hook type was created instead of modifying the `register_correlation_hook`. Internally, they both share the same implementation.
1 parent 8462863 commit 33bb0d3

4 files changed

Lines changed: 68 additions & 14 deletions

File tree

dp3/common/callback_registrar.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dp3.common.state import SharedFlag
1414
from dp3.common.task import DataPointTask
1515
from dp3.common.types import ParsedTimedelta
16+
from dp3.common.utils import get_func_name
1617
from dp3.core.updater import Updater
1718
from dp3.snapshots.snapshooter import SnapShooter
1819
from dp3.task_processing.task_executor import TaskExecutor
@@ -365,6 +366,41 @@ def register_correlation_hook(
365366
may_change: each item should specify an attribute that `hook` may change.
366367
specification format is identical to `depends_on`.
367368
369+
Raises:
370+
ValueError: On failure of specification validation.
371+
"""
372+
# Ignore master record for this variant of the hook
373+
hook_name = get_func_name(hook)
374+
self._snap_shooter.register_correlation_hook(lambda e, s, _m: hook(e, s), entity_type, depends_on, may_change, hook_name=hook_name)
375+
376+
def register_correlation_hook_with_master_record(
377+
self,
378+
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
379+
entity_type: str,
380+
depends_on: list[list[str]],
381+
may_change: list[list[str]],
382+
):
383+
"""
384+
Registers passed hook to be called during snapshot creation.
385+
386+
Identical to `register_correlation_hook`, but the hook also receives the master record.
387+
388+
Binds hook to specified entity_type (though same hook can be bound multiple times).
389+
390+
`entity_type` and attribute specifications are validated, `ValueError` is raised on failure.
391+
392+
Args:
393+
hook: `hook` callable should expect entity type as str;
394+
its current values, including linked entities, as dict;
395+
and its master record as dict.
396+
Can optionally return a list of DataPointTask objects to perform.
397+
entity_type: specifies entity type
398+
depends_on: each item should specify an attribute that is depended on
399+
in the form of a path from the specified entity_type to individual attributes
400+
(even on linked entities).
401+
may_change: each item should specify an attribute that `hook` may change.
402+
specification format is identical to `depends_on`.
403+
368404
Raises:
369405
ValueError: On failure of specification validation.
370406
"""

dp3/snapshots/snapshooter.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,33 +185,39 @@ def register_timeseries_hook(
185185

186186
def register_correlation_hook(
187187
self,
188-
hook: Callable[[str, dict], Union[None, list[DataPointTask]]],
188+
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
189189
entity_type: str,
190190
depends_on: list[list[str]],
191191
may_change: list[list[str]],
192+
hook_name: Optional[str] = None,
192193
):
193194
"""
194195
Registers passed hook to be called during snapshot creation.
195196
197+
Common implementation for hooks with and without master record.
198+
196199
Binds hook to specified entity_type (though same hook can be bound multiple times).
197200
198201
`entity_type` and attribute specifications are validated, `ValueError` is raised on failure.
199202
200203
Args:
201204
hook: `hook` callable should expect entity type as str
202-
and its current values, including linked entities, as dict
205+
and its current values, including linked entities, as dict;
206+
and its master record as dict.
203207
Can optionally return a list of DataPointTask objects to perform.
204208
entity_type: specifies entity type
205209
depends_on: each item should specify an attribute that is depended on
206210
in the form of a path from the specified entity_type to individual attributes
207211
(even on linked entities).
208212
may_change: each item should specify an attribute that `hook` may change.
209213
specification format is identical to `depends_on`.
214+
hook_name: Optional custom name for the hook, used for logging purposes. If not
215+
provided, the function name of `hook` will be used.
210216
211217
Raises:
212218
ValueError: On failure of specification validation.
213219
"""
214-
self._correlation_hooks.register(hook, entity_type, depends_on, may_change)
220+
self._correlation_hooks.register(hook, entity_type, depends_on, may_change, hook_name=hook_name)
215221

216222
def register_run_init_hook(self, hook: Callable[[], list[DataPointTask]]):
217223
"""
@@ -457,9 +463,10 @@ def make_linkless_snapshot(self, entity_type: str, master_record: dict, time: da
457463
self.run_timeseries_processing(entity_type, master_record)
458464
values = self.get_values_at_time(entity_type, master_record, time)
459465
self.add_mirrored_links(entity_type, values)
460-
entity_values = {(entity_type, master_record["_id"]): values}
466+
entity_id = master_record["_id"]
467+
entity_values = {(entity_type, entity_id): values}
461468

462-
tasks = self._correlation_hooks.run(entity_values)
469+
tasks = self._correlation_hooks.run(entity_values, {(entity_type, entity_id): master_record})
463470
for task in tasks:
464471
self.task_queue_writer.put_task(task)
465472

@@ -499,6 +506,7 @@ def make_snapshot(self, task: Snapshot):
499506
The resulting snapshots are saved into DB.
500507
"""
501508
entity_values = {}
509+
entity_master_records = {}
502510
for entity_type, entity_id in task.entities:
503511
record = self.db.get_master_record(entity_type, entity_id) or {"_id": entity_id}
504512
if not self.config.keep_empty and len(record) == 1:
@@ -508,9 +516,10 @@ def make_snapshot(self, task: Snapshot):
508516
values = self.get_values_at_time(entity_type, record, task.time)
509517
self.add_mirrored_links(entity_type, values)
510518
entity_values[entity_type, entity_id] = values
519+
entity_master_records[entity_type, entity_id] = record
511520

512521
self.link_loaded_entities(entity_values)
513-
created_tasks = self._correlation_hooks.run(entity_values)
522+
created_tasks = self._correlation_hooks.run(entity_values, entity_master_records)
514523
for created_task in created_tasks:
515524
self.task_queue_writer.put_task(created_task)
516525

dp3/snapshots/snapshot_hooks.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ def __init__(self, log: logging.Logger, model_spec: ModelSpec, elog: EventGroupT
8484

8585
def register(
8686
self,
87-
hook: Callable[[str, dict], Union[None, list[DataPointTask]]],
87+
hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]],
8888
entity_type: str,
8989
depends_on: list[list[str]],
9090
may_change: list[list[str]],
91+
hook_name: Union[None, str] = None,
9192
) -> str:
9293
"""
9394
Registers passed hook to be called during snapshot creation.
@@ -97,15 +98,18 @@ def register(
9798
If entity_type and attribute specifications are validated
9899
and ValueError is raised on failure.
99100
Args:
100-
hook: `hook` callable should expect entity type as str
101-
and its current values, including linked entities, as dict.
101+
hook: `hook` callable should expect entity type as str;
102+
its current values, including linked entities, as dict;
103+
and its master record as dict.
102104
Can optionally return a list of DataPointTask objects to perform.
103105
entity_type: specifies entity type
104106
depends_on: each item should specify an attribute that is depended on
105107
in the form of a path from the specified entity_type to individual attributes
106108
(even on linked entities).
107109
may_change: each item should specify an attribute that `hook` may change.
108110
specification format is identical to `depends_on`.
111+
hook_name: Optional custom name for the hook, used for logging purposes. If not
112+
provided, the function name of `hook` will be used.
109113
Returns:
110114
Generated hook id.
111115
"""
@@ -120,7 +124,8 @@ def register(
120124
may_change = self._get_attr_path_destinations(entity_type, may_change)
121125

122126
hook_args = f"({entity_type}, [{','.join(depends_on)}], [{','.join(may_change)}])"
123-
hook_id = f"{get_func_name(hook)}{hook_args}"
127+
hook_name = hook_name if hook_name is not None else get_func_name(hook)
128+
hook_id = f"{hook_name}{hook_args}"
124129
self._short_hook_ids[hook_id] = hook_args
125130
self._dependency_graph.add_hook_dependency(hook_id, depends_on, may_change)
126131

@@ -191,7 +196,7 @@ def _resolve_entities_in_path(self, base_entity: str, path: list[str]) -> list[t
191196
position = entity_attributes[position.relation_to]
192197
return resolved_path
193198

194-
def run(self, entities: dict) -> list[DataPointTask]:
199+
def run(self, entities: dict, entity_master_records: dict) -> list[DataPointTask]:
195200
"""Runs registered hooks."""
196201
entity_types = {etype for etype, _ in entities}
197202
hook_subset = [
@@ -200,18 +205,22 @@ def run(self, entities: dict) -> list[DataPointTask]:
200205
topological_order = self._dependency_graph.topological_order
201206
hook_subset.sort(key=lambda x: topological_order.index(x[0]))
202207
entities_by_etype = defaultdict(dict)
208+
entity_master_records_by_etype = defaultdict(dict)
203209
for (etype, eid), values in entities.items():
204210
entities_by_etype[etype][eid] = values
211+
for (etype, eid), mr in entity_master_records.items():
212+
entity_master_records_by_etype[etype][eid] = mr
205213

206214
created_tasks = []
207215

208216
with task_context(self.model_spec):
209217
for hook_id, hook, etype in hook_subset:
210218
short_id = hook_id if len(hook_id) < 160 else self._short_hook_ids[hook_id]
211219
for eid, entity_values in entities_by_etype[etype].items():
220+
entity_master_record = entity_master_records_by_etype[etype].get(eid, {})
212221
self.log.debug("Running hook %s on entity %s", short_id, eid)
213222
try:
214-
tasks = hook(etype, entity_values)
223+
tasks = hook(etype, entity_values, entity_master_record)
215224
if tasks is not None and tasks:
216225
created_tasks.extend(tasks)
217226
except Exception as e:

tests/test_common/test_snapshots.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dp3.snapshots.snapshot_hooks import SnapshotCorrelationHookContainer
1515

1616

17-
def modify_value(_: str, record: dict, attr: str, value):
17+
def modify_value(_: str, record: dict, _master_record: dict, attr: str, value):
1818
record[attr] = value
1919

2020

@@ -37,7 +37,7 @@ def test_basic_function(self):
3737
hook=dummy_hook_abc, entity_type="A", depends_on=[["data1"]], may_change=[["data2"]]
3838
)
3939
values = {}
40-
self.container.run({("A", "a1"): values})
40+
self.container.run({("A", "a1"): values}, {})
4141
self.assertEqual(values["data2"], "abc")
4242

4343
def test_circular_dependency_error(self):

0 commit comments

Comments
 (0)