Skip to content

Commit a46342f

Browse files
authored
Faster preprocessing
1 parent 3fa3153 commit a46342f

1 file changed

Lines changed: 111 additions & 46 deletions

File tree

tools/preprocessing/extract_features.py

Lines changed: 111 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
import argparse
2+
import os
23
import random
3-
from concurrent.futures import ProcessPoolExecutor
4+
import subprocess as sp
5+
import time
46
from copy import deepcopy
7+
from datetime import timedelta
58
from pathlib import Path
9+
from random import Random
610
from typing import Optional
711

812
import librosa
913
import numpy as np
1014
import torch
11-
import torch.multiprocessing as mp
1215
import torchcrepe
1316
from fish_audio_preprocess.utils.file import AUDIO_EXTENSIONS, list_files
1417
from loguru import logger
1518
from mmengine import Config
16-
from tqdm import tqdm
1719

1820
from fish_diffusion.modules.energy_extractors import ENERGY_EXTRACTORS
1921
from fish_diffusion.modules.feature_extractors import FEATURE_EXTRACTORS
@@ -38,14 +40,10 @@ def init(
3840
global model_caches
3941
device = torch.device("cpu")
4042

41-
rank = mp.current_process()._identity
42-
rank = rank[0] if len(rank) > 0 else 0
43-
4443
if torch.cuda.is_available():
45-
gpu_id = rank % torch.cuda.device_count()
46-
device = torch.device(f"cuda:{gpu_id}")
44+
device = torch.device("cuda")
4745

48-
logger.info(f"Rank {rank} uses device {device}")
46+
logger.info(f"{curr_worker} Uses device {device}")
4947

5048
text_features_extractor = None
5149
if getattr(config.preprocessing, "text_features_extractor", None):
@@ -213,7 +211,7 @@ def safe_process(args, config, audio_path: Path):
213211

214212
return aug_count + 1
215213
except Exception as e:
216-
logger.error(f"Error processing {audio_path}")
214+
logger.error(f"{curr_worker} Error processing {audio_path}")
217215

218216
if args.debug:
219217
logger.exception(e)
@@ -225,64 +223,131 @@ def parse_args():
225223
parser.add_argument("--config", type=str, required=True)
226224
parser.add_argument("--path", type=str, required=True)
227225
parser.add_argument("--clean", action="store_true")
228-
parser.add_argument("--num-workers", type=int, default=1)
226+
parser.add_argument(
227+
"--num-workers",
228+
type=int,
229+
default=1,
230+
help="Number of workers, will launch a process pool if > 1",
231+
)
229232
parser.add_argument("--no-augmentation", action="store_true")
230233
parser.add_argument("--debug", action="store_true")
231234

235+
# For multiprocessing
236+
parser.add_argument("--rank", type=int, default=0)
237+
parser.add_argument("--world-size", type=int, default=1)
238+
232239
return parser.parse_args()
233240

234241

235242
if __name__ == "__main__":
236-
# mp.set_start_method("spawn", force=True)
237-
238243
args = parse_args()
239-
240-
logger.info(f"Using {args.num_workers} workers")
244+
curr_worker = f"[Rank {args.rank}]" if args.world_size > 1 else "[Main]"
241245

242246
if torch.cuda.is_available():
243-
logger.info(f"Found {torch.cuda.device_count()} GPUs")
247+
logger.info(f"{curr_worker} Found {torch.cuda.device_count()} GPUs")
244248
else:
245-
logger.warning("No GPU found, using CPU")
249+
logger.warning(f"{curr_worker} No GPU found, using CPU")
246250

247-
if args.clean:
248-
logger.info("Cleaning *.npy files...")
251+
# Only clean on main process
252+
if args.clean and args.rank == 0:
253+
logger.info(f"{curr_worker} Cleaning *.npy files...")
249254

250255
files = list_files(args.path, {".npy"}, recursive=True, sort=True)
251256
for f in files:
252257
f.unlink()
253258

254-
logger.info("Done!")
259+
logger.info(f"{curr_worker} Done!")
260+
261+
# Multi-processing
262+
if args.num_workers > 1:
263+
logger.info(f"{curr_worker} Launching {args.num_workers} workers")
264+
265+
processes = []
266+
for idx in range(args.num_workers):
267+
new_args = [
268+
"python",
269+
__file__,
270+
"--config",
271+
args.config,
272+
"--path",
273+
args.path,
274+
"--rank",
275+
str(idx),
276+
"--world-size",
277+
str(args.num_workers),
278+
]
279+
280+
if args.no_augmentation:
281+
new_args.append("--no-augmentation")
282+
283+
if args.debug:
284+
new_args.append("--debug")
285+
286+
env = deepcopy(os.environ)
287+
288+
# Respect CUDA_VISIBLE_DEVICES
289+
if "CUDA_VISIBLE_DEVICES" in env:
290+
devices = env["CUDA_VISIBLE_DEVICES"].split(",")
291+
env["CUDA_VISIBLE_DEVICES"] = devices[idx % len(devices)]
292+
else:
293+
env["CUDA_VISIBLE_DEVICES"] = str(idx % torch.cuda.device_count())
294+
295+
processes.append(sp.Popen(new_args, env=env))
296+
logger.info(f"{curr_worker} Launched worker {idx}")
297+
298+
for p in processes:
299+
p.wait()
300+
301+
if p.returncode != 0:
302+
logger.error(
303+
f"{curr_worker} Worker {idx} failed with code {p.returncode}, exiting..."
304+
)
305+
exit(p.returncode)
306+
307+
logger.info(f"{curr_worker} All workers done!")
308+
exit(0)
255309

310+
# Load config
256311
config = Config.fromfile(args.config)
257-
files = list_files(args.path, AUDIO_EXTENSIONS, recursive=True, sort=False)
258-
logger.info(f"Found {len(files)} files, processing...")
312+
files = list_files(args.path, AUDIO_EXTENSIONS, recursive=True, sort=True)
259313

260314
# Shuffle files will balance the workload of workers
261-
random.shuffle(files)
315+
Random(42).shuffle(files)
316+
317+
logger.info(f"{curr_worker} Found {len(files)} files, processing...")
318+
319+
# Chunk files
320+
if args.world_size > 1:
321+
files = files[args.rank :: args.world_size]
322+
logger.info(f"{curr_worker} Processing subset of {len(files)} files")
323+
324+
# Main process
262325
total_samples, failed = 0, 0
326+
log_time = 0
327+
start_time = time.time()
328+
329+
for idx, audio_path in enumerate(files):
330+
i = safe_process(args, config, audio_path)
331+
if isinstance(i, int):
332+
total_samples += i
333+
else:
334+
failed += 1
335+
336+
if (idx + 1) % 100 == 0 and time.time() - log_time > 10:
337+
eta = (time.time() - start_time) / (idx + 1) * (len(files) - idx - 1)
338+
339+
logger.info(
340+
f"{curr_worker} "
341+
+ f"Processed {idx + 1}/{len(files)} files, "
342+
+ f"{total_samples} samples, {failed} failed, "
343+
+ f"ETA: {timedelta(seconds=eta)}"
344+
)
345+
346+
log_time = time.time()
263347

264-
if args.num_workers <= 1:
265-
for audio_path in tqdm(files):
266-
i = safe_process(args, config, audio_path)
267-
if isinstance(i, int):
268-
total_samples += i
269-
else:
270-
failed += 1
271-
else:
272-
with ProcessPoolExecutor(
273-
max_workers=args.num_workers,
274-
) as executor:
275-
params = [(args, config, audio_path) for audio_path in files]
276-
277-
for i in tqdm(executor.map(safe_process, *zip(*params)), total=len(params)):
278-
if isinstance(i, int):
279-
total_samples += i
280-
else:
281-
failed += 1
282-
283-
logger.info(f"Finished!")
284-
logger.info(f"Original samples: {len(files)}")
285348
logger.info(
286-
f"Augmented samples: {total_samples} (x{total_samples / len(files):.2f})"
349+
f"{curr_worker} Done! "
350+
+ f"Original samples: {len(files)}, "
351+
+ f"Augmented samples: {total_samples} (x{total_samples / len(files):.2f}), "
352+
+ f"Failed: {failed}"
287353
)
288-
logger.info(f"Failed: {failed}")

0 commit comments

Comments
 (0)