Skip to content

Commit e271df3

Browse files
polarGHongtao Zhang
andauthored
Enhancement - Fix preprocessing crash when num_workers=0 in Megatron GPT3 dataset generation (#800)
**Description** preprocess_data.py uses multiprocessing.Pool(workers), which requires workers >= 1. When num_workers is set to 0 (valid for DataLoader, where it means "load in main process"), the preprocessing step crashes. This change clamps the worker count to max(1, self._args.num_workers) before passing it to preprocess_data.py, while leaving the original num_workers value unchanged for other uses like DataLoader. --------- Co-authored-by: Hongtao Zhang <hongtaozhang@microsoft.com>
1 parent b6c3a8f commit e271df3

3 files changed

Lines changed: 152 additions & 4 deletions

File tree

superbench/benchmarks/model_benchmarks/megatron_gpt3.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,12 +638,19 @@ def _init_distributed_setting(self):
638638
f'--node_rank {node_rank} --master_addr {addr} --master_port {port}'
639639
return True
640640

641-
def _generate_dataset(self):
641+
def _generate_dataset(self): # noqa: C901
642642
"""Generate dataset for benchmarking.
643643
644644
Return:
645645
True if dataset is created successfully.
646646
"""
647+
# Validate num_workers unconditionally so a negative value is rejected even when
648+
# dataset files already exist (it would otherwise be emitted as `--num-workers -1`
649+
# into the Megatron training command).
650+
if self._args.num_workers < 0:
651+
logger.error('num_workers must be >= 0 (got {}).'.format(self._args.num_workers))
652+
self._result.set_return_code(ReturnCode.INVALID_ARGUMENT)
653+
return False
647654
self._data_options = ''
648655
if self._args.mock_data:
649656
logger.info('Using mock data.')
@@ -657,15 +664,46 @@ def _generate_dataset(self):
657664
if not os.path.exists(os.path.join(self._args.data_home, f'{self._args.data_prefix}.bin')) \
658665
or not os.path.exists(os.path.join(self._args.data_home, f'{self._args.data_prefix}.idx')):
659666
if self._args.dataset_url:
667+
# Megatron's preprocess_data.py appends '_text_document' to --output-prefix
668+
# when producing the .bin/.idx files. For the existence check below
669+
# (which looks for {data_prefix}.bin/.idx) to pass, data_prefix must end
670+
# with '_text_document' and have a non-empty stem when generation is needed.
671+
suffix = '_text_document'
672+
if not self._args.data_prefix.endswith(suffix) or self._args.data_prefix == suffix:
673+
logger.error(
674+
'data_prefix must end with "{}" and have a non-empty stem when '
675+
'dataset generation is required (got "{}"). preprocess_data.py '
676+
'always appends "{}" to --output-prefix.'.format(suffix, self._args.data_prefix, suffix)
677+
)
678+
self._result.set_return_code(ReturnCode.INVALID_ARGUMENT)
679+
return False
680+
660681
self._raw_data_path = str(Path(self._args.data_home) / 'data.json')
661682
download_file(self._args.dataset_url, self._raw_data_path)
683+
684+
output_prefix_basename = self._args.data_prefix[:-len(suffix)]
685+
output_prefix = os.path.join(self._args.data_home, output_prefix_basename)
686+
687+
# num_workers=0 is valid for DataLoader (main process loads data),
688+
# but preprocess_data.py requires workers>=1 for multiprocessing.Pool.
689+
preprocess_workers = self._args.num_workers
690+
if self._args.num_workers == 0:
691+
preprocess_workers = 1
692+
logger.warning(
693+
'preprocess_data.py requires --workers >= 1; '
694+
'overriding num_workers={} to {} for dataset preprocessing only '
695+
'(DataLoader still uses num_workers={}).'.format(
696+
self._args.num_workers, preprocess_workers, self._args.num_workers
697+
)
698+
)
699+
662700
command = (
663701
'python3 '
664702
f'{os.path.join(self._args.code_base, "tools/preprocess_data.py")} '
665703
f'--input {self._raw_data_path} '
666704
f'--tokenizer-type {self._args.tokenizer_type} '
667-
f'--output-prefix {os.path.join(self._args.data_home, "dataset")} '
668-
f'--workers {str(self._args.num_workers)} '
705+
f'--output-prefix {output_prefix} '
706+
f'--workers {preprocess_workers} '
669707
f'--vocab-file {self._vocab_path} '
670708
f'--merge-file {self._merges_path}'
671709
)

superbench/benchmarks/model_benchmarks/model_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ def _preprocess(self):
242242
self._args.sample_count = math.ceil(self._args.sample_count / self._args.batch_size) * self._args.batch_size
243243

244244
if not self._generate_dataset():
245-
self._result.set_return_code(ReturnCode.DATASET_GENERATION_FAILURE)
245+
if self._result.return_code == ReturnCode.SUCCESS:
246+
self._result.set_return_code(ReturnCode.DATASET_GENERATION_FAILURE)
246247
return False
247248

248249
if not self._init_dataloader():

tests/benchmarks/model_benchmarks/test_megatron_gpt.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,115 @@ def test_megatron_gpt_dataset(self):
174174
ret = benchmark._generate_dataset()
175175
assert (ret is True)
176176

177+
@mock.patch('superbench.benchmarks.model_benchmarks.megatron_gpt3.run_command')
178+
@mock.patch('superbench.benchmarks.model_benchmarks.megatron_gpt3.download_file')
179+
def test_megatron_gpt_dataset_generate_command(self, mock_download_file, mock_run_command):
180+
"""Verify _generate_dataset clamps --workers to >=1 and derives --output-prefix from data_prefix."""
181+
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
182+
assert (benchmark_cls)
183+
os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
184+
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
185+
os.environ['OMPI_COMM_WORLD_RANK'] = '0'
186+
os.environ['MASTER_ADDR'] = 'localhost'
187+
os.environ['MASTER_PORT'] = '12345'
188+
189+
# Use a real, valid code_base so _preprocess() can validate it (avoid hardcoded /root path).
190+
# Clean up after this test so the alphabetically-later test_megatron_gpt_preprocess
191+
# (which expects pretrain_gpt.py to NOT exist initially) is not affected by leaked state.
192+
self.createMockFiles(['pretrain_gpt.py'])
193+
pretrain_path = Path(self._tmp_dir) / 'pretrain_gpt.py'
194+
195+
# Helper: make run_command's side_effect create the expected .bin/.idx files
196+
# so _generate_dataset() (invoked from within _preprocess()) succeeds.
197+
created_files = []
198+
199+
def _make_dataset_files(prefix):
200+
def _side_effect(*_args, **_kwargs):
201+
for ext in ('.bin', '.idx'):
202+
p = Path(self._tmp_dir) / f'{prefix}{ext}'
203+
p.touch()
204+
created_files.append(p)
205+
206+
return _side_effect
207+
208+
def _cleanup_created_files():
209+
for p in created_files + [pretrain_path]:
210+
if p.is_file():
211+
p.unlink()
212+
213+
self.addCleanup(_cleanup_created_files)
214+
215+
def _build_benchmark(extra_params):
216+
return benchmark_cls(
217+
self.benchmark_name,
218+
parameters=(
219+
f'--code_base {self._tmp_dir} --data_home {self._tmp_dir} '
220+
f'--batch_size 2048 --dataset_url http://example.com/data.json '
221+
f'{extra_params}'
222+
),
223+
)
224+
225+
def _run_case(extra_params, expected_workers, expected_prefix_basename, expected_data_prefix):
226+
mock_run_command.reset_mock()
227+
mock_run_command.side_effect = _make_dataset_files(expected_data_prefix)
228+
benchmark = _build_benchmark(extra_params)
229+
assert benchmark._preprocess() is True
230+
assert mock_run_command.call_count >= 1
231+
# Use tuple indexing instead of `.args` for Python 3.7 compatibility
232+
# (mock.call.args was added in Python 3.8).
233+
cmd = mock_run_command.call_args_list[0][0][0]
234+
units = normalize_command(cmd)
235+
assert f'--workers {expected_workers}' in units, units
236+
expected_output_prefix = os.path.join(self._tmp_dir, expected_prefix_basename)
237+
assert f'--output-prefix {expected_output_prefix}' in units, units
238+
239+
def _run_invalid_case(extra_params, expected_downloads):
240+
"""Assert _preprocess() fails fast with INVALID_ARGUMENT and no run_command call.
241+
242+
expected_downloads is the number of download_file calls before validation fails:
243+
negative num_workers is rejected before any download (0), while an invalid
244+
data_prefix is rejected only after the vocab + merges downloads (2).
245+
"""
246+
mock_run_command.reset_mock()
247+
mock_run_command.side_effect = None
248+
mock_download_file.reset_mock()
249+
benchmark = _build_benchmark(extra_params)
250+
assert benchmark._preprocess() is False
251+
assert mock_run_command.call_count == 0
252+
assert mock_download_file.call_count == expected_downloads
253+
assert benchmark.return_code == ReturnCode.INVALID_ARGUMENT
254+
255+
# Case 1: num_workers=0 with default data_prefix should produce '--workers 1' (clamped)
256+
# and '--output-prefix <data_home>/dataset' (default 'dataset_text_document' suffix stripped).
257+
_run_case(
258+
extra_params='--num_workers 0',
259+
expected_workers=1,
260+
expected_prefix_basename='dataset',
261+
expected_data_prefix='dataset_text_document',
262+
)
263+
264+
# Case 2: num_workers=4 with custom data_prefix='custom_text_document' should produce
265+
# '--workers 4' and '--output-prefix <data_home>/custom'.
266+
_run_case(
267+
extra_params='--num_workers 4 --data_prefix custom_text_document',
268+
expected_workers=4,
269+
expected_prefix_basename='custom',
270+
expected_data_prefix='custom_text_document',
271+
)
272+
273+
# Case 3: data_prefix without the '_text_document' suffix is invalid for generation
274+
# because preprocess_data.py would produce 'mydata_text_document.bin/.idx' but the
275+
# existence check looks for 'mydata.bin/.idx'. _preprocess() must fail fast (after the
276+
# vocab + merges downloads).
277+
_run_invalid_case(extra_params='--num_workers 2 --data_prefix mydata', expected_downloads=2)
278+
279+
# Case 4: data_prefix == '_text_document' has an empty stem after stripping the suffix,
280+
# which would produce a malformed '--output-prefix <data_home>/'. Must fail fast.
281+
_run_invalid_case(extra_params='--num_workers 1 --data_prefix _text_document', expected_downloads=2)
282+
283+
# Case 5: negative num_workers is invalid input and is rejected before any downloads.
284+
_run_invalid_case(extra_params='--num_workers -1 --data_prefix negative_text_document', expected_downloads=0)
285+
177286
@mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset')
178287
def test_megatron_gpt_command(self, mock_generate_dataset):
179288
"""Test command generation."""

0 commit comments

Comments
 (0)