@@ -227,7 +227,7 @@ def __init__(
227227 self .instance_params = instance
228228 self .use_window = use_window
229229 self .window_infer_size = window_infer_size
230- self .window_overlap_percentage = ( 0.8 ,)
230+ self .window_overlap_percentage = 0.8
231231 self .keep_on_cpu = keep_on_cpu
232232 self .stats_to_csv = stats_csv
233233 """These attributes are all arguments of :py:func:~inference, please see that for reference"""
@@ -346,7 +346,7 @@ def inference(self):
346346
347347 dims = self .model_dict ["model_input_size" ]
348348
349- model = self . model_dict [ "class" ]. get_net ()
349+
350350 if self .model_dict ["name" ] == "SegResNet" :
351351 model = self .model_dict ["class" ].get_net (
352352 input_image_size = [
@@ -360,6 +360,8 @@ def inference(self):
360360 img_size = [dims , dims , dims ],
361361 use_checkpoint = False ,
362362 )
363+ else :
364+ model = self .model_dict ["class" ].get_net ()
363365
364366 self .log_parameters ()
365367
@@ -445,6 +447,7 @@ def inference(self):
445447 inputs = inputs .to ("cpu" )
446448 print (inputs .shape )
447449
450+ # self.log("output")
448451 model_output = lambda inputs : post_process_transforms (
449452 self .model_dict ["class" ].get_output (model , inputs )
450453 )
@@ -460,6 +463,8 @@ def inference(self):
460463 else :
461464 window_size = None
462465 window_overlap = 0.25
466+
467+ # self.log("window")
463468 outputs = sliding_window_inference (
464469 inputs ,
465470 roi_size = window_size ,
0 commit comments