4545from collections import Counter
4646from collections .abc import Generator , Hashable , Iterable , MutableMapping
4747from dataclasses import dataclass , field
48- from typing import Any , Literal
48+ from typing import Any , Literal , cast
4949
5050import yaml
5151
5252from lsst .resources import ResourcePath , ResourcePathExpression
5353from lsst .utils .introspection import find_outside_stacklevel
54+ from lsst .utils import doImportType
5455
5556
5657class PipelineSubsetCtrl (enum .Enum ):
@@ -443,6 +444,34 @@ def __eq__(self, other: object) -> bool:
443444 return all (getattr (self , attr ) == getattr (other , attr ) for attr in ("label" , "klass" , "config" ))
444445
445446
447+ @dataclass
448+ class _AmbigousTask :
449+ """Representation of tasks which may have conflicting task classes."""
450+
451+ tasks : list [TaskIR ]
452+ """TaskIR objects that need to be compaired late."""
453+
454+ def resolve (self ) -> TaskIR :
455+ true_taskIR = self .tasks [0 ]
456+ task_class = doImportType (true_taskIR .klass )
457+ # need to find out if they are all actually the same
458+ for tmp_taskIR in self .tasks [1 :]:
459+ tmp_task_class = doImportType (tmp_taskIR .klass )
460+ if tmp_task_class is task_class :
461+ if tmp_taskIR .config is None :
462+ continue
463+ for config in tmp_taskIR .config :
464+ true_taskIR .add_or_update_config (config )
465+ else :
466+ true_taskIR = tmp_taskIR
467+ task_class = tmp_task_class
468+ return true_taskIR
469+
470+ def to_primitives (self ) -> dict [str , str | list [dict ]]:
471+ true_task = self .resolve ()
472+ return true_task .to_primitives ()
473+
474+
446475@dataclass
447476class ImportIR :
448477 """An intermediate representation of imported pipelines."""
@@ -778,7 +807,7 @@ def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None:
778807 existing in this object.
779808 """
780809 # integrate any imported pipelines
781- accumulate_tasks : dict [str , TaskIR ] = {}
810+ accumulate_tasks : dict [str , TaskIR | _AmbigousTask ] = {}
782811 accumulate_labeled_subsets : dict [str , LabeledSubset ] = {}
783812 accumulated_parameters = ParametersIR ({})
784813 accumulated_steps : dict [str , StepIR ] = {}
@@ -842,17 +871,39 @@ def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None:
842871 for label , task in self .tasks .items ():
843872 if label not in accumulate_tasks :
844873 accumulate_tasks [label ] = task
845- elif accumulate_tasks [label ].klass == task .klass :
846- if task .config is not None :
847- for config in task .config :
848- accumulate_tasks [label ].add_or_update_config (config )
849874 else :
850- accumulate_tasks [label ] = task
851- self .tasks : dict [str , TaskIR ] = accumulate_tasks
875+ match (accumulate_tasks [label ], task ):
876+ case (TaskIR () as taskir_obj , TaskIR () as ctask ) if taskir_obj .klass == ctask .klass :
877+ if ctask .config is not None :
878+ for config in ctask .config :
879+ taskir_obj .add_or_update_config (config )
880+ case (TaskIR (klass = klass ) as taskir_obj , TaskIR () as ctask ) if klass != ctask .klass :
881+ accumulate_tasks [label ] = _AmbigousTask ([taskir_obj , ctask ])
882+ case (_AmbigousTask (ambig_list ), TaskIR () as ctask ):
883+ ambig_list .append (ctask )
884+ case (TaskIR () as taskir_obj , _AmbigousTask (ambig_list )):
885+ accumulate_tasks [label ] = _AmbigousTask ([taskir_obj ] + ambig_list )
886+ case (_AmbigousTask (existing_ambig_list ), _AmbigousTask (new_ambig_list )):
887+ existing_ambig_list .extend (new_ambig_list )
888+
889+ self .tasks : MutableMapping [str , TaskIR | _AmbigousTask ] = accumulate_tasks
852890 accumulated_parameters .update (self .parameters )
853891 self .parameters = accumulated_parameters
854892 self .steps = list (accumulated_steps .values ())
855893
894+ def resolve_task_ambiguity (self ) -> None :
895+ new_tasks : dict [str , TaskIR ] = {}
896+ for label , task in self .tasks .items ():
897+ match task :
898+ case TaskIR ():
899+ new_tasks [label ] = task
900+ case _AmbigousTask ():
901+ new_tasks [label ] = task .resolve ()
902+ # Do a cast here, because within this function body we want the
903+ # protection that all the tasks are TaskIR objects, but for the
904+ # task level variable, it must stay the same mixed dictionary.
905+ self .tasks = cast (dict [str , TaskIR | _AmbigousTask ], new_tasks )
906+
856907 def _read_tasks (self , loaded_yaml : dict [str , Any ]) -> None :
857908 """Process the tasks portion of the loaded yaml document
858909
@@ -870,6 +921,7 @@ def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
870921 if "parameters" in tmp_tasks :
871922 raise ValueError ("parameters is a reserved word and cannot be used as a task label" )
872923
924+ definition : str | dict [str , Any ]
873925 for label , definition in tmp_tasks .items ():
874926 if isinstance (definition , str ):
875927 definition = {"class" : definition }
0 commit comments