From 6f8d1edf13c2abda4d8dde0418ff78fcf637deeb Mon Sep 17 00:00:00 2001 From: LI MOU Date: Wed, 28 May 2025 02:47:54 +0000 Subject: [PATCH 01/23] add pp to vpp model convert tool --- tools/checkpoint/pp_tp_vpp/.gitignore | 2 + tools/checkpoint/pp_tp_vpp/README.md | 99 ++++++++ tools/checkpoint/pp_tp_vpp/distrib_optim.py | 228 ++++++++++++++++++ tools/checkpoint/pp_tp_vpp/main.py | 70 ++++++ tools/checkpoint/pp_tp_vpp/model_optim_rng.py | 139 +++++++++++ .../pp_tp_vpp/run_convert_pp_to_vpp.sh | 12 + tools/checkpoint/pp_tp_vpp/utils.py | 40 +++ tools/checkpoint/pp_tp_vpp/vpp_converter.py | 48 ++++ 8 files changed, 638 insertions(+) create mode 100644 tools/checkpoint/pp_tp_vpp/.gitignore create mode 100644 tools/checkpoint/pp_tp_vpp/README.md create mode 100644 tools/checkpoint/pp_tp_vpp/distrib_optim.py create mode 100644 tools/checkpoint/pp_tp_vpp/main.py create mode 100644 tools/checkpoint/pp_tp_vpp/model_optim_rng.py create mode 100644 tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh create mode 100644 tools/checkpoint/pp_tp_vpp/utils.py create mode 100644 tools/checkpoint/pp_tp_vpp/vpp_converter.py diff --git a/tools/checkpoint/pp_tp_vpp/.gitignore b/tools/checkpoint/pp_tp_vpp/.gitignore new file mode 100644 index 000000000..fffc64de4 --- /dev/null +++ b/tools/checkpoint/pp_tp_vpp/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +*.swp diff --git a/tools/checkpoint/pp_tp_vpp/README.md b/tools/checkpoint/pp_tp_vpp/README.md new file mode 100644 index 000000000..aa9ee819a --- /dev/null +++ b/tools/checkpoint/pp_tp_vpp/README.md @@ -0,0 +1,99 @@ +# pp_to_vpp +## description +This tool can convert a language model checkpoint without virtual pipeline parallelism into one with virtual pipeline parallelism by increasing the virtual pipeline stage size. + +Other model parallel parameters (tensor-parallel-size, pipeline-parallel-size, expert-parallel-size ...) remain unchanged. + +**Currently, tests have been conducted on the DeepSeek(v2, v3) models.** + +Note that currently, all of the following configurations must be satisfied to be supported. +1. current pipeline partition is even (num_layers for each pipeline stage is equal) +2. tensor-model-parallel-size=1 +3. ckpt_type=CheckpointType.LEGACY +4. ckpt_format=torch + +so the checkpoint for each iteration folder should look like this: +``` +iter_0000050 +├── mp_rank_00_000_000 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +├── mp_rank_00_000_001 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +├── mp_rank_00_000_002 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +... +``` + + +## how to use +you can modify run_convert_pp_to_vpp.sh and launch it as an example +``` +usage: main.py [-h] --load-iteration-dir LOAD_ITERATION_DIR --expert-model-parallel-size EXPERT_MODEL_PARALLEL_SIZE --pipeline-model-parallel-size PIPELINE_MODEL_PARALLEL_SIZE + --save-iteration-dir SAVE_ITERATION_DIR --target-virtual-pipeline-model-parallel-size TARGET_VIRTUAL_PIPELINE_MODEL_PARALLEL_SIZE + [--num-max-processing-processes NUM_MAX_PROCESSING_PROCESSES] [--pipeline-ranks-to-process PIPELINE_RANKS_TO_PROCESS] + +convert a non-virtual pipeline checkpoint to virtual pipeline checkpoint + +options: + -h, --help show this help message and exit + --load-iteration-dir LOAD_ITERATION_DIR + iteration folder of source model checkpoint + --expert-model-parallel-size EXPERT_MODEL_PARALLEL_SIZE + ep_size of original model and the target model + --pipeline-model-parallel-size PIPELINE_MODEL_PARALLEL_SIZE + physical pp_size of original model and the target model + --save-iteration-dir SAVE_ITERATION_DIR + iteration folder of target model checkpoint, need to be empty if existed + --target-virtual-pipeline-model-parallel-size TARGET_VIRTUAL_PIPELINE_MODEL_PARALLEL_SIZE + vpp_size of target model + --num-max-processing-processes NUM_MAX_PROCESSING_PROCESSES + the maximum number of processing processes used by this script, increasing this value can speed up model conversion(but the final bottleneck may be disk + bandwidth), it will also consume more CPU memory. + --pipeline-ranks-to-process PIPELINE_RANKS_TO_PROCESS + pipeline rank list to process using this script, to accelerate converting user can launch multiple tasks on different nodes, each one process part of pipeline + ranks. example : --pipeline-ranks-to-process [0,1,2,3] default is None, means process all pipeline ranks +``` + +## examples +The target model has virtual_pipeline_size=2, and uses 4 processes in parallel. +``` +python main.py \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --expert-model-parallel-size 4 \ + --pipeline-model-parallel-size 2 \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ + --target-virtual-pipeline-model-parallel-size 2 \ + --num-max-processing-processes 4 +``` +Convert the checkpoints generated by pipeline ranks [0,1,2,3] on node 1, and convert the checkpoints generated by pipeline ranks [4,5,6,7] on node 2. (in cases where memory is limited on a single node) +``` +# node1 : +python main.py \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --expert-model-parallel-size 8 \ + --pipeline-model-parallel-size 8 \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ + --target-virtual-pipeline-model-parallel-size 2 \ + --num-max-processing-processes 4 \ + --pipeline-ranks-to-process [0,1,2,3] + +# node2: +python main.py \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --expert-model-parallel-size 8 \ + --pipeline-model-parallel-size 8 \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ + --target-virtual-pipeline-model-parallel-size 2 \ + --num-max-processing-processes 4 \ + --pipeline-ranks-to-process [4,5,6,7] +``` + +Some training logs from the tests are available in the **logs** directory for review. + +## NOTE +It's also possible to continue training by loading only the model weights without loading the optimizer state (add **--no-load-optim** argument when launch Megatron-LM, which will reset the optimizer), though performance may recover after training for a few more iterations. + + diff --git a/tools/checkpoint/pp_tp_vpp/distrib_optim.py b/tools/checkpoint/pp_tp_vpp/distrib_optim.py new file mode 100644 index 000000000..8da9b36c6 --- /dev/null +++ b/tools/checkpoint/pp_tp_vpp/distrib_optim.py @@ -0,0 +1,228 @@ +import os +import logging + +from utils import ( + log_and_exit, + get_folder_name, + get_vpp_source_position, + MODEL_OPTIM_RNG_FILENAME, + DISTRIB_OPTIM_FILENAME, +) + +import torch + +logger = logging.getLogger(__name__) + +def _fetch_opt_parameters(args, + src_pp_rank, + src_ep_rank, + src_start_layer_idx, + num_layers_per_virtual_stage, + count_dense : bool, + count_experts : bool): + src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, src_pp_rank, src_ep_rank)) + src_model_opt_rng_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) + src_distrib_optim_path = os.path.join(src_folder_path, DISTRIB_OPTIM_FILENAME) + + state_dict_model = torch.load(src_model_opt_rng_path, map_location="cpu", weights_only=False)["model"] + state_dict_disopts = torch.load(src_distrib_optim_path, map_location="cpu", weights_only=False) + if not isinstance(state_dict_disopts, list): + state_dict_disopts = [state_dict_disopts] + + if count_dense and count_experts: + assert len(state_dict_disopts)==1, "length of state_dict_disopts does not equal to 1" + state_dict_disopt = state_dict_disopts[0] + elif count_dense: + state_dict_disopt = state_dict_disopts[0] + else: + assert count_experts + assert len(state_dict_disopts)==2 + state_dict_disopt = state_dict_disopts[1] + + current_offset = 0 + start_offset, end_offset = -1, -1 + for (k, v) in state_dict_model.items(): + if not hasattr(v, "nelement"): + continue + if "router.expert_bias" in k: + continue + if ".layer_idx" in k: + # added for debug + continue + if not ".layers." in k: + if start_offset != -1: + end_offset = current_offset + break + current_offset += v.nelement() if count_dense else 0 + continue + layer_idx = int(k.split(".layers.")[1].split(".")[0]) + if layer_idx == src_start_layer_idx and start_offset == -1: + start_offset = current_offset + if layer_idx == src_start_layer_idx+num_layers_per_virtual_stage \ + and start_offset != -1: + end_offset = current_offset + break + if ".mlp.experts." in k: + current_offset += v.nelement() if count_experts else 0 + else: + current_offset += v.nelement() if count_dense else 0 + assert start_offset != -1 + if end_offset == -1: + end_offset = current_offset + + opt_parameters_dict = next(iter(state_dict_disopt[0].values())) + (param_part, exp_avg_part, exp_avg_sq_part) = ( + opt_parameters_dict["param"][-end_offset:(-start_offset if start_offset!=0 else None)], + opt_parameters_dict["exp_avg"][-end_offset:(-start_offset if start_offset!=0 else None)], + opt_parameters_dict["exp_avg_sq"][-end_offset:(-start_offset if start_offset!=0 else None)], + ) + logger.debug(f"_fetch_opt_parameters, src_pp_rank={src_pp_rank}, src_start_layer_idx={src_start_layer_idx}, " + "count_dense={}, count_experts={}, num_elements_in_param_part={}".format(count_dense, count_experts, param_part.nelement())) + return (param_part, exp_avg_part, exp_avg_sq_part) + +def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict): + src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) + src_file_path = os.path.join(src_folder_path, DISTRIB_OPTIM_FILENAME) + + logger.info(f"loading distrib_optim state dict {src_file_path} ...") + target_disopt_state_dicts = torch.load(src_file_path, map_location="cpu", weights_only=False) + + if not isinstance(target_disopt_state_dicts, list): + logger.info("target_disopt_state_dicts is not a list, pp_size={}, ep_size={}, pp_rank={}, ep_rank={}".format( + args.pipeline_model_parallel_size, + args.expert_model_parallel_size, + target_pp_rank, + target_ep_rank)) + target_disopt_state_dicts = [target_disopt_state_dicts] + else: + logger.info("length of target_disopt_state_dicts : {}".format(len(target_disopt_state_dicts))) + + num_virtual_stages = args.target_virtual_pipeline_model_parallel_size + num_layers_per_virtual_stage = target_model_state_dict["args"].num_layers \ + // (args.pipeline_model_parallel_size * args.target_virtual_pipeline_model_parallel_size) + """ + in distrib_optim.pt + for ep=1, len(target_disopt_state_dicts) is always 1 containing all parameters + for ep>1: + if ep_rank=0 + target_disopt_state_dicts[0] contains non-experts parameters + target_disopt_state_dicts[i>0] only contains experts parameters if experts exists + if ep_rank>1 + target_disopt_state_dicts[0] is None + target_disopt_state_dicts[i>0] only contains experts parameters if experts exists + """ + + ### + for (i, target_disopt_state_dict) in enumerate(target_disopt_state_dicts): + if target_disopt_state_dict is None: + continue + if 0 not in target_disopt_state_dict: + log_and_exit("0 is not a key of target_disopt_state_dict, keys : {}".format(target_disopt_state_dict.keys())) + src_disopt_state_dict = target_disopt_state_dict[0] + if len(src_disopt_state_dict) != 1: + log_and_exit("length of src_disopt_state_dict is not 1, keys : {}".format(src_disopt_state_dict.keys())) + type_key = next(iter(src_disopt_state_dict)) + opt_parameters_dict = src_disopt_state_dict[type_key] + if target_pp_rank==0 and target_ep_rank<=1: + logger.info("[pp_rank=0][ep_rank={}] keys of opt_parameters_dict[{}] : {}".format( + target_ep_rank, i, opt_parameters_dict.keys())) + + vdisopts = [] + for vidx in range(num_virtual_stages): + """ + value in opt_parameters_dict are flatened tensors and were saved in reverse order compare to model state_dict + """ + num_elements_in_vmodel = 0 + vmodel_dict = target_model_state_dict[f"model{vidx}"] + new_parameters_dict = { + "param" : torch.tensor([], dtype=opt_parameters_dict["param"].dtype, device="cpu"), + "exp_avg" : torch.tensor([], dtype=opt_parameters_dict["exp_avg"].dtype, device="cpu"), + "exp_avg_sq" : torch.tensor([], dtype=opt_parameters_dict["exp_avg_sq"].dtype, device="cpu"), + "numel_unpadded" : 0, + } + + # final_layernorm and output_layer + if target_pp_rank == args.pipeline_model_parallel_size-1 \ + and vidx == num_virtual_stages-1 \ + and i == 0: + num_elements_tail = 0 + for (k, v) in reversed(vmodel_dict.items()): + if "final_layernorm." in k or "output_layer." in k: + if hasattr(v, "nelement"): + num_elements_tail += v.nelement() + else: + break + logger.debug(f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] apply output_layer, " + f"num_elements_tail={num_elements_tail}") + assert num_elements_tail != 0, f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] num_elements_tail is 0" + new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], + opt_parameters_dict["param"][:num_elements_tail])) + new_parameters_dict["exp_avg"] = torch.cat((new_parameters_dict["exp_avg"], + opt_parameters_dict["exp_avg"][:num_elements_tail])) + new_parameters_dict["exp_avg_sq"] = torch.cat((new_parameters_dict["exp_avg_sq"], + opt_parameters_dict["exp_avg_sq"][:num_elements_tail])) + num_elements_in_vmodel += num_elements_tail + + # middle layer + src_pp_rank, src_start_layer_idx = get_vpp_source_position( + target_pp_rank, + vidx, + args.pipeline_model_parallel_size, + num_virtual_stages, + num_layers_per_virtual_stage) + param_part, exp_avg_part, exp_avg_sq_part = _fetch_opt_parameters( + args, + src_pp_rank, + target_ep_rank, + src_start_layer_idx, + num_layers_per_virtual_stage, + i==0, + i>0 or len(target_disopt_state_dicts)==1) + new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], param_part)) + new_parameters_dict["exp_avg"] = torch.cat((new_parameters_dict["exp_avg"], exp_avg_part)) + new_parameters_dict["exp_avg_sq"] = torch.cat((new_parameters_dict["exp_avg_sq"], exp_avg_sq_part)) + num_elements_in_vmodel += param_part.nelement() + + # embedding layer + if target_pp_rank == 0 and vidx == 0 and i == 0: + num_elements_head = 0 + for (k, v) in vmodel_dict.items(): + if k.startswith("embedding."): + if hasattr(v, "nelement"): + num_elements_head += v.nelement() + else: + break + logger.debug(f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] apply embedding, " + f"num_elements_head={num_elements_head}") + assert num_elements_head != 0, f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] num_elements_head is 0" + + new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], + opt_parameters_dict["param"][-num_elements_head:])) + new_parameters_dict["exp_avg"] = torch.cat((new_parameters_dict["exp_avg"], + opt_parameters_dict["exp_avg"][-num_elements_head:])) + new_parameters_dict["exp_avg_sq"] = torch.cat((new_parameters_dict["exp_avg_sq"], + opt_parameters_dict["exp_avg_sq"][-num_elements_head:])) + num_elements_in_vmodel += num_elements_head + + new_parameters_dict["numel_unpadded"] = num_elements_in_vmodel + + for v in new_parameters_dict.values(): + v = v.clone().detach() if torch.is_tensor(v) else v + + if num_elements_in_vmodel != 0: + vdisopts.append({type_key : new_parameters_dict}) + + logger.debug(f"i={i}, vidx={vidx}, num_elements_in_vmodel={num_elements_in_vmodel}") + + for vidx, vdisopt in enumerate(vdisopts): + target_disopt_state_dict[vidx] = vdisopt + + # save + target_folder_path = os.path.join(args.save_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) + os.makedirs(target_folder_path, exist_ok = True) + target_file_path = os.path.join(target_folder_path, DISTRIB_OPTIM_FILENAME) + + if len(target_disopt_state_dicts) == 1: + target_disopt_state_dicts = target_disopt_state_dicts[0] + logger.info(f"saving distrib_optim to {target_file_path} ...") + torch.save(target_disopt_state_dicts, target_file_path) \ No newline at end of file diff --git a/tools/checkpoint/pp_tp_vpp/main.py b/tools/checkpoint/pp_tp_vpp/main.py new file mode 100644 index 000000000..6761ab8a4 --- /dev/null +++ b/tools/checkpoint/pp_tp_vpp/main.py @@ -0,0 +1,70 @@ +import ast +import logging +import argparse +from vpp_converter import convert_checkpoint + +logger = logging.getLogger(__name__) + +def _parse_list(s): + if s is None: + return None + return ast.literal_eval(s) + +def parse_arguments(): + parser = argparse.ArgumentParser(description="convert a non-virtual pipeline checkpoint to virtual pipeline checkpoint") + parser.add_argument("--load-iteration-dir", type=str, required=True, help="iteration folder of source model checkpoint") + parser.add_argument("--expert-model-parallel-size", type=int, required=True, help="ep_size of original model and the target model") + parser.add_argument("--pipeline-model-parallel-size", type=int, required=True, help="physical pp_size of original model and the target model") + + # arguments of target model below + parser.add_argument("--save-iteration-dir", type=str, required=True, help="iteration folder of target model checkpoint, need to be empty if existed") + parser.add_argument("--target-virtual-pipeline-model-parallel-size", type=int, required=True, help="vpp_size of target model") + + parser.add_argument("--num-max-processing-processes", type=int, default=4, + help="the maximum number of processing processes used by this script, " \ + "increasing this value can speed up model conversion(but the final bottleneck may be disk bandwidth), it will also consume more CPU memory.") + parser.add_argument('--pipeline-ranks-to-process', type=_parse_list, default=None, + help="pipeline rank list to process using this script, to accelerate converting \ + user can launch multiple tasks on different nodes, each one process part of pipeline ranks. \ + example : --pipeline-ranks-to-process [0,1,2,3] \ + default is None, means process all pipeline ranks") + + args = parser.parse_args() + return args + +""" +This tool can convert a checkpoint without virtual pipeline parallelism into one with virtual pipeline parallelism + by increasing the virtual pipeline stage size. +Other model parallel parameters (tensor-parallel-size, pipeline-parallel-size, expert-parallel-size ...) remain unchanged. + +Note that currently, all of the following configurations must be satisfied to be supported. + current pipeline partition is even (num_layers for each pipeline stage is equal) + tensor-model-parallel-size=1 + expert-tensor-parallel-size=1 + ckpt_type=CheckpointType.LEGACY + ckpt_format=torch +so the checkpoint for each iteration folder should look like this: +. +├── mp_rank_00_000_000 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +├── mp_rank_00_000_001 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +├── mp_rank_00_000_002 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +... + +""" +def main(args): + logger.info(f"args : {args}\n") + convert_checkpoint(args) + +if __name__ == "__main__": + logging.basicConfig( + level = logging.DEBUG, + format = "[%(asctime)s][%(levelname)s] %(message)s", + handlers = [logging.StreamHandler()]) + args = parse_arguments() + main(args) diff --git a/tools/checkpoint/pp_tp_vpp/model_optim_rng.py b/tools/checkpoint/pp_tp_vpp/model_optim_rng.py new file mode 100644 index 000000000..81ea3cfec --- /dev/null +++ b/tools/checkpoint/pp_tp_vpp/model_optim_rng.py @@ -0,0 +1,139 @@ +import os +import logging +from collections import OrderedDict + +from utils import ( + log_and_exit, + get_folder_name, + get_vpp_source_position, + MODEL_OPTIM_RNG_FILENAME, +) + +import torch + +logger = logging.getLogger(__name__) + +def _convert_state_dict_args(args, target_pp_rank, target_ep_rank, state_dict): + state_dict_args = state_dict["args"] + if args.target_virtual_pipeline_model_parallel_size <= 1: + log_and_exit(f"target_virtual_pipeline_model_parallel_size {args.target_virtual_pipeline_model_parallel_size} is smaller or equal to 1") + + if state_dict_args.tensor_model_parallel_size != 1: + log_and_exit("currently only tensor_model_parallel_size=1 is supported, but found {} in checkpoint".format( + state_dict_args.tensor_model_parallel_size)) + if args.expert_model_parallel_size != state_dict_args.expert_model_parallel_size: + log_and_exit("expert_model_parallel_size in args does not match the one in checkpoint, {} vs {}".format( + args.expert_model_parallel_size, state_dict_args.expert_model_parallel_size)) + if args.pipeline_model_parallel_size != state_dict_args.pipeline_model_parallel_size: + log_and_exit("pipeline_model_parallel_size in args does not match the one in checkpoint, {} vs {}".format( + args.pipeline_model_parallel_size, state_dict_args.pipeline_model_parallel_size)) + + if state_dict_args.num_layers and \ + state_dict_args.num_layers % (args.pipeline_model_parallel_size * args.target_virtual_pipeline_model_parallel_size) != 0: + log_and_exit("num_layers can not be evenly divided pipeline_model_parallel_size*target_virtual_pipeline_model_parallel_size, " + "num_layers={}, pipeline_model_parallel_size={}, target_virtual_pipeline_model_parallel_size={}".format( + state_dict_args.num_layers, args.pipeline_model_parallel_size, args.target_virtual_pipeline_model_parallel_size)) + + # args + state_dict_args.num_virtual_stages_per_pipeline_rank = args.target_virtual_pipeline_model_parallel_size + state_dict_args.virtual_pipeline_model_parallel_size = args.target_virtual_pipeline_model_parallel_size + state_dict_args.overlap_p2p_comm = True + state_dict_args.align_param_gather = True + +def _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, state_dict): + # TODO (optimizer state) + # suggest to reinitialize optimizer and do not load state_dict from checkpoint + #del state_dict["optimizer"] + pass + +def _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, state_dict): + # TODO + # further check rng state is equal + pass + +def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_per_virtual_stage): + src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, src_pp_rank, src_ep_rank)) + src_file_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) + + #logger.debug(f"loading {src_file_path} to fetch source tensors in virtual stage...") + state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) + state_dict_model = state_dict["model"] + + layer_idx_added = set() + outputs = OrderedDict() + for k, v in state_dict_model.items(): + if not ".layers." in k: + continue + layer_idx = int(k.split(".layers.")[1].split(".")[0]) + if src_start_layer_idx <= layer_idx < src_start_layer_idx+num_layers_per_virtual_stage: + new_key = k.replace(f".layers.{layer_idx}", f".layers.{layer_idx-src_start_layer_idx}") + outputs[new_key] = v.clone().detach() if torch.is_tensor(v) else v + layer_idx_added.add(layer_idx) + + assert len(layer_idx_added) == num_layers_per_virtual_stage, \ + "size of layer_idx_added does not equal to num_layers_per_virtual_stage, " \ + "{} vs {}".format(len(layer_idx_added), num_layers_per_virtual_stage) + return outputs + +def _convert_state_dict_model(args, target_pp_rank, target_ep_rank, state_dict): + state_dict_model = state_dict["model"] + + num_virtual_stages = args.target_virtual_pipeline_model_parallel_size + num_layers_per_virtual_stage = state_dict["args"].num_layers \ + // (args.pipeline_model_parallel_size * args.target_virtual_pipeline_model_parallel_size) + + vmodels = [OrderedDict() for i in range(num_virtual_stages)] + + for (vidx, vmodel) in enumerate(vmodels): + if target_pp_rank == 0 and vidx == 0: + for (k, v) in state_dict_model.items(): + if k.startswith("embedding."): + vmodel[k] = v.clone().detach() if torch.is_tensor(v) else v + + src_pp_rank, src_start_layer_idx = get_vpp_source_position( + target_pp_rank, + vidx, + args.pipeline_model_parallel_size, + num_virtual_stages, + num_layers_per_virtual_stage) + + src_model_state_dict = _fetch_model_state_dict(args, src_pp_rank, target_ep_rank, + src_start_layer_idx, num_layers_per_virtual_stage) + vmodel.update(src_model_state_dict) + + if target_pp_rank == args.pipeline_model_parallel_size-1 and vidx == num_virtual_stages-1: + for (k, v) in state_dict_model.items(): + if "final_layernorm." in k or k.startswith("output_layer."): + vmodel[k] = v.clone().detach() if torch.is_tensor(v) else v + + for i in range(num_virtual_stages): + state_dict[f"model{i}"] = vmodels[i] + + del state_dict["model"] + +def convert_model_optim_rng(args, target_pp_rank, target_ep_rank): + src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) + src_file_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) + + logger.info(f"loading model_optim_rng from {src_file_path} ...") + target_state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) + if target_pp_rank==0 and target_ep_rank<=1: + logger.info("[pp_rank=0][ep_rank={}] keys of model state dict : {}\n".format(target_ep_rank, target_state_dict.keys())) + + _convert_state_dict_args(args, target_pp_rank, target_ep_rank, target_state_dict) + + _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, target_state_dict) + + _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, target_state_dict) + + _convert_state_dict_model(args, target_pp_rank, target_ep_rank, target_state_dict) + + # save + target_folder_path = os.path.join(args.save_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) + os.makedirs(target_folder_path, exist_ok = True) + target_file_path = os.path.join(target_folder_path, MODEL_OPTIM_RNG_FILENAME) + + logger.info(f"saving model_optim_rng to {target_file_path} ...") + torch.save(target_state_dict, target_file_path) + + return target_state_dict \ No newline at end of file diff --git a/tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh b/tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh new file mode 100644 index 000000000..7ddec9c95 --- /dev/null +++ b/tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh @@ -0,0 +1,12 @@ +set -ex + +CURDIR=$(cd $(dirname $0); pwd) +cd $CURDIR + +python main.py \ + --load-iteration-dir /path/to/checkpoint_load/iter_0000050/ \ + --expert-model-parallel-size 8 \ + --pipeline-model-parallel-size 2 \ + --save-iteration-dir /path/to/checkpoint_save/iter_0000050/ \ + --target-virtual-pipeline-model-parallel-size 2 \ + --num-max-processing-processes 4 diff --git a/tools/checkpoint/pp_tp_vpp/utils.py b/tools/checkpoint/pp_tp_vpp/utils.py new file mode 100644 index 000000000..63030c1af --- /dev/null +++ b/tools/checkpoint/pp_tp_vpp/utils.py @@ -0,0 +1,40 @@ +import os +import sys +import logging + +logger = logging.getLogger(__name__) + +MODEL_OPTIM_RNG_FILENAME = "model_optim_rng.pt" +DISTRIB_OPTIM_FILENAME = "distrib_optim.pt" + +def log_and_exit(message): + logger.fatal(message) + sys.exit(1) + +def get_folder_name(args, target_pp_rank, target_ep_rank): + folder_name = "mp_rank_00" + if args.pipeline_model_parallel_size != 1: + folder_name += f"_{target_pp_rank:03d}" + if args.expert_model_parallel_size != 1: + folder_name += f"_{target_ep_rank:03d}" + return folder_name + +def get_vpp_source_position( + target_pp_rank, + target_virtual_idx, + pipeline_parallel_size, + num_virtual_stages, + num_layers_per_virtual_stage): + + num_layers = pipeline_parallel_size * num_layers_per_virtual_stage * num_virtual_stages + assert num_layers % pipeline_parallel_size == 0, \ + f"num_layers({num_layers}) is not divisible by pipeline_parallel_size({pipeline_parallel_size})" + num_layers_per_pipeline_stage = num_layers // pipeline_parallel_size + + target_global_layer_idx = (target_virtual_idx * pipeline_parallel_size + target_pp_rank) * num_layers_per_virtual_stage + source_pp_rank = target_global_layer_idx // num_layers_per_pipeline_stage + source_start_layer_idx = target_global_layer_idx % num_layers_per_pipeline_stage + + logger.debug(f"get_vpp_source_position, target_pp_rank={target_pp_rank}, target_virtual_idx={target_virtual_idx}; " + f"source_pp_rank={source_pp_rank}, source_start_layer_idx={source_start_layer_idx}") + return (source_pp_rank, source_start_layer_idx) \ No newline at end of file diff --git a/tools/checkpoint/pp_tp_vpp/vpp_converter.py b/tools/checkpoint/pp_tp_vpp/vpp_converter.py new file mode 100644 index 000000000..169c5e0f3 --- /dev/null +++ b/tools/checkpoint/pp_tp_vpp/vpp_converter.py @@ -0,0 +1,48 @@ +import os +import logging +import itertools +import multiprocessing + +from utils import log_and_exit +from model_optim_rng import convert_model_optim_rng +from distrib_optim import convert_distrib_optim + +logger = logging.getLogger(__name__) + +def _check_output_folder(args): + output_folder = args.save_iteration_dir + if os.path.exists(output_folder): + if not os.path.isdir(output_folder): + log_and_exit(f"output path {output_folder} exists but is not a directory") + #if len(os.listdir(output_folder)) > 0: + # log_and_exit(f"output path {output_folder} exists but not empty") + else: + os.makedirs(output_folder) + +def _convert_checkpoint_partial(args, target_pp_rank, target_ep_rank): + logger.debug(f"start _convert_checkpoint_partial, pp_rank={target_pp_rank}, ep_rank={target_ep_rank}") + target_model_state_dict = convert_model_optim_rng(args, target_pp_rank, target_ep_rank) + convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict) + +def _func_arguments_wrapper(func_arguments): + args, pp_rank, ep_rank = func_arguments + _convert_checkpoint_partial(args, pp_rank, ep_rank) + +def convert_checkpoint(args): + _check_output_folder(args) + + if args.pipeline_ranks_to_process is None: + args.pipeline_ranks_to_process = range(args.pipeline_model_parallel_size) + + pp_ranges = args.pipeline_ranks_to_process + ep_ranges = range(args.expert_model_parallel_size) + + func_arguments_tuples = [(args, x, y) for x, y in itertools.product(pp_ranges, ep_ranges)] + logger.info("pp_ranges : {}".format(pp_ranges)) + logger.info("ep_ranges : {}".format(ep_ranges)) + + logger.info("start convert...") + logger.debug(f"args.num_max_processing_processes={args.num_max_processing_processes}") + with multiprocessing.Pool(processes = args.num_max_processing_processes) as pool: + pool.map(_func_arguments_wrapper, func_arguments_tuples) + logger.info("convert finished") \ No newline at end of file From dae071436499fdcb567150799db1649c91c88b04 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Tue, 3 Jun 2025 08:38:55 +0000 Subject: [PATCH 02/23] support uneven pipeline mode --- tools/checkpoint/pp_tp_vpp/README.md | 67 +++++++- tools/checkpoint/pp_tp_vpp/distrib_optim.py | 64 +++++--- tools/checkpoint/pp_tp_vpp/main.py | 34 ++++- tools/checkpoint/pp_tp_vpp/model_optim_rng.py | 134 ++++++++-------- .../pp_tp_vpp/run_convert_pp_to_vpp.sh | 8 +- tools/checkpoint/pp_tp_vpp/utils.py | 143 +++++++++++++++--- tools/checkpoint/pp_tp_vpp/vpp_converter.py | 13 +- 7 files changed, 338 insertions(+), 125 deletions(-) diff --git a/tools/checkpoint/pp_tp_vpp/README.md b/tools/checkpoint/pp_tp_vpp/README.md index aa9ee819a..57e08d5ff 100644 --- a/tools/checkpoint/pp_tp_vpp/README.md +++ b/tools/checkpoint/pp_tp_vpp/README.md @@ -4,14 +4,31 @@ This tool can convert a language model checkpoint without virtual pipeline paral Other model parallel parameters (tensor-parallel-size, pipeline-parallel-size, expert-parallel-size ...) remain unchanged. -**Currently, tests have been conducted on the DeepSeek(v2, v3) models.** +--- + +**(2025-05-30)** It now supports uneven pipeline mode, as well as cases where the number of layers in a pipeline stage is not divisible by the virtual pipeline degree. + +see arguments: +``` +--target-first-virtual-pipeline-num-layers-split +--target-last-virtual-pipeline-num-layers-split +``` +The above two parameters must either both be provided(or both be omitted), indicating that uneven pipeline mode is enabled + and specifying the virtual pipeline layer distribution for the first and last pipeline stages(this distribution may be even, but it still needs to be explicitly provided). + +This feature was introduced based on the following Pull Request. + + https://github.com/microsoft/ltp-megatron-lm/pull/27 + + The model after converted needs to be loaded using a Megatron-LM framework that has this Pull Request applied. + +--- + +**Currently, tests have been conducted on the DeepSeek(v2, v3) and Mixtral models.** Note that currently, all of the following configurations must be satisfied to be supported. -1. current pipeline partition is even (num_layers for each pipeline stage is equal) -2. tensor-model-parallel-size=1 -3. ckpt_type=CheckpointType.LEGACY -4. ckpt_format=torch - + tensor_parallel_size=1 + ckpt_format=torch so the checkpoint for each iteration folder should look like this: ``` iter_0000050 @@ -33,6 +50,8 @@ you can modify run_convert_pp_to_vpp.sh and launch it as an example ``` usage: main.py [-h] --load-iteration-dir LOAD_ITERATION_DIR --expert-model-parallel-size EXPERT_MODEL_PARALLEL_SIZE --pipeline-model-parallel-size PIPELINE_MODEL_PARALLEL_SIZE --save-iteration-dir SAVE_ITERATION_DIR --target-virtual-pipeline-model-parallel-size TARGET_VIRTUAL_PIPELINE_MODEL_PARALLEL_SIZE + [--target-first-virtual-pipeline-num-layers-split TARGET_FIRST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT [TARGET_FIRST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT ...]] + [--target-last-virtual-pipeline-num-layers-split TARGET_LAST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT [TARGET_LAST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT ...]] [--num-max-processing-processes NUM_MAX_PROCESSING_PROCESSES] [--pipeline-ranks-to-process PIPELINE_RANKS_TO_PROCESS] convert a non-virtual pipeline checkpoint to virtual pipeline checkpoint @@ -49,6 +68,10 @@ options: iteration folder of target model checkpoint, need to be empty if existed --target-virtual-pipeline-model-parallel-size TARGET_VIRTUAL_PIPELINE_MODEL_PARALLEL_SIZE vpp_size of target model + --target-first-virtual-pipeline-num-layers-split TARGET_FIRST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT [TARGET_FIRST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT ...] + only used in uneven pipeline mode, virtual pipeline split of the first stage + --target-last-virtual-pipeline-num-layers-split TARGET_LAST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT [TARGET_LAST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT ...] + only used in uneven pipeline mode, virtual pipeline split of the last stage --num-max-processing-processes NUM_MAX_PROCESSING_PROCESSES the maximum number of processing processes used by this script, increasing this value can speed up model conversion(but the final bottleneck may be disk bandwidth), it will also consume more CPU memory. @@ -58,7 +81,7 @@ options: ``` ## examples -The target model has virtual_pipeline_size=2, and uses 4 processes in parallel. +1) The target model has virtual_pipeline_size=2, and uses 4 processes in parallel. ``` python main.py \ --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ @@ -68,7 +91,8 @@ python main.py \ --target-virtual-pipeline-model-parallel-size 2 \ --num-max-processing-processes 4 ``` -Convert the checkpoints generated by pipeline ranks [0,1,2,3] on node 1, and convert the checkpoints generated by pipeline ranks [4,5,6,7] on node 2. (in cases where memory is limited on a single node) + +2) Convert the checkpoints generated by pipeline ranks [0,1,2,3] on node 1, and convert the checkpoints generated by pipeline ranks [4,5,6,7] on node 2. (in cases where memory is limited on a single node) ``` # node1 : python main.py \ @@ -91,6 +115,33 @@ python main.py \ --pipeline-ranks-to-process [4,5,6,7] ``` +3) convert a model with uneven pipeline mode, which was saved by Megatron-LM with arguments +``` +--decoder-first-pipeline-num-layers 8 +--decoder-last-pipeline-num-layers 7 +``` + +``` +# suppose pipeline_parallel_size=4, the model contains 31 layers in total, the layers distribution for each pipeline stages is [8, 8, 8, 7] +# now we use this model to inscrease virtual pipeline size to 2, +# the layer split in first pipeline stage is [4, 4] and the layer split in last pipeline stage is [4, 3] +# vpp0 vpp1 +# pp0 0, 1, 2, 3 16,17,18,19 +# pp1 4, 5, 6, 7 20,21,22,23 +# pp2 8, 9,10,11 24,25,26,27 +# pp3 12,13,14,15 28,29,30 + +python main.py \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ + --expert-model-parallel-size 8 \ + --pipeline-model-parallel-size 4 \ + --target-virtual-pipeline-model-parallel-size 2 \ + --target-first-virtual-pipeline-num-layers-split 4 4 \ + --target-last-virtual-pipeline-num-layers-split 4 3 \ + --num-max-processing-processes 8 +``` + Some training logs from the tests are available in the **logs** directory for review. ## NOTE diff --git a/tools/checkpoint/pp_tp_vpp/distrib_optim.py b/tools/checkpoint/pp_tp_vpp/distrib_optim.py index 8da9b36c6..d3fb4aa0c 100644 --- a/tools/checkpoint/pp_tp_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_tp_vpp/distrib_optim.py @@ -2,30 +2,34 @@ import logging from utils import ( + RetainLogLevel, log_and_exit, get_folder_name, get_vpp_source_position, MODEL_OPTIM_RNG_FILENAME, DISTRIB_OPTIM_FILENAME, + get_num_layers_for_this_vstage, ) import torch logger = logging.getLogger(__name__) -def _fetch_opt_parameters(args, +def _fetch_opt_parameters( + args, src_pp_rank, src_ep_rank, src_start_layer_idx, - num_layers_per_virtual_stage, + num_layers_for_this_virtual_stage, count_dense : bool, count_experts : bool): src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, src_pp_rank, src_ep_rank)) src_model_opt_rng_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) src_distrib_optim_path = os.path.join(src_folder_path, DISTRIB_OPTIM_FILENAME) - - state_dict_model = torch.load(src_model_opt_rng_path, map_location="cpu", weights_only=False)["model"] - state_dict_disopts = torch.load(src_distrib_optim_path, map_location="cpu", weights_only=False) + + with RetainLogLevel(): + state_dict_model = torch.load(src_model_opt_rng_path, map_location="cpu", weights_only=False)["model"] + state_dict_disopts = torch.load(src_distrib_optim_path, map_location="cpu") if not isinstance(state_dict_disopts, list): state_dict_disopts = [state_dict_disopts] @@ -41,24 +45,29 @@ def _fetch_opt_parameters(args, current_offset = 0 start_offset, end_offset = -1, -1 + upper_bound_layer_idx = -1 # variable for double-check for (k, v) in state_dict_model.items(): if not hasattr(v, "nelement"): continue if "router.expert_bias" in k: - continue - if ".layer_idx" in k: - # added for debug + # in Megatron-LM + # router.expert_bias is initialized by calling register_buffer continue if not ".layers." in k: if start_offset != -1: end_offset = current_offset + upper_bound_layer_idx += 1 break current_offset += v.nelement() if count_dense else 0 continue layer_idx = int(k.split(".layers.")[1].split(".")[0]) + + assert layer_idx >= upper_bound_layer_idx, "layer_idx stored unordered in checkpoint" + upper_bound_layer_idx = layer_idx + if layer_idx == src_start_layer_idx and start_offset == -1: start_offset = current_offset - if layer_idx == src_start_layer_idx+num_layers_per_virtual_stage \ + if layer_idx == src_start_layer_idx + num_layers_for_this_virtual_stage \ and start_offset != -1: end_offset = current_offset break @@ -69,6 +78,9 @@ def _fetch_opt_parameters(args, assert start_offset != -1 if end_offset == -1: end_offset = current_offset + upper_bound_layer_idx += 1 + assert upper_bound_layer_idx - src_start_layer_idx == num_layers_for_this_virtual_stage, \ + "layer idx double-check failed" opt_parameters_dict = next(iter(state_dict_disopt[0].values())) (param_part, exp_avg_part, exp_avg_sq_part) = ( @@ -80,26 +92,23 @@ def _fetch_opt_parameters(args, "count_dense={}, count_experts={}, num_elements_in_param_part={}".format(count_dense, count_experts, param_part.nelement())) return (param_part, exp_avg_part, exp_avg_sq_part) -def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict): +def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict, ckpt_ctx): src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) src_file_path = os.path.join(src_folder_path, DISTRIB_OPTIM_FILENAME) logger.info(f"loading distrib_optim state dict {src_file_path} ...") - target_disopt_state_dicts = torch.load(src_file_path, map_location="cpu", weights_only=False) + target_disopt_state_dicts = torch.load(src_file_path, map_location="cpu") if not isinstance(target_disopt_state_dicts, list): logger.info("target_disopt_state_dicts is not a list, pp_size={}, ep_size={}, pp_rank={}, ep_rank={}".format( - args.pipeline_model_parallel_size, - args.expert_model_parallel_size, + ckpt_ctx.pp_size, + ckpt_ctx.ep_size, target_pp_rank, target_ep_rank)) target_disopt_state_dicts = [target_disopt_state_dicts] else: logger.info("length of target_disopt_state_dicts : {}".format(len(target_disopt_state_dicts))) - num_virtual_stages = args.target_virtual_pipeline_model_parallel_size - num_layers_per_virtual_stage = target_model_state_dict["args"].num_layers \ - // (args.pipeline_model_parallel_size * args.target_virtual_pipeline_model_parallel_size) """ in distrib_optim.pt for ep=1, len(target_disopt_state_dicts) is always 1 containing all parameters @@ -128,9 +137,19 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta target_ep_rank, i, opt_parameters_dict.keys())) vdisopts = [] - for vidx in range(num_virtual_stages): + for vidx in range(ckpt_ctx.vpp_size): """ value in opt_parameters_dict are flatened tensors and were saved in reverse order compare to model state_dict + for example: + model_ordered_state_dict(key-tensor pair) : { + k1 : t1, + k2 : t2, + k3 : t3, + ... + kn : tn + } + opt_parameters_dict["param"] : [tn, ... t3, t2, t1] + opt_parameters_dict["exp_avg"] : [tn, ... t3, t2, t1] """ num_elements_in_vmodel = 0 vmodel_dict = target_model_state_dict[f"model{vidx}"] @@ -142,8 +161,8 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta } # final_layernorm and output_layer - if target_pp_rank == args.pipeline_model_parallel_size-1 \ - and vidx == num_virtual_stages-1 \ + if target_pp_rank == ckpt_ctx.pp_size - 1 \ + and vidx == ckpt_ctx.vpp_size - 1 \ and i == 0: num_elements_tail = 0 for (k, v) in reversed(vmodel_dict.items()): @@ -167,15 +186,14 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta src_pp_rank, src_start_layer_idx = get_vpp_source_position( target_pp_rank, vidx, - args.pipeline_model_parallel_size, - num_virtual_stages, - num_layers_per_virtual_stage) + ckpt_ctx) + num_layers_for_this_virtual_stage = get_num_layers_for_this_vstage(target_pp_rank, vidx, ckpt_ctx) param_part, exp_avg_part, exp_avg_sq_part = _fetch_opt_parameters( args, src_pp_rank, target_ep_rank, src_start_layer_idx, - num_layers_per_virtual_stage, + num_layers_for_this_virtual_stage, i==0, i>0 or len(target_disopt_state_dicts)==1) new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], param_part)) diff --git a/tools/checkpoint/pp_tp_vpp/main.py b/tools/checkpoint/pp_tp_vpp/main.py index 6761ab8a4..8cb35d02c 100644 --- a/tools/checkpoint/pp_tp_vpp/main.py +++ b/tools/checkpoint/pp_tp_vpp/main.py @@ -1,4 +1,6 @@ +import os import ast +import sys import logging import argparse from vpp_converter import convert_checkpoint @@ -19,8 +21,13 @@ def parse_arguments(): # arguments of target model below parser.add_argument("--save-iteration-dir", type=str, required=True, help="iteration folder of target model checkpoint, need to be empty if existed") parser.add_argument("--target-virtual-pipeline-model-parallel-size", type=int, required=True, help="vpp_size of target model") + parser.add_argument("--target-first-virtual-pipeline-num-layers-split", type=int, nargs="+", default=None, + help="only used in uneven pipeline mode, virtual pipeline split of the first stage") + parser.add_argument("--target-last-virtual-pipeline-num-layers-split", type=int, nargs="+", default=None, + help="only used in uneven pipeline mode, virtual pipeline split of the last stage") - parser.add_argument("--num-max-processing-processes", type=int, default=4, + # arguments of acceleration in parallel + parser.add_argument("--num-max-processing-processes", type=int, default=8, help="the maximum number of processing processes used by this script, " \ "increasing this value can speed up model conversion(but the final bottleneck may be disk bandwidth), it will also consume more CPU memory.") parser.add_argument('--pipeline-ranks-to-process', type=_parse_list, default=None, @@ -35,16 +42,27 @@ def parse_arguments(): """ This tool can convert a checkpoint without virtual pipeline parallelism into one with virtual pipeline parallelism by increasing the virtual pipeline stage size. + +(2025-05-30) +It now supports uneven pipeline mode, as well as cases where the number of layers in a pipeline stage is not divisible by the virtual pipeline degree. + see arguments: + --target-first-virtual-pipeline-num-layers-split + --target-last-virtual-pipeline-num-layers-split +The above two parameters must either both be provided(or both be omitted), indicating that uneven pipeline mode is enabled + and specifying the virtual pipeline layer distribution for the first and last pipeline stages. + (this distribution may be even, but it still needs to be explicitly provided.) +This feature was introduced based on the following Pull Request. + https://github.com/microsoft/ltp-megatron-lm/pull/27 + The model after converted needs to be loaded using a Megatron-LM framework that has this Pull Request applied. + Other model parallel parameters (tensor-parallel-size, pipeline-parallel-size, expert-parallel-size ...) remain unchanged. Note that currently, all of the following configurations must be satisfied to be supported. - current pipeline partition is even (num_layers for each pipeline stage is equal) - tensor-model-parallel-size=1 - expert-tensor-parallel-size=1 - ckpt_type=CheckpointType.LEGACY + tensor_parallel_size=1 ckpt_format=torch -so the checkpoint for each iteration folder should look like this: -. +So the checkpoint for each iteration folder should look like this: + +iter_0000005 ├── mp_rank_00_000_000 │ ├── distrib_optim.pt │ └── model_optim_rng.pt @@ -57,8 +75,8 @@ def parse_arguments(): ... """ + def main(args): - logger.info(f"args : {args}\n") convert_checkpoint(args) if __name__ == "__main__": diff --git a/tools/checkpoint/pp_tp_vpp/model_optim_rng.py b/tools/checkpoint/pp_tp_vpp/model_optim_rng.py index 81ea3cfec..8ef4a9042 100644 --- a/tools/checkpoint/pp_tp_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_tp_vpp/model_optim_rng.py @@ -3,60 +3,65 @@ from collections import OrderedDict from utils import ( + RetainLogLevel, log_and_exit, get_folder_name, get_vpp_source_position, MODEL_OPTIM_RNG_FILENAME, + CKPTContext, + get_num_layers_for_this_vstage, ) import torch logger = logging.getLogger(__name__) -def _convert_state_dict_args(args, target_pp_rank, target_ep_rank, state_dict): +def _convert_state_dict_args(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): state_dict_args = state_dict["args"] - if args.target_virtual_pipeline_model_parallel_size <= 1: - log_and_exit(f"target_virtual_pipeline_model_parallel_size {args.target_virtual_pipeline_model_parallel_size} is smaller or equal to 1") - - if state_dict_args.tensor_model_parallel_size != 1: - log_and_exit("currently only tensor_model_parallel_size=1 is supported, but found {} in checkpoint".format( - state_dict_args.tensor_model_parallel_size)) - if args.expert_model_parallel_size != state_dict_args.expert_model_parallel_size: - log_and_exit("expert_model_parallel_size in args does not match the one in checkpoint, {} vs {}".format( - args.expert_model_parallel_size, state_dict_args.expert_model_parallel_size)) - if args.pipeline_model_parallel_size != state_dict_args.pipeline_model_parallel_size: - log_and_exit("pipeline_model_parallel_size in args does not match the one in checkpoint, {} vs {}".format( - args.pipeline_model_parallel_size, state_dict_args.pipeline_model_parallel_size)) - - if state_dict_args.num_layers and \ - state_dict_args.num_layers % (args.pipeline_model_parallel_size * args.target_virtual_pipeline_model_parallel_size) != 0: - log_and_exit("num_layers can not be evenly divided pipeline_model_parallel_size*target_virtual_pipeline_model_parallel_size, " - "num_layers={}, pipeline_model_parallel_size={}, target_virtual_pipeline_model_parallel_size={}".format( - state_dict_args.num_layers, args.pipeline_model_parallel_size, args.target_virtual_pipeline_model_parallel_size)) - - # args - state_dict_args.num_virtual_stages_per_pipeline_rank = args.target_virtual_pipeline_model_parallel_size - state_dict_args.virtual_pipeline_model_parallel_size = args.target_virtual_pipeline_model_parallel_size + if ckpt_ctx.uneven_mode: + state_dict_args.decoder_first_pipeline_num_layers_split = ckpt_ctx.first_vpp_layer_split + state_dict_args.decoder_last_pipeline_num_layers_split = ckpt_ctx.last_vpp_layer_split + + state_dict_args.num_virtual_stages_per_pipeline_rank = ckpt_ctx.vpp_size + state_dict_args.virtual_pipeline_model_parallel_size = ckpt_ctx.vpp_size state_dict_args.overlap_p2p_comm = True state_dict_args.align_param_gather = True - -def _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, state_dict): - # TODO (optimizer state) - # suggest to reinitialize optimizer and do not load state_dict from checkpoint - #del state_dict["optimizer"] + if hasattr(state_dict_args, "local_rank"): + delattr(state_dict_args, "local_rank") + if hasattr(state_dict_args, "rank"): + delattr(state_dict_args, "rank") + +def _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): + # optimizer state dict is equal + # TODO : step + optimizer_states = state_dict["optimizer"] + try: + current_step = -1 + param_group_cadidates = [] + for opt_state in optimizer_states: + for param_group in opt_state["optimizer"]["param_groups"]: + if "step" in param_group: + current_step = param_group["step"] + else: + param_group_cadidates.append(param_group) + if current_step != -1: + for param_group in param_group_cadidates: + param_group["step"] = current_step + logger.info(f"add step={current_step} in optimizer state") + except Exception: + logger.warning("add step to optimizer state failed") + +def _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): + # rng state is equal pass -def _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, state_dict): - # TODO - # further check rng state is equal - pass - -def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_per_virtual_stage): +def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_for_this_virtual_stage): src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, src_pp_rank, src_ep_rank)) src_file_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) #logger.debug(f"loading {src_file_path} to fetch source tensors in virtual stage...") - state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) + with RetainLogLevel(): + state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) state_dict_model = state_dict["model"] layer_idx_added = set() @@ -65,24 +70,25 @@ def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, if not ".layers." in k: continue layer_idx = int(k.split(".layers.")[1].split(".")[0]) - if src_start_layer_idx <= layer_idx < src_start_layer_idx+num_layers_per_virtual_stage: + if src_start_layer_idx <= layer_idx < src_start_layer_idx+num_layers_for_this_virtual_stage: new_key = k.replace(f".layers.{layer_idx}", f".layers.{layer_idx-src_start_layer_idx}") outputs[new_key] = v.clone().detach() if torch.is_tensor(v) else v layer_idx_added.add(layer_idx) - - assert len(layer_idx_added) == num_layers_per_virtual_stage, \ - "size of layer_idx_added does not equal to num_layers_per_virtual_stage, " \ - "{} vs {}".format(len(layer_idx_added), num_layers_per_virtual_stage) + + # TODO + # Currently, arbitrary layer partitioning is not supported, so I add a double-check assert here. + # If not supported, an error will be raised at this point. + assert len(layer_idx_added) == num_layers_for_this_virtual_stage, \ + "size of layer_idx_added does not equal to num_layers_for_this_virtual_stage, " \ + "{} vs {} ,".format(len(layer_idx_added), num_layers_for_this_virtual_stage) + \ + "src_pp_rank={}, src_ep_rank={}, src_start_layer_idx={}, num_layers_for_this_virtual_stage={}".format( + src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_for_this_virtual_stage) + return outputs -def _convert_state_dict_model(args, target_pp_rank, target_ep_rank, state_dict): +def _convert_state_dict_model(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): state_dict_model = state_dict["model"] - - num_virtual_stages = args.target_virtual_pipeline_model_parallel_size - num_layers_per_virtual_stage = state_dict["args"].num_layers \ - // (args.pipeline_model_parallel_size * args.target_virtual_pipeline_model_parallel_size) - - vmodels = [OrderedDict() for i in range(num_virtual_stages)] + vmodels = [OrderedDict() for i in range(ckpt_ctx.vpp_size)] for (vidx, vmodel) in enumerate(vmodels): if target_pp_rank == 0 and vidx == 0: @@ -93,20 +99,23 @@ def _convert_state_dict_model(args, target_pp_rank, target_ep_rank, state_dict): src_pp_rank, src_start_layer_idx = get_vpp_source_position( target_pp_rank, vidx, - args.pipeline_model_parallel_size, - num_virtual_stages, - num_layers_per_virtual_stage) + ckpt_ctx) - src_model_state_dict = _fetch_model_state_dict(args, src_pp_rank, target_ep_rank, - src_start_layer_idx, num_layers_per_virtual_stage) + num_layers_for_this_virtual_stage = get_num_layers_for_this_vstage(target_pp_rank, vidx, ckpt_ctx) + src_model_state_dict = _fetch_model_state_dict( + args, + src_pp_rank, + target_ep_rank, + src_start_layer_idx, + num_layers_for_this_virtual_stage) vmodel.update(src_model_state_dict) - if target_pp_rank == args.pipeline_model_parallel_size-1 and vidx == num_virtual_stages-1: + if target_pp_rank == ckpt_ctx.pp_size-1 and vidx == ckpt_ctx.vpp_size-1: for (k, v) in state_dict_model.items(): if "final_layernorm." in k or k.startswith("output_layer."): vmodel[k] = v.clone().detach() if torch.is_tensor(v) else v - for i in range(num_virtual_stages): + for i in range(ckpt_ctx.vpp_size): state_dict[f"model{i}"] = vmodels[i] del state_dict["model"] @@ -116,17 +125,22 @@ def convert_model_optim_rng(args, target_pp_rank, target_ep_rank): src_file_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) logger.info(f"loading model_optim_rng from {src_file_path} ...") - target_state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) + with RetainLogLevel(): + target_state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) + if target_pp_rank==0 and target_ep_rank<=1: logger.info("[pp_rank=0][ep_rank={}] keys of model state dict : {}\n".format(target_ep_rank, target_state_dict.keys())) - _convert_state_dict_args(args, target_pp_rank, target_ep_rank, target_state_dict) + ckpt_ctx = CKPTContext() + ckpt_ctx.check_args_and_fill(args, target_state_dict) + + _convert_state_dict_args(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) - _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, target_state_dict) + _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) - _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, target_state_dict) + _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) - _convert_state_dict_model(args, target_pp_rank, target_ep_rank, target_state_dict) + _convert_state_dict_model(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) # save target_folder_path = os.path.join(args.save_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) @@ -136,4 +150,4 @@ def convert_model_optim_rng(args, target_pp_rank, target_ep_rank): logger.info(f"saving model_optim_rng to {target_file_path} ...") torch.save(target_state_dict, target_file_path) - return target_state_dict \ No newline at end of file + return (target_state_dict, ckpt_ctx) \ No newline at end of file diff --git a/tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh b/tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh index 7ddec9c95..270be5d45 100644 --- a/tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh +++ b/tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh @@ -4,9 +4,9 @@ CURDIR=$(cd $(dirname $0); pwd) cd $CURDIR python main.py \ - --load-iteration-dir /path/to/checkpoint_load/iter_0000050/ \ - --expert-model-parallel-size 8 \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --expert-model-parallel-size 4 \ --pipeline-model-parallel-size 2 \ - --save-iteration-dir /path/to/checkpoint_save/iter_0000050/ \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ --target-virtual-pipeline-model-parallel-size 2 \ - --num-max-processing-processes 4 + --num-max-processing-processes 8 \ No newline at end of file diff --git a/tools/checkpoint/pp_tp_vpp/utils.py b/tools/checkpoint/pp_tp_vpp/utils.py index 63030c1af..551286299 100644 --- a/tools/checkpoint/pp_tp_vpp/utils.py +++ b/tools/checkpoint/pp_tp_vpp/utils.py @@ -7,9 +7,19 @@ MODEL_OPTIM_RNG_FILENAME = "model_optim_rng.pt" DISTRIB_OPTIM_FILENAME = "distrib_optim.pt" +# torch.load sometimes change loglevel with weights_only=False +# so log message after torch.load does not show on terminal screen +class RetainLogLevel: + def __enter__(self): + self.origin_log_level = logging.getLogger().level + return self + def __exit__(self, exc_type, exc_value, traceback): + logging.getLogger().setLevel(self.origin_log_level) + def log_and_exit(message): logger.fatal(message) - sys.exit(1) + raise Exception("exit with fatal error") + #sys.exit(1) def get_folder_name(args, target_pp_rank, target_ep_rank): folder_name = "mp_rank_00" @@ -22,19 +32,118 @@ def get_folder_name(args, target_pp_rank, target_ep_rank): def get_vpp_source_position( target_pp_rank, target_virtual_idx, - pipeline_parallel_size, - num_virtual_stages, - num_layers_per_virtual_stage): - - num_layers = pipeline_parallel_size * num_layers_per_virtual_stage * num_virtual_stages - assert num_layers % pipeline_parallel_size == 0, \ - f"num_layers({num_layers}) is not divisible by pipeline_parallel_size({pipeline_parallel_size})" - num_layers_per_pipeline_stage = num_layers // pipeline_parallel_size - - target_global_layer_idx = (target_virtual_idx * pipeline_parallel_size + target_pp_rank) * num_layers_per_virtual_stage - source_pp_rank = target_global_layer_idx // num_layers_per_pipeline_stage - source_start_layer_idx = target_global_layer_idx % num_layers_per_pipeline_stage - - logger.debug(f"get_vpp_source_position, target_pp_rank={target_pp_rank}, target_virtual_idx={target_virtual_idx}; " - f"source_pp_rank={source_pp_rank}, source_start_layer_idx={source_start_layer_idx}") - return (source_pp_rank, source_start_layer_idx) \ No newline at end of file + ckpt_ctx): + + num_middle_stages = ckpt_ctx.pp_size - 2 # should be non-negative + num_layers_per_middle_virtual_stage = (ckpt_ctx.num_middle_layers // (num_middle_stages * ckpt_ctx.vpp_size)) if num_middle_stages > 0 else 0 + + target_global_layer_idx = target_virtual_idx * num_middle_stages * num_layers_per_middle_virtual_stage + \ + sum(ckpt_ctx.first_vpp_layer_split[:target_virtual_idx]) + sum(ckpt_ctx.last_vpp_layer_split[:target_virtual_idx]) \ + + (0 if target_pp_rank == 0 else ckpt_ctx.first_vpp_layer_split[target_virtual_idx] + (target_pp_rank-1) * num_layers_per_middle_virtual_stage) + + prefix_sum = 0 + for pp_stage in range(ckpt_ctx.pp_size): + if pp_stage == 0: + layers_in_current_stage = sum(ckpt_ctx.first_vpp_layer_split) + elif pp_stage == ckpt_ctx.pp_size - 1: + layers_in_current_stage = sum(ckpt_ctx.last_vpp_layer_split) + else: + layers_in_current_stage = num_layers_per_middle_virtual_stage * ckpt_ctx.vpp_size + if prefix_sum <= target_global_layer_idx < (prefix_sum + layers_in_current_stage): + source_start_layer_idx = target_global_layer_idx-prefix_sum + logger.info(f"get_vpp_source_position, target_pp_rank={target_pp_rank}, target_virtual_idx={target_virtual_idx}; " + f"source_pp_rank={pp_stage}, source_start_layer_idx={source_start_layer_idx}") + return (pp_stage, source_start_layer_idx) + prefix_sum += layers_in_current_stage + log_and_exit("should never reach here") + # + +def get_num_layers_for_this_vstage(pp_rank, vpp_rank, ckpt_ctx): + if not ckpt_ctx.uneven_mode: + return ckpt_ctx.num_layers // (ckpt_ctx.pp_size * ckpt_ctx.vpp_size) + if pp_rank == 0: + return ckpt_ctx.first_vpp_layer_split[vpp_rank] + if pp_rank == ckpt_ctx.pp_size - 1: + return ckpt_ctx.last_vpp_layer_split[vpp_rank] + + num_middle_stages = ckpt_ctx.pp_size - 2 + return ckpt_ctx.num_middle_layers // (num_middle_stages * ckpt_ctx.vpp_size) + + +class CKPTContext: + def check_args_and_fill(self, args, state_dict): + self.vpp_size = args.target_virtual_pipeline_model_parallel_size + self.pp_size = args.pipeline_model_parallel_size + self.ep_size = args.expert_model_parallel_size + self.first_vpp_layer_split = args.target_first_virtual_pipeline_num_layers_split + self.last_vpp_layer_split = args.target_last_virtual_pipeline_num_layers_split + + state_dict_args = state_dict["args"] + self.num_layers = state_dict_args.num_layers + + if state_dict_args.tensor_model_parallel_size != 1: + log_and_exit("currently only tensor_model_parallel_size=1 is supported, but found {} in checkpoint".format( + state_dict_args.tensor_model_parallel_size)) + + if self.vpp_size <= 1: + log_and_exit(f"target_virtual_pipeline_model_parallel_size {self.vpp_size} is smaller or equal to 1") + + if self.ep_size != state_dict_args.expert_model_parallel_size: + log_and_exit("expert_model_parallel_size in args does not match the one in checkpoint, {} vs {}".format( + self.ep_size, state_dict_args.expert_model_parallel_size)) + + if self.pp_size != state_dict_args.pipeline_model_parallel_size: + log_and_exit("pipeline_model_parallel_size in args does not match the one in checkpoint, {} vs {}".format( + self.pp_size, state_dict_args.pipeline_model_parallel_size)) + + if self.first_vpp_layer_split or self.last_vpp_layer_split: + self.uneven_mode = True + + if not (self.first_vpp_layer_split and self.last_vpp_layer_split): + log_and_exit("target_first_virtual_pipeline_num_layers_split and target_last_virtual_pipeline_num_layers_split " + "should be set at the same time for uneven pipeline mode") + + if state_dict_args.decoder_first_pipeline_num_layers != sum(self.first_vpp_layer_split) or \ + state_dict_args.decoder_last_pipeline_num_layers != sum(self.last_vpp_layer_split): + log_and_exit("uneven layer number does not match arguments in state_dict :" + "decoder_first_pipeline_num_layers={}, decoder_last_pipeline_num_layers={}".format( + state_dict_args.decoder_first_pipeline_num_layers, state_dict_args.decoder_last_pipeline_num_layers)) + + num_middle_pipeline_stages = self.pp_size - 2 + if num_middle_pipeline_stages < 0: + log_and_exit("pipeline_model_parallel_size is too small for uneven mode, pipeline_model_parallel_size={}".format( + self.pp_size)) + + if len(self.first_vpp_layer_split) != self.vpp_size or \ + len(self.first_vpp_layer_split) != self.vpp_size: + log_and_exit("length of target_first_virtual_pipeline_num_layers_split and target_last_virtual_pipeline_num_layers_split should " + "equal to target_virtual_pipeline_model_parallel_size") + + self.num_middle_layers = self.num_layers - sum(self.first_vpp_layer_split) \ + - sum(self.last_vpp_layer_split) + if num_middle_pipeline_stages > 0: + if self.num_middle_layers <= 0: + log_and_exit("num_middle_layers can not be non-positve, " + "num_middle_layers={}, num_middle_pipeline_stages={}".format(self.num_middle_layers, num_middle_pipeline_stages)) + if self.num_middle_layers % (num_middle_pipeline_stages*self.vpp_size) != 0: + log_and_exit("num_middle_layers can not be evenly divided by " + "num_middle_pipeline_stages*target_virtual_pipeline_model_parallel_size, " + "num_middle_layers={}, num_middle_pipeline_stages={}".format(self.num_middle_layers, num_middle_pipeline_stages)) + elif self.num_middle_layers > 0: + log_and_exit("insufficient num_middle_pipeline_stages, " + "num_middle_layers={}, num_middle_pipeline_stages={}".format(self.num_middle_layers, num_middle_pipeline_stages)) + + state_dict_args.decoder_first_pipeline_num_layers_split = args.target_first_virtual_pipeline_num_layers_split + state_dict_args.decoder_last_pipeline_num_layers_split = args.target_last_virtual_pipeline_num_layers_split + + else: + self.uneven_mode = False + if self.num_layers % (self.pp_size * self.vpp_size) != 0: + log_and_exit("for even pipeline mode, num_layers can not be evenly divided " + "pipeline_model_parallel_size*target_virtual_pipeline_model_parallel_size, " + "num_layers={}, pipeline_model_parallel_size={}, target_virtual_pipeline_model_parallel_size={}".format( + self.num_layers, self.pp_size, self.vpp_size)) + + num_layers_per_vpp_stage = self.num_layers // (self.pp_size * self.vpp_size) + self.first_vpp_layer_split = [num_layers_per_vpp_stage for _ in range(self.vpp_size)] + self.last_vpp_layer_split = [num_layers_per_vpp_stage for _ in range(self.vpp_size)] \ No newline at end of file diff --git a/tools/checkpoint/pp_tp_vpp/vpp_converter.py b/tools/checkpoint/pp_tp_vpp/vpp_converter.py index 169c5e0f3..a0b81da52 100644 --- a/tools/checkpoint/pp_tp_vpp/vpp_converter.py +++ b/tools/checkpoint/pp_tp_vpp/vpp_converter.py @@ -14,21 +14,23 @@ def _check_output_folder(args): if os.path.exists(output_folder): if not os.path.isdir(output_folder): log_and_exit(f"output path {output_folder} exists but is not a directory") - #if len(os.listdir(output_folder)) > 0: + # if len(os.listdir(output_folder)) > 0: # log_and_exit(f"output path {output_folder} exists but not empty") else: os.makedirs(output_folder) def _convert_checkpoint_partial(args, target_pp_rank, target_ep_rank): logger.debug(f"start _convert_checkpoint_partial, pp_rank={target_pp_rank}, ep_rank={target_ep_rank}") - target_model_state_dict = convert_model_optim_rng(args, target_pp_rank, target_ep_rank) - convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict) + target_model_state_dict, ckpt_ctx = convert_model_optim_rng(args, target_pp_rank, target_ep_rank) + convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict, ckpt_ctx) def _func_arguments_wrapper(func_arguments): args, pp_rank, ep_rank = func_arguments + logger.info("pp_rank={}, ep_rank={} ............".format(pp_rank, ep_rank)) _convert_checkpoint_partial(args, pp_rank, ep_rank) def convert_checkpoint(args): + _check_output_folder(args) if args.pipeline_ranks_to_process is None: @@ -41,8 +43,9 @@ def convert_checkpoint(args): logger.info("pp_ranges : {}".format(pp_ranges)) logger.info("ep_ranges : {}".format(ep_ranges)) - logger.info("start convert...") - logger.debug(f"args.num_max_processing_processes={args.num_max_processing_processes}") + logger.info(f"start convert with {args.num_max_processing_processes} processes...") + with multiprocessing.Pool(processes = args.num_max_processing_processes) as pool: pool.map(_func_arguments_wrapper, func_arguments_tuples) + logger.info("convert finished") \ No newline at end of file From f9afb04661ad25ac9ae38441a5b8fa0373c1ecde Mon Sep 17 00:00:00 2001 From: LI MOU Date: Wed, 4 Jun 2025 03:05:33 +0000 Subject: [PATCH 03/23] support arbitary distribution in layer split --- tools/checkpoint/pp_tp_vpp/distrib_optim.py | 20 ++++++++--- tools/checkpoint/pp_tp_vpp/model_optim_rng.py | 33 ++++++++++--------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/tools/checkpoint/pp_tp_vpp/distrib_optim.py b/tools/checkpoint/pp_tp_vpp/distrib_optim.py index d3fb4aa0c..8dab8514d 100644 --- a/tools/checkpoint/pp_tp_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_tp_vpp/distrib_optim.py @@ -79,8 +79,6 @@ def _fetch_opt_parameters( if end_offset == -1: end_offset = current_offset upper_bound_layer_idx += 1 - assert upper_bound_layer_idx - src_start_layer_idx == num_layers_for_this_virtual_stage, \ - "layer idx double-check failed" opt_parameters_dict = next(iter(state_dict_disopt[0].values())) (param_part, exp_avg_part, exp_avg_sq_part) = ( @@ -88,8 +86,22 @@ def _fetch_opt_parameters( opt_parameters_dict["exp_avg"][-end_offset:(-start_offset if start_offset!=0 else None)], opt_parameters_dict["exp_avg_sq"][-end_offset:(-start_offset if start_offset!=0 else None)], ) - logger.debug(f"_fetch_opt_parameters, src_pp_rank={src_pp_rank}, src_start_layer_idx={src_start_layer_idx}, " - "count_dense={}, count_experts={}, num_elements_in_param_part={}".format(count_dense, count_experts, param_part.nelement())) + + num_layers_remain = num_layers_for_this_virtual_stage - upper_bound_layer_idx + src_start_layer_idx + if num_layers_remain > 0: + (next_level_param_part, next_level_exp_avg_part, next_level_exp_avg_sq_part) = _fetch_opt_parameters( + args, + src_pp_rank + 1, + src_ep_rank, + 0, + num_layers_remain, + count_dense, + count_experts) + + param_part = torch.cat((next_level_param_part, param_part)) + exp_avg_part = torch.cat((next_level_exp_avg_part, exp_avg_part)) + exp_avg_sq_part = torch.cat((next_level_exp_avg_sq_part, exp_avg_sq_part)) + return (param_part, exp_avg_part, exp_avg_sq_part) def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict, ckpt_ctx): diff --git a/tools/checkpoint/pp_tp_vpp/model_optim_rng.py b/tools/checkpoint/pp_tp_vpp/model_optim_rng.py index 8ef4a9042..837139633 100644 --- a/tools/checkpoint/pp_tp_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_tp_vpp/model_optim_rng.py @@ -33,8 +33,9 @@ def _convert_state_dict_args(args, target_pp_rank, target_ep_rank, state_dict, c def _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): # optimizer state dict is equal - # TODO : step - optimizer_states = state_dict["optimizer"] + state_optimizer = state_dict["optimizer"] + optimizer_states = [state_optimizer] if not isinstance(state_optimizer, list) \ + else state_optimizer try: current_step = -1 param_group_cadidates = [] @@ -55,7 +56,7 @@ def _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, state_dict, ck # rng state is equal pass -def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_for_this_virtual_stage): +def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_for_this_virtual_stage, base_layer_idx): src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, src_pp_rank, src_ep_rank)) src_file_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) @@ -71,19 +72,20 @@ def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, continue layer_idx = int(k.split(".layers.")[1].split(".")[0]) if src_start_layer_idx <= layer_idx < src_start_layer_idx+num_layers_for_this_virtual_stage: - new_key = k.replace(f".layers.{layer_idx}", f".layers.{layer_idx-src_start_layer_idx}") + new_key = k.replace(f".layers.{layer_idx}", f".layers.{layer_idx - src_start_layer_idx + base_layer_idx}") outputs[new_key] = v.clone().detach() if torch.is_tensor(v) else v layer_idx_added.add(layer_idx) - - # TODO - # Currently, arbitrary layer partitioning is not supported, so I add a double-check assert here. - # If not supported, an error will be raised at this point. - assert len(layer_idx_added) == num_layers_for_this_virtual_stage, \ - "size of layer_idx_added does not equal to num_layers_for_this_virtual_stage, " \ - "{} vs {} ,".format(len(layer_idx_added), num_layers_for_this_virtual_stage) + \ - "src_pp_rank={}, src_ep_rank={}, src_start_layer_idx={}, num_layers_for_this_virtual_stage={}".format( - src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_for_this_virtual_stage) - + + num_layers_remain = num_layers_for_this_virtual_stage - len(layer_idx_added) + if num_layers_remain > 0: + assert base_layer_idx == 0 + next_level_outputs = _fetch_model_state_dict(args, + src_pp_rank+1, + src_ep_rank, + 0, + num_layers_remain, + len(layer_idx_added)) + outputs.update(next_level_outputs) return outputs def _convert_state_dict_model(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): @@ -107,7 +109,8 @@ def _convert_state_dict_model(args, target_pp_rank, target_ep_rank, state_dict, src_pp_rank, target_ep_rank, src_start_layer_idx, - num_layers_for_this_virtual_stage) + num_layers_for_this_virtual_stage, + 0) vmodel.update(src_model_state_dict) if target_pp_rank == ckpt_ctx.pp_size-1 and vidx == ckpt_ctx.vpp_size-1: From 06c851b699e6d7af0ccb692fbb1ff1769e2bbb33 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Thu, 12 Jun 2025 03:32:51 +0000 Subject: [PATCH 04/23] fix typo, remove .gitignore file in subfolder --- tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/README.md | 0 tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/distrib_optim.py | 0 tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/main.py | 0 tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/model_optim_rng.py | 0 .../{pp_tp_vpp => pp_to_vpp}/run_convert_pp_to_vpp.sh | 0 tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/utils.py | 0 tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/vpp_converter.py | 0 tools/checkpoint/pp_tp_vpp/.gitignore | 2 -- 8 files changed, 2 deletions(-) rename tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/README.md (100%) rename tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/distrib_optim.py (100%) rename tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/main.py (100%) rename tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/model_optim_rng.py (100%) rename tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/run_convert_pp_to_vpp.sh (100%) rename tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/utils.py (100%) rename tools/checkpoint/{pp_tp_vpp => pp_to_vpp}/vpp_converter.py (100%) delete mode 100644 tools/checkpoint/pp_tp_vpp/.gitignore diff --git a/tools/checkpoint/pp_tp_vpp/README.md b/tools/checkpoint/pp_to_vpp/README.md similarity index 100% rename from tools/checkpoint/pp_tp_vpp/README.md rename to tools/checkpoint/pp_to_vpp/README.md diff --git a/tools/checkpoint/pp_tp_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py similarity index 100% rename from tools/checkpoint/pp_tp_vpp/distrib_optim.py rename to tools/checkpoint/pp_to_vpp/distrib_optim.py diff --git a/tools/checkpoint/pp_tp_vpp/main.py b/tools/checkpoint/pp_to_vpp/main.py similarity index 100% rename from tools/checkpoint/pp_tp_vpp/main.py rename to tools/checkpoint/pp_to_vpp/main.py diff --git a/tools/checkpoint/pp_tp_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py similarity index 100% rename from tools/checkpoint/pp_tp_vpp/model_optim_rng.py rename to tools/checkpoint/pp_to_vpp/model_optim_rng.py diff --git a/tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh b/tools/checkpoint/pp_to_vpp/run_convert_pp_to_vpp.sh similarity index 100% rename from tools/checkpoint/pp_tp_vpp/run_convert_pp_to_vpp.sh rename to tools/checkpoint/pp_to_vpp/run_convert_pp_to_vpp.sh diff --git a/tools/checkpoint/pp_tp_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py similarity index 100% rename from tools/checkpoint/pp_tp_vpp/utils.py rename to tools/checkpoint/pp_to_vpp/utils.py diff --git a/tools/checkpoint/pp_tp_vpp/vpp_converter.py b/tools/checkpoint/pp_to_vpp/vpp_converter.py similarity index 100% rename from tools/checkpoint/pp_tp_vpp/vpp_converter.py rename to tools/checkpoint/pp_to_vpp/vpp_converter.py diff --git a/tools/checkpoint/pp_tp_vpp/.gitignore b/tools/checkpoint/pp_tp_vpp/.gitignore deleted file mode 100644 index fffc64de4..000000000 --- a/tools/checkpoint/pp_tp_vpp/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -__pycache__ -*.swp From 46101dec2fe41189cf834295fd72b8cdde3c16bd Mon Sep 17 00:00:00 2001 From: LI MOU Date: Thu, 12 Jun 2025 03:33:49 +0000 Subject: [PATCH 05/23] remove run_convert_pp_to_vpp.sh --- tools/checkpoint/pp_to_vpp/run_convert_pp_to_vpp.sh | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 tools/checkpoint/pp_to_vpp/run_convert_pp_to_vpp.sh diff --git a/tools/checkpoint/pp_to_vpp/run_convert_pp_to_vpp.sh b/tools/checkpoint/pp_to_vpp/run_convert_pp_to_vpp.sh deleted file mode 100644 index 270be5d45..000000000 --- a/tools/checkpoint/pp_to_vpp/run_convert_pp_to_vpp.sh +++ /dev/null @@ -1,12 +0,0 @@ -set -ex - -CURDIR=$(cd $(dirname $0); pwd) -cd $CURDIR - -python main.py \ - --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ - --expert-model-parallel-size 4 \ - --pipeline-model-parallel-size 2 \ - --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ - --target-virtual-pipeline-model-parallel-size 2 \ - --num-max-processing-processes 8 \ No newline at end of file From 0ced57e328e5bb3d03d495a9b434b18af0c28cac Mon Sep 17 00:00:00 2001 From: LI MOU Date: Fri, 13 Jun 2025 08:30:44 +0000 Subject: [PATCH 06/23] fix bug --- tools/checkpoint/pp_to_vpp/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/checkpoint/pp_to_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py index 551286299..f9a5e6b36 100644 --- a/tools/checkpoint/pp_to_vpp/utils.py +++ b/tools/checkpoint/pp_to_vpp/utils.py @@ -143,6 +143,9 @@ def check_args_and_fill(self, args, state_dict): "pipeline_model_parallel_size*target_virtual_pipeline_model_parallel_size, " "num_layers={}, pipeline_model_parallel_size={}, target_virtual_pipeline_model_parallel_size={}".format( self.num_layers, self.pp_size, self.vpp_size)) + + num_layers_per_pp = self.num_layers // self.pp_size + self.num_middle_layers = self.num_layers - num_layers_per_pp * 2 num_layers_per_vpp_stage = self.num_layers // (self.pp_size * self.vpp_size) self.first_vpp_layer_split = [num_layers_per_vpp_stage for _ in range(self.vpp_size)] From 71485f69ab88ee5b916c3a3293daa41ed5579272 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 13:51:36 +0000 Subject: [PATCH 07/23] add unit test for pp_to_vpp convert tool --- tests/unit_tests/test_convert_checkpoint.py | 334 ++++++++++++++++++++ 1 file changed, 334 insertions(+) create mode 100644 tests/unit_tests/test_convert_checkpoint.py diff --git a/tests/unit_tests/test_convert_checkpoint.py b/tests/unit_tests/test_convert_checkpoint.py new file mode 100644 index 000000000..b58f3b2d8 --- /dev/null +++ b/tests/unit_tests/test_convert_checkpoint.py @@ -0,0 +1,334 @@ +import os +import sys +import subprocess +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +import torch.distributed as dist + +from tests.unit_tests.test_utilities import Utils +from tests.unit_tests.dist_checkpointing import TempNamedDir + +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer +from megatron.training.checkpointing import load_checkpoint, save_checkpoint +from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator +from megatron.training.global_vars import set_args +from megatron.training.utils import print_rank_0 + +CURDIR = os.path.dirname(os.path.abspath(__file__)) + +def create_args( + num_layers, + hidden_size, + num_attn_heads, + pipeline_parallel_size, + ckpt_dir): + args = SimpleNamespace() + args.finetune = False + args.non_persistent_global_ckpt_dir = None + args.non_persistent_ckpt_type = None + args.non_persistent_save_interval = None + args.exit_on_missing_checkpoint = True + args.async_save = False + args.data_parallel_random_init = False + args.no_save_optim = False + args.no_save_rng = False + args.no_load_optim = False + args.no_load_rng = False + args.log_progress = False + args.ckpt_fully_parallel_save = False + args.auto_detect_ckpt_format = False + args.retro_add_retriever = False + args.ckpt_convert_update_legacy_dist_opt_format = False + args.ckpt_step = None + args.use_distributed_optimizer = True + args.use_dist_ckpt = False + args.consumed_train_samples = 0 + args.skipped_train_samples = 0 + args.consumed_valid_samples = 0 + args.add_position_embedding = False + args.vocab_file = None + args.tensor_model_parallel_size = 1 + args.ckpt_format = "torch" + args.ckpt_isolated_save = True + args.local_rank = int(os.environ["LOCAL_RANK"]) + args.ckpt_upload_blob_path = None + args.perform_initialization = True + args.num_virtual_stages_per_pipeline_rank = None + args.num_layers = num_layers + args.hidden_size = hidden_size + args.num_attention_heads = num_attn_heads + args.pipeline_model_parallel_size = pipeline_parallel_size + args.normalization = "RMSNorm" + args.transformer_impl = "transformer_engine" + args.expert_model_parallel_size = 1 + args.save = ckpt_dir / "save_dir" + args.load = ckpt_dir / "load_dir" + + return args + +def get_checkpoint_content(args): + + def model_provider(args): + transformer_config = TransformerConfig( + add_bias_linear = False, + params_dtype = torch.bfloat16, + pipeline_dtype = torch.bfloat16, + normalization = args.normalization, + num_layers = args.num_layers, + hidden_size = args.hidden_size, + num_attention_heads = args.num_attention_heads, + tensor_model_parallel_size = args.tensor_model_parallel_size, + pipeline_model_parallel_size = args.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size = args.num_virtual_stages_per_pipeline_rank, + perform_initialization = args.perform_initialization) + + def get_model(): + transformer_layer_spec = get_gpt_decoder_block_spec( + transformer_config, + use_transformer_engine = args.transformer_impl == "transformer_engine") + return GPTModel( + config = transformer_config, + transformer_layer_spec = transformer_layer_spec, + position_embedding_type = "rope", + vocab_size = 32, + max_sequence_length = 32, + pre_process = mpu.is_pipeline_first_stage(), + post_process = mpu.is_pipeline_last_stage()) + + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) + model = [] + if args.num_virtual_stages_per_pipeline_rank \ + and args.num_virtual_stages_per_pipeline_rank > 1: + model = [] + for i in range(args.num_virtual_stages_per_pipeline_rank): + mpu.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + this_model = DistributedDataParallel(transformer_config, + ddp_config, get_model()) + model.append(this_model) + else: + model.append(DistributedDataParallel(transformer_config, + ddp_config, get_model())) + return model + + model = model_provider(args) + + optimizer_config = OptimizerConfig( + optimizer='adam', + bf16=True, + use_distributed_optimizer=True, + params_dtype = torch.bfloat16, + lr = 1e-6, + min_lr = 1e-9) + optimizer = get_megatron_optimizer(optimizer_config, model) + optimizer.step() + + class MockState: + def __init__(self, state_dict): + self._state_dict = state_dict + self.is_stub_optimizer = False + + def state_dict(self, is_loading=False): + return self._state_dict + + def load_state_dict(self, state_dict): + self._state_dict = state_dict + + def save_parameter_state(self, *args, **kwargs): + pass + + def load_parameter_state(self, *args, **kwargs): + pass + + opt_scheduler = MockState({"opt_param_scheduler": "scheduler_state"}) + + return (model, optimizer, opt_scheduler) + +def reset_parallel_state(args): + Utils.initialize_model_parallel( + tensor_model_parallel_size = args.tensor_model_parallel_size, + pipeline_model_parallel_size = args.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size = args.num_virtual_stages_per_pipeline_rank) + model_parallel_cuda_manual_seed(123) + +def get_global_state(args, model, optimizer): + + def get_global_layer_index(num_layers, + pp_size, + vpp_size, + current_pp_rank, + current_vpp_rank, + current_local_layer_index): + num_layers_per_pipeline_rank = num_layers // pp_size + if vpp_size is None or vpp_size == 1: + return current_pp_rank * num_layers_per_pipeline_rank + current_local_layer_index + num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vpp_size + total_virtual_chunks = num_layers // vpp_size + return current_vpp_rank * total_virtual_chunks + ( + current_pp_rank * num_layers_per_virtual_rank) + + def to_cpu(x): + if torch.is_tensor(x): + return x.to("cpu") + for k in x: + x[k] = to_cpu(x[k]) + return x + + current_pp_rank = mpu.get_pipeline_model_parallel_rank() + + global_model_state = {} + global_optimizer_state = {} + for vpp_idx, model_chunk in enumerate(model): + for name, param in model_chunk.named_parameters(): + key = name + if ".layers." in key: + layer_idx = int(key.split(".layers.")[1].split(".")[0]) + global_layer_idx = get_global_layer_index( + args.num_layers, + args.pipeline_model_parallel_size, + args.num_virtual_stages_per_pipeline_rank, + current_pp_rank, + vpp_idx, + layer_idx) + key = key.replace(f".layers.{layer_idx}", f".layers.{global_layer_idx}") + optimizer_param = optimizer.chained_optimizers[0]._get_main_param_and_optimizer_states(param) + global_model_state[key] = to_cpu(param) + global_optimizer_state[key] = to_cpu(optimizer_param) + + return (global_model_state, global_optimizer_state) + +def merge_state_dict_to_pipeline_rank0(local_dict): + pipeline_group = mpu.get_pipeline_model_parallel_group() + rank_in_group = dist.get_rank(group=pipeline_group) + world_size = dist.get_world_size(group=pipeline_group) + + gathered_list = [None for _ in range(world_size)] if rank_in_group == 0 else None + dist.gather_object(local_dict, gathered_list, dst=0, group=pipeline_group) + + if rank_in_group == 0: + merged_dict = {} + for d in gathered_list: + merged_dict.update(d) + return merged_dict + return None + +def is_state_dict_equal(x, y): + if x.keys() != y.keys(): + return False + for k in x: + if isinstance(x[k], dict): + assert isinstance(y[k], dict) + if not is_state_dict_equal(x[k], y[k]): + return False + else: + assert torch.is_tensor(x[k]) + assert torch.is_tensor(y[k]) + if not torch.equal(x[k], y[k]): + return False + return True + +def _test_convert_pp_to_vpp_internal(ckpt_dir : Path): + rank = dist.get_rank() + world_size = dist.get_world_size() + assert world_size == 8, "current only support world_size=8" + init_num_microbatches_calculator(rank, None, 1, 1, 1) + + args = create_args( + num_layers=16, + hidden_size=64, + num_attn_heads=4, + pipeline_parallel_size=8, + ckpt_dir=ckpt_dir) + + set_args(args) + + reset_parallel_state(args) + + model, optimizer, opt_scheduler = get_checkpoint_content(args) + + # save model with virtual_pipeline_size=1 + iteration = 123 + flops = 456 + print_rank_0("saving checkpoint with virtual_pipeline_size=1...") + save_checkpoint(iteration, model, optimizer, opt_scheduler, flops) + + + global_model_state, global_optimizer_state = get_global_state(args, model, optimizer) + global_model_state = merge_state_dict_to_pipeline_rank0(global_model_state) + global_optimizer_state = merge_state_dict_to_pipeline_rank0(global_optimizer_state) + + if rank == 0: + # convert model, increase virtual_pipeline_size to 2 + command = ( + "mkdir -p {}/iter_{:07d} ".format(args.load, iteration) + + "&& echo {} > {}/latest_checkpointed_iteration.txt ".format(iteration, args.load) + + "&& python {}/../../tools/checkpoint/pp_to_vpp/main.py ".format(CURDIR) + + "--load-iteration-dir {}/iter_{:07d} ".format(args.save, iteration) + + "--expert-model-parallel-size 1 " + + "--pipeline-model-parallel-size 8 " + + "--save-iteration-dir {}/iter_{:07d} ".format(args.load, iteration) + + "--target-virtual-pipeline-model-parallel-size 2 " + + "--num-max-processing-processes 2 " + ) + print_rank_0("converting checkpoint from virtual_pipeline_size from 1 to 2") + subprocess_result = subprocess.run( + command, + shell = True, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE, + text = True) + assert subprocess_result.returncode == 0 + + dist.barrier() + + # change virtual_pipeline_size to 2 and load the model converted + args.num_virtual_stages_per_pipeline_rank = 2 + args.perform_initialization = False + reset_parallel_state(args) + + new_model, new_optimizer, new_opt_scheduler = get_checkpoint_content(args) + + print_rank_0("loading checkpoint with virtual_pipeline_size=2") + loaded_iter, loaded_flops = load_checkpoint( + new_model, new_optimizer, new_opt_scheduler, strict=True + ) + + # check iteration and flops are equal + assert loaded_iter == iteration and loaded_flops == flops + + new_global_model_state, new_global_optimizer_state = get_global_state(args, new_model, new_optimizer) + new_global_model_state = merge_state_dict_to_pipeline_rank0(new_global_model_state) + new_global_optimizer_state = merge_state_dict_to_pipeline_rank0(new_global_optimizer_state) + + if mpu.get_pipeline_model_parallel_rank() == 0: + # check model_state and optimizer parameter state are equal + global_model_state_equal = is_state_dict_equal(global_model_state, new_global_model_state) + assert global_model_state_equal + global_optimizer_state_equal = is_state_dict_equal(global_optimizer_state, new_global_optimizer_state) + assert global_optimizer_state_equal + + +""" +launch test with command: + +torchrun \ + --nproc_per_node 8 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 50326 \ + -m pytest -vx test_convert_checkpoint.py +""" +def test_convert_pp_to_vpp(tmp_path_dist_ckpt): + Utils.initialize_distributed() + with TempNamedDir(tmp_path_dist_ckpt / "test_convert_checkpoint", sync=True) as ckpt_dir: + _test_convert_pp_to_vpp_internal(ckpt_dir) \ No newline at end of file From 959254291af8a2c1a8033b1e4b951d920be167f3 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:00:21 +0000 Subject: [PATCH 08/23] delete _parse_list --- tools/checkpoint/pp_to_vpp/README.md | 6 +++--- tools/checkpoint/pp_to_vpp/main.py | 10 +++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/README.md b/tools/checkpoint/pp_to_vpp/README.md index 57e08d5ff..d64ff2380 100644 --- a/tools/checkpoint/pp_to_vpp/README.md +++ b/tools/checkpoint/pp_to_vpp/README.md @@ -77,7 +77,7 @@ options: bandwidth), it will also consume more CPU memory. --pipeline-ranks-to-process PIPELINE_RANKS_TO_PROCESS pipeline rank list to process using this script, to accelerate converting user can launch multiple tasks on different nodes, each one process part of pipeline - ranks. example : --pipeline-ranks-to-process [0,1,2,3] default is None, means process all pipeline ranks + ranks. example : --pipeline-ranks-to-process 0 1 2 3 default is None, means process all pipeline ranks ``` ## examples @@ -102,7 +102,7 @@ python main.py \ --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ --target-virtual-pipeline-model-parallel-size 2 \ --num-max-processing-processes 4 \ - --pipeline-ranks-to-process [0,1,2,3] + --pipeline-ranks-to-process 0 1 2 3 # node2: python main.py \ @@ -112,7 +112,7 @@ python main.py \ --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ --target-virtual-pipeline-model-parallel-size 2 \ --num-max-processing-processes 4 \ - --pipeline-ranks-to-process [4,5,6,7] + --pipeline-ranks-to-process 4 5 6 7 ``` 3) convert a model with uneven pipeline mode, which was saved by Megatron-LM with arguments diff --git a/tools/checkpoint/pp_to_vpp/main.py b/tools/checkpoint/pp_to_vpp/main.py index 8cb35d02c..8e02b65ae 100644 --- a/tools/checkpoint/pp_to_vpp/main.py +++ b/tools/checkpoint/pp_to_vpp/main.py @@ -7,11 +7,6 @@ logger = logging.getLogger(__name__) -def _parse_list(s): - if s is None: - return None - return ast.literal_eval(s) - def parse_arguments(): parser = argparse.ArgumentParser(description="convert a non-virtual pipeline checkpoint to virtual pipeline checkpoint") parser.add_argument("--load-iteration-dir", type=str, required=True, help="iteration folder of source model checkpoint") @@ -30,10 +25,10 @@ def parse_arguments(): parser.add_argument("--num-max-processing-processes", type=int, default=8, help="the maximum number of processing processes used by this script, " \ "increasing this value can speed up model conversion(but the final bottleneck may be disk bandwidth), it will also consume more CPU memory.") - parser.add_argument('--pipeline-ranks-to-process', type=_parse_list, default=None, + parser.add_argument('--pipeline-ranks-to-process', type=int, nargs="+", default=None, help="pipeline rank list to process using this script, to accelerate converting \ user can launch multiple tasks on different nodes, each one process part of pipeline ranks. \ - example : --pipeline-ranks-to-process [0,1,2,3] \ + example : --pipeline-ranks-to-process 0 1 2 3 \ default is None, means process all pipeline ranks") args = parser.parse_args() @@ -77,6 +72,7 @@ def parse_arguments(): """ def main(args): + logger.info("args : {}".format(args)) convert_checkpoint(args) if __name__ == "__main__": From a8ec492a6d7eb681177e5c4ac23e03d54315228c Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:01:57 +0000 Subject: [PATCH 09/23] add newline for each .py file --- tools/checkpoint/pp_to_vpp/distrib_optim.py | 2 +- tools/checkpoint/pp_to_vpp/model_optim_rng.py | 2 +- tools/checkpoint/pp_to_vpp/utils.py | 2 +- tools/checkpoint/pp_to_vpp/vpp_converter.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py index 8dab8514d..cf16d9647 100644 --- a/tools/checkpoint/pp_to_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_to_vpp/distrib_optim.py @@ -255,4 +255,4 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta if len(target_disopt_state_dicts) == 1: target_disopt_state_dicts = target_disopt_state_dicts[0] logger.info(f"saving distrib_optim to {target_file_path} ...") - torch.save(target_disopt_state_dicts, target_file_path) \ No newline at end of file + torch.save(target_disopt_state_dicts, target_file_path) diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py index 837139633..49ccb0cfd 100644 --- a/tools/checkpoint/pp_to_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -153,4 +153,4 @@ def convert_model_optim_rng(args, target_pp_rank, target_ep_rank): logger.info(f"saving model_optim_rng to {target_file_path} ...") torch.save(target_state_dict, target_file_path) - return (target_state_dict, ckpt_ctx) \ No newline at end of file + return (target_state_dict, ckpt_ctx) diff --git a/tools/checkpoint/pp_to_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py index f9a5e6b36..f8f329170 100644 --- a/tools/checkpoint/pp_to_vpp/utils.py +++ b/tools/checkpoint/pp_to_vpp/utils.py @@ -149,4 +149,4 @@ def check_args_and_fill(self, args, state_dict): num_layers_per_vpp_stage = self.num_layers // (self.pp_size * self.vpp_size) self.first_vpp_layer_split = [num_layers_per_vpp_stage for _ in range(self.vpp_size)] - self.last_vpp_layer_split = [num_layers_per_vpp_stage for _ in range(self.vpp_size)] \ No newline at end of file + self.last_vpp_layer_split = [num_layers_per_vpp_stage for _ in range(self.vpp_size)] diff --git a/tools/checkpoint/pp_to_vpp/vpp_converter.py b/tools/checkpoint/pp_to_vpp/vpp_converter.py index a0b81da52..b7e2d23b6 100644 --- a/tools/checkpoint/pp_to_vpp/vpp_converter.py +++ b/tools/checkpoint/pp_to_vpp/vpp_converter.py @@ -48,4 +48,4 @@ def convert_checkpoint(args): with multiprocessing.Pool(processes = args.num_max_processing_processes) as pool: pool.map(_func_arguments_wrapper, func_arguments_tuples) - logger.info("convert finished") \ No newline at end of file + logger.info("convert finished") From 0778492b6ec51ec7033d40ddbe62507cca2c878a Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:06:03 +0000 Subject: [PATCH 10/23] rename to parallel_convert --- tools/checkpoint/pp_to_vpp/main.py | 2 +- .../pp_to_vpp/{vpp_converter.py => parallel_convert.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tools/checkpoint/pp_to_vpp/{vpp_converter.py => parallel_convert.py} (100%) diff --git a/tools/checkpoint/pp_to_vpp/main.py b/tools/checkpoint/pp_to_vpp/main.py index 8e02b65ae..ce635b983 100644 --- a/tools/checkpoint/pp_to_vpp/main.py +++ b/tools/checkpoint/pp_to_vpp/main.py @@ -3,7 +3,7 @@ import sys import logging import argparse -from vpp_converter import convert_checkpoint +from parallel_convert import convert_checkpoint logger = logging.getLogger(__name__) diff --git a/tools/checkpoint/pp_to_vpp/vpp_converter.py b/tools/checkpoint/pp_to_vpp/parallel_convert.py similarity index 100% rename from tools/checkpoint/pp_to_vpp/vpp_converter.py rename to tools/checkpoint/pp_to_vpp/parallel_convert.py From c622a19c866121c1df83ab00d300a3180c4456e2 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:11:20 +0000 Subject: [PATCH 11/23] remove uneeded code --- tools/checkpoint/pp_to_vpp/main.py | 2 +- tools/checkpoint/pp_to_vpp/model_optim_rng.py | 2 +- tools/checkpoint/pp_to_vpp/parallel_convert.py | 2 -- tools/checkpoint/pp_to_vpp/utils.py | 4 ++-- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/main.py b/tools/checkpoint/pp_to_vpp/main.py index ce635b983..771407ad6 100644 --- a/tools/checkpoint/pp_to_vpp/main.py +++ b/tools/checkpoint/pp_to_vpp/main.py @@ -77,7 +77,7 @@ def main(args): if __name__ == "__main__": logging.basicConfig( - level = logging.DEBUG, + level = logging.INFO, format = "[%(asctime)s][%(levelname)s] %(message)s", handlers = [logging.StreamHandler()]) args = parse_arguments() diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py index 49ccb0cfd..45e425bdf 100644 --- a/tools/checkpoint/pp_to_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -60,7 +60,7 @@ def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, src_pp_rank, src_ep_rank)) src_file_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) - #logger.debug(f"loading {src_file_path} to fetch source tensors in virtual stage...") + logger.debug(f"loading {src_file_path} to fetch source tensors in virtual stage...") with RetainLogLevel(): state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) state_dict_model = state_dict["model"] diff --git a/tools/checkpoint/pp_to_vpp/parallel_convert.py b/tools/checkpoint/pp_to_vpp/parallel_convert.py index b7e2d23b6..4278013a9 100644 --- a/tools/checkpoint/pp_to_vpp/parallel_convert.py +++ b/tools/checkpoint/pp_to_vpp/parallel_convert.py @@ -14,8 +14,6 @@ def _check_output_folder(args): if os.path.exists(output_folder): if not os.path.isdir(output_folder): log_and_exit(f"output path {output_folder} exists but is not a directory") - # if len(os.listdir(output_folder)) > 0: - # log_and_exit(f"output path {output_folder} exists but not empty") else: os.makedirs(output_folder) diff --git a/tools/checkpoint/pp_to_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py index f8f329170..1be88f93c 100644 --- a/tools/checkpoint/pp_to_vpp/utils.py +++ b/tools/checkpoint/pp_to_vpp/utils.py @@ -19,7 +19,6 @@ def __exit__(self, exc_type, exc_value, traceback): def log_and_exit(message): logger.fatal(message) raise Exception("exit with fatal error") - #sys.exit(1) def get_folder_name(args, target_pp_rank, target_ep_rank): folder_name = "mp_rank_00" @@ -55,7 +54,8 @@ def get_vpp_source_position( f"source_pp_rank={pp_stage}, source_start_layer_idx={source_start_layer_idx}") return (pp_stage, source_start_layer_idx) prefix_sum += layers_in_current_stage - log_and_exit("should never reach here") + + log_and_exit("double check failed, should never reach here") # def get_num_layers_for_this_vstage(pp_rank, vpp_rank, ckpt_ctx): From 8d65d7a8ee179a3ad9a8778362cb4a29265b7454 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:19:15 +0000 Subject: [PATCH 12/23] remove check_and_fill_args in CKPTContext, use __init__ instead --- tools/checkpoint/pp_to_vpp/model_optim_rng.py | 3 +-- tools/checkpoint/pp_to_vpp/utils.py | 6 +----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py index 45e425bdf..0f7684655 100644 --- a/tools/checkpoint/pp_to_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -134,8 +134,7 @@ def convert_model_optim_rng(args, target_pp_rank, target_ep_rank): if target_pp_rank==0 and target_ep_rank<=1: logger.info("[pp_rank=0][ep_rank={}] keys of model state dict : {}\n".format(target_ep_rank, target_state_dict.keys())) - ckpt_ctx = CKPTContext() - ckpt_ctx.check_args_and_fill(args, target_state_dict) + ckpt_ctx = CKPTContext(args, target_state_dict) _convert_state_dict_args(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) diff --git a/tools/checkpoint/pp_to_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py index 1be88f93c..b8fdf429f 100644 --- a/tools/checkpoint/pp_to_vpp/utils.py +++ b/tools/checkpoint/pp_to_vpp/utils.py @@ -71,7 +71,7 @@ def get_num_layers_for_this_vstage(pp_rank, vpp_rank, ckpt_ctx): class CKPTContext: - def check_args_and_fill(self, args, state_dict): + def __init__(self, args, state_dict): self.vpp_size = args.target_virtual_pipeline_model_parallel_size self.pp_size = args.pipeline_model_parallel_size self.ep_size = args.expert_model_parallel_size @@ -132,10 +132,6 @@ def check_args_and_fill(self, args, state_dict): elif self.num_middle_layers > 0: log_and_exit("insufficient num_middle_pipeline_stages, " "num_middle_layers={}, num_middle_pipeline_stages={}".format(self.num_middle_layers, num_middle_pipeline_stages)) - - state_dict_args.decoder_first_pipeline_num_layers_split = args.target_first_virtual_pipeline_num_layers_split - state_dict_args.decoder_last_pipeline_num_layers_split = args.target_last_virtual_pipeline_num_layers_split - else: self.uneven_mode = False if self.num_layers % (self.pp_size * self.vpp_size) != 0: From a9498965cca02835d130f5815d121711bc6ed1d4 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:21:25 +0000 Subject: [PATCH 13/23] rename CKPTContext to TargetCkptContext --- tools/checkpoint/pp_to_vpp/model_optim_rng.py | 4 ++-- tools/checkpoint/pp_to_vpp/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py index 0f7684655..ad9262c11 100644 --- a/tools/checkpoint/pp_to_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -8,7 +8,7 @@ get_folder_name, get_vpp_source_position, MODEL_OPTIM_RNG_FILENAME, - CKPTContext, + TargetCkptContext, get_num_layers_for_this_vstage, ) @@ -134,7 +134,7 @@ def convert_model_optim_rng(args, target_pp_rank, target_ep_rank): if target_pp_rank==0 and target_ep_rank<=1: logger.info("[pp_rank=0][ep_rank={}] keys of model state dict : {}\n".format(target_ep_rank, target_state_dict.keys())) - ckpt_ctx = CKPTContext(args, target_state_dict) + ckpt_ctx = TargetCkptContext(args, target_state_dict) _convert_state_dict_args(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) diff --git a/tools/checkpoint/pp_to_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py index b8fdf429f..cced781d4 100644 --- a/tools/checkpoint/pp_to_vpp/utils.py +++ b/tools/checkpoint/pp_to_vpp/utils.py @@ -70,7 +70,7 @@ def get_num_layers_for_this_vstage(pp_rank, vpp_rank, ckpt_ctx): return ckpt_ctx.num_middle_layers // (num_middle_stages * ckpt_ctx.vpp_size) -class CKPTContext: +class TargetCkptContext: def __init__(self, args, state_dict): self.vpp_size = args.target_virtual_pipeline_model_parallel_size self.pp_size = args.pipeline_model_parallel_size From 794cb5cdcd2b99178b89fc4bc2549fcb953d1b4a Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:23:11 +0000 Subject: [PATCH 14/23] fix typo on param_group_candidates --- tools/checkpoint/pp_to_vpp/model_optim_rng.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py index ad9262c11..e914a3d0f 100644 --- a/tools/checkpoint/pp_to_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -38,15 +38,15 @@ def _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, state_di else state_optimizer try: current_step = -1 - param_group_cadidates = [] + param_group_candidates = [] for opt_state in optimizer_states: for param_group in opt_state["optimizer"]["param_groups"]: if "step" in param_group: current_step = param_group["step"] else: - param_group_cadidates.append(param_group) + param_group_candidates.append(param_group) if current_step != -1: - for param_group in param_group_cadidates: + for param_group in param_group_candidates: param_group["step"] = current_step logger.info(f"add step={current_step} in optimizer state") except Exception: From 9d42fa429bde7c5afb2f341494235b1bdc22e92d Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:27:03 +0000 Subject: [PATCH 15/23] remove unused arguments in model_optim_rng.py --- tools/checkpoint/pp_to_vpp/model_optim_rng.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py index e914a3d0f..dca57efb0 100644 --- a/tools/checkpoint/pp_to_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -def _convert_state_dict_args(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): +def _convert_state_dict_args(state_dict, ckpt_ctx): state_dict_args = state_dict["args"] if ckpt_ctx.uneven_mode: state_dict_args.decoder_first_pipeline_num_layers_split = ckpt_ctx.first_vpp_layer_split @@ -31,7 +31,7 @@ def _convert_state_dict_args(args, target_pp_rank, target_ep_rank, state_dict, c if hasattr(state_dict_args, "rank"): delattr(state_dict_args, "rank") -def _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): +def _convert_state_dict_optimizer(state_dict): # optimizer state dict is equal state_optimizer = state_dict["optimizer"] optimizer_states = [state_optimizer] if not isinstance(state_optimizer, list) \ @@ -52,7 +52,7 @@ def _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, state_di except Exception: logger.warning("add step to optimizer state failed") -def _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): +def _convert_state_dict_rng(): # rng state is equal pass @@ -136,11 +136,11 @@ def convert_model_optim_rng(args, target_pp_rank, target_ep_rank): ckpt_ctx = TargetCkptContext(args, target_state_dict) - _convert_state_dict_args(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) + _convert_state_dict_args(target_state_dict, ckpt_ctx) - _convert_state_dict_optimizer(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) + _convert_state_dict_optimizer(target_state_dict) - _convert_state_dict_rng(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) + _convert_state_dict_rng() _convert_state_dict_model(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) From 60c1a0c0377369565a9a47bcf5cab49ccbeaed50 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:30:08 +0000 Subject: [PATCH 16/23] add comment for double check assert --- tools/checkpoint/pp_to_vpp/distrib_optim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/checkpoint/pp_to_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py index cf16d9647..48bab421d 100644 --- a/tools/checkpoint/pp_to_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_to_vpp/distrib_optim.py @@ -62,7 +62,8 @@ def _fetch_opt_parameters( continue layer_idx = int(k.split(".layers.")[1].split(".")[0]) - assert layer_idx >= upper_bound_layer_idx, "layer_idx stored unordered in checkpoint" + assert layer_idx >= upper_bound_layer_idx, \ + "double check failed,layer_idx stored unordered in checkpoint" upper_bound_layer_idx = layer_idx if layer_idx == src_start_layer_idx and start_offset == -1: From 248ffbbe67c8b8baf33b7b73d650ad8bd1428189 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:36:34 +0000 Subject: [PATCH 17/23] fix for loop which does not change dict value --- tools/checkpoint/pp_to_vpp/distrib_optim.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py index 48bab421d..1fbf5d3a5 100644 --- a/tools/checkpoint/pp_to_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_to_vpp/distrib_optim.py @@ -237,8 +237,11 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta new_parameters_dict["numel_unpadded"] = num_elements_in_vmodel - for v in new_parameters_dict.values(): - v = v.clone().detach() if torch.is_tensor(v) else v + for k, v in new_parameters_dict.items(): + if torch.is_tensor(v): + # .clone.detach() is needed for removing unneeded data + # and decrease checkpoint file size + new_parameters_dict[k] = v.clone().detach() if num_elements_in_vmodel != 0: vdisopts.append({type_key : new_parameters_dict}) From a335a5294d9cf1ceb092cfe623946ebb1fd66133 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 14:39:39 +0000 Subject: [PATCH 18/23] add comment for why use recursion instead of iteration --- tools/checkpoint/pp_to_vpp/distrib_optim.py | 2 ++ tools/checkpoint/pp_to_vpp/model_optim_rng.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tools/checkpoint/pp_to_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py index 1fbf5d3a5..d0ae6469a 100644 --- a/tools/checkpoint/pp_to_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_to_vpp/distrib_optim.py @@ -90,6 +90,8 @@ def _fetch_opt_parameters( num_layers_remain = num_layers_for_this_virtual_stage - upper_bound_layer_idx + src_start_layer_idx if num_layers_remain > 0: + # The source parameters spans two pipeline stages + # and this recursion is executed at most once. (next_level_param_part, next_level_exp_avg_part, next_level_exp_avg_sq_part) = _fetch_opt_parameters( args, src_pp_rank + 1, diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py index dca57efb0..c20227a68 100644 --- a/tools/checkpoint/pp_to_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -78,6 +78,8 @@ def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_remain = num_layers_for_this_virtual_stage - len(layer_idx_added) if num_layers_remain > 0: + # The source tensor state_dict spans two pipeline stages + # and this recursion is executed at most once. assert base_layer_idx == 0 next_level_outputs = _fetch_model_state_dict(args, src_pp_rank+1, From a8a72006a30e7d8b467a1091982c96768d87c975 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 15:01:56 +0000 Subject: [PATCH 19/23] extract tensor concat as a function --- tools/checkpoint/pp_to_vpp/distrib_optim.py | 48 +++++++++++++-------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py index d0ae6469a..ea945cb77 100644 --- a/tools/checkpoint/pp_to_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_to_vpp/distrib_optim.py @@ -150,7 +150,19 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta if target_pp_rank==0 and target_ep_rank<=1: logger.info("[pp_rank=0][ep_rank={}] keys of opt_parameters_dict[{}] : {}".format( target_ep_rank, i, opt_parameters_dict.keys())) - + + def concat_parameter(new_parameters_dict, + param_master, + param_exp_avg, + param_exp_avg_sq, + slice_range): + new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], + param_master[slice_range])) + new_parameters_dict["exp_avg"] = torch.cat((new_parameters_dict["exp_avg"], + param_exp_avg[slice_range])) + new_parameters_dict["exp_avg_sq"] = torch.cat((new_parameters_dict["exp_avg_sq"], + param_exp_avg_sq[slice_range])) + vdisopts = [] for vidx in range(ckpt_ctx.vpp_size): """ @@ -189,12 +201,12 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta logger.debug(f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] apply output_layer, " f"num_elements_tail={num_elements_tail}") assert num_elements_tail != 0, f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] num_elements_tail is 0" - new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], - opt_parameters_dict["param"][:num_elements_tail])) - new_parameters_dict["exp_avg"] = torch.cat((new_parameters_dict["exp_avg"], - opt_parameters_dict["exp_avg"][:num_elements_tail])) - new_parameters_dict["exp_avg_sq"] = torch.cat((new_parameters_dict["exp_avg_sq"], - opt_parameters_dict["exp_avg_sq"][:num_elements_tail])) + + concat_parameter(new_parameters_dict, + opt_parameters_dict["param"], + opt_parameters_dict["exp_avg"], + opt_parameters_dict["exp_avg_sq"], + slice(None, num_elements_tail)) num_elements_in_vmodel += num_elements_tail # middle layer @@ -211,9 +223,12 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta num_layers_for_this_virtual_stage, i==0, i>0 or len(target_disopt_state_dicts)==1) - new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], param_part)) - new_parameters_dict["exp_avg"] = torch.cat((new_parameters_dict["exp_avg"], exp_avg_part)) - new_parameters_dict["exp_avg_sq"] = torch.cat((new_parameters_dict["exp_avg_sq"], exp_avg_sq_part)) + + concat_parameter(new_parameters_dict, + param_part, + exp_avg_part, + exp_avg_sq_part, + slice(None, None)) num_elements_in_vmodel += param_part.nelement() # embedding layer @@ -228,13 +243,12 @@ def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_sta logger.debug(f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] apply embedding, " f"num_elements_head={num_elements_head}") assert num_elements_head != 0, f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] num_elements_head is 0" - - new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], - opt_parameters_dict["param"][-num_elements_head:])) - new_parameters_dict["exp_avg"] = torch.cat((new_parameters_dict["exp_avg"], - opt_parameters_dict["exp_avg"][-num_elements_head:])) - new_parameters_dict["exp_avg_sq"] = torch.cat((new_parameters_dict["exp_avg_sq"], - opt_parameters_dict["exp_avg_sq"][-num_elements_head:])) + + concat_parameter(new_parameters_dict, + opt_parameters_dict["param"], + opt_parameters_dict["exp_avg"], + opt_parameters_dict["exp_avg_sq"], + slice(-num_elements_head, None)) num_elements_in_vmodel += num_elements_head new_parameters_dict["numel_unpadded"] = num_elements_in_vmodel From 8209ca10e0bc8121340f4fa5bdab5d03657ab152 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 15:05:12 +0000 Subject: [PATCH 20/23] change get_num_layers_for_this_vstage to get_num_layers_for_this_vpp_stage --- tools/checkpoint/pp_to_vpp/distrib_optim.py | 4 ++-- tools/checkpoint/pp_to_vpp/model_optim_rng.py | 4 ++-- tools/checkpoint/pp_to_vpp/utils.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py index ea945cb77..711ccf1be 100644 --- a/tools/checkpoint/pp_to_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_to_vpp/distrib_optim.py @@ -8,7 +8,7 @@ get_vpp_source_position, MODEL_OPTIM_RNG_FILENAME, DISTRIB_OPTIM_FILENAME, - get_num_layers_for_this_vstage, + get_num_layers_for_this_vpp_stage, ) import torch @@ -214,7 +214,7 @@ def concat_parameter(new_parameters_dict, target_pp_rank, vidx, ckpt_ctx) - num_layers_for_this_virtual_stage = get_num_layers_for_this_vstage(target_pp_rank, vidx, ckpt_ctx) + num_layers_for_this_virtual_stage = get_num_layers_for_this_vpp_stage(target_pp_rank, vidx, ckpt_ctx) param_part, exp_avg_part, exp_avg_sq_part = _fetch_opt_parameters( args, src_pp_rank, diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py index c20227a68..76bf88c69 100644 --- a/tools/checkpoint/pp_to_vpp/model_optim_rng.py +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -9,7 +9,7 @@ get_vpp_source_position, MODEL_OPTIM_RNG_FILENAME, TargetCkptContext, - get_num_layers_for_this_vstage, + get_num_layers_for_this_vpp_stage, ) import torch @@ -105,7 +105,7 @@ def _convert_state_dict_model(args, target_pp_rank, target_ep_rank, state_dict, vidx, ckpt_ctx) - num_layers_for_this_virtual_stage = get_num_layers_for_this_vstage(target_pp_rank, vidx, ckpt_ctx) + num_layers_for_this_virtual_stage = get_num_layers_for_this_vpp_stage(target_pp_rank, vidx, ckpt_ctx) src_model_state_dict = _fetch_model_state_dict( args, src_pp_rank, diff --git a/tools/checkpoint/pp_to_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py index cced781d4..28f2c1985 100644 --- a/tools/checkpoint/pp_to_vpp/utils.py +++ b/tools/checkpoint/pp_to_vpp/utils.py @@ -58,7 +58,7 @@ def get_vpp_source_position( log_and_exit("double check failed, should never reach here") # -def get_num_layers_for_this_vstage(pp_rank, vpp_rank, ckpt_ctx): +def get_num_layers_for_this_vpp_stage(pp_rank, vpp_rank, ckpt_ctx): if not ckpt_ctx.uneven_mode: return ckpt_ctx.num_layers // (ckpt_ctx.pp_size * ckpt_ctx.vpp_size) if pp_rank == 0: From b08170469c4bb9d6999a7b3df314bf02f2d82744 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Sat, 14 Jun 2025 15:10:52 +0000 Subject: [PATCH 21/23] Break down complex computation logic to improve code readability. --- tools/checkpoint/pp_to_vpp/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tools/checkpoint/pp_to_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py index 28f2c1985..1be645013 100644 --- a/tools/checkpoint/pp_to_vpp/utils.py +++ b/tools/checkpoint/pp_to_vpp/utils.py @@ -34,11 +34,15 @@ def get_vpp_source_position( ckpt_ctx): num_middle_stages = ckpt_ctx.pp_size - 2 # should be non-negative - num_layers_per_middle_virtual_stage = (ckpt_ctx.num_middle_layers // (num_middle_stages * ckpt_ctx.vpp_size)) if num_middle_stages > 0 else 0 + num_layers_per_middle_virtual_stage = (ckpt_ctx.num_middle_layers // (num_middle_stages * ckpt_ctx.vpp_size)) \ + if num_middle_stages > 0 else 0 target_global_layer_idx = target_virtual_idx * num_middle_stages * num_layers_per_middle_virtual_stage + \ - sum(ckpt_ctx.first_vpp_layer_split[:target_virtual_idx]) + sum(ckpt_ctx.last_vpp_layer_split[:target_virtual_idx]) \ - + (0 if target_pp_rank == 0 else ckpt_ctx.first_vpp_layer_split[target_virtual_idx] + (target_pp_rank-1) * num_layers_per_middle_virtual_stage) + sum(ckpt_ctx.first_vpp_layer_split[:target_virtual_idx]) + sum(ckpt_ctx.last_vpp_layer_split[:target_virtual_idx]) + + if target_pp_rank != 0: + target_global_layer_idx += (target_pp_rank - 1) * num_layers_per_middle_virtual_stage + \ + ckpt_ctx.first_vpp_layer_split[target_virtual_idx] prefix_sum = 0 for pp_stage in range(ckpt_ctx.pp_size): From a79c5d5b5534f55b3df93f5e2af3d7046bee3606 Mon Sep 17 00:00:00 2001 From: LI MOU Date: Wed, 25 Jun 2025 08:30:19 +0000 Subject: [PATCH 22/23] fix utest on NVIDIA Platform --- tests/unit_tests/test_convert_checkpoint.py | 12 +++++++----- tools/checkpoint/pp_to_vpp/distrib_optim.py | 6 +++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/test_convert_checkpoint.py b/tests/unit_tests/test_convert_checkpoint.py index b58f3b2d8..b7fc83c71 100644 --- a/tests/unit_tests/test_convert_checkpoint.py +++ b/tests/unit_tests/test_convert_checkpoint.py @@ -239,7 +239,9 @@ def is_state_dict_equal(x, y): def _test_convert_pp_to_vpp_internal(ckpt_dir : Path): rank = dist.get_rank() world_size = dist.get_world_size() - assert world_size == 8, "current only support world_size=8" + if world_size != 8: + print_rank_0("current test_convert_pp_to_vpp only support world_size=8") + return init_num_microbatches_calculator(rank, None, 1, 1, 1) args = create_args( @@ -269,7 +271,8 @@ def _test_convert_pp_to_vpp_internal(ckpt_dir : Path): if rank == 0: # convert model, increase virtual_pipeline_size to 2 command = ( - "mkdir -p {}/iter_{:07d} ".format(args.load, iteration) + + "PYTHONPATH={} ".format(os.path.join(CURDIR, "../..")) + + "&& mkdir -p {}/iter_{:07d} ".format(args.load, iteration) + "&& echo {} > {}/latest_checkpointed_iteration.txt ".format(iteration, args.load) + "&& python {}/../../tools/checkpoint/pp_to_vpp/main.py ".format(CURDIR) + "--load-iteration-dir {}/iter_{:07d} ".format(args.save, iteration) + @@ -283,9 +286,8 @@ def _test_convert_pp_to_vpp_internal(ckpt_dir : Path): subprocess_result = subprocess.run( command, shell = True, - stdout = subprocess.PIPE, - stderr = subprocess.PIPE, text = True) + print_rank_0(f"convert finished, exit code : {subprocess_result.returncode}") assert subprocess_result.returncode == 0 dist.barrier() @@ -294,7 +296,7 @@ def _test_convert_pp_to_vpp_internal(ckpt_dir : Path): args.num_virtual_stages_per_pipeline_rank = 2 args.perform_initialization = False reset_parallel_state(args) - + new_model, new_optimizer, new_opt_scheduler = get_checkpoint_content(args) print_rank_0("loading checkpoint with virtual_pipeline_size=2") diff --git a/tools/checkpoint/pp_to_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py index 711ccf1be..9123f5de9 100644 --- a/tools/checkpoint/pp_to_vpp/distrib_optim.py +++ b/tools/checkpoint/pp_to_vpp/distrib_optim.py @@ -47,7 +47,7 @@ def _fetch_opt_parameters( start_offset, end_offset = -1, -1 upper_bound_layer_idx = -1 # variable for double-check for (k, v) in state_dict_model.items(): - if not hasattr(v, "nelement"): + if not hasattr(v, "nelement") or k.endswith("._extra_state"): continue if "router.expert_bias" in k: # in Megatron-LM @@ -194,7 +194,7 @@ def concat_parameter(new_parameters_dict, num_elements_tail = 0 for (k, v) in reversed(vmodel_dict.items()): if "final_layernorm." in k or "output_layer." in k: - if hasattr(v, "nelement"): + if hasattr(v, "nelement") and not k.endswith("._extra_state"): num_elements_tail += v.nelement() else: break @@ -236,7 +236,7 @@ def concat_parameter(new_parameters_dict, num_elements_head = 0 for (k, v) in vmodel_dict.items(): if k.startswith("embedding."): - if hasattr(v, "nelement"): + if hasattr(v, "nelement") and not k.endswith("._extra_state"): num_elements_head += v.nelement() else: break From 7bfd7121686510560467c9320d67978b6bd32fd6 Mon Sep 17 00:00:00 2001 From: limou102 Date: Thu, 26 Jun 2025 09:58:36 +0800 Subject: [PATCH 23/23] fix env --- tests/unit_tests/test_convert_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/test_convert_checkpoint.py b/tests/unit_tests/test_convert_checkpoint.py index b7fc83c71..35ae3bd26 100644 --- a/tests/unit_tests/test_convert_checkpoint.py +++ b/tests/unit_tests/test_convert_checkpoint.py @@ -271,7 +271,7 @@ def _test_convert_pp_to_vpp_internal(ckpt_dir : Path): if rank == 0: # convert model, increase virtual_pipeline_size to 2 command = ( - "PYTHONPATH={} ".format(os.path.join(CURDIR, "../..")) + + "export PYTHONPATH={} ".format(os.path.join(CURDIR, "../..")) + "&& mkdir -p {}/iter_{:07d} ".format(args.load, iteration) + "&& echo {} > {}/latest_checkpointed_iteration.txt ".format(iteration, args.load) + "&& python {}/../../tools/checkpoint/pp_to_vpp/main.py ".format(CURDIR) + @@ -333,4 +333,4 @@ def _test_convert_pp_to_vpp_internal(ckpt_dir : Path): def test_convert_pp_to_vpp(tmp_path_dist_ckpt): Utils.initialize_distributed() with TempNamedDir(tmp_path_dist_ckpt / "test_convert_checkpoint", sync=True) as ckpt_dir: - _test_convert_pp_to_vpp_internal(ckpt_dir) \ No newline at end of file + _test_convert_pp_to_vpp_internal(ckpt_dir)