Skip to content

Commit c79df71

Browse files
natelusttimj
authored andcommitted
Fixups from review
1 parent 3263e1d commit c79df71

4 files changed

Lines changed: 144 additions & 6 deletions

File tree

python/lsst/pipe/base/pipelineIR.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
import yaml
5151

5252
from lsst.resources import ResourcePath, ResourcePathExpression
53-
from lsst.utils.introspection import find_outside_stacklevel
5453
from lsst.utils import doImportType
54+
from lsst.utils.introspection import find_outside_stacklevel
5555

5656

5757
class PipelineSubsetCtrl(enum.Enum):

python/lsst/pipe/base/tests/pipelineIRTestClasses.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,15 @@
2525
# You should have received a copy of the GNU General Public License
2626
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2727

28-
"""Module defining PipelineIR test classes
29-
"""
28+
"""Module defining PipelineIR test classes."""
3029

3130
from __future__ import annotations
3231

3332
__all__ = ("ModuleA", "ModuleAAlias", "ModuleAReplace")
3433

3534

3635
class ModuleA:
37-
"""PipelineIR test class for importing"""
36+
"""PipelineIR test class for importing."""
3837

3938
pass
4039

@@ -43,6 +42,6 @@ class ModuleA:
4342

4443

4544
class ModuleAReplace:
46-
"""PipelineIR test class for importing"""
45+
"""PipelineIR test class for importing."""
4746

4847
pass

python/lsst/pipe/base/tests/simpleQGraph.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,103 @@ def makeTask(
188188
return task
189189

190190

191+
class SubTaskConnections(
192+
PipelineTaskConnections,
193+
dimensions=("instrument", "detector"),
194+
defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
195+
):
196+
"""Connections for SubTask, has one input and two outputs,
197+
plus one init output.
198+
"""
199+
200+
input = cT.Input(
201+
name="add_dataset{in_tmpl}",
202+
dimensions=["instrument", "detector"],
203+
storageClass="NumpyArray",
204+
doc="Input dataset type for this task",
205+
)
206+
output = cT.Output(
207+
name="add_dataset{out_tmpl}",
208+
dimensions=["instrument", "detector"],
209+
storageClass="NumpyArray",
210+
doc="Output dataset type for this task",
211+
)
212+
output2 = cT.Output(
213+
name="add2_dataset{out_tmpl}",
214+
dimensions=["instrument", "detector"],
215+
storageClass="NumpyArray",
216+
doc="Output dataset type for this task",
217+
)
218+
initout = cT.InitOutput(
219+
name="add_init_output{out_tmpl}",
220+
storageClass="NumpyArray",
221+
doc="Init Output dataset type for this task",
222+
)
223+
224+
225+
class SubTaskConfig(PipelineTaskConfig, pipelineConnections=SubTaskConnections):
226+
"""Config for SubTask."""
227+
228+
subtract = pexConfig.Field[int](doc="amount to subtract", default=3)
229+
230+
231+
class SubTask(PipelineTask):
232+
"""Trivial PipelineTask for testing, has some extras useful for specific
233+
unit tests.
234+
"""
235+
236+
ConfigClass = SubTaskConfig
237+
_DefaultName = "sub_task"
238+
239+
initout = numpy.array([999])
240+
"""InitOutputs for this task"""
241+
242+
taskFactory: SubTaskFactoryMock | None = None
243+
"""Factory that makes instances"""
244+
245+
def run(self, input: int) -> Struct:
246+
if self.taskFactory:
247+
# do some bookkeeping
248+
if self.taskFactory.stopAt == self.taskFactory.countExec:
249+
raise RuntimeError("pretend something bad happened")
250+
self.taskFactory.countExec -= 1
251+
252+
self.config = cast(SubTaskConfig, self.config)
253+
self.metadata.add("sub", self.config.subtract)
254+
output = input - self.config.subtract
255+
output2 = output + self.config.subtract
256+
_LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
257+
return Struct(output=output, output2=output2)
258+
259+
260+
class SubTaskFactoryMock(TaskFactory):
261+
"""Special task factory that instantiates AddTask.
262+
263+
It also defines some bookkeeping variables used by SubTask to report
264+
progress to unit tests.
265+
266+
Parameters
267+
----------
268+
stopAt : `int`, optional
269+
Number of times to call `run` before stopping.
270+
"""
271+
272+
def __init__(self, stopAt: int = -1):
273+
self.countExec = 100 # reduced by SubTask
274+
self.stopAt = stopAt # AddTask raises exception at this call to run()
275+
276+
def makeTask(
277+
self,
278+
task_node: TaskNode,
279+
/,
280+
butler: LimitedButler,
281+
initInputRefs: Iterable[DatasetRef] | None,
282+
) -> PipelineTask:
283+
task = task_node.task_class(config=task_node.config, initInputs=None, name=task_node.label)
284+
task.taskFactory = self # type: ignore
285+
return task
286+
287+
191288
def registerDatasetTypes(registry: Registry, pipeline: Pipeline | PipelineGraph) -> None:
192289
"""Register all dataset types used by tasks in a registry.
193290

tests/test_pipeline.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import lsst.utils.tests
3535
from lsst.pipe.base import LabelSpecifier, Pipeline, TaskDef
3636
from lsst.pipe.base.pipelineIR import LabeledSubset
37-
from lsst.pipe.base.tests.simpleQGraph import AddTask, makeSimplePipeline
37+
from lsst.pipe.base.tests.simpleQGraph import AddTask, SubTask, makeSimplePipeline
3838

3939

4040
class PipelineTestCase(unittest.TestCase):
@@ -130,6 +130,48 @@ def testMergingPipelines(self):
130130
pipeline1.mergePipeline(pipeline2)
131131
self.assertEqual(pipeline1._pipelineIR.tasks.keys(), {"task0", "task1", "task2", "task3"})
132132

133+
# Test merging pipelines with ambiguous tasks
134+
pipeline1 = makeSimplePipeline(2)
135+
pipeline2 = makeSimplePipeline(2)
136+
pipeline2.addTask(SubTask, "task1")
137+
pipeline2.mergePipeline(pipeline1)
138+
139+
# Now merge in another pipeline with a config applied.
140+
pipeline3 = makeSimplePipeline(2)
141+
pipeline3.addTask(SubTask, "task1")
142+
pipeline3.addConfigOverride("task1", "subtract", 10)
143+
pipeline3.mergePipeline(pipeline2)
144+
graph = pipeline3.to_graph()
145+
# assert equality from the graph to trigger ambiquity resolution
146+
self.assertEqual(graph.tasks["task1"].config.subtract, 10)
147+
148+
# Now change the order of the merging
149+
pipeline1 = makeSimplePipeline(2)
150+
pipeline2 = makeSimplePipeline(2)
151+
pipeline2.addTask(SubTask, "task1")
152+
pipeline3 = makeSimplePipeline(2)
153+
pipeline3.mergePipeline(pipeline2)
154+
pipeline3.mergePipeline(pipeline1)
155+
graph = pipeline3.to_graph()
156+
# assert equality from the graph to trigger ambiquity resolution
157+
self.assertEqual(graph.tasks["task1"].config.addend, 3)
158+
159+
# Now do two ambiguous chains
160+
pipeline1 = makeSimplePipeline(2)
161+
pipeline2 = makeSimplePipeline(2)
162+
pipeline2.addTask(SubTask, "task1")
163+
pipeline2.addConfigOverride("task1", "subtract", 10)
164+
pipeline2.mergePipeline(pipeline1)
165+
166+
pipeline3 = makeSimplePipeline(2)
167+
pipeline4 = makeSimplePipeline(2)
168+
pipeline4.addTask(SubTask, "task1")
169+
pipeline4.addConfigOverride("task1", "subtract", 7)
170+
pipeline4.mergePipeline(pipeline3)
171+
graph = pipeline4.to_graph()
172+
# assert equality from the graph to trigger ambiquity resolution
173+
self.assertEqual(graph.tasks["task1"].config.subtract, 7)
174+
133175
def testFindingSubset(self):
134176
pipeline = makeSimplePipeline(2)
135177
pipeline._pipelineIR.labeled_subsets["test1"] = LabeledSubset("test1", set(), None)

0 commit comments

Comments
 (0)