-
Notifications
You must be signed in to change notification settings - Fork 36
[Feature] shard-level P2P weight transfer (Phase 1, #160) #161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py | ||
| index 1111111..2222222 100644 | ||
| --- a/vllm/distributed/weight_transfer/nccl_engine.py | ||
| +++ b/vllm/distributed/weight_transfer/nccl_engine.py | ||
| @@ -1,5 +1,7 @@ | ||
| from __future__ import annotations | ||
|
|
||
| +import logging | ||
| + | ||
| from typing import TYPE_CHECKING, Any, Optional | ||
|
|
||
| import torch | ||
| @@ -20,6 +22,7 @@ | ||
| if TYPE_CHECKING: | ||
| from vllm.config import VllmConfig | ||
|
|
||
| +logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| @@ -37,6 +40,9 @@ class NCCLWeightTransferInitInfo: | ||
| world_size: int | ||
| rank_offset: int = 0 | ||
| group_name: Optional[str] = None | ||
| + target_tp_rank: Optional[int] = None | ||
| + shard_rank: Optional[int] = None | ||
| + group_name_for_engine: Optional[str] = None | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| @@ -147,7 +153,9 @@ class NCCLWeightTransferEngine: | ||
| model_update_groups = {} | ||
| for group_name, init_info in init_infos.items(): | ||
| assert isinstance(init_info, NCCLWeightTransferInitInfo), init_info | ||
| - model_update_group = cls._stateless_init_process_group(init_info) | ||
| + model_update_group = cls._stateless_init_process_group( | ||
| + init_info, vllm_config=vllm_config | ||
| + ) | ||
| model_update_groups[group_name] = model_update_group | ||
| return model_update_groups | ||
|
|
||
| @@ -170,10 +178,16 @@ class NCCLWeightTransferEngine: | ||
| return model_update_groups | ||
|
|
||
| @classmethod | ||
| - def _stateless_init_process_group( | ||
| + def _stateless_init_process_group( # type: ignore[override] | ||
| cls, init_info: NCCLWeightTransferInitInfo | ||
| ) -> PyNcclCommunicator: | ||
| + return cls._stateless_init_process_group_v2(init_info) | ||
| + | ||
| + @classmethod | ||
| + def _stateless_init_process_group_v2( | ||
| + cls, init_info: NCCLWeightTransferInitInfo, *, vllm_config=None | ||
| + ) -> PyNcclCommunicator: | ||
| from vllm.distributed.parallel_state import ( | ||
| in_the_same_tp_group_as_pp, | ||
| ) | ||
| @@ -185,6 +199,23 @@ class NCCLWeightTransferEngine: | ||
| assert in_the_same_tp_group_as_pp(), ( | ||
| "Weight transfer across PP groups is not supported." | ||
| ) | ||
| + | ||
| + target_tp_rank = getattr(init_info, "target_tp_rank", None) | ||
| + if target_tp_rank is not None and vllm_config is not None: | ||
| + tp_rank = get_tp_group().rank_in_group | ||
| + if tp_rank != target_tp_rank: | ||
| + logger.debug( | ||
| + "Skipping NCCL group init for group_name=%s: " | ||
| + "tp_rank=%d != target_tp_rank=%d", | ||
| + init_info.group_name, tp_rank, target_tp_rank, | ||
| + ) | ||
| + return None | ||
| + logger.debug( | ||
| + "Creating NCCL group for group_name=%s: tp_rank=%d matches target_tp_rank=%d", | ||
| + init_info.group_name, tp_rank, target_tp_rank, | ||
| + ) | ||
| + | ||
| from vllm.distributed.device_communicators.stateless_comm import ( | ||
| StatelessProcessGroup, | ||
| ) | ||
| @@ -247,7 +278,12 @@ class NCCLWeightTransferEngine: | ||
| assert group_name in model_update_groups, ( | ||
| f"Group {group_name} not found in model_update_groups" | ||
| ) | ||
| - group = model_update_groups[group_name] | ||
| + group = model_update_groups.get(group_name) | ||
| + if group is None: | ||
| + logger.debug( | ||
| + "receive_weights: skipping group_name=%s (not active for this rank)", | ||
| + group_name, | ||
| + ) | ||
| + continue | ||
| if packed: | ||
| cls._packed_nccl_broadcast(group, update_info) | ||
| else: | ||
| diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py | ||
| index 3333333..4444444 100644 | ||
| --- a/vllm/v1/worker/gpu_worker.py | ||
| +++ b/vllm/v1/worker/gpu_worker.py | ||
| @@ -1,5 +1,5 @@ | ||
| # ... existing imports ... | ||
|
|
||
| def load_weights_direct(self, named_weights): | ||
| for name, weight in named_weights: | ||
| param = self.model.get_parameter(name) | ||
| - param.copy_(weight) | ||
| + param.data.copy_(weight) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,3 +69,101 @@ def convert_qwen2_to_hf(args, name, param): | |
| return [(f"model.layers.{layer_idx}.self_attn.k_norm.weight", param)] | ||
|
|
||
| raise ValueError(f"Unknown parameter name: {name}") | ||
|
|
||
|
|
||
| def convert_qwen2_to_hf_shard(args, name, param, tp_rank, tp_size): | ||
| """Shard-level HF conversion: operates on a single TP shard without all_gather. | ||
|
|
||
| For Qwen2/3 with GQA, Megatron shards by query groups, which maps directly | ||
| to vLLM's head-based sharding after QKV split. Each TP rank converts its | ||
| own shard and sends to the corresponding vLLM TP rank. | ||
|
|
||
| Args: | ||
| args: Model config (num_attention_heads, num_query_groups, etc.) | ||
| name: Megatron parameter name | ||
| param: TP-sharded parameter (this rank's shard only) | ||
| tp_rank: Current tensor model parallel rank | ||
| tp_size: Tensor model parallel world size | ||
|
|
||
| Returns: | ||
| List of (hf_name, shard_tensor) tuples. Empty list for duplicated | ||
| params on non-rank-0 (only rank 0 sends duplicated params). | ||
| """ | ||
| try: | ||
| head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads | ||
| except AttributeError: | ||
| head_dim = args.hidden_size // args.num_attention_heads | ||
| value_num_per_group = args.num_attention_heads // args.num_query_groups | ||
|
|
||
| # Duplicated params: every TP rank sends these (each group needs them) | ||
| if name == "module.module.decoder.final_layernorm.weight": | ||
| return [("model.norm.weight", param)] | ||
|
|
||
| decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" | ||
| match = re.match(decoder_layers_pattern, name) | ||
| if match: | ||
| layer_idx, rest = match.groups() | ||
|
|
||
| # Duplicated params: layernorms, qk norm - every rank sends in its group | ||
| _duplicated_map = { | ||
| "self_attention.linear_qkv.layer_norm_weight": "input_layernorm", | ||
| "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm", | ||
| "self_attention.q_layernorm.weight": "self_attn.q_norm", | ||
| "self_attention.k_layernorm.weight": "self_attn.k_norm", | ||
| } | ||
| if rest in _duplicated_map: | ||
| hf_name = f"model.layers.{layer_idx}.{_duplicated_map[rest]}.weight" | ||
| return [(hf_name, param)] | ||
|
|
||
| # TP-sharded params: each rank converts its own shard | ||
| if rest == "self_attention.linear_qkv.weight": | ||
| groups_per_rank = args.num_query_groups // tp_size | ||
| param = param.view(groups_per_rank, -1, head_dim, args.hidden_size) | ||
| q_param, k_param, v_param = torch.split( | ||
| param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1 | ||
| ) | ||
| q_param = q_param.reshape(-1, args.hidden_size) | ||
| k_param = k_param.reshape(-1, args.hidden_size) | ||
| v_param = v_param.reshape(-1, args.hidden_size) | ||
| # For is_checkpoint_format=False: produce combined qkv_proj matching vLLM layout | ||
| qkv_param = torch.cat([q_param, k_param, v_param], dim=0) | ||
| return [ | ||
| (f"model.layers.{layer_idx}.self_attn.qkv_proj.weight", qkv_param), | ||
| ] | ||
|
|
||
| if rest == "self_attention.linear_qkv.bias": | ||
| groups_per_rank = args.num_query_groups // tp_size | ||
| param = param.view(groups_per_rank, -1) | ||
| q_bias, k_bias, v_bias = torch.split( | ||
| param, | ||
| split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim], | ||
| dim=1, | ||
| ) | ||
| q_bias = q_bias.contiguous().flatten() | ||
| k_bias = k_bias.contiguous().flatten() | ||
| v_bias = v_bias.contiguous().flatten() | ||
| return [ | ||
| (f"model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias), | ||
| (f"model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias), | ||
| (f"model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias), | ||
| ] | ||
|
Comment on lines
+134
to
+149
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For if rest == "self_attention.linear_qkv.bias":
groups_per_rank = args.num_query_groups // tp_size
param = param.view(groups_per_rank, -1)
q_bias, k_bias, v_bias = torch.split(
param,
split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim],
dim=1,
)
q_bias = q_bias.contiguous().flatten()
k_bias = k_bias.contiguous().flatten()
v_bias = v_bias.contiguous().flatten()
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0)
return [
(f"model.layers.{layer_idx}.self_attn.qkv_proj.bias", qkv_bias),
] |
||
|
|
||
| if rest == "self_attention.linear_proj.weight": | ||
| return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] | ||
|
|
||
| if rest == "mlp.linear_fc1.weight": | ||
| # For is_checkpoint_format=False: produce combined gate_up_proj matching vLLM layout | ||
| return [ | ||
| (f"model.layers.{layer_idx}.mlp.gate_up_proj.weight", param), | ||
| ] | ||
|
|
||
| if rest == "mlp.linear_fc2.weight": | ||
| return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] | ||
|
|
||
| # Embedding and output layer: TP-sharded along vocab dim | ||
| if name == "module.module.embedding.word_embeddings.weight": | ||
| return [("model.embed_tokens.weight", param)] | ||
| if name == "module.module.output_layer.weight": | ||
| return [("lm_head.weight", param)] | ||
|
|
||
| raise ValueError(f"Unknown parameter name: {name}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patched
_stateless_init_process_groupmethod signature does not accept*argsor**kwargs(such asvllm_config), but it is called withvllm_config=vllm_configinnccl_engine.py. This will raise aTypeErrorat runtime. Update the signature to accept and forward these arguments.