[Feature] shard-level P2P weight transfer (Phase 1, #160)#161
Conversation
- Each DP=0 TP rank converts its own shard via convert_to_hf_shard and sends
directly to matching vLLM inference TP rank via dedicated NCCL group
- NCCL groups are persistent: created once, reused across update_weights calls,
only disconnected at shutdown (fixes intermittent reconnection hang)
- Group naming: vime-pp_{pp_rank}_tp{tp_rank} with target_tp_rank filtering
- Fallback to broadcast when training TP != inference TP
- vllm_engine.py: add shard_rank param and conditional group_name forwarding
There was a problem hiding this comment.
Code Review
This pull request introduces shard-level P2P weight transfer from Megatron to vLLM without requiring a full all-gather operation, significantly reducing memory bottlenecks. It adds shard-aware Hugging Face conversion utilities (specifically for Qwen2/3) and updates the NCCL weight transfer engine initialization to support target tensor parallel ranks. The review feedback highlights critical runtime issues, including a signature mismatch in the patched vLLM process group initialization that would cause a TypeError, redundant sequential executions in full-packed synchronization paths, incorrect group-name formatting when shard conversion is disabled, and a layout mismatch where QKV bias is not properly combined into qkv_proj.bias.
| + def _stateless_init_process_group( # type: ignore[override] | ||
| cls, init_info: NCCLWeightTransferInitInfo | ||
| ) -> PyNcclCommunicator: | ||
| + return cls._stateless_init_process_group_v2(init_info) |
There was a problem hiding this comment.
The patched _stateless_init_process_group method signature does not accept *args or **kwargs (such as vllm_config), but it is called with vllm_config=vllm_config in nccl_engine.py. This will raise a TypeError at runtime. Update the signature to accept and forward these arguments.
+ def _stateless_init_process_group( # type: ignore[override]
+ cls, init_info: NCCLWeightTransferInitInfo, *args, **kwargs
+ ) -> PyNcclCommunicator:
+ return cls._stateless_init_process_group_v2(init_info, *args, **kwargs)
| if self._is_dp0: | ||
| self._group_name = f"vime-pp_{pp_rank}_tp{tp_rank}" | ||
| if self._model_update_groups is not None: | ||
| disconnect_rollout_engines_from_distributed( | ||
| self.args, self._group_name, self._model_update_groups, self.rollout_engines | ||
| logger.info("NCCL group %s already connected, skipping reconnection", self._group_name) | ||
| return | ||
| tp_rank_for_group = self._tp_rank if self._use_shard_conversion() else None | ||
| while not ray.get(self.rollout_engine_lock.acquire.remote()): | ||
| time.sleep(0.1) | ||
| try: | ||
| self._model_update_groups = connect_rollout_engines_from_distributed( | ||
| self.args, | ||
| self._group_name, | ||
| rollout_engines, | ||
| engine_gpu_counts=engine_gpu_counts, | ||
| target_tp_rank=tp_rank_for_group, | ||
| ) | ||
| self._model_update_groups = connect_rollout_engines_from_distributed( | ||
| self.args, | ||
| self._group_name, | ||
| rollout_engines, | ||
| engine_gpu_counts=engine_gpu_counts, | ||
| ) | ||
| finally: | ||
| ray.get(self.rollout_engine_lock.release.remote()) |
There was a problem hiding this comment.
When _use_shard_conversion() is False, all self._is_dp0 ranks will still attempt to connect and initialize the NCCL group sequentially, overwriting each other's initialization on the vLLM side. Furthermore, self._group_name will contain _tp{tp_rank}, which triggers the _tp check in vllm_engine.py and causes an AssertionError on the vLLM side because the group name was not registered with _tp during initialization. Restrict the connection and group name formatting to the original fallback behavior when shard conversion is not used.
| if self._is_dp0: | |
| self._group_name = f"vime-pp_{pp_rank}_tp{tp_rank}" | |
| if self._model_update_groups is not None: | |
| disconnect_rollout_engines_from_distributed( | |
| self.args, self._group_name, self._model_update_groups, self.rollout_engines | |
| logger.info("NCCL group %s already connected, skipping reconnection", self._group_name) | |
| return | |
| tp_rank_for_group = self._tp_rank if self._use_shard_conversion() else None | |
| while not ray.get(self.rollout_engine_lock.acquire.remote()): | |
| time.sleep(0.1) | |
| try: | |
| self._model_update_groups = connect_rollout_engines_from_distributed( | |
| self.args, | |
| self._group_name, | |
| rollout_engines, | |
| engine_gpu_counts=engine_gpu_counts, | |
| target_tp_rank=tp_rank_for_group, | |
| ) | |
| self._model_update_groups = connect_rollout_engines_from_distributed( | |
| self.args, | |
| self._group_name, | |
| rollout_engines, | |
| engine_gpu_counts=engine_gpu_counts, | |
| ) | |
| finally: | |
| ray.get(self.rollout_engine_lock.release.remote()) | |
| should_connect = self._is_dp0 if self._use_shard_conversion() else self._is_pp_src_rank | |
| if should_connect: | |
| self._group_name = f"vime-pp_{pp_rank}_tp{tp_rank}" if self._use_shard_conversion() else f"vime-pp_{pp_rank}" | |
| if self._model_update_groups is not None: | |
| logger.info("NCCL group %s already connected, skipping reconnection", self._group_name) | |
| return | |
| tp_rank_for_group = self._tp_rank if self._use_shard_conversion() else None | |
| while not ray.get(self.rollout_engine_lock.acquire.remote()): | |
| time.sleep(0.1) | |
| try: | |
| self._model_update_groups = connect_rollout_engines_from_distributed( | |
| self.args, | |
| self._group_name, | |
| rollout_engines, | |
| engine_gpu_counts=engine_gpu_counts, | |
| target_tp_rank=tp_rank_for_group, | |
| ) | |
| finally: | |
| ray.get(self.rollout_engine_lock.release.remote()) |
| if self._is_dp0: | ||
| gathered_params.append((name, param)) | ||
|
|
||
| dist.barrier(group=get_gloo_group()) | ||
|
|
||
| if self._is_dp0: |
There was a problem hiding this comment.
In _sync_weights_full_packed, using self._is_dp0 instead of self._is_pp_src_rank causes all TP ranks to perform the full weight conversion and call _update_weights_vllm_packed sequentially. This is highly redundant and can cause hangs or crashes. Restrict this path to self._is_pp_src_rank.
| if self._is_dp0: | |
| gathered_params.append((name, param)) | |
| dist.barrier(group=get_gloo_group()) | |
| if self._is_dp0: | |
| if self._is_pp_src_rank: | |
| gathered_params.append((name, param)) | |
| dist.barrier(group=get_gloo_group()) | |
| if self._is_pp_src_rank: |
| if rest == "self_attention.linear_qkv.bias": | ||
| groups_per_rank = args.num_query_groups // tp_size | ||
| param = param.view(groups_per_rank, -1) | ||
| q_bias, k_bias, v_bias = torch.split( | ||
| param, | ||
| split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim], | ||
| dim=1, | ||
| ) | ||
| q_bias = q_bias.contiguous().flatten() | ||
| k_bias = k_bias.contiguous().flatten() | ||
| v_bias = v_bias.contiguous().flatten() | ||
| return [ | ||
| (f"model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias), | ||
| (f"model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias), | ||
| (f"model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias), | ||
| ] |
There was a problem hiding this comment.
For is_checkpoint_format=False (which is always the case for shard weight transfer), the QKV weight is combined into qkv_proj.weight. However, the QKV bias is still returned as separate q_proj.bias, k_proj.bias, and v_proj.bias parameters. This will cause a layout mismatch on the vLLM side if the model has bias enabled. Combine the bias parameters into qkv_proj.bias to match the weight layout.
if rest == "self_attention.linear_qkv.bias":
groups_per_rank = args.num_query_groups // tp_size
param = param.view(groups_per_rank, -1)
q_bias, k_bias, v_bias = torch.split(
param,
split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim],
dim=1,
)
q_bias = q_bias.contiguous().flatten()
k_bias = k_bias.contiguous().flatten()
v_bias = v_bias.contiguous().flatten()
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0)
return [
(f"model.layers.{layer_idx}.self_attn.qkv_proj.bias", qkv_bias),
]
Phase 1 validation resultsSetup: Qwen3-4B, train TP=2 / infer TP=2, 4 training GPUs + 4 rollout GPUs (2 engines × TP=2), non-colocate. Memory balance (main goal)
Weight transfer time
End-to-end Training correctness
Consistent with main branch baseline (~0.012). Activation / fallback
|
Summary
Implement Phase 1 of #160: shard-level P2P weight transfer when
train_tp == infer_tp.convert_to_hf_shard) and P2P-sends to the matching infer TP rankCloses step 1 of the rollout plan in #160.
Changes
update_weight_from_distributed.py— shard P2P sync path + persistent NCCL groupsmegatron_to_hf/qwen2.py—convert_to_hf_shard(), vLLM combined names (qkv_proj/gate_up_proj)vllm_engine.py—target_tp_rankfor selective NCCL group joindocker/patch/latest/vllm.patch— multi-group NCCL engine +param.data.copy_()fixTODO (WIP)
Test plan
train_tp == infer_tp: log showsUsing shard P2P weight sync (TP rank X/Y)train_tp != infer_tp: falls back to packed broadcasttrain_rollout_logprob_abs_diffstable