@@ -188,6 +188,103 @@ def makeTask(
188188 return task
189189
190190
191+ class SubTaskConnections (
192+ PipelineTaskConnections ,
193+ dimensions = ("instrument" , "detector" ),
194+ defaultTemplates = {"in_tmpl" : "_in" , "out_tmpl" : "_out" },
195+ ):
196+ """Connections for SubTask, has one input and two outputs,
197+ plus one init output.
198+ """
199+
200+ input = cT .Input (
201+ name = "add_dataset{in_tmpl}" ,
202+ dimensions = ["instrument" , "detector" ],
203+ storageClass = "NumpyArray" ,
204+ doc = "Input dataset type for this task" ,
205+ )
206+ output = cT .Output (
207+ name = "add_dataset{out_tmpl}" ,
208+ dimensions = ["instrument" , "detector" ],
209+ storageClass = "NumpyArray" ,
210+ doc = "Output dataset type for this task" ,
211+ )
212+ output2 = cT .Output (
213+ name = "add2_dataset{out_tmpl}" ,
214+ dimensions = ["instrument" , "detector" ],
215+ storageClass = "NumpyArray" ,
216+ doc = "Output dataset type for this task" ,
217+ )
218+ initout = cT .InitOutput (
219+ name = "add_init_output{out_tmpl}" ,
220+ storageClass = "NumpyArray" ,
221+ doc = "Init Output dataset type for this task" ,
222+ )
223+
224+
225+ class SubTaskConfig (PipelineTaskConfig , pipelineConnections = SubTaskConnections ):
226+ """Config for SubTask."""
227+
228+ subtract = pexConfig .Field [int ](doc = "amount to subtract" , default = 3 )
229+
230+
231+ class SubTask (PipelineTask ):
232+ """Trivial PipelineTask for testing, has some extras useful for specific
233+ unit tests.
234+ """
235+
236+ ConfigClass = SubTaskConfig
237+ _DefaultName = "sub_task"
238+
239+ initout = numpy .array ([999 ])
240+ """InitOutputs for this task"""
241+
242+ taskFactory : SubTaskFactoryMock | None = None
243+ """Factory that makes instances"""
244+
245+ def run (self , input : int ) -> Struct :
246+ if self .taskFactory :
247+ # do some bookkeeping
248+ if self .taskFactory .stopAt == self .taskFactory .countExec :
249+ raise RuntimeError ("pretend something bad happened" )
250+ self .taskFactory .countExec -= 1
251+
252+ self .config = cast (SubTaskConfig , self .config )
253+ self .metadata .add ("sub" , self .config .subtract )
254+ output = input - self .config .subtract
255+ output2 = output + self .config .subtract
256+ _LOG .info ("input = %s, output = %s, output2 = %s" , input , output , output2 )
257+ return Struct (output = output , output2 = output2 )
258+
259+
260+ class SubTaskFactoryMock (TaskFactory ):
261+ """Special task factory that instantiates AddTask.
262+
263+ It also defines some bookkeeping variables used by SubTask to report
264+ progress to unit tests.
265+
266+ Parameters
267+ ----------
268+ stopAt : `int`, optional
269+ Number of times to call `run` before stopping.
270+ """
271+
272+ def __init__ (self , stopAt : int = - 1 ):
273+ self .countExec = 100 # reduced by SubTask
274+ self .stopAt = stopAt # AddTask raises exception at this call to run()
275+
276+ def makeTask (
277+ self ,
278+ task_node : TaskNode ,
279+ / ,
280+ butler : LimitedButler ,
281+ initInputRefs : Iterable [DatasetRef ] | None ,
282+ ) -> PipelineTask :
283+ task = task_node .task_class (config = task_node .config , initInputs = None , name = task_node .label )
284+ task .taskFactory = self # type: ignore
285+ return task
286+
287+
191288def registerDatasetTypes (registry : Registry , pipeline : Pipeline | PipelineGraph ) -> None :
192289 """Register all dataset types used by tasks in a registry.
193290
0 commit comments