Skip to content

[Feature] shard-level P2P weight transfer (Phase 1, #160)#161

Open
CalvinXKY wants to merge 1 commit into
mainfrom
feature/p2p_transfer
Open

[Feature] shard-level P2P weight transfer (Phase 1, #160)#161
CalvinXKY wants to merge 1 commit into
mainfrom
feature/p2p_transfer

Conversation

@CalvinXKY

Copy link
Copy Markdown
Collaborator

Summary

Implement Phase 1 of #160: shard-level P2P weight transfer when train_tp == infer_tp.

  • Each trainer TP rank converts only its own shard (convert_to_hf_shard) and P2P-sends to the matching infer TP rank
  • Embedding / output_layer still use all_gather + broadcast (vocab padding)
  • TP mismatch falls back to the original broadcast path — default behavior unchanged

Closes step 1 of the rollout plan in #160.

Changes

  • update_weight_from_distributed.py — shard P2P sync path + persistent NCCL groups
  • megatron_to_hf/qwen2.pyconvert_to_hf_shard(), vLLM combined names (qkv_proj / gate_up_proj)
  • vllm_engine.pytarget_tp_rank for selective NCCL group join
  • docker/patch/latest/vllm.patch — multi-group NCCL engine + param.data.copy_() fix

TODO (WIP)

  • CI smoke test: Qwen3-4B TP=2 matched scenario
  • Docs: weight sync mode guide + troubleshooting
  • Review vLLM patch upstreaming strategy

Test plan

  • train_tp == infer_tp: log shows Using shard P2P weight sync (TP rank X/Y)
  • train_tp != infer_tp: falls back to packed broadcast
  • Training correctness: train_rollout_logprob_abs_diff stable

- Each DP=0 TP rank converts its own shard via convert_to_hf_shard and sends
  directly to matching vLLM inference TP rank via dedicated NCCL group
- NCCL groups are persistent: created once, reused across update_weights calls,
  only disconnected at shutdown (fixes intermittent reconnection hang)
- Group naming: vime-pp_{pp_rank}_tp{tp_rank} with target_tp_rank filtering
- Fallback to broadcast when training TP != inference TP
- vllm_engine.py: add shard_rank param and conditional group_name forwarding
@CalvinXKY CalvinXKY changed the title [WIP] feat: shard-level P2P weight transfer (Phase 1, #160) [Feature] shard-level P2P weight transfer (Phase 1, #160) Jun 5, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces shard-level P2P weight transfer from Megatron to vLLM without requiring a full all-gather operation, significantly reducing memory bottlenecks. It adds shard-aware Hugging Face conversion utilities (specifically for Qwen2/3) and updates the NCCL weight transfer engine initialization to support target tensor parallel ranks. The review feedback highlights critical runtime issues, including a signature mismatch in the patched vLLM process group initialization that would cause a TypeError, redundant sequential executions in full-packed synchronization paths, incorrect group-name formatting when shard conversion is disabled, and a layout mismatch where QKV bias is not properly combined into qkv_proj.bias.

Comment on lines +47 to +50
+ def _stateless_init_process_group( # type: ignore[override]
cls, init_info: NCCLWeightTransferInitInfo
) -> PyNcclCommunicator:
+ return cls._stateless_init_process_group_v2(init_info)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The patched _stateless_init_process_group method signature does not accept *args or **kwargs (such as vllm_config), but it is called with vllm_config=vllm_config in nccl_engine.py. This will raise a TypeError at runtime. Update the signature to accept and forward these arguments.

+    def _stateless_init_process_group(  # type: ignore[override]
+         cls, init_info: NCCLWeightTransferInitInfo, *args, **kwargs
+     ) -> PyNcclCommunicator:
+        return cls._stateless_init_process_group_v2(init_info, *args, **kwargs)

Comment on lines +90 to +107
if self._is_dp0:
self._group_name = f"vime-pp_{pp_rank}_tp{tp_rank}"
if self._model_update_groups is not None:
disconnect_rollout_engines_from_distributed(
self.args, self._group_name, self._model_update_groups, self.rollout_engines
logger.info("NCCL group %s already connected, skipping reconnection", self._group_name)
return
tp_rank_for_group = self._tp_rank if self._use_shard_conversion() else None
while not ray.get(self.rollout_engine_lock.acquire.remote()):
time.sleep(0.1)
try:
self._model_update_groups = connect_rollout_engines_from_distributed(
self.args,
self._group_name,
rollout_engines,
engine_gpu_counts=engine_gpu_counts,
target_tp_rank=tp_rank_for_group,
)
self._model_update_groups = connect_rollout_engines_from_distributed(
self.args,
self._group_name,
rollout_engines,
engine_gpu_counts=engine_gpu_counts,
)
finally:
ray.get(self.rollout_engine_lock.release.remote())

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When _use_shard_conversion() is False, all self._is_dp0 ranks will still attempt to connect and initialize the NCCL group sequentially, overwriting each other's initialization on the vLLM side. Furthermore, self._group_name will contain _tp{tp_rank}, which triggers the _tp check in vllm_engine.py and causes an AssertionError on the vLLM side because the group name was not registered with _tp during initialization. Restrict the connection and group name formatting to the original fallback behavior when shard conversion is not used.

Suggested change
if self._is_dp0:
self._group_name = f"vime-pp_{pp_rank}_tp{tp_rank}"
if self._model_update_groups is not None:
disconnect_rollout_engines_from_distributed(
self.args, self._group_name, self._model_update_groups, self.rollout_engines
logger.info("NCCL group %s already connected, skipping reconnection", self._group_name)
return
tp_rank_for_group = self._tp_rank if self._use_shard_conversion() else None
while not ray.get(self.rollout_engine_lock.acquire.remote()):
time.sleep(0.1)
try:
self._model_update_groups = connect_rollout_engines_from_distributed(
self.args,
self._group_name,
rollout_engines,
engine_gpu_counts=engine_gpu_counts,
target_tp_rank=tp_rank_for_group,
)
self._model_update_groups = connect_rollout_engines_from_distributed(
self.args,
self._group_name,
rollout_engines,
engine_gpu_counts=engine_gpu_counts,
)
finally:
ray.get(self.rollout_engine_lock.release.remote())
should_connect = self._is_dp0 if self._use_shard_conversion() else self._is_pp_src_rank
if should_connect:
self._group_name = f"vime-pp_{pp_rank}_tp{tp_rank}" if self._use_shard_conversion() else f"vime-pp_{pp_rank}"
if self._model_update_groups is not None:
logger.info("NCCL group %s already connected, skipping reconnection", self._group_name)
return
tp_rank_for_group = self._tp_rank if self._use_shard_conversion() else None
while not ray.get(self.rollout_engine_lock.acquire.remote()):
time.sleep(0.1)
try:
self._model_update_groups = connect_rollout_engines_from_distributed(
self.args,
self._group_name,
rollout_engines,
engine_gpu_counts=engine_gpu_counts,
target_tp_rank=tp_rank_for_group,
)
finally:
ray.get(self.rollout_engine_lock.release.remote())

Comment on lines +274 to +279
if self._is_dp0:
gathered_params.append((name, param))

dist.barrier(group=get_gloo_group())

if self._is_dp0:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _sync_weights_full_packed, using self._is_dp0 instead of self._is_pp_src_rank causes all TP ranks to perform the full weight conversion and call _update_weights_vllm_packed sequentially. This is highly redundant and can cause hangs or crashes. Restrict this path to self._is_pp_src_rank.

Suggested change
if self._is_dp0:
gathered_params.append((name, param))
dist.barrier(group=get_gloo_group())
if self._is_dp0:
if self._is_pp_src_rank:
gathered_params.append((name, param))
dist.barrier(group=get_gloo_group())
if self._is_pp_src_rank:

Comment on lines +134 to +149
if rest == "self_attention.linear_qkv.bias":
groups_per_rank = args.num_query_groups // tp_size
param = param.view(groups_per_rank, -1)
q_bias, k_bias, v_bias = torch.split(
param,
split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim],
dim=1,
)
q_bias = q_bias.contiguous().flatten()
k_bias = k_bias.contiguous().flatten()
v_bias = v_bias.contiguous().flatten()
return [
(f"model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias),
(f"model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias),
(f"model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias),
]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For is_checkpoint_format=False (which is always the case for shard weight transfer), the QKV weight is combined into qkv_proj.weight. However, the QKV bias is still returned as separate q_proj.bias, k_proj.bias, and v_proj.bias parameters. This will cause a layout mismatch on the vLLM side if the model has bias enabled. Combine the bias parameters into qkv_proj.bias to match the weight layout.

        if rest == "self_attention.linear_qkv.bias":
            groups_per_rank = args.num_query_groups // tp_size
            param = param.view(groups_per_rank, -1)
            q_bias, k_bias, v_bias = torch.split(
                param,
                split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim],
                dim=1,
            )
            q_bias = q_bias.contiguous().flatten()
            k_bias = k_bias.contiguous().flatten()
            v_bias = v_bias.contiguous().flatten()
            qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0)
            return [
                (f"model.layers.{layer_idx}.self_attn.qkv_proj.bias", qkv_bias),
            ]

@CalvinXKY CalvinXKY marked this pull request as ready for review June 6, 2026 01:26
@CalvinXKY

Copy link
Copy Markdown
Collaborator Author

Phase 1 validation results

Setup: Qwen3-4B, train TP=2 / infer TP=2, 4 training GPUs + 4 rollout GPUs (2 engines × TP=2), non-colocate.

Memory balance (main goal)

Main (broadcast) Shard P2P
Infer TP imbalance ~10 GiB <200 MiB (192 / 194 MiB per engine)
Trainer rank skew (after sync) ~10 GiB ~6.3 GiB (rank0 22.99 GB vs rank1 16.71 GB)

Weight transfer time

Prototype numbers below are from isolated NCCL send/recv (p2p_test_v2.py), excluding Megatron→HF conversion and vLLM engine interaction.

Scenario Time
Main branch broadcast (non-offload) ~5.2s
Prototype pure NCCL (steps 0–3) 0.96 / 0.53 / 0.84 / 0.60s
Real training, Timer update_weights (offload, A800) ~19–21s
Real training, 48 steps (offload, A100) mean 20.3s, median 20.2s, std 1.6s; min 13.8s (step 2, first reuse of persistent NCCL group)

End-to-end update_weights in offload mode is dominated by convert + engine overhead; the main win here is memory balance, not raw step-time reduction.

Training correctness

Step train_rollout_logprob_abs_diff
0 0.01213
1 0.01290
2 0.01214

Consistent with main branch baseline (~0.012). actor_train_tok_per_s at step 1: ~16334 (no throughput regression).

Activation / fallback

  • Enabled when train_tp == infer_tp (--tensor-model-parallel-size == --rollout-num-gpus-per-engine)
  • Log when active: Using shard P2P weight sync (TP rank X/Y)
  • TP mismatch → falls back to original packed broadcast path

@aoshen02 aoshen02 mentioned this pull request Jun 21, 2026
15 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant