diff --git a/openhtf/core/phase_descriptor.py b/openhtf/core/phase_descriptor.py index 7327c1a9..391949ab 100644 --- a/openhtf/core/phase_descriptor.py +++ b/openhtf/core/phase_descriptor.py @@ -124,6 +124,8 @@ class PhaseOptions(object): timeout will still apply when under the debugger. phase_name_case: Case formatting options for phase name. stop_on_measurement_fail: Whether to stop the test if any measurements fail. + prerequisites: List of phases that must be completed before this phase can + be run. Example Usages: @PhaseOptions(timeout_s=1) def PhaseFunc(test): pass @PhaseOptions(name='Phase({port})') def PhaseFunc(test, port, other_info): pass @@ -140,6 +142,7 @@ def PhaseFunc(test, port, other_info): pass run_under_pdb = attr.ib(type=bool, default=False) phase_name_case = attr.ib(type=PhaseNameCase, default=PhaseNameCase.KEEP) stop_on_measurement_fail = attr.ib(type=bool, default=False) + prerequisites = attr.ib(type=Optional[List[Any]], default=None) def format_strings(self, **kwargs: Any) -> 'PhaseOptions': """String substitution of name.""" @@ -173,6 +176,8 @@ def __call__(self, phase_func: PhaseT) -> 'PhaseDescriptor': phase.options.stop_on_measurement_fail = self.stop_on_measurement_fail if self.phase_name_case: phase.options.phase_name_case = self.phase_name_case + if self.prerequisites is not None: + phase.options.prerequisites = self.prerequisites return phase diff --git a/openhtf/core/phase_executor.py b/openhtf/core/phase_executor.py index 637dab47..5f22303e 100644 --- a/openhtf/core/phase_executor.py +++ b/openhtf/core/phase_executor.py @@ -36,7 +36,7 @@ import time import traceback import types -from typing import Any, Dict, Optional, Text, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Text, Tuple, Type, Union import attr from openhtf import util @@ -170,10 +170,12 @@ def __init__(self, phase_desc: phase_descriptor.PhaseDescriptor, self._phase_desc = phase_desc self._test_state = test_state self._subtest_rec = subtest_rec + self._phase_state = test_state.running_phase_state self._phase_execution_outcome = None # type: Optional[PhaseExecutionOutcome] def _thread_proc(self) -> None: """Execute the encompassed phase and save the result.""" + self._test_state.running_phase_state = self._phase_state # Call the phase, save the return value, or default it to CONTINUE. phase_return = self._phase_desc(self._test_state) if phase_return is None: @@ -239,6 +241,7 @@ def __init__(self, test_state: 'htf_test_state.TestState'): # _execute_phase_once is setting up the next phase thread. self._current_phase_thread_lock = threading.Lock() self._current_phase_thread = None # type: Optional[PhaseExecutorThread] + self._active_phase_threads = set() # type: Set[PhaseExecutorThread] self._stopping = threading.Event() def _should_repeat(self, phase: phase_descriptor.PhaseDescriptor, @@ -320,9 +323,12 @@ def _execute_phase_once( phase_desc.name) return PhaseExecutionOutcome(phase_descriptor.PhaseResult.SKIP), None - override_result = None - with self.test_state.running_phase_context(phase_desc) as phase_state: + if id(phase_desc) in getattr(self.test_state, '_concurrent_nodes', set()): + ctx_mgr = self.test_state.concurrent_running_phase_context + else: + ctx_mgr = self.test_state.running_phase_context + with ctx_mgr(phase_desc) as phase_state: if subtest_rec: self.logger.debug('Executing phase %s under subtest %s (from %s)', phase_desc.name, phase_desc.func_location, @@ -347,6 +353,7 @@ def _execute_phase_once( run_with_profiling, subtest_rec) phase_thread.start() self._current_phase_thread = phase_thread + self._active_phase_threads.add(phase_thread) phase_state.result = phase_thread.join_or_die() if phase_state.result.is_repeat and is_last_repeat: @@ -354,7 +361,10 @@ def _execute_phase_once( phase_state.hit_repeat_limit = True override_result = PhaseExecutionOutcome( phase_descriptor.PhaseResult.STOP) - self._current_phase_thread = None + with self._current_phase_thread_lock: + self._active_phase_threads.discard(phase_thread) + if self._current_phase_thread == phase_thread: + self._current_phase_thread = None # Refresh the result in case a validation for a partially set measurement # or phase diagnoser raised an exception. @@ -433,18 +443,23 @@ def stop( """ self._stopping.set() with self._current_phase_thread_lock: - phase_thread = self._current_phase_thread - if not phase_thread: - return - - if phase_thread.is_alive(): - phase_thread.kill() - - self.logger.debug('Waiting for cancelled phase to exit: %s', phase_thread) - timeout = timeouts.PolledTimeout.from_seconds(timeout_s) - while phase_thread.is_alive() and not timeout.has_expired(): - time.sleep(0.1) - self.logger.debug('Cancelled phase %s exit', - "didn't" if phase_thread.is_alive() else 'did') + threads_to_kill = list(self._active_phase_threads) + + timeout = timeouts.PolledTimeout.from_seconds(timeout_s) + for phase_thread in threads_to_kill: + if phase_thread.is_alive(): + phase_thread.kill() + + for phase_thread in threads_to_kill: + if phase_thread.is_alive(): + self.logger.debug( + 'Waiting for cancelled phase to exit: %s', phase_thread + ) + while phase_thread.is_alive() and not timeout.has_expired(): + time.sleep(0.1) + self.logger.debug( + 'Cancelled phase %s exit', + "didn't" if phase_thread.is_alive() else 'did', + ) # Clear the currently running phase, whether it finished or timed out. self.test_state.stop_running_phase() diff --git a/openhtf/core/phase_graph.py b/openhtf/core/phase_graph.py new file mode 100644 index 00000000..f92b8f21 --- /dev/null +++ b/openhtf/core/phase_graph.py @@ -0,0 +1,229 @@ +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Phase Graph support for OpenHTF. + +PhaseGraph is a PhaseCollectionNode that manages its contained phases via +a topological sort based on their explicit prerequisites. +""" + +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Text, Tuple, Type + +import attr +from openhtf import util +from openhtf.core import base_plugs +from openhtf.core import phase_collections +from openhtf.core import phase_descriptor + + +class CyclicDependencyError(Exception): + """PhaseGraph phases have cyclic dependencies.""" + + +class PhaseUnreachableError(Exception): + """A prerequisite is not defined in the graph.""" + + +class DuplicatePhaseNameError(Exception): + """PhaseGraph phases have duplicate names.""" + + +@attr.s(slots=True, frozen=True) +class PhaseEdge: + """A dependent phase and its prerequisite phases. + + Attributes: + dependent: The phase that depends on the prerequisites. + prerequisites: Phases that must run before the dependent phase. + """ + + dependent = attr.ib(type=phase_descriptor.PhaseCallableOrNodeT) + prerequisites = attr.ib(type=Sequence[phase_descriptor.PhaseCallableOrNodeT]) + + +@attr.s(slots=True, frozen=True, init=False) +class PhaseGraph(phase_collections.PhaseCollectionNode): + """A phase collection whose execution order is defined by a DAG. + + For each phase, the name must be unique within the PhaseGraph. The execution + order is determined by a topological sort of the phases based on their + prerequisites. + + Attributes: + nodes: A tuple of PhaseDescriptor instances in topologically sorted order. + name: An optional name for this PhaseGraph. + """ + + @classmethod + def from_edges( + cls, + edges: Sequence['PhaseEdge'], + name: Optional[Text] = None, + ) -> 'PhaseGraph': + """Constructs a PhaseGraph from explicit PhaseEdge objects.""" + flat_unique_nodes = [] + seen_ids = set() + + def _add(node): + node_id = id(node) + if node_id not in seen_ids: + seen_ids.add(node_id) + flat_unique_nodes.append(node) + + for edge in edges: + _add(edge.dependent) + for prereq in edge.prerequisites: + _add(prereq) + + wrapped_nodes = [] + wrapped_by_orig = {} + for n in flat_unique_nodes: + wrapped = phase_descriptor.PhaseDescriptor.wrap_or_copy(n) + wrapped_nodes.append(wrapped) + wrapped_by_orig[id(n)] = wrapped + wrapped_by_orig[wrapped.name] = wrapped + + for edge in edges: + wrapped_dep = wrapped_by_orig[id(edge.dependent)] + prereq_names = [] + for prereq in edge.prerequisites: + wrapped_prereq = wrapped_by_orig.get(id(prereq)) + if wrapped_prereq: + prereq_names.append(wrapped_prereq.name) + wrapped_dep.options.prerequisites = prereq_names + + return cls(*wrapped_nodes, name=name) + + nodes = attr.ib(type=Tuple[phase_descriptor.PhaseDescriptor, ...]) + name = attr.ib(type=Optional[Text], default=None) + + def __init__( + self, + *args: phase_descriptor.PhaseCallableOrNodeT, + name: Optional[Text] = None, + nodes: Optional[Tuple[phase_descriptor.PhaseDescriptor, ...]] = None, + ): + super(PhaseGraph, self).__init__() + object.__setattr__(self, 'name', name) + + if nodes is not None: + args = args + tuple(nodes) + + flattened = list(phase_collections._recursive_flatten(args)) + # Verify elements are PhaseDescriptor instances for prerequisite matching + ph_desc_list = [] + for n in flattened: + if isinstance(n, phase_descriptor.PhaseDescriptor): + ph_desc_list.append(n) + else: + # Wrap or copy standard callables / nodes + ph_desc_list.append(phase_descriptor.PhaseDescriptor.wrap_or_copy(n)) + + topologically_sorted = self._validate_and_toposort(ph_desc_list) + object.__setattr__(self, 'nodes', tuple(topologically_sorted)) + + def _validate_and_toposort( + self, nodes: List[phase_descriptor.PhaseDescriptor] + ) -> List[phase_descriptor.PhaseDescriptor]: + """Validates the DAG structure and returns topologically sorted nodes.""" + # Verify unique phase names + node_by_name: Dict[str, phase_descriptor.PhaseDescriptor] = {} + for n in nodes: + if n.name in node_by_name: + raise DuplicatePhaseNameError( + f"Duplicate phase name '{n.name}' detected in PhaseGraph." + ) + node_by_name[n.name] = n + + # Match prerequisites to actual nodes + adjacency = {n.name: set() for n in nodes} + for n in nodes: + if n.options.prerequisites is not None: + for pr in n.options.prerequisites: + pr_name = pr if isinstance(pr, str) else getattr(pr, 'name', None) + if not pr_name or pr_name not in node_by_name: + raise PhaseUnreachableError( + f"Prerequisite '{pr_name}' for phase '{n.name}' not found in" + ' PhaseGraph.' + ) + adjacency[n.name].add(pr_name) + + # Perform DFS-based topological sort with cycle detection. + visited = set() + temp_marked = set() + sorted_names = [] + + def _visit(node_name: str): + if node_name in temp_marked: + raise CyclicDependencyError(f"Cycle detected involving '{node_name}'") + if node_name in visited: + return + temp_marked.add(node_name) + for prereq_name in adjacency[node_name]: + _visit(prereq_name) + temp_marked.remove(node_name) + visited.add(node_name) + sorted_names.append(node_name) + + for n in nodes: + if n.name not in visited: + _visit(n.name) + + return [node_by_name[name] for name in sorted_names] + + def _asdict(self) -> Dict[Text, Any]: + return { + 'name': self.name, + 'nodes': [n._asdict() for n in self.nodes], + } + + def with_args(self, **kwargs: Any) -> 'PhaseGraph': + return attr.evolve( + self, + nodes=tuple(n.with_args(**kwargs) for n in self.nodes), + name=util.format_string(self.name, kwargs), + ) + + def with_plugs(self, **subplugs: Type[base_plugs.BasePlug]) -> 'PhaseGraph': + return attr.evolve( + self, + nodes=tuple(n.with_plugs(**subplugs) for n in self.nodes), + name=util.format_string(self.name, subplugs), + ) + + def load_code_info(self) -> 'PhaseGraph': + return attr.evolve( + self, + nodes=tuple(n.load_code_info() for n in self.nodes), + name=self.name, + ) + + def apply_to_all_phases( + self, + func: Callable[ + [phase_descriptor.PhaseDescriptor], phase_descriptor.PhaseDescriptor + ], + ) -> 'PhaseGraph': + return attr.evolve( + self, + nodes=tuple(n.apply_to_all_phases(func) for n in self.nodes), + name=self.name, + ) + + def filter_by_type(self, node_cls: Type[Any]) -> Iterator[Any]: + for node in self.nodes: + if isinstance(node, node_cls): + yield node + if isinstance(node, phase_collections.PhaseCollectionNode): + for sub_n in node.filter_by_type(node_cls): + yield sub_n diff --git a/openhtf/core/test_executor.py b/openhtf/core/test_executor.py index d5a0d2b5..0138efaf 100644 --- a/openhtf/core/test_executor.py +++ b/openhtf/core/test_executor.py @@ -13,15 +13,17 @@ # limitations under the License. """TestExecutor executes tests.""" +import concurrent.futures import contextlib import enum import logging +import multiprocessing import pstats import sys import tempfile import threading import traceback -from typing import Iterator, List, Optional, Text, Type, TYPE_CHECKING +from typing import Iterator, List, Optional, TYPE_CHECKING, Text, Type from openhtf import util from openhtf.core import base_plugs @@ -30,6 +32,7 @@ from openhtf.core import phase_collections from openhtf.core import phase_descriptor from openhtf.core import phase_executor +from openhtf.core import phase_graph from openhtf.core import phase_group from openhtf.core import phase_nodes from openhtf.core import test_record @@ -602,6 +605,82 @@ def _execute_phase_group(self, group: phase_group.PhaseGroup, teardown_ret = _ExecutorReturn.CONTINUE return _more_critical(main_ret, teardown_ret) + def _execute_phase_graph( + self, + graph: phase_graph.PhaseGraph, + subtest_rec: Optional[test_record.SubtestRecord], + in_teardown: bool, + ) -> _ExecutorReturn: + """Executes the phases in a phase graph concurrently according to DAG ordering.""" + + if graph.name: + self.logger.debug('Entering PhaseGraph %s', graph.name) + + nodes = list(graph.nodes) + + completed_phases = set() + failed_phases = set() + running_futures = {} + + with concurrent.futures.ThreadPoolExecutor( + max_workers=min(32, (multiprocessing.cpu_count() or 1) + 4) + ) as pool: + while len(completed_phases) + len(failed_phases) < len(nodes): + if self._abort.is_set() or self._full_abort.is_set(): + for fut in running_futures: + fut.cancel() + return _ExecutorReturn.TERMINAL + + # Submit any unblocked and un-scheduled phase + made_progress = False + for node in nodes: + if ( + node.name in completed_phases + or node.name in failed_phases + or node.name in running_futures.values() + ): + continue + + prerequisite_satisfied = not node.options.prerequisites or all( + (pr if isinstance(pr, str) else getattr(pr, 'name', None)) + in completed_phases + for pr in node.options.prerequisites + ) + + if prerequisite_satisfied: + self.running_test_state.add_concurrent_node(id(node)) + fut = pool.submit( + self._execute_phase, node, subtest_rec, in_teardown + ) + running_futures[fut] = node.name + made_progress = True + + if not running_futures and not made_progress: + # Blocked completely / cycle or missing prerequisites + return _ExecutorReturn.TERMINAL + + # Wait for at least one currently running future to complete + done, _ = concurrent.futures.wait( + running_futures.keys(), + return_when=concurrent.futures.FIRST_COMPLETED, + ) + + for fut in done: + p_name = running_futures.pop(fut) + try: + res = fut.result() + if res == _ExecutorReturn.TERMINAL: + failed_phases.add(p_name) + return _ExecutorReturn.TERMINAL + else: + completed_phases.add(p_name) + except Exception: # pylint: disable=broad-except + self.logger.exception('Phase worker thread raised an exception.') + failed_phases.add(p_name) + return _ExecutorReturn.TERMINAL + + return _ExecutorReturn.CONTINUE + def _execute_node(self, node: phase_nodes.PhaseNode, subtest_rec: Optional[test_record.SubtestRecord], in_teardown: bool) -> _ExecutorReturn: @@ -613,6 +692,8 @@ def _execute_node(self, node: phase_nodes.PhaseNode, return self._execute_sequence(node, subtest_rec, in_teardown) if isinstance(node, phase_group.PhaseGroup): return self._execute_phase_group(node, subtest_rec, in_teardown) + if isinstance(node, phase_graph.PhaseGraph): + return self._execute_phase_graph(node, subtest_rec, in_teardown) if isinstance(node, phase_descriptor.PhaseDescriptor): return self._execute_phase(node, subtest_rec, in_teardown) if isinstance(node, phase_branches.Checkpoint): diff --git a/openhtf/core/test_state.py b/openhtf/core/test_state.py index 89269650..3c311dce 100644 --- a/openhtf/core/test_state.py +++ b/openhtf/core/test_state.py @@ -33,6 +33,7 @@ import os import socket import sys +import threading from typing import Any, Dict, Iterator, List, Optional, Set, TYPE_CHECKING, Text, Tuple, Union import attr @@ -137,6 +138,10 @@ def __init__(self, test_desc: 'test_descriptor.TestDescriptor', test_options: test_options passed through from Test. """ super(TestState, self).__init__() + self.state_lock = threading.RLock() + self._thread_local = threading.local() + self._active_phase_states = list() + self._concurrent_nodes = set() self._status = self.Status.WAITING_FOR_TEST_START # type: TestState.Status self.test_record = test_record.TestRecord( @@ -180,6 +185,32 @@ def logger(self) -> logging.Logger: 'Calling `logger` attribute while phase not running; use state_logger ' 'instead.') + @property + def running_phase_state(self): + state = getattr(self._thread_local, 'phase_state', None) + if state is not None: + return state + + with self.state_lock: + if not self._active_phase_states: + return None + if len(self._active_phase_states) > 1: + raise RuntimeError( + 'Multiple active phase states found when' + ' self._thread_local.phase_state is None. This use case is not' + ' supported.' + ) + return self._active_phase_states[0] + + def add_concurrent_node(self, node_id: int) -> None: + """Add a node to the set of nodes that are running concurrently.""" + with self.state_lock: + self._concurrent_nodes.add(node_id) + + @running_phase_state.setter + def running_phase_state(self, value): + self._thread_local.phase_state = value + @property def test_api(self) -> 'test_descriptor.TestApi': """Create a TestApi for access to this TestState. @@ -196,14 +227,20 @@ def test_api(self) -> 'test_descriptor.TestApi': """ if not self.running_phase_state: raise ValueError('test_api only available when phase is running.') - if not self._running_test_api: - self._running_test_api = openhtf.TestApi( + api = getattr(self._thread_local, 'test_api', None) + if ( + not api + or getattr(api, 'running_phase_state', None) != self.running_phase_state + ): + api = openhtf.TestApi( measurements=measurements.Collection( self.running_phase_state.measurements), running_phase_state=self.running_phase_state, running_test_state=self, ) - return self._running_test_api + self._thread_local.test_api = api + self._running_test_api = api + return api def get_attachment(self, attachment_name: Text) -> Optional[test_record.Attachment]: @@ -287,8 +324,10 @@ def running_phase_context( """ assert not self.running_phase_state, 'Phase already running!' phase_logger = self.state_logger.getChild('phase.' + phase_desc.name) - phase_state = self.running_phase_state = PhaseState.from_descriptor( - phase_desc, self, phase_logger) + phase_state = PhaseState.from_descriptor(phase_desc, self, phase_logger) + self.running_phase_state = phase_state + with self.state_lock: + self._active_phase_states.append(phase_state) self.notify_update() # New phase started. try: yield phase_state @@ -296,19 +335,69 @@ def running_phase_context( phase_state.finalize() self.test_record.add_phase_record(phase_state.phase_record) self.running_phase_state = None - self._running_test_api = None + with self.state_lock: + if phase_state in self._active_phase_states: + self._active_phase_states.remove(phase_state) self.notify_update() # Phase finished. + @contextlib.contextmanager + def concurrent_running_phase_context( + self, + phase_desc: phase_descriptor.PhaseDescriptor) -> Iterator['PhaseState']: + """Create an isolated, thread-safe context for a concurrent PhaseGraph node. + + Args: + phase_desc: openhtf.PhaseDescriptor to start a context for. + + Yields: + PhaseState to track transient state. + """ + phase_logger = self.state_logger.getChild('phase.' + phase_desc.name) + phase_state = PhaseState.from_descriptor(phase_desc, self, phase_logger) + + # Set running_phase_state and _running_test_api for the thread's context + with self.state_lock: + self.running_phase_state = phase_state + self._running_test_api = None + self._active_phase_states.append(phase_state) + self.notify_update() + + try: + yield phase_state + finally: + phase_state.finalize() + with self.state_lock: + self.test_record.add_phase_record(phase_state.phase_record) + if phase_state in self._active_phase_states: + self._active_phase_states.remove(phase_state) + self.notify_update() + self.running_phase_state = None + def as_base_types(self) -> Dict[Text, Any]: """Convert to a dict representation composed exclusively of base types.""" - running_phase_state = None - if self.running_phase_state: - running_phase_state = self.running_phase_state.as_base_types() + with self.state_lock: + running_phase_states = [ + phase.as_base_types() for phase in self._active_phase_states + ] + + if running_phase_states: + running_phase_output = running_phase_states[0] + elif self.running_phase_state: + # Includes manually/legacy set PhaseState. + legacy_basetype = self.running_phase_state.as_base_types() + running_phase_output = legacy_basetype + running_phase_states = [legacy_basetype] + else: + running_phase_output = None + return { 'status': data.convert_to_base_types(self._status), 'test_record': self.test_record.as_base_types(), 'plugs': self.plug_manager.as_base_types(), - 'running_phase_state': running_phase_state, + # TODO(wuchiju): Remove running_phase_state once WebUI migrates to + # running_phase_states. + 'running_phase_state': running_phase_output, + 'running_phase_states': running_phase_states, } def _asdict(self) -> Dict[Text, Any]: @@ -321,7 +410,12 @@ def is_finalized(self) -> bool: def stop_running_phase(self) -> None: """Stops the currently running phase, allowing another phase to run.""" - self.running_phase_state = None + phase_state = self.running_phase_state + if phase_state and phase_state in self._active_phase_states: + with self.state_lock: + if phase_state in self._active_phase_states: + self._active_phase_states.remove(phase_state) + self.running_phase_state = None @property def last_run_phase_name(self) -> Optional[Text]: diff --git a/test/core/phase_graph_test.py b/test/core/phase_graph_test.py new file mode 100644 index 00000000..fb1ad1f7 --- /dev/null +++ b/test/core/phase_graph_test.py @@ -0,0 +1,104 @@ +import unittest + +from openhtf.core import phase_descriptor +from openhtf.core import phase_graph + + +@phase_descriptor.PhaseOptions(name='phase_a') +def phase_a(): + pass + + +@phase_descriptor.PhaseOptions(name='phase_b', prerequisites=['phase_a']) +def phase_b(): + pass + + +@phase_descriptor.PhaseOptions(name='phase_c', prerequisites=['phase_b']) +def phase_c(): + pass + + +@phase_descriptor.PhaseOptions(name='cycle_1', prerequisites=['cycle_2']) +def cycle_1(): + pass + + +@phase_descriptor.PhaseOptions(name='cycle_2', prerequisites=['cycle_1']) +def cycle_2(): + pass + + +@phase_descriptor.PhaseOptions(name='phase_d') +def phase_d(): + pass + + +@phase_descriptor.PhaseOptions(name='phase_e') +def phase_e(): + pass + + +@phase_descriptor.PhaseOptions(name='phase_f') +def phase_f(): + pass + + +@phase_descriptor.PhaseOptions(name='duplicate_phase') +def dup1(): + pass + + +@phase_descriptor.PhaseOptions(name='duplicate_phase', prerequisites=[dup1]) +def dup2(): + pass + + +class PhaseGraphTest(unittest.TestCase): + + def test_topological_sorting(self): + # Provide in random order, must sort to A -> B -> C + graph = phase_graph.PhaseGraph(phase_c, phase_a, phase_b) + self.assertEqual( + [node.name for node in graph.nodes], ['phase_a', 'phase_b', 'phase_c'] + ) + + def test_cyclic_dependency_raises(self): + with self.assertRaises(phase_graph.CyclicDependencyError): + phase_graph.PhaseGraph(cycle_1, cycle_2) + + def test_phase_unreachable_raises(self): + with self.assertRaises(phase_graph.PhaseUnreachableError): + phase_graph.PhaseGraph(phase_c) + + def test_from_edges_construction(self): + # Graph topology: + # + # [A] [B] [C] + # \ / \ / + # v v v v + # [D] [E] + # \ / + # v v + # [F] + graph = phase_graph.PhaseGraph.from_edges([ + phase_graph.PhaseEdge(phase_d, [phase_a, phase_b]), + phase_graph.PhaseEdge(phase_e, [phase_b, phase_c]), + phase_graph.PhaseEdge(phase_f, [phase_d, phase_e]), + ]) + names = [node.name for node in graph.nodes] + self.assertIn('phase_f', names) + self.assertGreater(names.index('phase_d'), names.index('phase_a')) + self.assertGreater(names.index('phase_d'), names.index('phase_b')) + self.assertGreater(names.index('phase_e'), names.index('phase_b')) + self.assertGreater(names.index('phase_e'), names.index('phase_c')) + self.assertGreater(names.index('phase_f'), names.index('phase_d')) + self.assertGreater(names.index('phase_f'), names.index('phase_e')) + + def test_duplicate_phase_names(self): + with self.assertRaises(phase_graph.DuplicatePhaseNameError): + phase_graph.PhaseGraph(dup1, dup2) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_state_test.py b/test/test_state_test.py index 1ec94893..54a0c381 100644 --- a/test/test_state_test.py +++ b/test/test_state_test.py @@ -101,6 +101,7 @@ def test_phase(): 'plug_states': {}, }, 'running_phase_state': copy.deepcopy(PHASE_STATE_BASE_TYPE_INITIAL), + 'running_phase_states': [copy.deepcopy(PHASE_STATE_BASE_TYPE_INITIAL)], } @@ -220,6 +221,8 @@ def test_test_state_cache(self): descriptor_id = basetypes['running_phase_state']['descriptor_id'] expected_initial_basetypes['running_phase_state']['descriptor_id'] = ( descriptor_id) + expected_initial_basetypes['running_phase_states'][0]['descriptor_id'] = ( + descriptor_id) self.assertEqual(expected_initial_basetypes, basetypes) self.running_phase_state._finalize_measurements() self.test_record.add_phase_record(self.running_phase_state.phase_record)