@@ -199,6 +199,114 @@ def makeTask(
199199 return task
200200
201201
202+ class SubTaskConnections (
203+ PipelineTaskConnections ,
204+ dimensions = ("instrument" , "detector" ),
205+ defaultTemplates = {"in_tmpl" : "_in" , "out_tmpl" : "_out" },
206+ ):
207+ """Connections for SubTask, has one input and two outputs,
208+ plus one init output.
209+ """
210+
211+ input = cT .Input (
212+ name = "add_dataset{in_tmpl}" ,
213+ dimensions = ["instrument" , "detector" ],
214+ storageClass = "NumpyArray" ,
215+ doc = "Input dataset type for this task" ,
216+ )
217+ output = cT .Output (
218+ name = "add_dataset{out_tmpl}" ,
219+ dimensions = ["instrument" , "detector" ],
220+ storageClass = "NumpyArray" ,
221+ doc = "Output dataset type for this task" ,
222+ )
223+ output2 = cT .Output (
224+ name = "add2_dataset{out_tmpl}" ,
225+ dimensions = ["instrument" , "detector" ],
226+ storageClass = "NumpyArray" ,
227+ doc = "Output dataset type for this task" ,
228+ )
229+ initout = cT .InitOutput (
230+ name = "add_init_output{out_tmpl}" ,
231+ storageClass = "NumpyArray" ,
232+ doc = "Init Output dataset type for this task" ,
233+ )
234+
235+
236+ class SubTaskConfig (PipelineTaskConfig , pipelineConnections = SubTaskConnections ):
237+ """Config for SubTask."""
238+
239+ subtract = pexConfig .Field [int ](doc = "amount to subtract" , default = 3 )
240+
241+
242+ class SubTask (PipelineTask ):
243+ """Trivial PipelineTask for testing, has some extras useful for specific
244+ unit tests.
245+ """
246+
247+ ConfigClass = SubTaskConfig
248+ _DefaultName = "sub_task"
249+
250+ initout = numpy .array ([999 ])
251+ """InitOutputs for this task"""
252+
253+ taskFactory : SubTaskFactoryMock | None = None
254+ """Factory that makes instances"""
255+
256+ def run (self , input : int ) -> Struct :
257+ if self .taskFactory :
258+ # do some bookkeeping
259+ if self .taskFactory .stopAt == self .taskFactory .countExec :
260+ raise RuntimeError ("pretend something bad happened" )
261+ self .taskFactory .countExec -= 1
262+
263+ self .config = cast (SubTaskConfig , self .config )
264+ self .metadata .add ("sub" , self .config .subtract )
265+ output = input - self .config .subtract
266+ output2 = output + self .config .subtract
267+ _LOG .info ("input = %s, output = %s, output2 = %s" , input , output , output2 )
268+ return Struct (output = output , output2 = output2 )
269+
270+
271+ class SubTaskFactoryMock (TaskFactory ):
272+ """Special task factory that instantiates AddTask.
273+
274+ It also defines some bookkeeping variables used by SubTask to report
275+ progress to unit tests.
276+
277+ Parameters
278+ ----------
279+ stopAt : `int`, optional
280+ Number of times to call `run` before stopping.
281+ """
282+
283+ def __init__ (self , stopAt : int = - 1 ):
284+ self .countExec = 100 # reduced by SubTask
285+ self .stopAt = stopAt # AddTask raises exception at this call to run()
286+
287+ def makeTask (
288+ self ,
289+ task_node : TaskDef | TaskNode ,
290+ / ,
291+ butler : LimitedButler ,
292+ initInputRefs : Iterable [DatasetRef ] | None ,
293+ ) -> PipelineTask :
294+ if isinstance (task_node , TaskDef ):
295+ # TODO: remove support on DM-40443.
296+ warnings .warn (
297+ "Passing TaskDef to TaskFactory is deprecated and will not be supported after v27." ,
298+ FutureWarning ,
299+ find_outside_stacklevel ("lsst.pipe.base" ),
300+ )
301+ task_class = task_node .taskClass
302+ assert task_class is not None
303+ else :
304+ task_class = task_node .task_class
305+ task = task_class (config = task_node .config , initInputs = None , name = task_node .label )
306+ task .taskFactory = self # type: ignore
307+ return task
308+
309+
202310def registerDatasetTypes (registry : Registry , pipeline : Pipeline | Iterable [TaskDef ] | PipelineGraph ) -> None :
203311 """Register all dataset types used by tasks in a registry.
204312
0 commit comments