Skip to content

Commit 9c5f990

Browse files
committed
Fixups from review
1 parent 6830160 commit 9c5f990

4 files changed

Lines changed: 155 additions & 5 deletions

File tree

python/lsst/pipe/base/pipelineIR.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949

5050
import yaml
5151
from lsst.resources import ResourcePath, ResourcePathExpression
52-
from lsst.utils.introspection import find_outside_stacklevel
5352
from lsst.utils import doImportType
53+
from lsst.utils.introspection import find_outside_stacklevel
5454

5555

5656
class PipelineSubsetCtrl(enum.Enum):

python/lsst/pipe/base/tests/pipelineIRTestClasses.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# You should have received a copy of the GNU General Public License
2626
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2727

28-
"""Module defining PipelineIR test classes
28+
"""Module defining PipelineIR test classes.
2929
"""
3030

3131
from __future__ import annotations
@@ -34,7 +34,7 @@
3434

3535

3636
class ModuleA:
37-
"""PipelineIR test class for importing"""
37+
"""PipelineIR test class for importing."""
3838

3939
pass
4040

@@ -43,6 +43,6 @@ class ModuleA:
4343

4444

4545
class ModuleAReplace:
46-
"""PipelineIR test class for importing"""
46+
"""PipelineIR test class for importing."""
4747

4848
pass

python/lsst/pipe/base/tests/simpleQGraph.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,114 @@ def makeTask(
199199
return task
200200

201201

202+
class SubTaskConnections(
203+
PipelineTaskConnections,
204+
dimensions=("instrument", "detector"),
205+
defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
206+
):
207+
"""Connections for SubTask, has one input and two outputs,
208+
plus one init output.
209+
"""
210+
211+
input = cT.Input(
212+
name="add_dataset{in_tmpl}",
213+
dimensions=["instrument", "detector"],
214+
storageClass="NumpyArray",
215+
doc="Input dataset type for this task",
216+
)
217+
output = cT.Output(
218+
name="add_dataset{out_tmpl}",
219+
dimensions=["instrument", "detector"],
220+
storageClass="NumpyArray",
221+
doc="Output dataset type for this task",
222+
)
223+
output2 = cT.Output(
224+
name="add2_dataset{out_tmpl}",
225+
dimensions=["instrument", "detector"],
226+
storageClass="NumpyArray",
227+
doc="Output dataset type for this task",
228+
)
229+
initout = cT.InitOutput(
230+
name="add_init_output{out_tmpl}",
231+
storageClass="NumpyArray",
232+
doc="Init Output dataset type for this task",
233+
)
234+
235+
236+
class SubTaskConfig(PipelineTaskConfig, pipelineConnections=SubTaskConnections):
237+
"""Config for SubTask."""
238+
239+
subtract = pexConfig.Field[int](doc="amount to subtract", default=3)
240+
241+
242+
class SubTask(PipelineTask):
243+
"""Trivial PipelineTask for testing, has some extras useful for specific
244+
unit tests.
245+
"""
246+
247+
ConfigClass = SubTaskConfig
248+
_DefaultName = "sub_task"
249+
250+
initout = numpy.array([999])
251+
"""InitOutputs for this task"""
252+
253+
taskFactory: SubTaskFactoryMock | None = None
254+
"""Factory that makes instances"""
255+
256+
def run(self, input: int) -> Struct:
257+
if self.taskFactory:
258+
# do some bookkeeping
259+
if self.taskFactory.stopAt == self.taskFactory.countExec:
260+
raise RuntimeError("pretend something bad happened")
261+
self.taskFactory.countExec -= 1
262+
263+
self.config = cast(SubTaskConfig, self.config)
264+
self.metadata.add("sub", self.config.subtract)
265+
output = input - self.config.subtract
266+
output2 = output + self.config.subtract
267+
_LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
268+
return Struct(output=output, output2=output2)
269+
270+
271+
class SubTaskFactoryMock(TaskFactory):
272+
"""Special task factory that instantiates AddTask.
273+
274+
It also defines some bookkeeping variables used by SubTask to report
275+
progress to unit tests.
276+
277+
Parameters
278+
----------
279+
stopAt : `int`, optional
280+
Number of times to call `run` before stopping.
281+
"""
282+
283+
def __init__(self, stopAt: int = -1):
284+
self.countExec = 100 # reduced by SubTask
285+
self.stopAt = stopAt # AddTask raises exception at this call to run()
286+
287+
def makeTask(
288+
self,
289+
task_node: TaskDef | TaskNode,
290+
/,
291+
butler: LimitedButler,
292+
initInputRefs: Iterable[DatasetRef] | None,
293+
) -> PipelineTask:
294+
if isinstance(task_node, TaskDef):
295+
# TODO: remove support on DM-40443.
296+
warnings.warn(
297+
"Passing TaskDef to TaskFactory is deprecated and will not be supported after v27.",
298+
FutureWarning,
299+
find_outside_stacklevel("lsst.pipe.base"),
300+
)
301+
task_class = task_node.taskClass
302+
assert task_class is not None
303+
else:
304+
task_class = task_node.task_class
305+
task = task_class(config=task_node.config, initInputs=None, name=task_node.label)
306+
task.taskFactory = self # type: ignore
307+
return task
308+
309+
202310
def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef] | PipelineGraph) -> None:
203311
"""Register all dataset types used by tasks in a registry.
204312

tests/test_pipeline.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import lsst.utils.tests
3636
from lsst.pipe.base import LabelSpecifier, Pipeline, TaskDef
3737
from lsst.pipe.base.pipelineIR import LabeledSubset
38-
from lsst.pipe.base.tests.simpleQGraph import AddTask, makeSimplePipeline
38+
from lsst.pipe.base.tests.simpleQGraph import AddTask, SubTask, makeSimplePipeline
3939

4040

4141
class PipelineTestCase(unittest.TestCase):
@@ -131,6 +131,48 @@ def testMergingPipelines(self):
131131
pipeline1.mergePipeline(pipeline2)
132132
self.assertEqual(pipeline1._pipelineIR.tasks.keys(), {"task0", "task1", "task2", "task3"})
133133

134+
# Test merging pipelines with ambiguous tasks
135+
pipeline1 = makeSimplePipeline(2)
136+
pipeline2 = makeSimplePipeline(2)
137+
pipeline2.addTask(SubTask, "task1")
138+
pipeline2.mergePipeline(pipeline1)
139+
140+
# Now merge in another pipeline with a config applied.
141+
pipeline3 = makeSimplePipeline(2)
142+
pipeline3.addTask(SubTask, "task1")
143+
pipeline3.addConfigOverride("task1", "subtract", 10)
144+
pipeline3.mergePipeline(pipeline2)
145+
graph = pipeline3.to_graph()
146+
# assert equality from the graph to trigger ambiquity resolution
147+
self.assertEqual(graph.tasks["task1"].config.subtract, 10)
148+
149+
# Now change the order of the merging
150+
pipeline1 = makeSimplePipeline(2)
151+
pipeline2 = makeSimplePipeline(2)
152+
pipeline2.addTask(SubTask, "task1")
153+
pipeline3 = makeSimplePipeline(2)
154+
pipeline3.mergePipeline(pipeline2)
155+
pipeline3.mergePipeline(pipeline1)
156+
graph = pipeline3.to_graph()
157+
# assert equality from the graph to trigger ambiquity resolution
158+
self.assertEqual(graph.tasks["task1"].config.addend, 3)
159+
160+
# Now do two ambiguous chains
161+
pipeline1 = makeSimplePipeline(2)
162+
pipeline2 = makeSimplePipeline(2)
163+
pipeline2.addTask(SubTask, "task1")
164+
pipeline2.addConfigOverride("task1", "subtract", 10)
165+
pipeline2.mergePipeline(pipeline1)
166+
167+
pipeline3 = makeSimplePipeline(2)
168+
pipeline4 = makeSimplePipeline(2)
169+
pipeline4.addTask(SubTask, "task1")
170+
pipeline4.addConfigOverride("task1", "subtract", 7)
171+
pipeline4.mergePipeline(pipeline3)
172+
graph = pipeline4.to_graph()
173+
# assert equality from the graph to trigger ambiquity resolution
174+
self.assertEqual(graph.tasks["task1"].config.subtract, 7)
175+
134176
def testFindingSubset(self):
135177
pipeline = makeSimplePipeline(2)
136178
pipeline._pipelineIR.labeled_subsets["test1"] = LabeledSubset("test1", set(), None)

0 commit comments

Comments
 (0)