1616import lib .infers
1717import lib .trainers
1818from lib .model .vista_point_2pt5 .models_samm2pt5d import sam_model_registry
19- from monai .networks .nets import UNet
2019from monailabel .interfaces .config import TaskConfig
2120from monailabel .interfaces .tasks .infer_v2 import InferTask
22- from monailabel .interfaces .tasks .scoring import ScoringMethod
23- from monailabel .interfaces .tasks .strategy import Strategy
2421from monailabel .interfaces .tasks .train import TrainTask
25- from monailabel .tasks .activelearning .epistemic import Epistemic
26- from monailabel .tasks .scoring .dice import Dice
27- from monailabel .tasks .scoring .epistemic import EpistemicScoring
28- from monailabel .tasks .scoring .sum import Sum
2922from monailabel .utils .others .generic import download_file , strtobool
3023
3124logger = logging .getLogger (__name__ )
@@ -35,9 +28,6 @@ class VISTAPOINT2PT5(TaskConfig):
3528 def __init__ (self ):
3629 super ().__init__ ()
3730
38- self .epistemic_enabled = None
39- self .epistemic_samples = None
40-
4131 def init (self , name : str , model_dir : str , conf : Dict [str , str ], planner : Any , ** kwargs ):
4232 super ().init (name , model_dir , conf , planner , ** kwargs )
4333
@@ -166,10 +156,6 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
166156 self .roi_size = (96 , 96 , 96 )
167157
168158 self .network = sam_model_registry ["vit_b" ](checkpoint = None , image_size = 1024 , encoder_in_chans = 9 * 3 )
169- # Others
170- self .epistemic_enabled = strtobool (conf .get ("epistemic_enabled" , "false" ))
171- self .epistemic_samples = int (conf .get ("epistemic_samples" , "5" ))
172- logger .info (f"EPISTEMIC Enabled: { self .epistemic_enabled } ; Samples: { self .epistemic_samples } " )
173159
174160 def infer (self ) -> Union [InferTask , Dict [str , InferTask ]]:
175161 task : InferTask = lib .infers .VISTAPOINT2PT5 (
@@ -179,50 +165,3 @@ def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
179165 preload = strtobool (self .conf .get ("preload" , "false" )),
180166 )
181167 return task
182-
183- def trainer (self ) -> Optional [TrainTask ]:
184- output_dir = os .path .join (self .model_dir , self .name )
185- load_path = self .path [0 ] if os .path .exists (self .path [0 ]) else self .path [1 ]
186-
187- task : TrainTask = lib .trainers .SegmentationSpleen (
188- model_dir = output_dir ,
189- network = self .network ,
190- roi_size = self .roi_size ,
191- target_spacing = self .target_spacing ,
192- description = "Train Spleen Segmentation Model" ,
193- load_path = load_path ,
194- publish_path = self .path [1 ],
195- labels = self .labels ,
196- disable_meta_tracking = False ,
197- )
198- return task
199-
200- def strategy (self ) -> Union [None , Strategy , Dict [str , Strategy ]]:
201- strategies : Dict [str , Strategy ] = {}
202- if self .epistemic_enabled :
203- strategies [f"{ self .name } _epistemic" ] = Epistemic ()
204- return strategies
205-
206- def scoring_method (self ) -> Union [None , ScoringMethod , Dict [str , ScoringMethod ]]:
207- methods : Dict [str , ScoringMethod ] = {
208- "dice" : Dice (),
209- "sum" : Sum (),
210- }
211-
212- if self .epistemic_enabled :
213- methods [f"{ self .name } _epistemic" ] = EpistemicScoring (
214- model = self .path ,
215- network = UNet (
216- spatial_dims = 3 ,
217- in_channels = 1 ,
218- out_channels = 2 ,
219- channels = [16 , 32 , 64 , 128 , 256 ],
220- strides = [2 , 2 , 2 , 2 ],
221- num_res_units = 2 ,
222- norm = "batch" ,
223- dropout = 0.2 ,
224- ),
225- transforms = lib .infers .SegmentationSpleen (None ).pre_transforms (),
226- num_samples = self .epistemic_samples ,
227- )
228- return methods
0 commit comments