11import argparse
2+ import os
23import random
3- from concurrent .futures import ProcessPoolExecutor
4+ import subprocess as sp
5+ import time
46from copy import deepcopy
7+ from datetime import timedelta
58from pathlib import Path
9+ from random import Random
610from typing import Optional
711
812import librosa
913import numpy as np
1014import torch
11- import torch .multiprocessing as mp
1215import torchcrepe
1316from fish_audio_preprocess .utils .file import AUDIO_EXTENSIONS , list_files
1417from loguru import logger
1518from mmengine import Config
16- from tqdm import tqdm
1719
1820from fish_diffusion .modules .energy_extractors import ENERGY_EXTRACTORS
1921from fish_diffusion .modules .feature_extractors import FEATURE_EXTRACTORS
@@ -38,14 +40,10 @@ def init(
3840 global model_caches
3941 device = torch .device ("cpu" )
4042
41- rank = mp .current_process ()._identity
42- rank = rank [0 ] if len (rank ) > 0 else 0
43-
4443 if torch .cuda .is_available ():
45- gpu_id = rank % torch .cuda .device_count ()
46- device = torch .device (f"cuda:{ gpu_id } " )
44+ device = torch .device ("cuda" )
4745
48- logger .info (f"Rank { rank } uses device { device } " )
46+ logger .info (f"{ curr_worker } Uses device { device } " )
4947
5048 text_features_extractor = None
5149 if getattr (config .preprocessing , "text_features_extractor" , None ):
@@ -213,7 +211,7 @@ def safe_process(args, config, audio_path: Path):
213211
214212 return aug_count + 1
215213 except Exception as e :
216- logger .error (f"Error processing { audio_path } " )
214+ logger .error (f"{ curr_worker } Error processing { audio_path } " )
217215
218216 if args .debug :
219217 logger .exception (e )
@@ -225,64 +223,131 @@ def parse_args():
225223 parser .add_argument ("--config" , type = str , required = True )
226224 parser .add_argument ("--path" , type = str , required = True )
227225 parser .add_argument ("--clean" , action = "store_true" )
228- parser .add_argument ("--num-workers" , type = int , default = 1 )
226+ parser .add_argument (
227+ "--num-workers" ,
228+ type = int ,
229+ default = 1 ,
230+ help = "Number of workers, will launch a process pool if > 1" ,
231+ )
229232 parser .add_argument ("--no-augmentation" , action = "store_true" )
230233 parser .add_argument ("--debug" , action = "store_true" )
231234
235+ # For multiprocessing
236+ parser .add_argument ("--rank" , type = int , default = 0 )
237+ parser .add_argument ("--world-size" , type = int , default = 1 )
238+
232239 return parser .parse_args ()
233240
234241
235242if __name__ == "__main__" :
236- # mp.set_start_method("spawn", force=True)
237-
238243 args = parse_args ()
239-
240- logger .info (f"Using { args .num_workers } workers" )
244+ curr_worker = f"[Rank { args .rank } ]" if args .world_size > 1 else "[Main]"
241245
242246 if torch .cuda .is_available ():
243- logger .info (f"Found { torch .cuda .device_count ()} GPUs" )
247+ logger .info (f"{ curr_worker } Found { torch .cuda .device_count ()} GPUs" )
244248 else :
245- logger .warning (" No GPU found, using CPU" )
249+ logger .warning (f" { curr_worker } No GPU found, using CPU" )
246250
247- if args .clean :
248- logger .info ("Cleaning *.npy files..." )
251+ # Only clean on main process
252+ if args .clean and args .rank == 0 :
253+ logger .info (f"{ curr_worker } Cleaning *.npy files..." )
249254
250255 files = list_files (args .path , {".npy" }, recursive = True , sort = True )
251256 for f in files :
252257 f .unlink ()
253258
254- logger .info ("Done!" )
259+ logger .info (f"{ curr_worker } Done!" )
260+
261+ # Multi-processing
262+ if args .num_workers > 1 :
263+ logger .info (f"{ curr_worker } Launching { args .num_workers } workers" )
264+
265+ processes = []
266+ for idx in range (args .num_workers ):
267+ new_args = [
268+ "python" ,
269+ __file__ ,
270+ "--config" ,
271+ args .config ,
272+ "--path" ,
273+ args .path ,
274+ "--rank" ,
275+ str (idx ),
276+ "--world-size" ,
277+ str (args .num_workers ),
278+ ]
279+
280+ if args .no_augmentation :
281+ new_args .append ("--no-augmentation" )
282+
283+ if args .debug :
284+ new_args .append ("--debug" )
285+
286+ env = deepcopy (os .environ )
287+
288+ # Respect CUDA_VISIBLE_DEVICES
289+ if "CUDA_VISIBLE_DEVICES" in env :
290+ devices = env ["CUDA_VISIBLE_DEVICES" ].split ("," )
291+ env ["CUDA_VISIBLE_DEVICES" ] = devices [idx % len (devices )]
292+ else :
293+ env ["CUDA_VISIBLE_DEVICES" ] = str (idx % torch .cuda .device_count ())
294+
295+ processes .append (sp .Popen (new_args , env = env ))
296+ logger .info (f"{ curr_worker } Launched worker { idx } " )
297+
298+ for p in processes :
299+ p .wait ()
300+
301+ if p .returncode != 0 :
302+ logger .error (
303+ f"{ curr_worker } Worker { idx } failed with code { p .returncode } , exiting..."
304+ )
305+ exit (p .returncode )
306+
307+ logger .info (f"{ curr_worker } All workers done!" )
308+ exit (0 )
255309
310+ # Load config
256311 config = Config .fromfile (args .config )
257- files = list_files (args .path , AUDIO_EXTENSIONS , recursive = True , sort = False )
258- logger .info (f"Found { len (files )} files, processing..." )
312+ files = list_files (args .path , AUDIO_EXTENSIONS , recursive = True , sort = True )
259313
260314 # Shuffle files will balance the workload of workers
261- random .shuffle (files )
315+ Random (42 ).shuffle (files )
316+
317+ logger .info (f"{ curr_worker } Found { len (files )} files, processing..." )
318+
319+ # Chunk files
320+ if args .world_size > 1 :
321+ files = files [args .rank :: args .world_size ]
322+ logger .info (f"{ curr_worker } Processing subset of { len (files )} files" )
323+
324+ # Main process
262325 total_samples , failed = 0 , 0
326+ log_time = 0
327+ start_time = time .time ()
328+
329+ for idx , audio_path in enumerate (files ):
330+ i = safe_process (args , config , audio_path )
331+ if isinstance (i , int ):
332+ total_samples += i
333+ else :
334+ failed += 1
335+
336+ if (idx + 1 ) % 100 == 0 and time .time () - log_time > 10 :
337+ eta = (time .time () - start_time ) / (idx + 1 ) * (len (files ) - idx - 1 )
338+
339+ logger .info (
340+ f"{ curr_worker } "
341+ + f"Processed { idx + 1 } /{ len (files )} files, "
342+ + f"{ total_samples } samples, { failed } failed, "
343+ + f"ETA: { timedelta (seconds = eta )} "
344+ )
345+
346+ log_time = time .time ()
263347
264- if args .num_workers <= 1 :
265- for audio_path in tqdm (files ):
266- i = safe_process (args , config , audio_path )
267- if isinstance (i , int ):
268- total_samples += i
269- else :
270- failed += 1
271- else :
272- with ProcessPoolExecutor (
273- max_workers = args .num_workers ,
274- ) as executor :
275- params = [(args , config , audio_path ) for audio_path in files ]
276-
277- for i in tqdm (executor .map (safe_process , * zip (* params )), total = len (params )):
278- if isinstance (i , int ):
279- total_samples += i
280- else :
281- failed += 1
282-
283- logger .info (f"Finished!" )
284- logger .info (f"Original samples: { len (files )} " )
285348 logger .info (
286- f"Augmented samples: { total_samples } (x{ total_samples / len (files ):.2f} )"
349+ f"{ curr_worker } Done! "
350+ + f"Original samples: { len (files )} , "
351+ + f"Augmented samples: { total_samples } (x{ total_samples / len (files ):.2f} ), "
352+ + f"Failed: { failed } "
287353 )
288- logger .info (f"Failed: { failed } " )
0 commit comments