Skip to content

Commit 3263e1d

Browse files
natelusttimj
authored andcommitted
Allow conflicting module paths to be resolved late
When instantiating a pipeline that contains imports, allow the system to resolve conflicting module paths of tasks late. This allows the same task imported from a different path to still resolve as the same.
1 parent e430572 commit 3263e1d

5 files changed

Lines changed: 141 additions & 14 deletions

File tree

python/lsst/pipe/base/pipeline.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,11 @@ def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
803803
return
804804
if label not in self._pipelineIR.tasks:
805805
raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
806-
self._pipelineIR.tasks[label].add_or_update_config(newConfig)
806+
match self._pipelineIR.tasks[label]:
807+
case pipelineIR.TaskIR() as task:
808+
task.add_or_update_config(newConfig)
809+
case pipelineIR._AmbigousTask() as ambig_task:
810+
ambig_task.tasks[-1].add_or_update_config(newConfig)
807811

808812
def write_to_uri(self, uri: ResourcePathExpression) -> None:
809813
"""Write the pipeline to a file or directory.
@@ -845,6 +849,7 @@ def to_graph(
845849
graph : `pipeline_graph.PipelineGraph`
846850
Representation of the pipeline as a graph.
847851
"""
852+
self._pipelineIR.resolve_task_ambiguity()
848853
instrument_class_name = self._pipelineIR.instrument
849854
data_id = {}
850855
if instrument_class_name is not None:
@@ -888,7 +893,8 @@ def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) ->
888893
"""
889894
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
890895
raise NameError(f"Label {label} does not appear in this pipeline")
891-
taskClass: type[PipelineTask] = doImportType(taskIR.klass)
896+
# type ignore here because all ambiguity should be resolved
897+
taskClass: type[PipelineTask] = doImportType(taskIR.klass) # type: ignore
892898
config = taskClass.ConfigClass()
893899
instrument: PipeBaseInstrument | None = None
894900
if (instrumentName := self._pipelineIR.instrument) is not None:
@@ -897,7 +903,8 @@ def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) ->
897903
config.applyConfigOverrides(
898904
instrument,
899905
getattr(taskClass, "_DefaultName", ""),
900-
taskIR.config,
906+
# type ignore here because all ambiguity should be resolved
907+
taskIR.config, # type: ignore
901908
self._pipelineIR.parameters,
902909
label,
903910
)

python/lsst/pipe/base/pipelineIR.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@
4545
from collections import Counter
4646
from collections.abc import Generator, Hashable, Iterable, MutableMapping
4747
from dataclasses import dataclass, field
48-
from typing import Any, Literal
48+
from typing import Any, Literal, cast
4949

5050
import yaml
5151

5252
from lsst.resources import ResourcePath, ResourcePathExpression
5353
from lsst.utils.introspection import find_outside_stacklevel
54+
from lsst.utils import doImportType
5455

5556

5657
class PipelineSubsetCtrl(enum.Enum):
@@ -443,6 +444,34 @@ def __eq__(self, other: object) -> bool:
443444
return all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config"))
444445

445446

447+
@dataclass
448+
class _AmbigousTask:
449+
"""Representation of tasks which may have conflicting task classes."""
450+
451+
tasks: list[TaskIR]
452+
"""TaskIR objects that need to be compaired late."""
453+
454+
def resolve(self) -> TaskIR:
455+
true_taskIR = self.tasks[0]
456+
task_class = doImportType(true_taskIR.klass)
457+
# need to find out if they are all actually the same
458+
for tmp_taskIR in self.tasks[1:]:
459+
tmp_task_class = doImportType(tmp_taskIR.klass)
460+
if tmp_task_class is task_class:
461+
if tmp_taskIR.config is None:
462+
continue
463+
for config in tmp_taskIR.config:
464+
true_taskIR.add_or_update_config(config)
465+
else:
466+
true_taskIR = tmp_taskIR
467+
task_class = tmp_task_class
468+
return true_taskIR
469+
470+
def to_primitives(self) -> dict[str, str | list[dict]]:
471+
true_task = self.resolve()
472+
return true_task.to_primitives()
473+
474+
446475
@dataclass
447476
class ImportIR:
448477
"""An intermediate representation of imported pipelines."""
@@ -778,7 +807,7 @@ def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None:
778807
existing in this object.
779808
"""
780809
# integrate any imported pipelines
781-
accumulate_tasks: dict[str, TaskIR] = {}
810+
accumulate_tasks: dict[str, TaskIR | _AmbigousTask] = {}
782811
accumulate_labeled_subsets: dict[str, LabeledSubset] = {}
783812
accumulated_parameters = ParametersIR({})
784813
accumulated_steps: dict[str, StepIR] = {}
@@ -842,17 +871,39 @@ def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None:
842871
for label, task in self.tasks.items():
843872
if label not in accumulate_tasks:
844873
accumulate_tasks[label] = task
845-
elif accumulate_tasks[label].klass == task.klass:
846-
if task.config is not None:
847-
for config in task.config:
848-
accumulate_tasks[label].add_or_update_config(config)
849874
else:
850-
accumulate_tasks[label] = task
851-
self.tasks: dict[str, TaskIR] = accumulate_tasks
875+
match (accumulate_tasks[label], task):
876+
case (TaskIR() as taskir_obj, TaskIR() as ctask) if taskir_obj.klass == ctask.klass:
877+
if ctask.config is not None:
878+
for config in ctask.config:
879+
taskir_obj.add_or_update_config(config)
880+
case (TaskIR(klass=klass) as taskir_obj, TaskIR() as ctask) if klass != ctask.klass:
881+
accumulate_tasks[label] = _AmbigousTask([taskir_obj, ctask])
882+
case (_AmbigousTask(ambig_list), TaskIR() as ctask):
883+
ambig_list.append(ctask)
884+
case (TaskIR() as taskir_obj, _AmbigousTask(ambig_list)):
885+
accumulate_tasks[label] = _AmbigousTask([taskir_obj] + ambig_list)
886+
case (_AmbigousTask(existing_ambig_list), _AmbigousTask(new_ambig_list)):
887+
existing_ambig_list.extend(new_ambig_list)
888+
889+
self.tasks: MutableMapping[str, TaskIR | _AmbigousTask] = accumulate_tasks
852890
accumulated_parameters.update(self.parameters)
853891
self.parameters = accumulated_parameters
854892
self.steps = list(accumulated_steps.values())
855893

894+
def resolve_task_ambiguity(self) -> None:
895+
new_tasks: dict[str, TaskIR] = {}
896+
for label, task in self.tasks.items():
897+
match task:
898+
case TaskIR():
899+
new_tasks[label] = task
900+
case _AmbigousTask():
901+
new_tasks[label] = task.resolve()
902+
# Do a cast here, because within this function body we want the
903+
# protection that all the tasks are TaskIR objects, but for the
904+
# task level variable, it must stay the same mixed dictionary.
905+
self.tasks = cast(dict[str, TaskIR | _AmbigousTask], new_tasks)
906+
856907
def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
857908
"""Process the tasks portion of the loaded yaml document
858909
@@ -870,6 +921,7 @@ def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
870921
if "parameters" in tmp_tasks:
871922
raise ValueError("parameters is a reserved word and cannot be used as a task label")
872923

924+
definition: str | dict[str, Any]
873925
for label, definition in tmp_tasks.items():
874926
if isinstance(definition, str):
875927
definition = {"class": definition}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# This file is part of pipe_base.
2+
#
3+
# Developed for the LSST Data Management System.
4+
# This product includes software developed by the LSST Project
5+
# (http://www.lsst.org).
6+
# See the COPYRIGHT file at the top-level directory of this distribution
7+
# for details of code ownership.
8+
#
9+
# This software is dual licensed under the GNU General Public License and also
10+
# under a 3-clause BSD license. Recipients may choose which of these licenses
11+
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12+
# respectively. If you choose the GPL option then the following text applies
13+
# (but note that there is still no warranty even if you opt for BSD instead):
14+
#
15+
# This program is free software: you can redistribute it and/or modify
16+
# it under the terms of the GNU General Public License as published by
17+
# the Free Software Foundation, either version 3 of the License, or
18+
# (at your option) any later version.
19+
#
20+
# This program is distributed in the hope that it will be useful,
21+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
22+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23+
# GNU General Public License for more details.
24+
#
25+
# You should have received a copy of the GNU General Public License
26+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
27+
28+
"""Module defining PipelineIR test classes
29+
"""
30+
31+
from __future__ import annotations
32+
33+
__all__ = ("ModuleA", "ModuleAAlias", "ModuleAReplace")
34+
35+
36+
class ModuleA:
37+
"""PipelineIR test class for importing"""
38+
39+
pass
40+
41+
42+
ModuleAAlias = ModuleA
43+
44+
45+
class ModuleAReplace:
46+
"""PipelineIR test class for importing"""
47+
48+
pass

tests/testPipeline2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ parameters:
44
value3: valueC
55
tasks:
66
modA:
7-
class: "test.moduleA"
7+
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleA"
88
config:
99
value1: 1
1010
subsets:

tests/test_pipelineIR.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,31 @@ def testImportParsing(self):
230230
- $TESTDIR/testPipeline2.yaml
231231
tasks:
232232
modA:
233-
class: "test.moduleA"
233+
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleA"
234234
config:
235235
value2: 2
236236
"""
237237
)
238238
pipeline = PipelineIR.from_string(pipeline_str)
239+
pipeline.resolve_task_ambiguity()
240+
self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value1": 1, "value2": 2})
241+
242+
# Test that configs are imported when defining the same task again
243+
# that is aliased with the same label
244+
pipeline_str = textwrap.dedent(
245+
"""
246+
description: Test Pipeline
247+
imports:
248+
- $TESTDIR/testPipeline2.yaml
249+
tasks:
250+
modA:
251+
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleAAlias"
252+
config:
253+
value2: 2
254+
"""
255+
)
256+
pipeline = PipelineIR.from_string(pipeline_str)
257+
pipeline.resolve_task_ambiguity()
239258
self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value1": 1, "value2": 2})
240259

241260
# Test that configs are not imported when redefining the task
@@ -247,12 +266,13 @@ def testImportParsing(self):
247266
- $TESTDIR/testPipeline2.yaml
248267
tasks:
249268
modA:
250-
class: "test.moduleAReplace"
269+
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleAReplace"
251270
config:
252271
value2: 2
253272
"""
254273
)
255274
pipeline = PipelineIR.from_string(pipeline_str)
275+
pipeline.resolve_task_ambiguity()
256276
self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value2": 2})
257277

258278
# Test that named subsets are imported

0 commit comments

Comments
 (0)