diff --git a/tests/unit_tests/test_convert_checkpoint.py b/tests/unit_tests/test_convert_checkpoint.py new file mode 100644 index 000000000..35ae3bd26 --- /dev/null +++ b/tests/unit_tests/test_convert_checkpoint.py @@ -0,0 +1,336 @@ +import os +import sys +import subprocess +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +import torch.distributed as dist + +from tests.unit_tests.test_utilities import Utils +from tests.unit_tests.dist_checkpointing import TempNamedDir + +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer +from megatron.training.checkpointing import load_checkpoint, save_checkpoint +from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator +from megatron.training.global_vars import set_args +from megatron.training.utils import print_rank_0 + +CURDIR = os.path.dirname(os.path.abspath(__file__)) + +def create_args( + num_layers, + hidden_size, + num_attn_heads, + pipeline_parallel_size, + ckpt_dir): + args = SimpleNamespace() + args.finetune = False + args.non_persistent_global_ckpt_dir = None + args.non_persistent_ckpt_type = None + args.non_persistent_save_interval = None + args.exit_on_missing_checkpoint = True + args.async_save = False + args.data_parallel_random_init = False + args.no_save_optim = False + args.no_save_rng = False + args.no_load_optim = False + args.no_load_rng = False + args.log_progress = False + args.ckpt_fully_parallel_save = False + args.auto_detect_ckpt_format = False + args.retro_add_retriever = False + args.ckpt_convert_update_legacy_dist_opt_format = False + args.ckpt_step = None + args.use_distributed_optimizer = True + args.use_dist_ckpt = False + args.consumed_train_samples = 0 + args.skipped_train_samples = 0 + args.consumed_valid_samples = 0 + args.add_position_embedding = False + args.vocab_file = None + args.tensor_model_parallel_size = 1 + args.ckpt_format = "torch" + args.ckpt_isolated_save = True + args.local_rank = int(os.environ["LOCAL_RANK"]) + args.ckpt_upload_blob_path = None + args.perform_initialization = True + args.num_virtual_stages_per_pipeline_rank = None + args.num_layers = num_layers + args.hidden_size = hidden_size + args.num_attention_heads = num_attn_heads + args.pipeline_model_parallel_size = pipeline_parallel_size + args.normalization = "RMSNorm" + args.transformer_impl = "transformer_engine" + args.expert_model_parallel_size = 1 + args.save = ckpt_dir / "save_dir" + args.load = ckpt_dir / "load_dir" + + return args + +def get_checkpoint_content(args): + + def model_provider(args): + transformer_config = TransformerConfig( + add_bias_linear = False, + params_dtype = torch.bfloat16, + pipeline_dtype = torch.bfloat16, + normalization = args.normalization, + num_layers = args.num_layers, + hidden_size = args.hidden_size, + num_attention_heads = args.num_attention_heads, + tensor_model_parallel_size = args.tensor_model_parallel_size, + pipeline_model_parallel_size = args.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size = args.num_virtual_stages_per_pipeline_rank, + perform_initialization = args.perform_initialization) + + def get_model(): + transformer_layer_spec = get_gpt_decoder_block_spec( + transformer_config, + use_transformer_engine = args.transformer_impl == "transformer_engine") + return GPTModel( + config = transformer_config, + transformer_layer_spec = transformer_layer_spec, + position_embedding_type = "rope", + vocab_size = 32, + max_sequence_length = 32, + pre_process = mpu.is_pipeline_first_stage(), + post_process = mpu.is_pipeline_last_stage()) + + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) + model = [] + if args.num_virtual_stages_per_pipeline_rank \ + and args.num_virtual_stages_per_pipeline_rank > 1: + model = [] + for i in range(args.num_virtual_stages_per_pipeline_rank): + mpu.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + this_model = DistributedDataParallel(transformer_config, + ddp_config, get_model()) + model.append(this_model) + else: + model.append(DistributedDataParallel(transformer_config, + ddp_config, get_model())) + return model + + model = model_provider(args) + + optimizer_config = OptimizerConfig( + optimizer='adam', + bf16=True, + use_distributed_optimizer=True, + params_dtype = torch.bfloat16, + lr = 1e-6, + min_lr = 1e-9) + optimizer = get_megatron_optimizer(optimizer_config, model) + optimizer.step() + + class MockState: + def __init__(self, state_dict): + self._state_dict = state_dict + self.is_stub_optimizer = False + + def state_dict(self, is_loading=False): + return self._state_dict + + def load_state_dict(self, state_dict): + self._state_dict = state_dict + + def save_parameter_state(self, *args, **kwargs): + pass + + def load_parameter_state(self, *args, **kwargs): + pass + + opt_scheduler = MockState({"opt_param_scheduler": "scheduler_state"}) + + return (model, optimizer, opt_scheduler) + +def reset_parallel_state(args): + Utils.initialize_model_parallel( + tensor_model_parallel_size = args.tensor_model_parallel_size, + pipeline_model_parallel_size = args.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size = args.num_virtual_stages_per_pipeline_rank) + model_parallel_cuda_manual_seed(123) + +def get_global_state(args, model, optimizer): + + def get_global_layer_index(num_layers, + pp_size, + vpp_size, + current_pp_rank, + current_vpp_rank, + current_local_layer_index): + num_layers_per_pipeline_rank = num_layers // pp_size + if vpp_size is None or vpp_size == 1: + return current_pp_rank * num_layers_per_pipeline_rank + current_local_layer_index + num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vpp_size + total_virtual_chunks = num_layers // vpp_size + return current_vpp_rank * total_virtual_chunks + ( + current_pp_rank * num_layers_per_virtual_rank) + + def to_cpu(x): + if torch.is_tensor(x): + return x.to("cpu") + for k in x: + x[k] = to_cpu(x[k]) + return x + + current_pp_rank = mpu.get_pipeline_model_parallel_rank() + + global_model_state = {} + global_optimizer_state = {} + for vpp_idx, model_chunk in enumerate(model): + for name, param in model_chunk.named_parameters(): + key = name + if ".layers." in key: + layer_idx = int(key.split(".layers.")[1].split(".")[0]) + global_layer_idx = get_global_layer_index( + args.num_layers, + args.pipeline_model_parallel_size, + args.num_virtual_stages_per_pipeline_rank, + current_pp_rank, + vpp_idx, + layer_idx) + key = key.replace(f".layers.{layer_idx}", f".layers.{global_layer_idx}") + optimizer_param = optimizer.chained_optimizers[0]._get_main_param_and_optimizer_states(param) + global_model_state[key] = to_cpu(param) + global_optimizer_state[key] = to_cpu(optimizer_param) + + return (global_model_state, global_optimizer_state) + +def merge_state_dict_to_pipeline_rank0(local_dict): + pipeline_group = mpu.get_pipeline_model_parallel_group() + rank_in_group = dist.get_rank(group=pipeline_group) + world_size = dist.get_world_size(group=pipeline_group) + + gathered_list = [None for _ in range(world_size)] if rank_in_group == 0 else None + dist.gather_object(local_dict, gathered_list, dst=0, group=pipeline_group) + + if rank_in_group == 0: + merged_dict = {} + for d in gathered_list: + merged_dict.update(d) + return merged_dict + return None + +def is_state_dict_equal(x, y): + if x.keys() != y.keys(): + return False + for k in x: + if isinstance(x[k], dict): + assert isinstance(y[k], dict) + if not is_state_dict_equal(x[k], y[k]): + return False + else: + assert torch.is_tensor(x[k]) + assert torch.is_tensor(y[k]) + if not torch.equal(x[k], y[k]): + return False + return True + +def _test_convert_pp_to_vpp_internal(ckpt_dir : Path): + rank = dist.get_rank() + world_size = dist.get_world_size() + if world_size != 8: + print_rank_0("current test_convert_pp_to_vpp only support world_size=8") + return + init_num_microbatches_calculator(rank, None, 1, 1, 1) + + args = create_args( + num_layers=16, + hidden_size=64, + num_attn_heads=4, + pipeline_parallel_size=8, + ckpt_dir=ckpt_dir) + + set_args(args) + + reset_parallel_state(args) + + model, optimizer, opt_scheduler = get_checkpoint_content(args) + + # save model with virtual_pipeline_size=1 + iteration = 123 + flops = 456 + print_rank_0("saving checkpoint with virtual_pipeline_size=1...") + save_checkpoint(iteration, model, optimizer, opt_scheduler, flops) + + + global_model_state, global_optimizer_state = get_global_state(args, model, optimizer) + global_model_state = merge_state_dict_to_pipeline_rank0(global_model_state) + global_optimizer_state = merge_state_dict_to_pipeline_rank0(global_optimizer_state) + + if rank == 0: + # convert model, increase virtual_pipeline_size to 2 + command = ( + "export PYTHONPATH={} ".format(os.path.join(CURDIR, "../..")) + + "&& mkdir -p {}/iter_{:07d} ".format(args.load, iteration) + + "&& echo {} > {}/latest_checkpointed_iteration.txt ".format(iteration, args.load) + + "&& python {}/../../tools/checkpoint/pp_to_vpp/main.py ".format(CURDIR) + + "--load-iteration-dir {}/iter_{:07d} ".format(args.save, iteration) + + "--expert-model-parallel-size 1 " + + "--pipeline-model-parallel-size 8 " + + "--save-iteration-dir {}/iter_{:07d} ".format(args.load, iteration) + + "--target-virtual-pipeline-model-parallel-size 2 " + + "--num-max-processing-processes 2 " + ) + print_rank_0("converting checkpoint from virtual_pipeline_size from 1 to 2") + subprocess_result = subprocess.run( + command, + shell = True, + text = True) + print_rank_0(f"convert finished, exit code : {subprocess_result.returncode}") + assert subprocess_result.returncode == 0 + + dist.barrier() + + # change virtual_pipeline_size to 2 and load the model converted + args.num_virtual_stages_per_pipeline_rank = 2 + args.perform_initialization = False + reset_parallel_state(args) + + new_model, new_optimizer, new_opt_scheduler = get_checkpoint_content(args) + + print_rank_0("loading checkpoint with virtual_pipeline_size=2") + loaded_iter, loaded_flops = load_checkpoint( + new_model, new_optimizer, new_opt_scheduler, strict=True + ) + + # check iteration and flops are equal + assert loaded_iter == iteration and loaded_flops == flops + + new_global_model_state, new_global_optimizer_state = get_global_state(args, new_model, new_optimizer) + new_global_model_state = merge_state_dict_to_pipeline_rank0(new_global_model_state) + new_global_optimizer_state = merge_state_dict_to_pipeline_rank0(new_global_optimizer_state) + + if mpu.get_pipeline_model_parallel_rank() == 0: + # check model_state and optimizer parameter state are equal + global_model_state_equal = is_state_dict_equal(global_model_state, new_global_model_state) + assert global_model_state_equal + global_optimizer_state_equal = is_state_dict_equal(global_optimizer_state, new_global_optimizer_state) + assert global_optimizer_state_equal + + +""" +launch test with command: + +torchrun \ + --nproc_per_node 8 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr localhost \ + --master_port 50326 \ + -m pytest -vx test_convert_checkpoint.py +""" +def test_convert_pp_to_vpp(tmp_path_dist_ckpt): + Utils.initialize_distributed() + with TempNamedDir(tmp_path_dist_ckpt / "test_convert_checkpoint", sync=True) as ckpt_dir: + _test_convert_pp_to_vpp_internal(ckpt_dir) diff --git a/tools/checkpoint/pp_to_vpp/README.md b/tools/checkpoint/pp_to_vpp/README.md new file mode 100644 index 000000000..d64ff2380 --- /dev/null +++ b/tools/checkpoint/pp_to_vpp/README.md @@ -0,0 +1,150 @@ +# pp_to_vpp +## description +This tool can convert a language model checkpoint without virtual pipeline parallelism into one with virtual pipeline parallelism by increasing the virtual pipeline stage size. + +Other model parallel parameters (tensor-parallel-size, pipeline-parallel-size, expert-parallel-size ...) remain unchanged. + +--- + +**(2025-05-30)** It now supports uneven pipeline mode, as well as cases where the number of layers in a pipeline stage is not divisible by the virtual pipeline degree. + +see arguments: +``` +--target-first-virtual-pipeline-num-layers-split +--target-last-virtual-pipeline-num-layers-split +``` +The above two parameters must either both be provided(or both be omitted), indicating that uneven pipeline mode is enabled + and specifying the virtual pipeline layer distribution for the first and last pipeline stages(this distribution may be even, but it still needs to be explicitly provided). + +This feature was introduced based on the following Pull Request. + + https://github.com/microsoft/ltp-megatron-lm/pull/27 + + The model after converted needs to be loaded using a Megatron-LM framework that has this Pull Request applied. + +--- + +**Currently, tests have been conducted on the DeepSeek(v2, v3) and Mixtral models.** + +Note that currently, all of the following configurations must be satisfied to be supported. + tensor_parallel_size=1 + ckpt_format=torch +so the checkpoint for each iteration folder should look like this: +``` +iter_0000050 +├── mp_rank_00_000_000 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +├── mp_rank_00_000_001 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +├── mp_rank_00_000_002 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +... +``` + + +## how to use +you can modify run_convert_pp_to_vpp.sh and launch it as an example +``` +usage: main.py [-h] --load-iteration-dir LOAD_ITERATION_DIR --expert-model-parallel-size EXPERT_MODEL_PARALLEL_SIZE --pipeline-model-parallel-size PIPELINE_MODEL_PARALLEL_SIZE + --save-iteration-dir SAVE_ITERATION_DIR --target-virtual-pipeline-model-parallel-size TARGET_VIRTUAL_PIPELINE_MODEL_PARALLEL_SIZE + [--target-first-virtual-pipeline-num-layers-split TARGET_FIRST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT [TARGET_FIRST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT ...]] + [--target-last-virtual-pipeline-num-layers-split TARGET_LAST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT [TARGET_LAST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT ...]] + [--num-max-processing-processes NUM_MAX_PROCESSING_PROCESSES] [--pipeline-ranks-to-process PIPELINE_RANKS_TO_PROCESS] + +convert a non-virtual pipeline checkpoint to virtual pipeline checkpoint + +options: + -h, --help show this help message and exit + --load-iteration-dir LOAD_ITERATION_DIR + iteration folder of source model checkpoint + --expert-model-parallel-size EXPERT_MODEL_PARALLEL_SIZE + ep_size of original model and the target model + --pipeline-model-parallel-size PIPELINE_MODEL_PARALLEL_SIZE + physical pp_size of original model and the target model + --save-iteration-dir SAVE_ITERATION_DIR + iteration folder of target model checkpoint, need to be empty if existed + --target-virtual-pipeline-model-parallel-size TARGET_VIRTUAL_PIPELINE_MODEL_PARALLEL_SIZE + vpp_size of target model + --target-first-virtual-pipeline-num-layers-split TARGET_FIRST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT [TARGET_FIRST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT ...] + only used in uneven pipeline mode, virtual pipeline split of the first stage + --target-last-virtual-pipeline-num-layers-split TARGET_LAST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT [TARGET_LAST_VIRTUAL_PIPELINE_NUM_LAYERS_SPLIT ...] + only used in uneven pipeline mode, virtual pipeline split of the last stage + --num-max-processing-processes NUM_MAX_PROCESSING_PROCESSES + the maximum number of processing processes used by this script, increasing this value can speed up model conversion(but the final bottleneck may be disk + bandwidth), it will also consume more CPU memory. + --pipeline-ranks-to-process PIPELINE_RANKS_TO_PROCESS + pipeline rank list to process using this script, to accelerate converting user can launch multiple tasks on different nodes, each one process part of pipeline + ranks. example : --pipeline-ranks-to-process 0 1 2 3 default is None, means process all pipeline ranks +``` + +## examples +1) The target model has virtual_pipeline_size=2, and uses 4 processes in parallel. +``` +python main.py \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --expert-model-parallel-size 4 \ + --pipeline-model-parallel-size 2 \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ + --target-virtual-pipeline-model-parallel-size 2 \ + --num-max-processing-processes 4 +``` + +2) Convert the checkpoints generated by pipeline ranks [0,1,2,3] on node 1, and convert the checkpoints generated by pipeline ranks [4,5,6,7] on node 2. (in cases where memory is limited on a single node) +``` +# node1 : +python main.py \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --expert-model-parallel-size 8 \ + --pipeline-model-parallel-size 8 \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ + --target-virtual-pipeline-model-parallel-size 2 \ + --num-max-processing-processes 4 \ + --pipeline-ranks-to-process 0 1 2 3 + +# node2: +python main.py \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --expert-model-parallel-size 8 \ + --pipeline-model-parallel-size 8 \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ + --target-virtual-pipeline-model-parallel-size 2 \ + --num-max-processing-processes 4 \ + --pipeline-ranks-to-process 4 5 6 7 +``` + +3) convert a model with uneven pipeline mode, which was saved by Megatron-LM with arguments +``` +--decoder-first-pipeline-num-layers 8 +--decoder-last-pipeline-num-layers 7 +``` + +``` +# suppose pipeline_parallel_size=4, the model contains 31 layers in total, the layers distribution for each pipeline stages is [8, 8, 8, 7] +# now we use this model to inscrease virtual pipeline size to 2, +# the layer split in first pipeline stage is [4, 4] and the layer split in last pipeline stage is [4, 3] +# vpp0 vpp1 +# pp0 0, 1, 2, 3 16,17,18,19 +# pp1 4, 5, 6, 7 20,21,22,23 +# pp2 8, 9,10,11 24,25,26,27 +# pp3 12,13,14,15 28,29,30 + +python main.py \ + --load-iteration-dir /path/to/src_checkpoints/iter_0000050 \ + --save-iteration-dir /path/to/dst_checkpoints/iter_0000050 \ + --expert-model-parallel-size 8 \ + --pipeline-model-parallel-size 4 \ + --target-virtual-pipeline-model-parallel-size 2 \ + --target-first-virtual-pipeline-num-layers-split 4 4 \ + --target-last-virtual-pipeline-num-layers-split 4 3 \ + --num-max-processing-processes 8 +``` + +Some training logs from the tests are available in the **logs** directory for review. + +## NOTE +It's also possible to continue training by loading only the model weights without loading the optimizer state (add **--no-load-optim** argument when launch Megatron-LM, which will reset the optimizer), though performance may recover after training for a few more iterations. + + diff --git a/tools/checkpoint/pp_to_vpp/distrib_optim.py b/tools/checkpoint/pp_to_vpp/distrib_optim.py new file mode 100644 index 000000000..9123f5de9 --- /dev/null +++ b/tools/checkpoint/pp_to_vpp/distrib_optim.py @@ -0,0 +1,278 @@ +import os +import logging + +from utils import ( + RetainLogLevel, + log_and_exit, + get_folder_name, + get_vpp_source_position, + MODEL_OPTIM_RNG_FILENAME, + DISTRIB_OPTIM_FILENAME, + get_num_layers_for_this_vpp_stage, +) + +import torch + +logger = logging.getLogger(__name__) + +def _fetch_opt_parameters( + args, + src_pp_rank, + src_ep_rank, + src_start_layer_idx, + num_layers_for_this_virtual_stage, + count_dense : bool, + count_experts : bool): + src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, src_pp_rank, src_ep_rank)) + src_model_opt_rng_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) + src_distrib_optim_path = os.path.join(src_folder_path, DISTRIB_OPTIM_FILENAME) + + with RetainLogLevel(): + state_dict_model = torch.load(src_model_opt_rng_path, map_location="cpu", weights_only=False)["model"] + state_dict_disopts = torch.load(src_distrib_optim_path, map_location="cpu") + if not isinstance(state_dict_disopts, list): + state_dict_disopts = [state_dict_disopts] + + if count_dense and count_experts: + assert len(state_dict_disopts)==1, "length of state_dict_disopts does not equal to 1" + state_dict_disopt = state_dict_disopts[0] + elif count_dense: + state_dict_disopt = state_dict_disopts[0] + else: + assert count_experts + assert len(state_dict_disopts)==2 + state_dict_disopt = state_dict_disopts[1] + + current_offset = 0 + start_offset, end_offset = -1, -1 + upper_bound_layer_idx = -1 # variable for double-check + for (k, v) in state_dict_model.items(): + if not hasattr(v, "nelement") or k.endswith("._extra_state"): + continue + if "router.expert_bias" in k: + # in Megatron-LM + # router.expert_bias is initialized by calling register_buffer + continue + if not ".layers." in k: + if start_offset != -1: + end_offset = current_offset + upper_bound_layer_idx += 1 + break + current_offset += v.nelement() if count_dense else 0 + continue + layer_idx = int(k.split(".layers.")[1].split(".")[0]) + + assert layer_idx >= upper_bound_layer_idx, \ + "double check failed,layer_idx stored unordered in checkpoint" + upper_bound_layer_idx = layer_idx + + if layer_idx == src_start_layer_idx and start_offset == -1: + start_offset = current_offset + if layer_idx == src_start_layer_idx + num_layers_for_this_virtual_stage \ + and start_offset != -1: + end_offset = current_offset + break + if ".mlp.experts." in k: + current_offset += v.nelement() if count_experts else 0 + else: + current_offset += v.nelement() if count_dense else 0 + assert start_offset != -1 + if end_offset == -1: + end_offset = current_offset + upper_bound_layer_idx += 1 + + opt_parameters_dict = next(iter(state_dict_disopt[0].values())) + (param_part, exp_avg_part, exp_avg_sq_part) = ( + opt_parameters_dict["param"][-end_offset:(-start_offset if start_offset!=0 else None)], + opt_parameters_dict["exp_avg"][-end_offset:(-start_offset if start_offset!=0 else None)], + opt_parameters_dict["exp_avg_sq"][-end_offset:(-start_offset if start_offset!=0 else None)], + ) + + num_layers_remain = num_layers_for_this_virtual_stage - upper_bound_layer_idx + src_start_layer_idx + if num_layers_remain > 0: + # The source parameters spans two pipeline stages + # and this recursion is executed at most once. + (next_level_param_part, next_level_exp_avg_part, next_level_exp_avg_sq_part) = _fetch_opt_parameters( + args, + src_pp_rank + 1, + src_ep_rank, + 0, + num_layers_remain, + count_dense, + count_experts) + + param_part = torch.cat((next_level_param_part, param_part)) + exp_avg_part = torch.cat((next_level_exp_avg_part, exp_avg_part)) + exp_avg_sq_part = torch.cat((next_level_exp_avg_sq_part, exp_avg_sq_part)) + + return (param_part, exp_avg_part, exp_avg_sq_part) + +def convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict, ckpt_ctx): + src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) + src_file_path = os.path.join(src_folder_path, DISTRIB_OPTIM_FILENAME) + + logger.info(f"loading distrib_optim state dict {src_file_path} ...") + target_disopt_state_dicts = torch.load(src_file_path, map_location="cpu") + + if not isinstance(target_disopt_state_dicts, list): + logger.info("target_disopt_state_dicts is not a list, pp_size={}, ep_size={}, pp_rank={}, ep_rank={}".format( + ckpt_ctx.pp_size, + ckpt_ctx.ep_size, + target_pp_rank, + target_ep_rank)) + target_disopt_state_dicts = [target_disopt_state_dicts] + else: + logger.info("length of target_disopt_state_dicts : {}".format(len(target_disopt_state_dicts))) + + """ + in distrib_optim.pt + for ep=1, len(target_disopt_state_dicts) is always 1 containing all parameters + for ep>1: + if ep_rank=0 + target_disopt_state_dicts[0] contains non-experts parameters + target_disopt_state_dicts[i>0] only contains experts parameters if experts exists + if ep_rank>1 + target_disopt_state_dicts[0] is None + target_disopt_state_dicts[i>0] only contains experts parameters if experts exists + """ + + ### + for (i, target_disopt_state_dict) in enumerate(target_disopt_state_dicts): + if target_disopt_state_dict is None: + continue + if 0 not in target_disopt_state_dict: + log_and_exit("0 is not a key of target_disopt_state_dict, keys : {}".format(target_disopt_state_dict.keys())) + src_disopt_state_dict = target_disopt_state_dict[0] + if len(src_disopt_state_dict) != 1: + log_and_exit("length of src_disopt_state_dict is not 1, keys : {}".format(src_disopt_state_dict.keys())) + type_key = next(iter(src_disopt_state_dict)) + opt_parameters_dict = src_disopt_state_dict[type_key] + if target_pp_rank==0 and target_ep_rank<=1: + logger.info("[pp_rank=0][ep_rank={}] keys of opt_parameters_dict[{}] : {}".format( + target_ep_rank, i, opt_parameters_dict.keys())) + + def concat_parameter(new_parameters_dict, + param_master, + param_exp_avg, + param_exp_avg_sq, + slice_range): + new_parameters_dict["param"] = torch.cat((new_parameters_dict["param"], + param_master[slice_range])) + new_parameters_dict["exp_avg"] = torch.cat((new_parameters_dict["exp_avg"], + param_exp_avg[slice_range])) + new_parameters_dict["exp_avg_sq"] = torch.cat((new_parameters_dict["exp_avg_sq"], + param_exp_avg_sq[slice_range])) + + vdisopts = [] + for vidx in range(ckpt_ctx.vpp_size): + """ + value in opt_parameters_dict are flatened tensors and were saved in reverse order compare to model state_dict + for example: + model_ordered_state_dict(key-tensor pair) : { + k1 : t1, + k2 : t2, + k3 : t3, + ... + kn : tn + } + opt_parameters_dict["param"] : [tn, ... t3, t2, t1] + opt_parameters_dict["exp_avg"] : [tn, ... t3, t2, t1] + """ + num_elements_in_vmodel = 0 + vmodel_dict = target_model_state_dict[f"model{vidx}"] + new_parameters_dict = { + "param" : torch.tensor([], dtype=opt_parameters_dict["param"].dtype, device="cpu"), + "exp_avg" : torch.tensor([], dtype=opt_parameters_dict["exp_avg"].dtype, device="cpu"), + "exp_avg_sq" : torch.tensor([], dtype=opt_parameters_dict["exp_avg_sq"].dtype, device="cpu"), + "numel_unpadded" : 0, + } + + # final_layernorm and output_layer + if target_pp_rank == ckpt_ctx.pp_size - 1 \ + and vidx == ckpt_ctx.vpp_size - 1 \ + and i == 0: + num_elements_tail = 0 + for (k, v) in reversed(vmodel_dict.items()): + if "final_layernorm." in k or "output_layer." in k: + if hasattr(v, "nelement") and not k.endswith("._extra_state"): + num_elements_tail += v.nelement() + else: + break + logger.debug(f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] apply output_layer, " + f"num_elements_tail={num_elements_tail}") + assert num_elements_tail != 0, f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] num_elements_tail is 0" + + concat_parameter(new_parameters_dict, + opt_parameters_dict["param"], + opt_parameters_dict["exp_avg"], + opt_parameters_dict["exp_avg_sq"], + slice(None, num_elements_tail)) + num_elements_in_vmodel += num_elements_tail + + # middle layer + src_pp_rank, src_start_layer_idx = get_vpp_source_position( + target_pp_rank, + vidx, + ckpt_ctx) + num_layers_for_this_virtual_stage = get_num_layers_for_this_vpp_stage(target_pp_rank, vidx, ckpt_ctx) + param_part, exp_avg_part, exp_avg_sq_part = _fetch_opt_parameters( + args, + src_pp_rank, + target_ep_rank, + src_start_layer_idx, + num_layers_for_this_virtual_stage, + i==0, + i>0 or len(target_disopt_state_dicts)==1) + + concat_parameter(new_parameters_dict, + param_part, + exp_avg_part, + exp_avg_sq_part, + slice(None, None)) + num_elements_in_vmodel += param_part.nelement() + + # embedding layer + if target_pp_rank == 0 and vidx == 0 and i == 0: + num_elements_head = 0 + for (k, v) in vmodel_dict.items(): + if k.startswith("embedding."): + if hasattr(v, "nelement") and not k.endswith("._extra_state"): + num_elements_head += v.nelement() + else: + break + logger.debug(f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] apply embedding, " + f"num_elements_head={num_elements_head}") + assert num_elements_head != 0, f"[pp_rank={target_pp_rank}][ep_rank={target_ep_rank}] num_elements_head is 0" + + concat_parameter(new_parameters_dict, + opt_parameters_dict["param"], + opt_parameters_dict["exp_avg"], + opt_parameters_dict["exp_avg_sq"], + slice(-num_elements_head, None)) + num_elements_in_vmodel += num_elements_head + + new_parameters_dict["numel_unpadded"] = num_elements_in_vmodel + + for k, v in new_parameters_dict.items(): + if torch.is_tensor(v): + # .clone.detach() is needed for removing unneeded data + # and decrease checkpoint file size + new_parameters_dict[k] = v.clone().detach() + + if num_elements_in_vmodel != 0: + vdisopts.append({type_key : new_parameters_dict}) + + logger.debug(f"i={i}, vidx={vidx}, num_elements_in_vmodel={num_elements_in_vmodel}") + + for vidx, vdisopt in enumerate(vdisopts): + target_disopt_state_dict[vidx] = vdisopt + + # save + target_folder_path = os.path.join(args.save_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) + os.makedirs(target_folder_path, exist_ok = True) + target_file_path = os.path.join(target_folder_path, DISTRIB_OPTIM_FILENAME) + + if len(target_disopt_state_dicts) == 1: + target_disopt_state_dicts = target_disopt_state_dicts[0] + logger.info(f"saving distrib_optim to {target_file_path} ...") + torch.save(target_disopt_state_dicts, target_file_path) diff --git a/tools/checkpoint/pp_to_vpp/main.py b/tools/checkpoint/pp_to_vpp/main.py new file mode 100644 index 000000000..771407ad6 --- /dev/null +++ b/tools/checkpoint/pp_to_vpp/main.py @@ -0,0 +1,84 @@ +import os +import ast +import sys +import logging +import argparse +from parallel_convert import convert_checkpoint + +logger = logging.getLogger(__name__) + +def parse_arguments(): + parser = argparse.ArgumentParser(description="convert a non-virtual pipeline checkpoint to virtual pipeline checkpoint") + parser.add_argument("--load-iteration-dir", type=str, required=True, help="iteration folder of source model checkpoint") + parser.add_argument("--expert-model-parallel-size", type=int, required=True, help="ep_size of original model and the target model") + parser.add_argument("--pipeline-model-parallel-size", type=int, required=True, help="physical pp_size of original model and the target model") + + # arguments of target model below + parser.add_argument("--save-iteration-dir", type=str, required=True, help="iteration folder of target model checkpoint, need to be empty if existed") + parser.add_argument("--target-virtual-pipeline-model-parallel-size", type=int, required=True, help="vpp_size of target model") + parser.add_argument("--target-first-virtual-pipeline-num-layers-split", type=int, nargs="+", default=None, + help="only used in uneven pipeline mode, virtual pipeline split of the first stage") + parser.add_argument("--target-last-virtual-pipeline-num-layers-split", type=int, nargs="+", default=None, + help="only used in uneven pipeline mode, virtual pipeline split of the last stage") + + # arguments of acceleration in parallel + parser.add_argument("--num-max-processing-processes", type=int, default=8, + help="the maximum number of processing processes used by this script, " \ + "increasing this value can speed up model conversion(but the final bottleneck may be disk bandwidth), it will also consume more CPU memory.") + parser.add_argument('--pipeline-ranks-to-process', type=int, nargs="+", default=None, + help="pipeline rank list to process using this script, to accelerate converting \ + user can launch multiple tasks on different nodes, each one process part of pipeline ranks. \ + example : --pipeline-ranks-to-process 0 1 2 3 \ + default is None, means process all pipeline ranks") + + args = parser.parse_args() + return args + +""" +This tool can convert a checkpoint without virtual pipeline parallelism into one with virtual pipeline parallelism + by increasing the virtual pipeline stage size. + +(2025-05-30) +It now supports uneven pipeline mode, as well as cases where the number of layers in a pipeline stage is not divisible by the virtual pipeline degree. + see arguments: + --target-first-virtual-pipeline-num-layers-split + --target-last-virtual-pipeline-num-layers-split +The above two parameters must either both be provided(or both be omitted), indicating that uneven pipeline mode is enabled + and specifying the virtual pipeline layer distribution for the first and last pipeline stages. + (this distribution may be even, but it still needs to be explicitly provided.) +This feature was introduced based on the following Pull Request. + https://github.com/microsoft/ltp-megatron-lm/pull/27 + The model after converted needs to be loaded using a Megatron-LM framework that has this Pull Request applied. + +Other model parallel parameters (tensor-parallel-size, pipeline-parallel-size, expert-parallel-size ...) remain unchanged. + +Note that currently, all of the following configurations must be satisfied to be supported. + tensor_parallel_size=1 + ckpt_format=torch +So the checkpoint for each iteration folder should look like this: + +iter_0000005 +├── mp_rank_00_000_000 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +├── mp_rank_00_000_001 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +├── mp_rank_00_000_002 +│ ├── distrib_optim.pt +│ └── model_optim_rng.pt +... + +""" + +def main(args): + logger.info("args : {}".format(args)) + convert_checkpoint(args) + +if __name__ == "__main__": + logging.basicConfig( + level = logging.INFO, + format = "[%(asctime)s][%(levelname)s] %(message)s", + handlers = [logging.StreamHandler()]) + args = parse_arguments() + main(args) diff --git a/tools/checkpoint/pp_to_vpp/model_optim_rng.py b/tools/checkpoint/pp_to_vpp/model_optim_rng.py new file mode 100644 index 000000000..76bf88c69 --- /dev/null +++ b/tools/checkpoint/pp_to_vpp/model_optim_rng.py @@ -0,0 +1,157 @@ +import os +import logging +from collections import OrderedDict + +from utils import ( + RetainLogLevel, + log_and_exit, + get_folder_name, + get_vpp_source_position, + MODEL_OPTIM_RNG_FILENAME, + TargetCkptContext, + get_num_layers_for_this_vpp_stage, +) + +import torch + +logger = logging.getLogger(__name__) + +def _convert_state_dict_args(state_dict, ckpt_ctx): + state_dict_args = state_dict["args"] + if ckpt_ctx.uneven_mode: + state_dict_args.decoder_first_pipeline_num_layers_split = ckpt_ctx.first_vpp_layer_split + state_dict_args.decoder_last_pipeline_num_layers_split = ckpt_ctx.last_vpp_layer_split + + state_dict_args.num_virtual_stages_per_pipeline_rank = ckpt_ctx.vpp_size + state_dict_args.virtual_pipeline_model_parallel_size = ckpt_ctx.vpp_size + state_dict_args.overlap_p2p_comm = True + state_dict_args.align_param_gather = True + if hasattr(state_dict_args, "local_rank"): + delattr(state_dict_args, "local_rank") + if hasattr(state_dict_args, "rank"): + delattr(state_dict_args, "rank") + +def _convert_state_dict_optimizer(state_dict): + # optimizer state dict is equal + state_optimizer = state_dict["optimizer"] + optimizer_states = [state_optimizer] if not isinstance(state_optimizer, list) \ + else state_optimizer + try: + current_step = -1 + param_group_candidates = [] + for opt_state in optimizer_states: + for param_group in opt_state["optimizer"]["param_groups"]: + if "step" in param_group: + current_step = param_group["step"] + else: + param_group_candidates.append(param_group) + if current_step != -1: + for param_group in param_group_candidates: + param_group["step"] = current_step + logger.info(f"add step={current_step} in optimizer state") + except Exception: + logger.warning("add step to optimizer state failed") + +def _convert_state_dict_rng(): + # rng state is equal + pass + +def _fetch_model_state_dict(args, src_pp_rank, src_ep_rank, src_start_layer_idx, num_layers_for_this_virtual_stage, base_layer_idx): + src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, src_pp_rank, src_ep_rank)) + src_file_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) + + logger.debug(f"loading {src_file_path} to fetch source tensors in virtual stage...") + with RetainLogLevel(): + state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) + state_dict_model = state_dict["model"] + + layer_idx_added = set() + outputs = OrderedDict() + for k, v in state_dict_model.items(): + if not ".layers." in k: + continue + layer_idx = int(k.split(".layers.")[1].split(".")[0]) + if src_start_layer_idx <= layer_idx < src_start_layer_idx+num_layers_for_this_virtual_stage: + new_key = k.replace(f".layers.{layer_idx}", f".layers.{layer_idx - src_start_layer_idx + base_layer_idx}") + outputs[new_key] = v.clone().detach() if torch.is_tensor(v) else v + layer_idx_added.add(layer_idx) + + num_layers_remain = num_layers_for_this_virtual_stage - len(layer_idx_added) + if num_layers_remain > 0: + # The source tensor state_dict spans two pipeline stages + # and this recursion is executed at most once. + assert base_layer_idx == 0 + next_level_outputs = _fetch_model_state_dict(args, + src_pp_rank+1, + src_ep_rank, + 0, + num_layers_remain, + len(layer_idx_added)) + outputs.update(next_level_outputs) + return outputs + +def _convert_state_dict_model(args, target_pp_rank, target_ep_rank, state_dict, ckpt_ctx): + state_dict_model = state_dict["model"] + vmodels = [OrderedDict() for i in range(ckpt_ctx.vpp_size)] + + for (vidx, vmodel) in enumerate(vmodels): + if target_pp_rank == 0 and vidx == 0: + for (k, v) in state_dict_model.items(): + if k.startswith("embedding."): + vmodel[k] = v.clone().detach() if torch.is_tensor(v) else v + + src_pp_rank, src_start_layer_idx = get_vpp_source_position( + target_pp_rank, + vidx, + ckpt_ctx) + + num_layers_for_this_virtual_stage = get_num_layers_for_this_vpp_stage(target_pp_rank, vidx, ckpt_ctx) + src_model_state_dict = _fetch_model_state_dict( + args, + src_pp_rank, + target_ep_rank, + src_start_layer_idx, + num_layers_for_this_virtual_stage, + 0) + vmodel.update(src_model_state_dict) + + if target_pp_rank == ckpt_ctx.pp_size-1 and vidx == ckpt_ctx.vpp_size-1: + for (k, v) in state_dict_model.items(): + if "final_layernorm." in k or k.startswith("output_layer."): + vmodel[k] = v.clone().detach() if torch.is_tensor(v) else v + + for i in range(ckpt_ctx.vpp_size): + state_dict[f"model{i}"] = vmodels[i] + + del state_dict["model"] + +def convert_model_optim_rng(args, target_pp_rank, target_ep_rank): + src_folder_path = os.path.join(args.load_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) + src_file_path = os.path.join(src_folder_path, MODEL_OPTIM_RNG_FILENAME) + + logger.info(f"loading model_optim_rng from {src_file_path} ...") + with RetainLogLevel(): + target_state_dict = torch.load(src_file_path, map_location="cpu", weights_only=False) + + if target_pp_rank==0 and target_ep_rank<=1: + logger.info("[pp_rank=0][ep_rank={}] keys of model state dict : {}\n".format(target_ep_rank, target_state_dict.keys())) + + ckpt_ctx = TargetCkptContext(args, target_state_dict) + + _convert_state_dict_args(target_state_dict, ckpt_ctx) + + _convert_state_dict_optimizer(target_state_dict) + + _convert_state_dict_rng() + + _convert_state_dict_model(args, target_pp_rank, target_ep_rank, target_state_dict, ckpt_ctx) + + # save + target_folder_path = os.path.join(args.save_iteration_dir, get_folder_name(args, target_pp_rank, target_ep_rank)) + os.makedirs(target_folder_path, exist_ok = True) + target_file_path = os.path.join(target_folder_path, MODEL_OPTIM_RNG_FILENAME) + + logger.info(f"saving model_optim_rng to {target_file_path} ...") + torch.save(target_state_dict, target_file_path) + + return (target_state_dict, ckpt_ctx) diff --git a/tools/checkpoint/pp_to_vpp/parallel_convert.py b/tools/checkpoint/pp_to_vpp/parallel_convert.py new file mode 100644 index 000000000..4278013a9 --- /dev/null +++ b/tools/checkpoint/pp_to_vpp/parallel_convert.py @@ -0,0 +1,49 @@ +import os +import logging +import itertools +import multiprocessing + +from utils import log_and_exit +from model_optim_rng import convert_model_optim_rng +from distrib_optim import convert_distrib_optim + +logger = logging.getLogger(__name__) + +def _check_output_folder(args): + output_folder = args.save_iteration_dir + if os.path.exists(output_folder): + if not os.path.isdir(output_folder): + log_and_exit(f"output path {output_folder} exists but is not a directory") + else: + os.makedirs(output_folder) + +def _convert_checkpoint_partial(args, target_pp_rank, target_ep_rank): + logger.debug(f"start _convert_checkpoint_partial, pp_rank={target_pp_rank}, ep_rank={target_ep_rank}") + target_model_state_dict, ckpt_ctx = convert_model_optim_rng(args, target_pp_rank, target_ep_rank) + convert_distrib_optim(args, target_pp_rank, target_ep_rank, target_model_state_dict, ckpt_ctx) + +def _func_arguments_wrapper(func_arguments): + args, pp_rank, ep_rank = func_arguments + logger.info("pp_rank={}, ep_rank={} ............".format(pp_rank, ep_rank)) + _convert_checkpoint_partial(args, pp_rank, ep_rank) + +def convert_checkpoint(args): + + _check_output_folder(args) + + if args.pipeline_ranks_to_process is None: + args.pipeline_ranks_to_process = range(args.pipeline_model_parallel_size) + + pp_ranges = args.pipeline_ranks_to_process + ep_ranges = range(args.expert_model_parallel_size) + + func_arguments_tuples = [(args, x, y) for x, y in itertools.product(pp_ranges, ep_ranges)] + logger.info("pp_ranges : {}".format(pp_ranges)) + logger.info("ep_ranges : {}".format(ep_ranges)) + + logger.info(f"start convert with {args.num_max_processing_processes} processes...") + + with multiprocessing.Pool(processes = args.num_max_processing_processes) as pool: + pool.map(_func_arguments_wrapper, func_arguments_tuples) + + logger.info("convert finished") diff --git a/tools/checkpoint/pp_to_vpp/utils.py b/tools/checkpoint/pp_to_vpp/utils.py new file mode 100644 index 000000000..1be645013 --- /dev/null +++ b/tools/checkpoint/pp_to_vpp/utils.py @@ -0,0 +1,152 @@ +import os +import sys +import logging + +logger = logging.getLogger(__name__) + +MODEL_OPTIM_RNG_FILENAME = "model_optim_rng.pt" +DISTRIB_OPTIM_FILENAME = "distrib_optim.pt" + +# torch.load sometimes change loglevel with weights_only=False +# so log message after torch.load does not show on terminal screen +class RetainLogLevel: + def __enter__(self): + self.origin_log_level = logging.getLogger().level + return self + def __exit__(self, exc_type, exc_value, traceback): + logging.getLogger().setLevel(self.origin_log_level) + +def log_and_exit(message): + logger.fatal(message) + raise Exception("exit with fatal error") + +def get_folder_name(args, target_pp_rank, target_ep_rank): + folder_name = "mp_rank_00" + if args.pipeline_model_parallel_size != 1: + folder_name += f"_{target_pp_rank:03d}" + if args.expert_model_parallel_size != 1: + folder_name += f"_{target_ep_rank:03d}" + return folder_name + +def get_vpp_source_position( + target_pp_rank, + target_virtual_idx, + ckpt_ctx): + + num_middle_stages = ckpt_ctx.pp_size - 2 # should be non-negative + num_layers_per_middle_virtual_stage = (ckpt_ctx.num_middle_layers // (num_middle_stages * ckpt_ctx.vpp_size)) \ + if num_middle_stages > 0 else 0 + + target_global_layer_idx = target_virtual_idx * num_middle_stages * num_layers_per_middle_virtual_stage + \ + sum(ckpt_ctx.first_vpp_layer_split[:target_virtual_idx]) + sum(ckpt_ctx.last_vpp_layer_split[:target_virtual_idx]) + + if target_pp_rank != 0: + target_global_layer_idx += (target_pp_rank - 1) * num_layers_per_middle_virtual_stage + \ + ckpt_ctx.first_vpp_layer_split[target_virtual_idx] + + prefix_sum = 0 + for pp_stage in range(ckpt_ctx.pp_size): + if pp_stage == 0: + layers_in_current_stage = sum(ckpt_ctx.first_vpp_layer_split) + elif pp_stage == ckpt_ctx.pp_size - 1: + layers_in_current_stage = sum(ckpt_ctx.last_vpp_layer_split) + else: + layers_in_current_stage = num_layers_per_middle_virtual_stage * ckpt_ctx.vpp_size + if prefix_sum <= target_global_layer_idx < (prefix_sum + layers_in_current_stage): + source_start_layer_idx = target_global_layer_idx-prefix_sum + logger.info(f"get_vpp_source_position, target_pp_rank={target_pp_rank}, target_virtual_idx={target_virtual_idx}; " + f"source_pp_rank={pp_stage}, source_start_layer_idx={source_start_layer_idx}") + return (pp_stage, source_start_layer_idx) + prefix_sum += layers_in_current_stage + + log_and_exit("double check failed, should never reach here") + # + +def get_num_layers_for_this_vpp_stage(pp_rank, vpp_rank, ckpt_ctx): + if not ckpt_ctx.uneven_mode: + return ckpt_ctx.num_layers // (ckpt_ctx.pp_size * ckpt_ctx.vpp_size) + if pp_rank == 0: + return ckpt_ctx.first_vpp_layer_split[vpp_rank] + if pp_rank == ckpt_ctx.pp_size - 1: + return ckpt_ctx.last_vpp_layer_split[vpp_rank] + + num_middle_stages = ckpt_ctx.pp_size - 2 + return ckpt_ctx.num_middle_layers // (num_middle_stages * ckpt_ctx.vpp_size) + + +class TargetCkptContext: + def __init__(self, args, state_dict): + self.vpp_size = args.target_virtual_pipeline_model_parallel_size + self.pp_size = args.pipeline_model_parallel_size + self.ep_size = args.expert_model_parallel_size + self.first_vpp_layer_split = args.target_first_virtual_pipeline_num_layers_split + self.last_vpp_layer_split = args.target_last_virtual_pipeline_num_layers_split + + state_dict_args = state_dict["args"] + self.num_layers = state_dict_args.num_layers + + if state_dict_args.tensor_model_parallel_size != 1: + log_and_exit("currently only tensor_model_parallel_size=1 is supported, but found {} in checkpoint".format( + state_dict_args.tensor_model_parallel_size)) + + if self.vpp_size <= 1: + log_and_exit(f"target_virtual_pipeline_model_parallel_size {self.vpp_size} is smaller or equal to 1") + + if self.ep_size != state_dict_args.expert_model_parallel_size: + log_and_exit("expert_model_parallel_size in args does not match the one in checkpoint, {} vs {}".format( + self.ep_size, state_dict_args.expert_model_parallel_size)) + + if self.pp_size != state_dict_args.pipeline_model_parallel_size: + log_and_exit("pipeline_model_parallel_size in args does not match the one in checkpoint, {} vs {}".format( + self.pp_size, state_dict_args.pipeline_model_parallel_size)) + + if self.first_vpp_layer_split or self.last_vpp_layer_split: + self.uneven_mode = True + + if not (self.first_vpp_layer_split and self.last_vpp_layer_split): + log_and_exit("target_first_virtual_pipeline_num_layers_split and target_last_virtual_pipeline_num_layers_split " + "should be set at the same time for uneven pipeline mode") + + if state_dict_args.decoder_first_pipeline_num_layers != sum(self.first_vpp_layer_split) or \ + state_dict_args.decoder_last_pipeline_num_layers != sum(self.last_vpp_layer_split): + log_and_exit("uneven layer number does not match arguments in state_dict :" + "decoder_first_pipeline_num_layers={}, decoder_last_pipeline_num_layers={}".format( + state_dict_args.decoder_first_pipeline_num_layers, state_dict_args.decoder_last_pipeline_num_layers)) + + num_middle_pipeline_stages = self.pp_size - 2 + if num_middle_pipeline_stages < 0: + log_and_exit("pipeline_model_parallel_size is too small for uneven mode, pipeline_model_parallel_size={}".format( + self.pp_size)) + + if len(self.first_vpp_layer_split) != self.vpp_size or \ + len(self.first_vpp_layer_split) != self.vpp_size: + log_and_exit("length of target_first_virtual_pipeline_num_layers_split and target_last_virtual_pipeline_num_layers_split should " + "equal to target_virtual_pipeline_model_parallel_size") + + self.num_middle_layers = self.num_layers - sum(self.first_vpp_layer_split) \ + - sum(self.last_vpp_layer_split) + if num_middle_pipeline_stages > 0: + if self.num_middle_layers <= 0: + log_and_exit("num_middle_layers can not be non-positve, " + "num_middle_layers={}, num_middle_pipeline_stages={}".format(self.num_middle_layers, num_middle_pipeline_stages)) + if self.num_middle_layers % (num_middle_pipeline_stages*self.vpp_size) != 0: + log_and_exit("num_middle_layers can not be evenly divided by " + "num_middle_pipeline_stages*target_virtual_pipeline_model_parallel_size, " + "num_middle_layers={}, num_middle_pipeline_stages={}".format(self.num_middle_layers, num_middle_pipeline_stages)) + elif self.num_middle_layers > 0: + log_and_exit("insufficient num_middle_pipeline_stages, " + "num_middle_layers={}, num_middle_pipeline_stages={}".format(self.num_middle_layers, num_middle_pipeline_stages)) + else: + self.uneven_mode = False + if self.num_layers % (self.pp_size * self.vpp_size) != 0: + log_and_exit("for even pipeline mode, num_layers can not be evenly divided " + "pipeline_model_parallel_size*target_virtual_pipeline_model_parallel_size, " + "num_layers={}, pipeline_model_parallel_size={}, target_virtual_pipeline_model_parallel_size={}".format( + self.num_layers, self.pp_size, self.vpp_size)) + + num_layers_per_pp = self.num_layers // self.pp_size + self.num_middle_layers = self.num_layers - num_layers_per_pp * 2 + + num_layers_per_vpp_stage = self.num_layers // (self.pp_size * self.vpp_size) + self.first_vpp_layer_split = [num_layers_per_vpp_stage for _ in range(self.vpp_size)] + self.last_vpp_layer_split = [num_layers_per_vpp_stage for _ in range(self.vpp_size)]