@@ -428,6 +428,18 @@ def validate_paramspectree(
428428 else :
429429 raise ValueError (f"Invalid { interdep_type_internal } " ) from TypeError (cause )
430430
431+ def _invalid_subsets (
432+ self , paramspecs : Sequence [ParamSpecBase ]
433+ ) -> tuple [set [str ], set [str ]] | None :
434+ subset_nodes = {paramspec .name for paramspec in paramspecs }
435+ for subset_node in subset_nodes :
436+ descendant_nodes_per_subset_node = nx .descendants (self .graph , subset_node )
437+ if missing_nodes := descendant_nodes_per_subset_node .difference (
438+ subset_nodes
439+ ):
440+ return (subset_nodes , missing_nodes )
441+ return None
442+
431443 def validate_subset (self , paramspecs : Sequence [ParamSpecBase ]) -> None :
432444 """
433445 Validate that the given parameters form a valid subset of the
@@ -442,15 +454,11 @@ def validate_subset(self, paramspecs: Sequence[ParamSpecBase]) -> None:
442454 InterdependencyError: If a dependency or inference is missing
443455
444456 """
445- subset_nodes = set ([paramspec .name for paramspec in paramspecs ])
446- for subset_node in subset_nodes :
447- descendant_nodes_per_subset_node = nx .descendants (self .graph , subset_node )
448- if missing_nodes := descendant_nodes_per_subset_node .difference (
449- subset_nodes
450- ):
451- raise IncompleteSubsetError (
452- subset_params = subset_nodes , missing_params = missing_nodes
453- )
457+ invalid_subset = self ._invalid_subsets (paramspecs )
458+ if invalid_subset is not None :
459+ raise IncompleteSubsetError (
460+ subset_params = invalid_subset [0 ], missing_params = invalid_subset [1 ]
461+ )
454462
455463 @classmethod
456464 def _from_graph (cls , graph : nx .DiGraph [str ]) -> InterDependencies_ :
@@ -624,3 +632,161 @@ def paramspec_tree_to_param_name_tree(
624632 return {
625633 key .name : [item .name for item in items ] for key , items in paramspec_tree .items ()
626634 }
635+
636+
637+ class FrozenInterDependencies_ (InterDependencies_ ): # noqa: PLW1641
638+ # todo: not clear if this should implement __hash__.
639+ """
640+ A frozen version of InterDependencies_ that is immutable and caches
641+ expensive lookups. This is used exclusively while running a measurement
642+ to minimize the overhead of dependency lookups for each data operation.
643+
644+ Args:
645+ interdeps: An InterDependencies_ instance to freeze
646+
647+ """
648+
649+ def __init__ (self , interdeps : InterDependencies_ ):
650+ self ._graph = interdeps .graph .copy ()
651+ nx .freeze (self ._graph )
652+ self ._top_level_parameters_cache : tuple [ParamSpecBase , ...] | None = None
653+ self ._dependencies_cache : ParamSpecTree | None = None
654+ self ._inferences_cache : ParamSpecTree | None = None
655+ self ._standalones_cache : frozenset [ParamSpecBase ] | None = None
656+ self ._find_all_parameters_in_tree_cache : dict [
657+ ParamSpecBase , set [ParamSpecBase ]
658+ ] = {}
659+ self ._invalid_subsets_cache : dict [
660+ tuple [ParamSpecBase , ...], tuple [set [str ], set [str ]] | None
661+ ] = {}
662+ self ._id_to_paramspec_cache : dict [str , ParamSpecBase ] | None = None
663+ self ._paramspec_to_id_cache : dict [ParamSpecBase , str ] | None = None
664+
665+ def add_dependencies (self , dependencies : ParamSpecTree | None ) -> None :
666+ raise TypeError ("FrozenInterDependencies_ is immutable" )
667+
668+ def add_inferences (self , inferences : ParamSpecTree | None ) -> None :
669+ raise TypeError ("FrozenInterDependencies_ is immutable" )
670+
671+ def add_standalones (self , standalones : tuple [ParamSpecBase , ...]) -> None :
672+ raise TypeError ("FrozenInterDependencies_ is immutable" )
673+
674+ def add_paramspecs (self , paramspecs : Sequence [ParamSpecBase ]) -> None :
675+ raise TypeError ("FrozenInterDependencies_ is immutable" )
676+
677+ def remove (self , paramspec : ParamSpecBase ) -> InterDependencies_ :
678+ raise TypeError ("FrozenInterDependencies_ is immutable" )
679+
680+ def extend (
681+ self ,
682+ dependencies : ParamSpecTree | None = None ,
683+ inferences : ParamSpecTree | None = None ,
684+ standalones : tuple [ParamSpecBase , ...] = (),
685+ ) -> InterDependencies_ :
686+ """
687+ Create a new :class:`InterDependencies_` object
688+ that is an extension of this instance with the provided input
689+ """
690+ # We need to unfreeze the graph for the new instance
691+ new_graph = nx .DiGraph (self .graph )
692+ new_interdependencies = InterDependencies_ ._from_graph (new_graph )
693+
694+ new_interdependencies .add_dependencies (dependencies )
695+ new_interdependencies .add_inferences (inferences )
696+ new_interdependencies .add_standalones (standalones )
697+ return new_interdependencies
698+
699+ @property
700+ def top_level_parameters (self ) -> tuple [ParamSpecBase , ...]:
701+ if self ._top_level_parameters_cache is None :
702+ self ._top_level_parameters_cache = super ().top_level_parameters
703+ return self ._top_level_parameters_cache
704+
705+ @property
706+ def dependencies (self ) -> ParamSpecTree :
707+ if self ._dependencies_cache is None :
708+ self ._dependencies_cache = super ().dependencies
709+ return self ._dependencies_cache .copy ()
710+
711+ @property
712+ def inferences (self ) -> ParamSpecTree :
713+ if self ._inferences_cache is None :
714+ self ._inferences_cache = super ().inferences
715+ return self ._inferences_cache .copy ()
716+
717+ @property
718+ def standalones (self ) -> frozenset [ParamSpecBase ]:
719+ if self ._standalones_cache is None :
720+ self ._standalones_cache = super ().standalones
721+ return self ._standalones_cache
722+
723+ def find_all_parameters_in_tree (
724+ self , initial_param : ParamSpecBase
725+ ) -> set [ParamSpecBase ]:
726+ if initial_param not in self ._find_all_parameters_in_tree_cache :
727+ self ._find_all_parameters_in_tree_cache [initial_param ] = (
728+ super ().find_all_parameters_in_tree (initial_param )
729+ )
730+ return self ._find_all_parameters_in_tree_cache [initial_param ].copy ()
731+
732+ @classmethod
733+ def _from_dict (cls , ser : InterDependencies_Dict ) -> FrozenInterDependencies_ :
734+ interdeps = InterDependencies_ ._from_dict (ser )
735+ return cls (interdeps )
736+
737+ @classmethod
738+ def _from_graph (cls , graph : nx .DiGraph [str ]) -> FrozenInterDependencies_ :
739+ interdeps = InterDependencies_ ._from_graph (graph )
740+ return cls (interdeps )
741+
742+ def validate_subset (self , paramspecs : Sequence [ParamSpecBase ]) -> None :
743+ paramspecs_tuple = tuple (paramspecs )
744+ if paramspecs_tuple not in self ._invalid_subsets_cache :
745+ self ._invalid_subsets_cache [paramspecs_tuple ] = self ._invalid_subsets (
746+ paramspecs_tuple
747+ )
748+ invalid_subset = self ._invalid_subsets_cache [paramspecs_tuple ]
749+ if invalid_subset is not None :
750+ raise IncompleteSubsetError (
751+ subset_params = invalid_subset [0 ], missing_params = invalid_subset [1 ]
752+ )
753+
754+ @property
755+ def _id_to_paramspec (self ) -> dict [str , ParamSpecBase ]:
756+ if self ._id_to_paramspec_cache is None :
757+ self ._id_to_paramspec_cache = {
758+ node_id : data ["value" ] for node_id , data in self .graph .nodes (data = True )
759+ }
760+ return self ._id_to_paramspec_cache
761+
762+ @property
763+ def _paramspec_to_id (self ) -> dict [ParamSpecBase , str ]:
764+ if self ._paramspec_to_id_cache is None :
765+ self ._paramspec_to_id_cache = {
766+ data ["value" ]: node_id for node_id , data in self .graph .nodes (data = True )
767+ }
768+ return self ._paramspec_to_id_cache
769+
770+ def __repr__ (self ) -> str :
771+ rep = (
772+ f"FrozenInterDependencies_(dependencies={ self .dependencies } , "
773+ f"inferences={ self .inferences } , "
774+ f"standalones={ self .standalones } )"
775+ )
776+ return rep
777+
778+ def __eq__ (self , other : object ) -> bool :
779+ if not isinstance (other , FrozenInterDependencies_ ):
780+ return False
781+ return nx .utils .graphs_equal (self .graph , other .graph )
782+
783+ def to_interdependencies (self ) -> InterDependencies_ :
784+ """
785+ Convert this FrozenInterDependencies_ back to a mutable InterDependencies_ instance.
786+
787+ Returns:
788+ A new InterDependencies_ instance with the same data as this frozen instance.
789+
790+ """
791+ new_graph = nx .DiGraph (self .graph )
792+ return InterDependencies_ ._from_graph (new_graph )
0 commit comments