Skip to content

Commit 6bbb13b

Browse files
authored
Merge pull request #7712 from jenshnielsen/improve_speed
Remove graph related overhead from data storage
2 parents 26ad190 + 41f6b9d commit 6bbb13b

7 files changed

Lines changed: 283 additions & 14 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
The `InterDependencies_` class is now frozen during the performance of a measurement so it cannot be modified.
2+
This enables caching of attributes on the class significantly reducing the overhead of measurements.

src/qcodes/dataset/data_set.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,10 @@ def toggle_debug(self) -> None:
566566
self.conn = connect(path_to_db, self._debug)
567567

568568
def set_interdependencies(
569-
self, interdeps: InterDependencies_, shapes: Shapes | None = None
569+
self,
570+
interdeps: InterDependencies_,
571+
shapes: Shapes | None = None,
572+
override: bool = False,
570573
) -> None:
571574
"""
572575
Set the interdependencies object (which holds all added
@@ -579,7 +582,7 @@ def set_interdependencies(
579582
f"Wrong input type. Expected InterDepencies_, got {type(interdeps)}"
580583
)
581584

582-
if not self.pristine:
585+
if not self.pristine and not override:
583586
mssg = "Can not set interdependencies on a DataSet that has been started."
584587
raise RuntimeError(mssg)
585588
self._rundescriber = RunDescriber(interdeps, shapes=shapes)

src/qcodes/dataset/data_set_in_memory.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,10 @@ def _set_parent_dataset_links(self, links: list[Link]) -> None:
748748
self._parent_dataset_links = links
749749

750750
def _set_interdependencies(
751-
self, interdeps: InterDependencies_, shapes: Shapes | None = None
751+
self,
752+
interdeps: InterDependencies_,
753+
shapes: Shapes | None = None,
754+
override: bool = False,
752755
) -> None:
753756
"""
754757
Set the interdependencies object (which holds all added
@@ -761,7 +764,7 @@ def _set_interdependencies(
761764
f"Wrong input type. Expected InterDepencies_, got {type(interdeps)}"
762765
)
763766

764-
if not self.pristine:
767+
if not self.pristine and not override:
765768
mssg = "Can not set interdependencies on a DataSet that has been started."
766769
raise RuntimeError(mssg)
767770
self._rundescriber = RunDescriber(interdeps, shapes=shapes)

src/qcodes/dataset/descriptions/dependencies.py

Lines changed: 175 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,18 @@ def validate_paramspectree(
428428
else:
429429
raise ValueError(f"Invalid {interdep_type_internal}") from TypeError(cause)
430430

431+
def _invalid_subsets(
432+
self, paramspecs: Sequence[ParamSpecBase]
433+
) -> tuple[set[str], set[str]] | None:
434+
subset_nodes = {paramspec.name for paramspec in paramspecs}
435+
for subset_node in subset_nodes:
436+
descendant_nodes_per_subset_node = nx.descendants(self.graph, subset_node)
437+
if missing_nodes := descendant_nodes_per_subset_node.difference(
438+
subset_nodes
439+
):
440+
return (subset_nodes, missing_nodes)
441+
return None
442+
431443
def validate_subset(self, paramspecs: Sequence[ParamSpecBase]) -> None:
432444
"""
433445
Validate that the given parameters form a valid subset of the
@@ -442,15 +454,11 @@ def validate_subset(self, paramspecs: Sequence[ParamSpecBase]) -> None:
442454
InterdependencyError: If a dependency or inference is missing
443455
444456
"""
445-
subset_nodes = set([paramspec.name for paramspec in paramspecs])
446-
for subset_node in subset_nodes:
447-
descendant_nodes_per_subset_node = nx.descendants(self.graph, subset_node)
448-
if missing_nodes := descendant_nodes_per_subset_node.difference(
449-
subset_nodes
450-
):
451-
raise IncompleteSubsetError(
452-
subset_params=subset_nodes, missing_params=missing_nodes
453-
)
457+
invalid_subset = self._invalid_subsets(paramspecs)
458+
if invalid_subset is not None:
459+
raise IncompleteSubsetError(
460+
subset_params=invalid_subset[0], missing_params=invalid_subset[1]
461+
)
454462

455463
@classmethod
456464
def _from_graph(cls, graph: nx.DiGraph[str]) -> InterDependencies_:
@@ -624,3 +632,161 @@ def paramspec_tree_to_param_name_tree(
624632
return {
625633
key.name: [item.name for item in items] for key, items in paramspec_tree.items()
626634
}
635+
636+
637+
class FrozenInterDependencies_(InterDependencies_): # noqa: PLW1641
638+
# todo: not clear if this should implement __hash__.
639+
"""
640+
A frozen version of InterDependencies_ that is immutable and caches
641+
expensive lookups. This is used exclusively while running a measurement
642+
to minimize the overhead of dependency lookups for each data operation.
643+
644+
Args:
645+
interdeps: An InterDependencies_ instance to freeze
646+
647+
"""
648+
649+
def __init__(self, interdeps: InterDependencies_):
650+
self._graph = interdeps.graph.copy()
651+
nx.freeze(self._graph)
652+
self._top_level_parameters_cache: tuple[ParamSpecBase, ...] | None = None
653+
self._dependencies_cache: ParamSpecTree | None = None
654+
self._inferences_cache: ParamSpecTree | None = None
655+
self._standalones_cache: frozenset[ParamSpecBase] | None = None
656+
self._find_all_parameters_in_tree_cache: dict[
657+
ParamSpecBase, set[ParamSpecBase]
658+
] = {}
659+
self._invalid_subsets_cache: dict[
660+
tuple[ParamSpecBase, ...], tuple[set[str], set[str]] | None
661+
] = {}
662+
self._id_to_paramspec_cache: dict[str, ParamSpecBase] | None = None
663+
self._paramspec_to_id_cache: dict[ParamSpecBase, str] | None = None
664+
665+
def add_dependencies(self, dependencies: ParamSpecTree | None) -> None:
666+
raise TypeError("FrozenInterDependencies_ is immutable")
667+
668+
def add_inferences(self, inferences: ParamSpecTree | None) -> None:
669+
raise TypeError("FrozenInterDependencies_ is immutable")
670+
671+
def add_standalones(self, standalones: tuple[ParamSpecBase, ...]) -> None:
672+
raise TypeError("FrozenInterDependencies_ is immutable")
673+
674+
def add_paramspecs(self, paramspecs: Sequence[ParamSpecBase]) -> None:
675+
raise TypeError("FrozenInterDependencies_ is immutable")
676+
677+
def remove(self, paramspec: ParamSpecBase) -> InterDependencies_:
678+
raise TypeError("FrozenInterDependencies_ is immutable")
679+
680+
def extend(
681+
self,
682+
dependencies: ParamSpecTree | None = None,
683+
inferences: ParamSpecTree | None = None,
684+
standalones: tuple[ParamSpecBase, ...] = (),
685+
) -> InterDependencies_:
686+
"""
687+
Create a new :class:`InterDependencies_` object
688+
that is an extension of this instance with the provided input
689+
"""
690+
# We need to unfreeze the graph for the new instance
691+
new_graph = nx.DiGraph(self.graph)
692+
new_interdependencies = InterDependencies_._from_graph(new_graph)
693+
694+
new_interdependencies.add_dependencies(dependencies)
695+
new_interdependencies.add_inferences(inferences)
696+
new_interdependencies.add_standalones(standalones)
697+
return new_interdependencies
698+
699+
@property
700+
def top_level_parameters(self) -> tuple[ParamSpecBase, ...]:
701+
if self._top_level_parameters_cache is None:
702+
self._top_level_parameters_cache = super().top_level_parameters
703+
return self._top_level_parameters_cache
704+
705+
@property
706+
def dependencies(self) -> ParamSpecTree:
707+
if self._dependencies_cache is None:
708+
self._dependencies_cache = super().dependencies
709+
return self._dependencies_cache.copy()
710+
711+
@property
712+
def inferences(self) -> ParamSpecTree:
713+
if self._inferences_cache is None:
714+
self._inferences_cache = super().inferences
715+
return self._inferences_cache.copy()
716+
717+
@property
718+
def standalones(self) -> frozenset[ParamSpecBase]:
719+
if self._standalones_cache is None:
720+
self._standalones_cache = super().standalones
721+
return self._standalones_cache
722+
723+
def find_all_parameters_in_tree(
724+
self, initial_param: ParamSpecBase
725+
) -> set[ParamSpecBase]:
726+
if initial_param not in self._find_all_parameters_in_tree_cache:
727+
self._find_all_parameters_in_tree_cache[initial_param] = (
728+
super().find_all_parameters_in_tree(initial_param)
729+
)
730+
return self._find_all_parameters_in_tree_cache[initial_param].copy()
731+
732+
@classmethod
733+
def _from_dict(cls, ser: InterDependencies_Dict) -> FrozenInterDependencies_:
734+
interdeps = InterDependencies_._from_dict(ser)
735+
return cls(interdeps)
736+
737+
@classmethod
738+
def _from_graph(cls, graph: nx.DiGraph[str]) -> FrozenInterDependencies_:
739+
interdeps = InterDependencies_._from_graph(graph)
740+
return cls(interdeps)
741+
742+
def validate_subset(self, paramspecs: Sequence[ParamSpecBase]) -> None:
743+
paramspecs_tuple = tuple(paramspecs)
744+
if paramspecs_tuple not in self._invalid_subsets_cache:
745+
self._invalid_subsets_cache[paramspecs_tuple] = self._invalid_subsets(
746+
paramspecs_tuple
747+
)
748+
invalid_subset = self._invalid_subsets_cache[paramspecs_tuple]
749+
if invalid_subset is not None:
750+
raise IncompleteSubsetError(
751+
subset_params=invalid_subset[0], missing_params=invalid_subset[1]
752+
)
753+
754+
@property
755+
def _id_to_paramspec(self) -> dict[str, ParamSpecBase]:
756+
if self._id_to_paramspec_cache is None:
757+
self._id_to_paramspec_cache = {
758+
node_id: data["value"] for node_id, data in self.graph.nodes(data=True)
759+
}
760+
return self._id_to_paramspec_cache
761+
762+
@property
763+
def _paramspec_to_id(self) -> dict[ParamSpecBase, str]:
764+
if self._paramspec_to_id_cache is None:
765+
self._paramspec_to_id_cache = {
766+
data["value"]: node_id for node_id, data in self.graph.nodes(data=True)
767+
}
768+
return self._paramspec_to_id_cache
769+
770+
def __repr__(self) -> str:
771+
rep = (
772+
f"FrozenInterDependencies_(dependencies={self.dependencies}, "
773+
f"inferences={self.inferences}, "
774+
f"standalones={self.standalones})"
775+
)
776+
return rep
777+
778+
def __eq__(self, other: object) -> bool:
779+
if not isinstance(other, FrozenInterDependencies_):
780+
return False
781+
return nx.utils.graphs_equal(self.graph, other.graph)
782+
783+
def to_interdependencies(self) -> InterDependencies_:
784+
"""
785+
Convert this FrozenInterDependencies_ back to a mutable InterDependencies_ instance.
786+
787+
Returns:
788+
A new InterDependencies_ instance with the same data as this frozen instance.
789+
790+
"""
791+
new_graph = nx.DiGraph(self.graph)
792+
return InterDependencies_._from_graph(new_graph)

src/qcodes/dataset/measurements.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ValuesType,
3737
)
3838
from qcodes.dataset.descriptions.dependencies import (
39+
FrozenInterDependencies_,
3940
IncompleteSubsetError,
4041
InterDependencies_,
4142
ParamSpecTree,
@@ -765,6 +766,28 @@ def __exit__(
765766
self._span.record_exception(exception_value)
766767
self.ds.add_metadata("measurement_exception", exception_string)
767768

769+
# for now we set the interdependencies back to the
770+
# not frozen state, so that further modifications are possible
771+
# this is not recommended but we want to minimize the changes for now
772+
773+
if isinstance(self.ds.description.interdeps, FrozenInterDependencies_):
774+
intedeps = self.ds.description.interdeps.to_interdependencies()
775+
else:
776+
intedeps = self.ds.description.interdeps
777+
778+
if isinstance(self.ds, DataSet):
779+
self.ds.set_interdependencies(
780+
shapes=self.ds.description.shapes,
781+
interdeps=intedeps,
782+
override=True,
783+
)
784+
elif isinstance(self.ds, DataSetInMem):
785+
self.ds._set_interdependencies(
786+
shapes=self.ds.description.shapes,
787+
interdeps=intedeps,
788+
override=True,
789+
)
790+
768791
# and finally mark the dataset as closed, thus
769792
# finishing the measurement
770793
# Note that the completion of a dataset entails waiting for the
@@ -1514,7 +1537,7 @@ def run(
15141537
self.experiment,
15151538
station=self.station,
15161539
write_period=self._write_period,
1517-
interdeps=self._interdeps,
1540+
interdeps=FrozenInterDependencies_(self._interdeps),
15181541
name=self.name,
15191542
subscribers=self.subscribers,
15201543
parent_datasets=self._parent_datasets,

tests/dataset/measurement/test_measurement_context_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
import qcodes as qc
2222
import qcodes.validators as vals
2323
from qcodes.dataset.data_set import DataSet, load_by_id
24+
from qcodes.dataset.descriptions.dependencies import (
25+
FrozenInterDependencies_,
26+
InterDependencies_,
27+
)
2428
from qcodes.dataset.experiment_container import new_experiment
2529
from qcodes.dataset.export_config import DataExportType
2630
from qcodes.dataset.measurements import Measurement
@@ -730,6 +734,16 @@ def test_datasaver_scalars(
730734
with pytest.raises(ValueError):
731735
datasaver.add_result((DMM.v1, 0))
732736

737+
ds = datasaver.dataset
738+
assert isinstance(ds, DataSet)
739+
assert isinstance(ds.description.interdeps, InterDependencies_)
740+
assert not isinstance(ds.description.interdeps, FrozenInterDependencies_)
741+
742+
loaded_ds = load_by_id(ds.run_id)
743+
744+
assert isinstance(loaded_ds.description.interdeps, InterDependencies_)
745+
assert not isinstance(loaded_ds.description.interdeps, FrozenInterDependencies_)
746+
733747
# More assertions of setpoints, labels and units in the DB!
734748

735749

tests/dataset/test_dependencies.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from networkx import NetworkXError
77

88
from qcodes.dataset.descriptions.dependencies import (
9+
FrozenInterDependencies_,
910
IncompleteSubsetError,
1011
InterDependencies_,
1112
)
@@ -477,3 +478,60 @@ def test_dependency_on_middle_parameter(
477478
# in both directions, ps4 is actually a member of the tree for ps1
478479
assert idps.top_level_parameters == (ps1,)
479480
assert idps.find_all_parameters_in_tree(ps1) == {ps1, ps2, ps3, ps4}
481+
482+
483+
def test_frozen_interdependencies(some_paramspecbases) -> None:
484+
ps1, ps2, ps3, ps4 = some_paramspecbases
485+
idps = InterDependencies_(dependencies={ps1: (ps2, ps3)}, inferences={ps2: (ps4,)})
486+
487+
frozen = FrozenInterDependencies_(idps)
488+
489+
assert frozen.dependencies == idps.dependencies
490+
assert frozen.inferences == idps.inferences
491+
assert frozen.standalones == idps.standalones
492+
assert frozen.top_level_parameters == idps.top_level_parameters
493+
494+
# Test immutability
495+
with pytest.raises(TypeError, match="FrozenInterDependencies_ is immutable"):
496+
frozen.add_dependencies({ps4: (ps1,)})
497+
498+
with pytest.raises(TypeError, match="FrozenInterDependencies_ is immutable"):
499+
frozen.add_inferences({ps4: (ps1,)})
500+
501+
with pytest.raises(TypeError, match="FrozenInterDependencies_ is immutable"):
502+
frozen.add_standalones((ps4,))
503+
504+
with pytest.raises(TypeError, match="FrozenInterDependencies_ is immutable"):
505+
frozen.remove(ps1)
506+
507+
with pytest.raises(TypeError, match="FrozenInterDependencies_ is immutable"):
508+
frozen.add_paramspecs((ps1,))
509+
510+
# Test extend returns InterDependencies_ (mutable)
511+
ps5 = ParamSpecBase("psb5", "numeric", "number", "")
512+
extended = frozen.extend(standalones=(ps5,))
513+
assert isinstance(extended, InterDependencies_)
514+
assert not isinstance(extended, FrozenInterDependencies_)
515+
assert ps5 in extended.standalones
516+
517+
# Test caching of properties
518+
# Access properties to trigger caching
519+
_ = frozen.dependencies
520+
_ = frozen.inferences
521+
_ = frozen.standalones
522+
_ = frozen.top_level_parameters
523+
524+
assert frozen._dependencies_cache is not None
525+
assert frozen._inferences_cache is not None
526+
assert frozen._standalones_cache is not None
527+
assert frozen._top_level_parameters_cache is not None
528+
529+
530+
def test_frozen_from_dict(some_paramspecbases) -> None:
531+
ps1, ps2, ps3, _ = some_paramspecbases
532+
idps = InterDependencies_(dependencies={ps1: (ps2, ps3)})
533+
ser = idps._to_dict()
534+
535+
frozen = FrozenInterDependencies_._from_dict(ser)
536+
assert isinstance(frozen, FrozenInterDependencies_)
537+
assert frozen == FrozenInterDependencies_(idps)

0 commit comments

Comments
 (0)