[NPU] colocate training with TMS, bridge sync, and IPC weight loading#285
[NPU] colocate training with TMS, bridge sync, and IPC weight loading#285CalvinXKY wants to merge 3 commits into
Conversation
Move packed-module remap into vime colocate worker extension and keep vllm-ascend patch limited to colocate init (free_memory check and profiling assert). NPU colocate uses is_checkpoint_format=False, skips layerwise_reload, and loads weights via direct param writes with QKV/gate_up remap. Signed-off-by: kaiyuan <kyxiezju@163.com>
Documentation build overview
28 files changed ·
|
There was a problem hiding this comment.
Code Review
This pull request introduces NPU support for colocated weight updates, including custom weight loading, NPU synchronization, and patching for the NPU worker. The review feedback highlights several critical issues: first, torch_npu.get_device_properties lacks a uuid attribute on Ascend NPU, which will cause an AttributeError and crash the process; second, the colocate worker bypasses general NPU patches (such as rotary embeddings and MoE weight loaders) which must be explicitly applied; and third, _load_colocate_weights_direct lacks a torch.no_grad() context and needs a shape length guard to prevent potential errors during dimension-sharding.
| def _current_gpu_uuid() -> str: | ||
| device_index = torch.cuda.current_device() | ||
| props = torch.cuda.get_device_properties(device_index) | ||
| if is_npu(): | ||
| device_index = torch.npu.current_device() | ||
| props = torch.npu.get_device_properties(device_index) | ||
| else: | ||
| device_index = torch.cuda.current_device() | ||
| props = torch.cuda.get_device_properties(device_index) | ||
| return str(props.uuid) |
There was a problem hiding this comment.
The torch_npu.get_device_properties object does not have a uuid attribute on Ascend NPU. Accessing props.uuid directly will raise an AttributeError and crash the weight update process. We should fallback to a unique identifier such as f"npu:{device_index}" if uuid is not present.
| def _current_gpu_uuid() -> str: | |
| device_index = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(device_index) | |
| if is_npu(): | |
| device_index = torch.npu.current_device() | |
| props = torch.npu.get_device_properties(device_index) | |
| else: | |
| device_index = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(device_index) | |
| return str(props.uuid) | |
| def _current_gpu_uuid() -> str: | |
| if is_npu(): | |
| device_index = torch.npu.current_device() | |
| props = torch.npu.get_device_properties(device_index) | |
| return getattr(props, "uuid", f"npu:{device_index}") | |
| else: | |
| device_index = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(device_index) | |
| return str(props.uuid) |
| if is_npu(): | ||
| device_index = torch.npu.current_device() | ||
| physical_gpu_id = str(torch.npu.get_device_properties(device_index).uuid) | ||
| else: | ||
| device_index = torch.cuda.current_device() | ||
| physical_gpu_id = str(torch.cuda.get_device_properties(device_index).uuid) |
There was a problem hiding this comment.
The torch_npu.get_device_properties object does not have a uuid attribute on Ascend NPU. Accessing props.uuid directly will raise an AttributeError and crash the weight update process. We should fallback to a unique identifier such as f"npu:{device_index}" if uuid is not present.
| if is_npu(): | |
| device_index = torch.npu.current_device() | |
| physical_gpu_id = str(torch.npu.get_device_properties(device_index).uuid) | |
| else: | |
| device_index = torch.cuda.current_device() | |
| physical_gpu_id = str(torch.cuda.get_device_properties(device_index).uuid) | |
| if is_npu(): | |
| device_index = torch.npu.current_device() | |
| props = torch.npu.get_device_properties(device_index) | |
| physical_gpu_id = getattr(props, "uuid", f"npu:{device_index}") | |
| else: | |
| device_index = torch.cuda.current_device() | |
| physical_gpu_id = str(torch.cuda.get_device_properties(device_index).uuid) |
| def _patch_npu_colocate_worker() -> None: | ||
| """Skip layerwise_reload in NPUWorker weight-update hooks (colocate OOM fix).""" | ||
| try: | ||
| from vllm_ascend.worker.worker import NPUWorker | ||
| except ImportError: | ||
| return | ||
|
|
||
| if getattr(NPUWorker, "_vime_colocate_patched", False): | ||
| return | ||
|
|
||
| def _patched_start_weight_update(self, is_checkpoint_format: bool = True) -> None: | ||
| self._weight_update_active = True | ||
| self._is_checkpoint_format = is_checkpoint_format | ||
|
|
||
| def _patched_finish_weight_update(self) -> None: | ||
| self._weight_update_active = False | ||
|
|
||
| NPUWorker.start_weight_update = _patched_start_weight_update # type: ignore[method-assign] | ||
| NPUWorker.finish_weight_update = _patched_finish_weight_update # type: ignore[method-assign] | ||
| NPUWorker._vime_colocate_patched = True # type: ignore[attr-defined] |
There was a problem hiding this comment.
When using NPU colocate mode, the vLLMColocateWorkerExtension is loaded instead of vLLMWorkerExtension. This bypasses the general NPU patches applied in _VLLMHijack._patch_npu_worker() and _VLLMHijack._patch_npu_rotary_emb(). As a result, MoE models will fail due to the missing weight_loader attribute on EP, and rotary embeddings will fail to initialize correctly if flash_attn is not available.
To fix this, we should apply the rotary embedding patch and ensure patch_moe_weight_loader is called during _patched_start_weight_update.
def _patch_npu_colocate_worker() -> None:
"""Skip layerwise_reload in NPUWorker weight-update hooks (colocate OOM fix)."""
try:
from vllm_ascend.worker.worker import NPUWorker
except ImportError:
return
if getattr(NPUWorker, "_vime_colocate_patched", False):
return
_VLLMHijack._patch_npu_rotary_emb()
def _patched_start_weight_update(self, is_checkpoint_format: bool = True) -> None:
self._weight_update_active = True
self._is_checkpoint_format = is_checkpoint_format
_VLLMHijack.patch_moe_weight_loader(self.model_runner.model)
def _patched_finish_weight_update(self) -> None:
self._weight_update_active = False
NPUWorker.start_weight_update = _patched_start_weight_update # type: ignore[method-assign]
NPUWorker.finish_weight_update = _patched_finish_weight_update # type: ignore[method-assign]
NPUWorker._vime_colocate_patched = True # type: ignore[attr-defined]| def _load_colocate_weights_direct( | ||
| model: torch.nn.Module, | ||
| weights: list[tuple[str, torch.Tensor]], | ||
| *, | ||
| tp_rank: int, | ||
| ) -> None: | ||
| """Load HF-named weights into vLLM packed params (qkv_proj, gate_up_proj, etc.).""" | ||
| from vllm.model_executor.utils import get_packed_modules_mapping | ||
|
|
||
| packed_map = get_packed_modules_mapping(model) | ||
| sub_to_packed = {} | ||
| for packed_name, sub_names in packed_map.items(): | ||
| for idx, sub_name in enumerate(sub_names): | ||
| sub_to_packed[sub_name] = (packed_name, idx) | ||
| qkv_shard_ids = {"q_proj": "q", "k_proj": "k", "v_proj": "v"} | ||
|
|
||
| def _remap_and_get_shard_id(hf_name: str) -> tuple[str, object | None]: | ||
| parts = hf_name.split(".") | ||
| shard_id = None | ||
| for i in range(len(parts)): | ||
| if parts[i] in sub_to_packed: | ||
| packed_name, idx = sub_to_packed[parts[i]] | ||
| shard_id = qkv_shard_ids.get(parts[i], idx) | ||
| parts[i] = packed_name | ||
| break | ||
| return ".".join(parts), shard_id | ||
|
|
||
| for name, weight in weights: | ||
| remapped, shard_id = _remap_and_get_shard_id(name) | ||
| try: | ||
| param = model.get_parameter(remapped) | ||
| except AttributeError: | ||
| continue | ||
| weight_loader = getattr(param, "weight_loader", None) | ||
| if weight_loader is not None and shard_id is not None: | ||
| weight_loader(param, weight, loaded_shard_id=shard_id) | ||
| elif weight_loader is not None: | ||
| weight_loader(param, weight) | ||
| elif param.shape == weight.shape: | ||
| param.data.copy_(weight) | ||
| else: | ||
| for dim in range(len(param.shape)): | ||
| if param.shape[dim] != weight.shape[dim]: | ||
| shard_size = param.shape[dim] | ||
| param.data.copy_(weight.narrow(dim, tp_rank * shard_size, shard_size)) | ||
| break |
There was a problem hiding this comment.
There are two issues in _load_colocate_weights_direct:
- The weight loading and copying operations are not wrapped in a
torch.no_grad()context. Since this runs on the worker side via RPC, there is no activetorch.no_grad()decorator, which can lead to unnecessary autograd graph construction and memory overhead. - The dimension-sharding fallback assumes
len(param.shape) == len(weight.shape). If they differ, it could raise anIndexErroror perform incorrect narrowing.
We should wrap the loading loop in with torch.no_grad(): and add a guard for shape length equality.
def _load_colocate_weights_direct(
model: torch.nn.Module,
weights: list[tuple[str, torch.Tensor]],
*,
tp_rank: int,
) -> None:
"""Load HF-named weights into vLLM packed params (qkv_proj, gate_up_proj, etc.)."""
from vllm.model_executor.utils import get_packed_modules_mapping
packed_map = get_packed_modules_mapping(model)
sub_to_packed = {}
for packed_name, sub_names in packed_map.items():
for idx, sub_name in enumerate(sub_names):
sub_to_packed[sub_name] = (packed_name, idx)
qkv_shard_ids = {"q_proj": "q", "k_proj": "k", "v_proj": "v"}
def _remap_and_get_shard_id(hf_name: str) -> tuple[str, object | None]:
parts = hf_name.split(".")
shard_id = None
for i in range(len(parts)):
if parts[i] in sub_to_packed:
packed_name, idx = sub_to_packed[parts[i]]
shard_id = qkv_shard_ids.get(parts[i], idx)
parts[i] = packed_name
break
return ".".join(parts), shard_id
with torch.no_grad():
for name, weight in weights:
remapped, shard_id = _remap_and_get_shard_id(name)
try:
param = model.get_parameter(remapped)
except AttributeError:
continue
weight_loader = getattr(param, "weight_loader", None)
if weight_loader is not None and shard_id is not None:
weight_loader(param, weight, loaded_shard_id=shard_id)
elif weight_loader is not None:
weight_loader(param, weight)
elif param.shape == weight.shape:
param.data.copy_(weight)
else:
if len(param.shape) == len(weight.shape):
for dim in range(len(param.shape)):
if param.shape[dim] != weight.shape[dim]:
shard_size = param.shape[dim]
param.data.copy_(weight.narrow(dim, tp_rank * shard_size, shard_size))
breakAdd TMS torch mode, training region, colocate-only PYTORCH_NPU_ALLOC_CONF override, CANN PYTHONPATH for Ray workers, and safe NPU empty_cache. Complements IPC weight sync for bridge colocate on 8x910B. Signed-off-by: kaiyuan <kyxiezju@163.com>
8c9e13f to
916af60
Compare
Add npu_attention_patch before mindspeed import so Megatron train-side attention uses npu_fusion_attention on Ascend, matching pr266 colocate. Signed-off-by: kaiyuan <kyxiezju@163.com>
916af60 to
3eb5436
Compare
Test on new vllm version
Result
|

Summary
PYTORCH_NPU_ALLOC_CONF=expandable_segments:False, CANNPYTHONPATH, safeempty_cache, and IPC weight sync viaUpdateWeightFromTensor+ bridge mode.NPUIPCWeightTransferEngineAPI for vLLM 0.22.vllm-ascend.patchworker changes to colocate only; remap HF packed modules (QKV, gate_up) in vime IPC path.Motivation
New NPU images set
PYTORCH_NPU_ALLOC_CONF=expandable_segments:Truein Dockerfile, which breaks TMS/CaMem colocate. Validated on 8×910B (Qwen3-4B, TP=2, bridge mode, smoke + long run). Code must explicitly override alloc conf;pop/""is insufficient.Key changes
actor_group.pyLD_PRELOAD/TMS_INIT_ENABLE); colocate alloc conf overriderollout.py,vllm_engine.pyexpandable_segments:Falsefor Ray workers and vLLM subprocessmemory_utils.py,actor.pyempty_cacheunder TMS custom allocatorupdate_weight_from_tensor.pyvllm.patchvllm-ascend.patchnpu_ipc_enginemodelargTMS .so note
Dockerfile.npubuildstorch_memory_saverfromsgl-kernel-npu@2026.6.0. The stock wheel.somay still be incompatible with NPU TMS torch mode until upstream fixesTMS_INIT_ENABLEhandling. Validated runs used known-good.soMD5 (03adebff…/83181484…). Track upstream or pin binaries if needed.Test plan
pre-commit runon changed filesvime_pr285_test(8×910B, num-rollout=1)docker/Dockerfile.npuand verify colocate without manual.soswap