From 307e4456286a2bd5fa3db1f726a4c0aac8302e20 Mon Sep 17 00:00:00 2001 From: kaiyuan Date: Fri, 5 Jun 2026 09:02:02 +0800 Subject: [PATCH] feat: shard-level P2P weight transfer with persistent NCCL groups - Each DP=0 TP rank converts its own shard via convert_to_hf_shard and sends directly to matching vLLM inference TP rank via dedicated NCCL group - NCCL groups are persistent: created once, reused across update_weights calls, only disconnected at shutdown (fixes intermittent reconnection hang) - Group naming: vime-pp_{pp_rank}_tp{tp_rank} with target_tp_rank filtering - Fallback to broadcast when training TP != inference TP - vllm_engine.py: add shard_rank param and conditional group_name forwarding --- docker/Dockerfile | 6 + docker/patch/latest/vllm.patch | 108 ++++ .../megatron_utils/megatron_to_hf/__init__.py | 22 +- .../megatron_utils/megatron_to_hf/qwen2.py | 98 ++++ .../update_weight_from_distributed.py | 542 +++++++++++------- vime/backends/vllm_utils/vllm_engine.py | 14 +- 6 files changed, 585 insertions(+), 205 deletions(-) create mode 100644 docker/patch/latest/vllm.patch diff --git a/docker/Dockerfile b/docker/Dockerfile index d145b6a7..b15875e7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -130,6 +130,12 @@ RUN cd Megatron-LM && \ rm megatron.patch && \ pip install -e . +COPY docker/patch/${PATCH_VERSION}/vllm.patch /tmp/ +RUN VLLM_SITE=$(python3 -c "import vllm; print(vllm.__file__.rsplit('/',1)[0])") && \ + cd ${VLLM_SITE} && \ + patch -p2 --no-backup-if-mismatch < /tmp/vllm.patch && \ + rm /tmp/vllm.patch + # ====================================== Install main package ============================================ ARG VIME_COMMIT=main diff --git a/docker/patch/latest/vllm.patch b/docker/patch/latest/vllm.patch new file mode 100644 index 00000000..e3edd881 --- /dev/null +++ b/docker/patch/latest/vllm.patch @@ -0,0 +1,108 @@ +diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py +index 1111111..2222222 100644 +--- a/vllm/distributed/weight_transfer/nccl_engine.py ++++ b/vllm/distributed/weight_transfer/nccl_engine.py +@@ -1,5 +1,7 @@ + from __future__ import annotations + ++import logging ++ + from typing import TYPE_CHECKING, Any, Optional + + import torch +@@ -20,6 +22,7 @@ + if TYPE_CHECKING: + from vllm.config import VllmConfig + ++logger = logging.getLogger(__name__) + + + @dataclasses.dataclass +@@ -37,6 +40,9 @@ class NCCLWeightTransferInitInfo: + world_size: int + rank_offset: int = 0 + group_name: Optional[str] = None ++ target_tp_rank: Optional[int] = None ++ shard_rank: Optional[int] = None ++ group_name_for_engine: Optional[str] = None + + + @dataclasses.dataclass +@@ -147,7 +153,9 @@ class NCCLWeightTransferEngine: + model_update_groups = {} + for group_name, init_info in init_infos.items(): + assert isinstance(init_info, NCCLWeightTransferInitInfo), init_info +- model_update_group = cls._stateless_init_process_group(init_info) ++ model_update_group = cls._stateless_init_process_group( ++ init_info, vllm_config=vllm_config ++ ) + model_update_groups[group_name] = model_update_group + return model_update_groups + +@@ -170,10 +178,16 @@ class NCCLWeightTransferEngine: + return model_update_groups + + @classmethod +- def _stateless_init_process_group( ++ def _stateless_init_process_group( # type: ignore[override] + cls, init_info: NCCLWeightTransferInitInfo + ) -> PyNcclCommunicator: ++ return cls._stateless_init_process_group_v2(init_info) ++ ++ @classmethod ++ def _stateless_init_process_group_v2( ++ cls, init_info: NCCLWeightTransferInitInfo, *, vllm_config=None ++ ) -> PyNcclCommunicator: + from vllm.distributed.parallel_state import ( + in_the_same_tp_group_as_pp, + ) +@@ -185,6 +199,23 @@ class NCCLWeightTransferEngine: + assert in_the_same_tp_group_as_pp(), ( + "Weight transfer across PP groups is not supported." + ) ++ ++ target_tp_rank = getattr(init_info, "target_tp_rank", None) ++ if target_tp_rank is not None and vllm_config is not None: ++ tp_rank = get_tp_group().rank_in_group ++ if tp_rank != target_tp_rank: ++ logger.debug( ++ "Skipping NCCL group init for group_name=%s: " ++ "tp_rank=%d != target_tp_rank=%d", ++ init_info.group_name, tp_rank, target_tp_rank, ++ ) ++ return None ++ logger.debug( ++ "Creating NCCL group for group_name=%s: tp_rank=%d matches target_tp_rank=%d", ++ init_info.group_name, tp_rank, target_tp_rank, ++ ) ++ + from vllm.distributed.device_communicators.stateless_comm import ( + StatelessProcessGroup, + ) +@@ -247,7 +278,12 @@ class NCCLWeightTransferEngine: + assert group_name in model_update_groups, ( + f"Group {group_name} not found in model_update_groups" + ) +- group = model_update_groups[group_name] ++ group = model_update_groups.get(group_name) ++ if group is None: ++ logger.debug( ++ "receive_weights: skipping group_name=%s (not active for this rank)", ++ group_name, ++ ) ++ continue + if packed: + cls._packed_nccl_broadcast(group, update_info) + else: +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index 3333333..4444444 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -1,5 +1,5 @@ + # ... existing imports ... + + def load_weights_direct(self, named_weights): + for name, weight in named_weights: + param = self.model.get_parameter(name) +- param.copy_(weight) ++ param.data.copy_(weight) diff --git a/vime/backends/megatron_utils/megatron_to_hf/__init__.py b/vime/backends/megatron_utils/megatron_to_hf/__init__.py index b51cb1b5..961cde13 100644 --- a/vime/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/vime/backends/megatron_utils/megatron_to_hf/__init__.py @@ -5,7 +5,7 @@ from .llama import convert_llama_to_hf from .mimo import convert_mimo_to_hf from .processors import quantize_params, remove_padding -from .qwen2 import convert_qwen2_to_hf +from .qwen2 import convert_qwen2_to_hf, convert_qwen2_to_hf_shard from .qwen3_5 import convert_qwen3_5_to_hf from .qwen3_next import convert_qwen3_next_to_hf from .qwen3_vl import convert_qwen3vl_to_hf @@ -28,6 +28,20 @@ def convert_to_hf(args, model_name, name, param, quantization_config=None): return quantize_params(args, name, converted_named_tensors, quantization_config) +def remove_padding_shard(name, param, vocab_size, tp_rank, tp_size): + """Shard-aware remove_padding. Embedding/output_layer are handled separately.""" + return param + + +def convert_to_hf_shard(args, model_name, name, param, tp_rank, tp_size, quantization_config=None): + """Shard-level HF conversion without all_gather. Each TP rank converts its own shard.""" + param = remove_padding_shard(name, param, args.vocab_size, tp_rank, tp_size) + + converted_named_tensors = _convert_to_hf_shard_core(args, model_name, name, param, tp_rank, tp_size) + + return quantize_params(args, name, converted_named_tensors, quantization_config) + + # TODO optimize code details def _convert_to_hf_core(args, model_name, name, param): if "glm4moelite" in model_name or "deepseekv3" in model_name: @@ -56,3 +70,9 @@ def _convert_to_hf_core(args, model_name, name, param): raise ValueError(f"Unsupported model: {model_name}") return converted_named_tensors + + +def _convert_to_hf_shard_core(args, model_name, name, param, tp_rank, tp_size): + if "qwen2" in model_name or "qwen3" in model_name: + return convert_qwen2_to_hf_shard(args, name, param, tp_rank, tp_size) + raise ValueError(f"Shard-level conversion not yet supported for model: {model_name}") diff --git a/vime/backends/megatron_utils/megatron_to_hf/qwen2.py b/vime/backends/megatron_utils/megatron_to_hf/qwen2.py index f7b72935..cffe1e63 100644 --- a/vime/backends/megatron_utils/megatron_to_hf/qwen2.py +++ b/vime/backends/megatron_utils/megatron_to_hf/qwen2.py @@ -69,3 +69,101 @@ def convert_qwen2_to_hf(args, name, param): return [(f"model.layers.{layer_idx}.self_attn.k_norm.weight", param)] raise ValueError(f"Unknown parameter name: {name}") + + +def convert_qwen2_to_hf_shard(args, name, param, tp_rank, tp_size): + """Shard-level HF conversion: operates on a single TP shard without all_gather. + + For Qwen2/3 with GQA, Megatron shards by query groups, which maps directly + to vLLM's head-based sharding after QKV split. Each TP rank converts its + own shard and sends to the corresponding vLLM TP rank. + + Args: + args: Model config (num_attention_heads, num_query_groups, etc.) + name: Megatron parameter name + param: TP-sharded parameter (this rank's shard only) + tp_rank: Current tensor model parallel rank + tp_size: Tensor model parallel world size + + Returns: + List of (hf_name, shard_tensor) tuples. Empty list for duplicated + params on non-rank-0 (only rank 0 sends duplicated params). + """ + try: + head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads + except AttributeError: + head_dim = args.hidden_size // args.num_attention_heads + value_num_per_group = args.num_attention_heads // args.num_query_groups + + # Duplicated params: every TP rank sends these (each group needs them) + if name == "module.module.decoder.final_layernorm.weight": + return [("model.norm.weight", param)] + + decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" + match = re.match(decoder_layers_pattern, name) + if match: + layer_idx, rest = match.groups() + + # Duplicated params: layernorms, qk norm - every rank sends in its group + _duplicated_map = { + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm", + "self_attention.q_layernorm.weight": "self_attn.q_norm", + "self_attention.k_layernorm.weight": "self_attn.k_norm", + } + if rest in _duplicated_map: + hf_name = f"model.layers.{layer_idx}.{_duplicated_map[rest]}.weight" + return [(hf_name, param)] + + # TP-sharded params: each rank converts its own shard + if rest == "self_attention.linear_qkv.weight": + groups_per_rank = args.num_query_groups // tp_size + param = param.view(groups_per_rank, -1, head_dim, args.hidden_size) + q_param, k_param, v_param = torch.split( + param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1 + ) + q_param = q_param.reshape(-1, args.hidden_size) + k_param = k_param.reshape(-1, args.hidden_size) + v_param = v_param.reshape(-1, args.hidden_size) + # For is_checkpoint_format=False: produce combined qkv_proj matching vLLM layout + qkv_param = torch.cat([q_param, k_param, v_param], dim=0) + return [ + (f"model.layers.{layer_idx}.self_attn.qkv_proj.weight", qkv_param), + ] + + if rest == "self_attention.linear_qkv.bias": + groups_per_rank = args.num_query_groups // tp_size + param = param.view(groups_per_rank, -1) + q_bias, k_bias, v_bias = torch.split( + param, + split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim], + dim=1, + ) + q_bias = q_bias.contiguous().flatten() + k_bias = k_bias.contiguous().flatten() + v_bias = v_bias.contiguous().flatten() + return [ + (f"model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias), + (f"model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias), + (f"model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias), + ] + + if rest == "self_attention.linear_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] + + if rest == "mlp.linear_fc1.weight": + # For is_checkpoint_format=False: produce combined gate_up_proj matching vLLM layout + return [ + (f"model.layers.{layer_idx}.mlp.gate_up_proj.weight", param), + ] + + if rest == "mlp.linear_fc2.weight": + return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] + + # Embedding and output layer: TP-sharded along vocab dim + if name == "module.module.embedding.word_embeddings.weight": + return [("model.embed_tokens.weight", param)] + if name == "module.module.output_layer.weight": + return [("lm_head.weight", param)] + + raise ValueError(f"Unknown parameter name: {name}") diff --git a/vime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/vime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 66cd0f38..cc1fe45f 100644 --- a/vime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/vime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -6,7 +6,6 @@ import time from argparse import Namespace from collections.abc import Callable, Mapping, Sequence -from typing import Any import ray import torch @@ -19,17 +18,16 @@ from vime.utils.distributed_utils import get_gloo_group -from ..megatron_to_hf import convert_to_hf +from ..megatron_to_hf import convert_to_hf, convert_to_hf_shard from .common import all_gather_param, named_params_and_buffers -from .hf_weight_iterator_base import HfWeightIteratorBase logger = logging.getLogger(__name__) -def _begin_vllm_weight_update_session(rollout_engines: Sequence[ActorHandle]) -> None: +def _begin_vllm_weight_update_session(rollout_engines: Sequence[ActorHandle], is_checkpoint_format: bool = True) -> None: if dist.get_rank() == 0: - logger.info("vLLM weight update: start_weight_update") - ray.get([engine.start_weight_update.remote(is_checkpoint_format=True) for engine in rollout_engines]) + logger.info("vLLM weight update: start_weight_update (checkpoint_format=%s)", is_checkpoint_format) + ray.get([engine.start_weight_update.remote(is_checkpoint_format=is_checkpoint_format) for engine in rollout_engines]) dist.barrier(group=get_gloo_group()) @@ -41,9 +39,14 @@ def _end_vllm_weight_update_session(rollout_engines: Sequence[ActorHandle]) -> N class UpdateWeightFromDistributed: - """ - Update distributed engines via NCCL. Each PP rank: group "vime-pp_{pp_rank}", - only DP=TP=0 broadcasts. Non-expert (TP) and expert (EP) params separate. + """Shard-level P2P weight transfer without all_gather. + + Each DP=0 TP rank converts only its own shard via convert_to_hf_shard + and broadcasts it to the corresponding vLLM inference ranks via its own + NCCL group. This eliminates the all_gather memory bottleneck. + + Optimized: each TP rank's NCCL group includes only matching vLLM workers, + eliminating wasted bandwidth from inactive broadcast participants. """ def __init__( @@ -55,26 +58,12 @@ def __init__( model_name: str, quantization_config: dict[str, int | str | list[str]] | None, ) -> None: - """ - Initialize. Groups created in connect_rollout_engines. - """ self.args = args self.model = model - self.weights_getter = weights_getter self.model_name = model_name self.quantization_config = quantization_config self.weight_version = 0 self._model_update_groups = None - self._hf_weight_iterator = ( - HfWeightIteratorBase.create( - args=args, - model=model, - model_name=model_name, - quantization_config=quantization_config, - ) - if args.megatron_to_hf_mode == "bridge" - else None - ) def connect_rollout_engines( self, @@ -83,37 +72,47 @@ def connect_rollout_engines( engine_gpu_counts: Sequence[int] | None = None, engine_gpu_offsets: Sequence[int] | None = None, ) -> None: - """ - Create NCCL "vime-pp_{pp_rank}" if PP source (DP=TP=0). Lock prevents concurrent broadcasts. - """ self.rollout_engines = rollout_engines self.rollout_engine_lock = rollout_engine_lock self._engine_gpu_counts = engine_gpu_counts - # For TP: - # 1. AllGather parameters to rank 0 - # 2. Broadcast parameters from rank 0 to all vLLM engines - self._is_pp_src_rank = ( - mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0 - ) + dp_rank = mpu.get_data_parallel_rank(with_context_parallel=True) + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() pp_rank = mpu.get_pipeline_model_parallel_rank() - if self._is_pp_src_rank: - self._group_name = f"vime-pp_{pp_rank}" - if self._is_pp_src_rank: + self._is_pp_src_rank = dp_rank == 0 and tp_rank == 0 + self._is_dp0 = dp_rank == 0 + self._tp_rank = tp_rank + self._tp_size = tp_size + self._pp_rank = pp_rank + + if self._is_dp0: + self._group_name = f"vime-pp_{pp_rank}_tp{tp_rank}" if self._model_update_groups is not None: - disconnect_rollout_engines_from_distributed( - self.args, self._group_name, self._model_update_groups, self.rollout_engines + logger.info("NCCL group %s already connected, skipping reconnection", self._group_name) + return + tp_rank_for_group = self._tp_rank if self._use_shard_conversion() else None + while not ray.get(self.rollout_engine_lock.acquire.remote()): + time.sleep(0.1) + try: + self._model_update_groups = connect_rollout_engines_from_distributed( + self.args, + self._group_name, + rollout_engines, + engine_gpu_counts=engine_gpu_counts, + target_tp_rank=tp_rank_for_group, ) - self._model_update_groups = connect_rollout_engines_from_distributed( - self.args, - self._group_name, - rollout_engines, - engine_gpu_counts=engine_gpu_counts, - ) + finally: + ray.get(self.rollout_engine_lock.release.remote()) def disconnect_rollout_engines(self) -> None: - if not getattr(self, "_is_pp_src_rank", False) or self._model_update_groups is None: + if not getattr(self, "_is_dp0", False) or self._model_update_groups is None: + return + logger.info("NCCL group %s kept alive (persistent connection)", self._group_name) + + def shutdown_rollout_engines(self) -> None: + if not getattr(self, "_is_dp0", False) or self._model_update_groups is None: return disconnect_rollout_engines_from_distributed( self.args, self._group_name, self._model_update_groups, self.rollout_engines @@ -122,16 +121,12 @@ def disconnect_rollout_engines(self) -> None: @torch.no_grad() def update_weights(self) -> None: - """ - Pause → flush → non-expert (TP) → expert (EP) → continue. Progress on PP source. - """ self.weight_version += 1 if dist.get_rank() == 0: ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) - # int4/fp4 pre_process if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: post_process_weights( restore_weights_before_load=True, @@ -140,7 +135,10 @@ def update_weights(self) -> None: ) dist.barrier(group=get_gloo_group()) - _begin_vllm_weight_update_session(self.rollout_engines) + use_shard = self._use_shard_conversion() + is_checkpoint_format = not use_shard + + _begin_vllm_weight_update_session(self.rollout_engines, is_checkpoint_format=is_checkpoint_format) try: self._sync_weights_to_rollout_engines() finally: @@ -148,7 +146,6 @@ def update_weights(self) -> None: dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: - # int4/fp4 post_process if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: post_process_weights( restore_weights_before_load=False, @@ -159,107 +156,185 @@ def update_weights(self) -> None: dist.barrier(group=get_gloo_group()) def _sync_weights_to_rollout_engines(self) -> None: - if self._hf_weight_iterator is not None: - self._sync_bridge_weights_to_rollout_engines() - return - use_vllm_packed = self._use_vllm_packed() - if use_vllm_packed and self._is_pp_src_rank: - logger.info("Using vLLM packed weight sync (bucketed; metadata + trainer_send_weights per bucket)") + use_shard = self._use_shard_conversion() - if use_vllm_packed: - buffer_size = 0 + if use_shard and use_vllm_packed and self._is_dp0: + logger.info("Using shard-level P2P weight sync (no all_gather)") + + if use_shard and use_vllm_packed: + self._sync_weights_shard_packed() + elif use_vllm_packed: + self._sync_weights_full_packed() + else: + self._sync_weights_full_nonpacked() + + if self._is_dp0: + torch.cuda.synchronize() + + def _use_shard_conversion(self) -> bool: + if self.quantization_config and self.quantization_config.get("quant_method") == "compressed-tensors": + return False + if any(".experts." in name for name, _ in named_params_and_buffers(self.args, self.model)): + return False + if self._engine_gpu_counts and any(c != self._tp_size for c in self._engine_gpu_counts): + return False + return True + + def _sync_weights_shard_packed(self) -> None: + """Shard-level P2P: each TP rank converts its own shard without all_gather. + + For embedding/output_layer, the Megatron shard layout (padded vocab) doesn't + align with vLLM's shard layout (unpadded vocab). These are handled with a + small all_gather + remove_padding + split, which is cheap since they're small. + """ + from vime.backends.megatron_utils.megatron_to_hf import remove_padding + from vime.backends.megatron_utils.misc_utils import strip_param_name_prefix + + if self._is_dp0: converted_named_tensors: list[tuple[str, torch.Tensor]] = [] pbar = ( - tqdm(desc=f"[{self._group_name}] Update weights (vLLM packed)", total=0) + tqdm(desc=f"[{self._group_name}] Shard P2P update", total=0) if self._is_pp_src_rank else None ) + buffer_size = 0 + + # Phase 1: Handle embedding/output_layer with all_gather (small params) for name, param in named_params_and_buffers(self.args, self.model): if ".experts." in name: continue - buffer_size = self._update_weight_from_distributed( - name, - param, - converted_named_tensors, - buffer_size, - pbar=pbar, - flush_packed=True, + stripped = strip_param_name_prefix(name) + if stripped not in {"embedding.word_embeddings.weight", "output_layer.weight"}: + continue + full_param = all_gather_param(name, param) + full_param = remove_padding(name, full_param, self.args.vocab_size) + converted = convert_to_hf( + self.args, self.model_name, name, full_param, + self.quantization_config, ) - if converted_named_tensors and self._is_pp_src_rank: - self._update_weights_vllm_packed(converted_named_tensors) - if pbar is not None: - pbar.update(1) - else: - buffer_size = 0 - converted_named_tensors = [] - pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None - + if not converted: + continue + for hf_name, hf_param in converted: + if hf_param.shape[0] % self._tp_size != 0: + logger.warning("Cannot split %s (shape %s) by tp_size=%d, skipping shard split", + hf_name, hf_param.shape, self._tp_size) + continue + shard_size = hf_param.shape[0] // self._tp_size + my_shard = hf_param[self._tp_rank * shard_size:(self._tp_rank + 1) * shard_size] + param_size = my_shard.numel() * my_shard.element_size() + if buffer_size + param_size > self.args.update_weight_buffer_size: + if converted_named_tensors: + self._update_weights_shard_packed(converted_named_tensors) + converted_named_tensors = [] + if pbar is not None: + pbar.update(1) + buffer_size = 0 + converted_named_tensors.append((hf_name, my_shard)) + buffer_size += param_size + + # Phase 2: Handle all other params with shard-level conversion (no all_gather) for name, param in named_params_and_buffers(self.args, self.model): if ".experts." in name: continue - buffer_size = self._update_weight_from_distributed( - name, param, converted_named_tensors, buffer_size, pbar=pbar + stripped = strip_param_name_prefix(name) + if stripped in {"embedding.word_embeddings.weight", "output_layer.weight"}: + continue + shard_converted = convert_to_hf_shard( + self.args, self.model_name, name, param.data, + self._tp_rank, self._tp_size, self.quantization_config, ) - + if not shard_converted: + continue + for hf_name, hf_param in shard_converted: + param_size = hf_param.numel() * hf_param.element_size() + if buffer_size + param_size > self.args.update_weight_buffer_size: + if converted_named_tensors: + self._update_weights_shard_packed(converted_named_tensors) + converted_named_tensors = [] + if pbar is not None: + pbar.update(1) + buffer_size = 0 + converted_named_tensors.append((hf_name, hf_param)) + buffer_size += param_size if converted_named_tensors: - self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) + self._update_weights_shard_packed(converted_named_tensors) + if pbar is not None: + pbar.update(1) dist.barrier(group=get_gloo_group()) - if not use_vllm_packed: - buffer_size = 0 - named_tensors = [] + def _sync_weights_full_packed(self) -> None: + """Original full all_gather + convert + broadcast path.""" + gathered_params: list[tuple[str, torch.Tensor]] = [] + for name, param in named_params_and_buffers(self.args, self.model): + if ".experts." in name: + continue + param = all_gather_param(name, param) + if self._is_dp0: + gathered_params.append((name, param)) + + dist.barrier(group=get_gloo_group()) + + if self._is_dp0: + converted_named_tensors: list[tuple[str, torch.Tensor]] = [] pbar = ( - tqdm(desc=f"[{self._group_name}] Update weights (experts)", total=0) if self._is_pp_src_rank else None + tqdm(desc=f"[{self._group_name}] Update weights (vLLM packed)", total=0) + if self._is_pp_src_rank + else None ) - for name, param in named_params_and_buffers(self.args, self.model): - if ".experts." not in name: - continue - buffer_size = self._update_expert_weight_from_distributed( - name, param, named_tensors, buffer_size, pbar=pbar - ) + buffer_size = 0 + for name, param in gathered_params: + param_size = param.numel() * param.element_size() + if buffer_size + param_size > self.args.update_weight_buffer_size: + if converted_named_tensors: + self._update_weights_vllm_packed(converted_named_tensors) + converted_named_tensors.clear() + if pbar is not None: + pbar.update(1) + buffer_size = 0 + converted_named_tensors += convert_to_hf(self.args, self.model_name, name, param, self.quantization_config) + buffer_size += param_size + if converted_named_tensors: + self._update_weights_vllm_packed(converted_named_tensors) + if pbar is not None: + pbar.update(1) - if named_tensors: - self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) + dist.barrier(group=get_gloo_group()) - if self._is_pp_src_rank: - torch.cuda.synchronize() + def _sync_weights_full_nonpacked(self) -> None: + buffer_size = 0 + converted_named_tensors = [] + pbar = tqdm(desc=f"[{self._group_name}] Update weights", total=0) if self._is_pp_src_rank else None - def _sync_bridge_weights_to_rollout_engines(self) -> None: - """ - Export HF weights through Megatron-Bridge, then send each exported chunk - over the same NCCL non-colocate path used by the raw converter. - """ - use_vllm_packed = self._use_vllm_packed() - if self._is_pp_src_rank: - logger.info("Using Megatron-Bridge HF weight export for non-colocate vLLM weight sync") - pbar = tqdm( - desc=f"[{self._group_name}] Update weights (Megatron-Bridge" - f"{', vLLM packed' if use_vllm_packed else ''})", - total=0, + for name, param in named_params_and_buffers(self.args, self.model): + if ".experts." in name: + continue + buffer_size = self._update_weight_from_distributed( + name, param, converted_named_tensors, buffer_size, pbar=pbar ) - else: - pbar = None - - megatron_local_weights = self.weights_getter() - for hf_named_tensors in self._hf_weight_iterator.get_hf_weight_chunks(megatron_local_weights): - if self._is_pp_src_rank: - hf_named_tensors = list(hf_named_tensors) - if use_vllm_packed: - self._update_weights_vllm_packed(hf_named_tensors) - if pbar is not None: - pbar.update(1) - else: - self._update_bucket_weights_from_distributed(hf_named_tensors, pbar=pbar) + + if converted_named_tensors: + self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) dist.barrier(group=get_gloo_group()) - if self._is_pp_src_rank: - torch.cuda.synchronize() + buffer_size = 0 + named_tensors = [] + pbar = ( + tqdm(desc=f"[{self._group_name}] Update weights (experts)", total=0) if self._is_pp_src_rank else None + ) + for name, param in named_params_and_buffers(self.args, self.model): + if ".experts." not in name: + continue + buffer_size = self._update_expert_weight_from_distributed( + name, param, named_tensors, buffer_size, pbar=pbar + ) + + if named_tensors: + self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) def _use_vllm_packed(self) -> bool: - """Use vLLM packed weight transfer (one-shot metadata + trainer_send_weights).""" if not getattr(self.args, "vllm_weight_sync_packed", True): return False if any(".experts." in name for name, _ in named_params_and_buffers(self.args, self.model)): @@ -268,13 +343,12 @@ def _use_vllm_packed(self) -> bool: return False return True - def _update_weights_vllm_packed(self, converted_named_tensors: list[tuple[str, torch.Tensor]]) -> None: - """Single-shot vLLM weight update using packed broadcast.""" + def _update_weights_shard_packed(self, converted_named_tensors: list[tuple[str, torch.Tensor]]) -> None: while not ray.get(self.rollout_engine_lock.acquire.remote()): time.sleep(0.1) try: - refs = update_weights_from_distributed( + refs = update_weights_from_distributed_shard( self._group_name, self._model_update_groups, self.weight_version, @@ -286,6 +360,36 @@ def _update_weights_vllm_packed(self, converted_named_tensors: list[tuple[str, t finally: ray.get(self.rollout_engine_lock.release.remote()) + def _update_weights_vllm_packed(self, converted_named_tensors: list[tuple[str, torch.Tensor]]) -> None: + use_shard = self._use_shard_conversion() + while not ray.get(self.rollout_engine_lock.acquire.remote()): + time.sleep(0.1) + + try: + if use_shard: + refs = update_weights_from_distributed_p2p( + self._group_name, + self._model_update_groups, + self.weight_version, + self.rollout_engines, + converted_named_tensors, + self._tp_rank, + self._tp_size, + packed=True, + ) + else: + refs = update_weights_from_distributed( + self._group_name, + self._model_update_groups, + self.weight_version, + self.rollout_engines, + converted_named_tensors, + packed=True, + ) + ray.get(refs) + finally: + ray.get(self.rollout_engine_lock.release.remote()) + def _update_weight_from_distributed( self, name: str, @@ -295,14 +399,10 @@ def _update_weight_from_distributed( pbar: tqdm | None = None, *, flush_packed: bool = False, - ) -> int | None: - """ - Non-expert: gather TP → rm pad → HF → buffer (flush if full). All gather, PP source buffers. - Returns updated bytes on source, None on non-source. - """ + ) -> int: param = all_gather_param(name, param) if not self._is_pp_src_rank: - return + return buffer_size param_size = param.numel() * param.element_size() if buffer_size + param_size > self.args.update_weight_buffer_size: @@ -327,9 +427,6 @@ def _update_expert_weight_from_distributed( buffer_size: int, pbar: tqdm | None = None, ) -> int: - """ - Expert: gather TP → rm pad → buffer. EP gather + HF deferred. Threshold × EP size. - """ param = all_gather_param(name, param) param_size = param.numel() * param.element_size() @@ -346,9 +443,6 @@ def _update_expert_weight_from_distributed( def _update_expert_bucket_weights_from_distributed( self, named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None ) -> None: - """ - Gather EP → HF → broadcast. Clears buffer. - """ names = [name for name, _ in named_tensors] all_names = [None] * mpu.get_expert_model_parallel_world_size() dist.all_gather_object(all_names, names, group=mpu.get_expert_model_parallel_group()) @@ -384,19 +478,17 @@ def _update_expert_bucket_weights_from_distributed( def _update_bucket_weights_from_distributed( self, converted_named_tensors: list[tuple[str, torch.Tensor]], pbar: tqdm | None = None ) -> None: - """ - Lock → broadcast → clear → unlock → pbar++. Lock prevents NCCL deadlock. - """ - # lock the rollout engines to prevent dead lock on broadcast. while not ray.get(self.rollout_engine_lock.acquire.remote()): time.sleep(0.1) - refs = update_weights_from_distributed( + refs = update_weights_from_distributed_p2p( self._group_name, self._model_update_groups, self.weight_version, self.rollout_engines, converted_named_tensors, + self._tp_rank, + self._tp_size, packed=False, ) @@ -412,59 +504,70 @@ def connect_rollout_engines_from_distributed( group_name: str, rollout_engines: Sequence[ActorHandle], engine_gpu_counts: Sequence[int] | None = None, -) -> Any: - """ - Create NCCL group: training rank 0 + all engine GPUs. Blocks until joined. - - ``engine_gpu_counts`` gives the number of GPUs per engine. When engines - have heterogeneous TP sizes (e.g. prefill TP=2, decode TP=4), each engine - occupies a different number of ranks in the NCCL group. - - Trainer rank 0 uses ``NCCLWeightTransferEngine.trainer_init`` - in-process (StatelessProcessGroup + PyNcclCommunicator). - """ + target_tp_rank: int | None = None, +) -> dist.ProcessGroup: if engine_gpu_counts is None: engine_gpu_counts = [args.rollout_num_gpus_per_engine] * len(rollout_engines) master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: sock.bind(("", 0)) master_port = sock.getsockname()[1] - world_size = sum(engine_gpu_counts) + 1 # +1 for training rank 0 - cumulative = [0] - for c in engine_gpu_counts: - cumulative.append(cumulative[-1] + c) + num_engines = len(rollout_engines) - refs = [ - engine.init_weights_update_group.remote( - master_address=master_address, - master_port=master_port, - rank_offset=cumulative[i] + 1, - world_size=world_size, - group_name=group_name, - backend="nccl", - ) - for i, engine in enumerate(rollout_engines) - ] + if target_tp_rank is not None: + world_size = 1 + num_engines + else: + world_size = sum(engine_gpu_counts) + 1 + + init_kwargs = dict( + master_address=master_address, + master_port=master_port, + world_size=world_size, + group_name=group_name, + backend="nccl", + ) + if target_tp_rank is not None: + init_kwargs["target_tp_rank"] = target_tp_rank + + if target_tp_rank is not None: + refs = [ + engine.init_weights_update_group.remote( + rank_offset=0, + shard_rank=1 + i, + **init_kwargs, + ) + for i, engine in enumerate(rollout_engines) + ] + else: + cumulative = [1] + for c in engine_gpu_counts: + cumulative.append(cumulative[-1] + c) + refs = [ + engine.init_weights_update_group.remote( + rank_offset=cumulative[i], + **init_kwargs, + ) + for i, engine in enumerate(rollout_engines) + ] torch.cuda.synchronize() torch.cuda.empty_cache() device = torch.cuda.current_device() logger.info( - "vLLM in-process weight transfer: addr=%s port=%d world_size=%d device=%d CVD=%s", - master_address, - master_port, - world_size, - device, - os.environ.get("CUDA_VISIBLE_DEVICES", ""), + "vLLM P2P weight transfer: group=%s addr=%s port=%d world_size=%d device=%d shard=%s", + group_name, master_address, master_port, world_size, device, + target_tp_rank is not None, ) model_update_groups = NCCLWeightTransferEngine.trainer_init( { "master_address": master_address, "master_port": master_port, "world_size": world_size, + "rank": 0, } ) @@ -475,44 +578,27 @@ def connect_rollout_engines_from_distributed( def disconnect_rollout_engines_from_distributed( args: Namespace, group_name: str, - model_update_groups: Any, + model_update_groups: dist.ProcessGroup, rollout_engines: Sequence[ActorHandle], ) -> None: - """ - Tear down the weight-update NCCL group on the rollout engines. - - ``model_update_groups`` is a vLLM ``PyNcclCommunicator`` returned by - ``NCCLWeightTransferEngine.trainer_init`` (built on a ``StatelessProcessGroup``), - NOT a torch c10d ``ProcessGroup``. It is deliberately not registered in - torch.distributed's global registry, so ``dist.destroy_process_group`` on it - raises ``ValueError: Invalid process group specified`` (see #127 regression). - - We therefore do not tear the trainer-side communicator down here; this matches - the pre-#127 behavior. (Note ``engine.destroy_weights_update_group`` is itself - a no-op on the engine side.) An explicit ``model_update_groups.destroy()`` would - abort the NCCL comm, but that changes long-standing behavior and risks the - CUDA-graph-capture self-deadlock documented in ``PyNcclCommunicator.destroy``; - leave it out of this fix. - """ refs = [engine.destroy_weights_update_group.remote(group_name) for engine in rollout_engines] + del model_update_groups ray.get(refs) -def update_weights_from_distributed( +def update_weights_from_distributed_shard( group_name: str, - group: Any, + group: dist.ProcessGroup, weight_version: int, rollout_engines: Sequence[ActorHandle], converted_named_tensors: Sequence[tuple[str, torch.Tensor]], *, packed: bool = False, ) -> list[ObjectRef]: - """ - Send metadata (Ray), broadcast tensors (NCCL rank 0 → engines). + """Broadcast all shard-converted tensors (no TP splitting - each rank already has its own shard).""" + if not converted_named_tensors: + return [] - The *group* is a vLLM ``PyNcclCommunicator`` from ``trainer_init`` - in the Megatron trainer process. - """ refs = [ engine.update_weights_from_distributed.remote( names=[name for name, _ in converted_named_tensors], @@ -537,14 +623,70 @@ def update_weights_from_distributed( return refs +def update_weights_from_distributed_p2p( + group_name: str, + group: dist.ProcessGroup, + weight_version: int, + rollout_engines: Sequence[ActorHandle], + converted_named_tensors: Sequence[tuple[str, torch.Tensor]], + tp_rank: int, + tp_size: int, + *, + packed: bool = False, +) -> list[ObjectRef]: + n = len(converted_named_tensors) + chunk_size = (n + tp_size - 1) // tp_size + start = tp_rank * chunk_size + end = min(start + chunk_size, n) + my_tensors = list(converted_named_tensors[start:end]) + + if not my_tensors: + return [] + + refs = [ + engine.update_weights_from_distributed.remote( + names=[name for name, _ in my_tensors], + dtypes=[param.dtype for _, param in my_tensors], + shapes=[param.shape for _, param in my_tensors], + group_name=group_name, + weight_version=str(weight_version), + packed=packed, + ) + for engine in rollout_engines + ] + + named_gpu_iter = ( + (name, (param.data if hasattr(param, "data") else param).contiguous()) + for name, param in my_tensors + ) + NCCLWeightTransferEngine.trainer_send_weights( + named_gpu_iter, + NCCLTrainerSendWeightsArgs(group=group, packed=packed), + ) + + return refs + + +def update_weights_from_distributed( + group_name: str, + group: dist.ProcessGroup, + weight_version: int, + rollout_engines: Sequence[ActorHandle], + converted_named_tensors: Sequence[tuple[str, torch.Tensor]], + *, + packed: bool = False, +) -> list[ObjectRef]: + return update_weights_from_distributed_p2p( + group_name, group, weight_version, rollout_engines, + converted_named_tensors, tp_rank=0, tp_size=1, packed=packed, + ) + + def post_process_weights( restore_weights_before_load: bool, post_process_quantization: bool, rollout_engines: Sequence[ActorHandle], ): - """ - Trigger post-process for int4/fp4 quantization on all rollout engines. - """ ray.get( [ engine.post_process_weights.remote( diff --git a/vime/backends/vllm_utils/vllm_engine.py b/vime/backends/vllm_utils/vllm_engine.py index 8286e117..8c5e836b 100644 --- a/vime/backends/vllm_utils/vllm_engine.py +++ b/vime/backends/vllm_utils/vllm_engine.py @@ -936,13 +936,13 @@ def check_weights(self, action: str): del action return {"ok": True, "supported": False, "note": "vLLM has no weights_checker endpoint."} - def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): + def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, target_tp_rank=None, shard_rank=None): """Call ``POST /init_weight_transfer_engine`` with an ``init_info`` block. - ``group_name`` / ``backend`` are accepted for a uniform caller signature but are not sent to vLLM. + ``group_name`` / ``backend`` are accepted for a uniform caller signature but is not sent to vLLM. Always uses the vllm-native weight transfer engine; reload-on-continue fallback is no longer supported. """ - del group_name, backend + del backend payload = { "init_info": { "master_address": master_address, @@ -951,6 +951,11 @@ def init_weights_update_group(self, master_address, master_port, rank_offset, wo "world_size": world_size, } } + if target_tp_rank is not None: + payload["init_info"]["target_tp_rank"] = target_tp_rank + payload["init_info"]["group_name"] = group_name + if shard_rank is not None: + payload["init_info"]["shard_rank"] = shard_rank init_timeout_s = self._weight_transfer_http_timeout() last_error = None for attempt in range(1, 4): @@ -982,7 +987,6 @@ def update_weights_from_distributed( Payload matches vLLM NCCL weight transfer (see upstream rlhf_http_nccl example). """ - del group_name if weight_version is not None: self._weight_version = str(weight_version) if flush_cache: @@ -994,6 +998,8 @@ def update_weights_from_distributed( "shapes": [list(s) for s in shapes], "packed": bool(packed), } + if group_name.startswith("vime-shard-") or group_name.startswith("vime-pp_") and "_tp" in group_name: + update_info["group_name"] = group_name return self._post_vllm_update_weights_http(update_info) def update_weights_from_disk(self, model_path: str, load_format: str | None = None):