@@ -105,6 +105,14 @@ def get_custom_model_vars(self):
105105 model_path = os .path .join (model_directory , model_file )
106106 return model_file , model_directory , model_path
107107
108+ def cleanup (self ):
109+ # Try to clear some GPU memory for other worker processes.
110+ try :
111+ from torch import cuda
112+ cuda .empty_cache ()
113+ except Exception as e :
114+ print (f"Unable to clear GPU memory. You may need to restart CellProfiler to change models. { e } " )
115+
108116class RunCellpose (ImageSegmentation ):
109117 category = "Object Processing"
110118
@@ -592,14 +600,6 @@ def validate_module(self, pipeline):
592600 % model_path , self .model_file_name ,
593601 )
594602
595- def cleanup (self ):
596- from torch import cuda
597- # Try to clear some GPU memory for other worker processes.
598- try :
599- cuda .empty_cache ()
600- except Exception as e :
601- print (f"Unable to clear GPU memory. You may need to restart CellProfiler to change models. { e } " )
602-
603603 def run (self , workspace ):
604604 x_name = self .x_name .value
605605 y_name = self .y_name .value
@@ -662,10 +662,6 @@ def run(self, workspace):
662662 from cellpose import models , io , core , utils
663663 self .cellpose_ver = importlib .metadata .version ('cellpose' )
664664
665- if self .use_gpu .value and model .torch :
666- from torch import cuda
667- cuda .set_per_process_memory_fraction (self .manual_GPU_memory_share .value )
668-
669665 if self .cellpose_version .value == 'omnipose' :
670666 assert int (self .cellpose_ver [0 ])< 2 , "Cellpose version selected in RunCellpose module doesn't match version in Python"
671667 assert float (self .cellpose_ver [0 :3 ]) >= 0.6 , "Cellpose v1/omnipose requires Cellpose >= 0.6"
@@ -675,6 +671,11 @@ def run(self, workspace):
675671 else :
676672 model_file , model_directory , model_path = get_custom_model_vars (self )
677673 model = models .CellposeModel (pretrained_model = model_path , gpu = self .use_gpu .value )
674+
675+ if self .use_gpu .value and model .torch :
676+ from torch import cuda
677+ cuda .set_per_process_memory_fraction (self .manual_GPU_memory_share .value )
678+
678679 try :
679680 y_data , flows , * _ = model .eval (
680681 x_data ,
@@ -694,7 +695,7 @@ def run(self, workspace):
694695 print (f"Unable to create masks. Check your module settings. { a } " )
695696 finally :
696697 if self .use_gpu .value and model .torch :
697- cleanup ()
698+ cleanup (self )
698699
699700 if self .cellpose_version .value == 'v2' :
700701 assert int (self .cellpose_ver [0 ])== 2 , "Cellpose version selected in RunCellpose module doesn't match version in Python"
@@ -704,6 +705,11 @@ def run(self, workspace):
704705 else :
705706 model_file , model_directory , model_path = get_custom_model_vars (self )
706707 model = models .CellposeModel (pretrained_model = model_path , gpu = self .use_gpu .value )
708+
709+ if self .use_gpu .value and model .torch :
710+ from torch import cuda
711+ cuda .set_per_process_memory_fraction (self .manual_GPU_memory_share .value )
712+
707713 try :
708714 y_data , flows , * _ = model .eval (
709715 x_data ,
@@ -722,7 +728,7 @@ def run(self, workspace):
722728 print (f"Unable to create masks. Check your module settings. { a } " )
723729 finally :
724730 if self .use_gpu .value and model .torch :
725- cleanup ()
731+ cleanup (self )
726732
727733 elif self .cellpose_version .value == 'v3' :
728734 assert int (self .cellpose_ver [0 ])== 3 , "Cellpose version selected in RunCellpose module doesn't match version in Python"
@@ -738,6 +744,15 @@ def run(self, workspace):
738744 model_type = self .mode .value , gpu = self .use_gpu .value )
739745 self .current_model_params = model_params
740746
747+ if self .use_gpu .value :
748+ try :
749+ from torch import cuda
750+ cuda .set_per_process_memory_fraction (self .manual_GPU_memory_share .value )
751+ except :
752+ print (
753+ "Failed to set GPU memory share. Please check your PyTorch installation. Not setting per-process memory share."
754+ )
755+
741756 if self .denoise .value :
742757 from cellpose import denoise
743758 recon_params = (
@@ -788,8 +803,8 @@ def run(self, workspace):
788803 except Exception as a :
789804 print (f"Unable to create masks. Check your module settings. { a } " )
790805 finally :
791- if self .use_gpu .value and model . torch :
792- cleanup ()
806+ if self .use_gpu .value :
807+ cleanup (self )
793808
794809 elif self .cellpose_version .value == 'v4' :
795810 assert int (self .cellpose_ver [0 ])== 4 , "Cellpose version selected in RunCellpose module doesn't match version in Python"
@@ -800,6 +815,15 @@ def run(self, workspace):
800815 self .current_model = models .CellposeModel (gpu = self .use_gpu .value )
801816 self .current_model_params = model_params
802817
818+ if self .use_gpu .value :
819+ try :
820+ from torch import cuda
821+ cuda .set_per_process_memory_fraction (self .manual_GPU_memory_share .value )
822+ except :
823+ print (
824+ "Failed to set GPU memory share. Please check your PyTorch installation. Not setting per-process memory share."
825+ )
826+
803827 if self .specify_diameter .value :
804828 try :
805829 y_data , flows , * _ = self .current_model .eval (
@@ -818,7 +842,7 @@ def run(self, workspace):
818842 print (f"Unable to create masks. Check your module settings. { a } " )
819843 finally :
820844 if self .use_gpu .value and model .torch :
821- cleanup ()
845+ cleanup (self )
822846 else :
823847 try :
824848 y_data , flows , * _ = self .current_model .eval (
@@ -835,8 +859,8 @@ def run(self, workspace):
835859 except Exception as a :
836860 print (f"Unable to create masks. Check your module settings. { a } " )
837861 finally :
838- if self .use_gpu .value and model . torch :
839- cleanup ()
862+ if self .use_gpu .value :
863+ cleanup (self )
840864
841865 if self .remove_edge_masks :
842866 y_data = utils .remove_edge_masks (y_data )
@@ -1055,6 +1079,7 @@ def display(self, workspace, figure):
10551079
10561080 def do_check_gpu (self ):
10571081 import importlib .util
1082+ from cellpose import core
10581083 torch_installed = importlib .util .find_spec ('torch' ) is not None
10591084 self .cellpose_ver = importlib .metadata .version ('cellpose' )
10601085 #if the old version of cellpose <2.0, then use istorch kwarg
0 commit comments