From da02ee68cd271e673016fd6a0418b2989136b059 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 19 Dec 2025 22:42:48 +0000 Subject: [PATCH] Fix checkpointing --- .../core/dist_checkpointing/exchange_utils.py | 2 +- megatron/core/dist_checkpointing/mapping.py | 2 +- .../strategies/filesystem_async.py | 16 +++++++++++++--- megatron/core/dist_checkpointing/validation.py | 2 +- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/megatron/core/dist_checkpointing/exchange_utils.py b/megatron/core/dist_checkpointing/exchange_utils.py index 8486c7efe..9fbc01580 100644 --- a/megatron/core/dist_checkpointing/exchange_utils.py +++ b/megatron/core/dist_checkpointing/exchange_utils.py @@ -62,7 +62,7 @@ class ShardDistribution(NamedTuple): def _shard_size(sh_ten: ShardedTensor): """Returns size in bytes of a given sharded tensor.""" if sh_ten.flattened_range is None: - numel = np.product(sh_ten.local_shape) + numel = np.prod(sh_ten.local_shape) else: numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start return numel * torch._utils._element_size(sh_ten.dtype) diff --git a/megatron/core/dist_checkpointing/mapping.py b/megatron/core/dist_checkpointing/mapping.py index 852b94f00..4df2222ab 100644 --- a/megatron/core/dist_checkpointing/mapping.py +++ b/megatron/core/dist_checkpointing/mapping.py @@ -208,7 +208,7 @@ def local_coordinates(self) -> Tuple[np.ndarray, ...]: ) # TODO: np.unravel_index? - mask = np.zeros(np.product(self.local_shape), dtype=bool) + mask = np.zeros(np.prod(self.local_shape), dtype=bool) mask[self.flattened_range] = True return np.nonzero(mask.reshape(self.local_shape)) diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py index b06ff9625..41a2f5ebe 100644 --- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -1,7 +1,8 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -""" Storage writer for PyT Distributed format allowing asynchronous save. """ +"""Storage writer for PyT Distributed format allowing asynchronous save.""" import dataclasses +import inspect import logging import os import queue @@ -314,16 +315,25 @@ def write_preloaded_data( local_results = [] try: file_name, storage_key, (bytes_data, tensor_data) = write_bucket + extra_kwargs = {} + if "serialization_format" in inspect.signature(_write_item).parameters: + from torch.distributed.checkpoint.filesystem import SerializationFormat + + extra_kwargs['serialization_format'] = SerializationFormat.TORCH_SAVE with open(file_name, "wb") as stream: for write_item, data in bytes_data: local_results.append( - _write_item(*transform_list, stream, data, write_item, storage_key) + _write_item( + *transform_list, stream, data, write_item, storage_key, **extra_kwargs + ) ) for write_item, tensor in tensor_data: assert tensor.is_cpu local_results.append( - _write_item(*transform_list, stream, tensor, write_item, storage_key) + _write_item( + *transform_list, stream, tensor, write_item, storage_key, **extra_kwargs + ) ) if use_fsync: diff --git a/megatron/core/dist_checkpointing/validation.py b/megatron/core/dist_checkpointing/validation.py index 546ec3547..f4572f11e 100644 --- a/megatron/core/dist_checkpointing/validation.py +++ b/megatron/core/dist_checkpointing/validation.py @@ -494,7 +494,7 @@ def _validate_sharding_for_key_flattened(tensors_by_shard): all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) starts, stops = map(np.asarray, zip(*sorted(all_slices))) - expected_size = np.product(local_shape) + expected_size = np.prod(local_shape) if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]): raise CheckpointingException( f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'