diff --git a/superbench/benchmarks/model_benchmarks/megatron_gpt3.py b/superbench/benchmarks/model_benchmarks/megatron_gpt3.py index 37d27bf1a..9d34ab9eb 100644 --- a/superbench/benchmarks/model_benchmarks/megatron_gpt3.py +++ b/superbench/benchmarks/model_benchmarks/megatron_gpt3.py @@ -651,13 +651,16 @@ def _generate_dataset(self): if self._args.dataset_url: self._raw_data_path = str(Path(self._args.data_home) / 'data.json') download_file(self._args.dataset_url, self._raw_data_path) + command = ( 'python3 ' f'{os.path.join(self._args.code_base, "tools/preprocess_data.py")} ' f'--input {self._raw_data_path} ' f'--tokenizer-type {self._args.tokenizer_type} ' f'--output-prefix {os.path.join(self._args.data_home, "dataset")} ' - f'--workers {str(self._args.num_workers)} ' + # num_workers=0 is valid for DataLoader (main process loads data), + # but preprocess_data.py requires workers>=1 for multiprocessing.Pool. + f'--workers {max(1, self._args.num_workers)} ' f'--vocab-file {self._vocab_path} ' f'--merge-file {self._merges_path}' )