1212from abc import ABC , abstractmethod
1313from pathlib import Path
1414import random
15+ import tempfile
1516import numpy as np
1617from typing import Callable , Any
1718
@@ -336,18 +337,13 @@ def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""
336337
337338 def _launch_with_file_store (self , request , world_size ):
338339 tmpdir = request .getfixturevalue ("tmpdir" )
339- dist_file_store = tmpdir .join ("dist_file_store" )
340- assert not os .path .exists (dist_file_store )
341- init_method = f"file://{ dist_file_store } "
342340
343341 if isinstance (world_size , int ):
344342 world_size = [world_size ]
345343 for procs in world_size :
346- try :
347- self ._launch_procs (procs , init_method )
348- finally :
349- if os .path .exists (dist_file_store ):
350- os .remove (dist_file_store )
344+ with tempfile .NamedTemporaryFile (delete = False , dir = str (tmpdir ), suffix = '_filestore' ) as fp :
345+ init_method = f"file://{ fp .name } "
346+ self ._launch_procs (procs , init_method )
351347 time .sleep (0.5 )
352348
353349 def _dist_destroy (self ):
@@ -357,7 +353,7 @@ def _dist_destroy(self):
357353
358354 def _close_pool (self , pool , num_procs , force = False ):
359355 if force or not self .reuse_dist_env :
360- msg = pool .starmap (self ._dist_destroy , [() for _ in range (num_procs )])
356+ pool .starmap (self ._dist_destroy , [() for _ in range (num_procs )])
361357 pool .close ()
362358 pool .join ()
363359
0 commit comments