Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/core/dist_checkpointing/exchange_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ShardDistribution(NamedTuple):
def _shard_size(sh_ten: ShardedTensor):
"""Returns size in bytes of a given sharded tensor."""
if sh_ten.flattened_range is None:
numel = np.product(sh_ten.local_shape)
numel = np.prod(sh_ten.local_shape)

@yzygitzh yzygitzh Dec 20, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this due to that np.product is deprecated? Can we make this unchanged?

else:
numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start
return numel * torch._utils._element_size(sh_ten.dtype)
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/dist_checkpointing/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def local_coordinates(self) -> Tuple[np.ndarray, ...]:
)

# TODO: np.unravel_index?
mask = np.zeros(np.product(self.local_shape), dtype=bool)
mask = np.zeros(np.prod(self.local_shape), dtype=bool)
mask[self.flattened_range] = True
return np.nonzero(mask.reshape(self.local_shape))

Expand Down
16 changes: 13 additions & 3 deletions megatron/core/dist_checkpointing/strategies/filesystem_async.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

""" Storage writer for PyT Distributed format allowing asynchronous save. """
"""Storage writer for PyT Distributed format allowing asynchronous save."""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related to bug fix, shall we drop it?

import dataclasses
import inspect
import logging
import os
import queue
Expand Down Expand Up @@ -314,16 +315,25 @@ def write_preloaded_data(
local_results = []
try:
file_name, storage_key, (bytes_data, tensor_data) = write_bucket
extra_kwargs = {}
if "serialization_format" in inspect.signature(_write_item).parameters:
from torch.distributed.checkpoint.filesystem import SerializationFormat

extra_kwargs['serialization_format'] = SerializationFormat.TORCH_SAVE

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously there is no issue here, is this due to an upgrade of PyTorch? Shall we put the reason of change in PR description

with open(file_name, "wb") as stream:
for write_item, data in bytes_data:
local_results.append(
_write_item(*transform_list, stream, data, write_item, storage_key)
_write_item(
*transform_list, stream, data, write_item, storage_key, **extra_kwargs
)
)

for write_item, tensor in tensor_data:
assert tensor.is_cpu
local_results.append(
_write_item(*transform_list, stream, tensor, write_item, storage_key)
_write_item(
*transform_list, stream, tensor, write_item, storage_key, **extra_kwargs
)
)

if use_fsync:
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/dist_checkpointing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))

starts, stops = map(np.asarray, zip(*sorted(all_slices)))
expected_size = np.product(local_shape)
expected_size = np.prod(local_shape)
if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'
Expand Down
Loading