[parallel] fix: NPU Hang/Deadlock during DTensor parameter loading in FSDP2#642
[parallel] fix: NPU Hang/Deadlock during DTensor parameter loading in FSDP2#642First-Frost-code wants to merge 4 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the _dispatch_parameter function in veomni/models/module_utils.py to include a specialized handling path for NPU devices when a parallel plan is active. The change manually implements sharding for DTensor placements on NPUs. Feedback indicates that the current implementation moves the full tensor to the NPU before sharding, which could cause Out-Of-Memory (OOM) errors for large parameters; it is recommended to perform sharding on the CPU first. Additionally, a safety check should be added to verify that the tensor dimensions are large enough for the requested sharding to prevent runtime errors.
| local_device = orig_tensor.device | ||
| tensor = tensor.to(device=local_device, dtype=orig_tensor.dtype) | ||
| target_local = orig_tensor.to_local() | ||
|
|
||
| for mesh_dim, p in enumerate(placements): | ||
| if p.__class__.__name__ == "Shard": | ||
| shard_dim = p.dim | ||
| my_mesh_rank = device_mesh.get_coordinate()[mesh_dim] | ||
| world_size = device_mesh.size(mesh_dim) | ||
|
|
||
| shards = tensor.chunk(world_size, dim=shard_dim) | ||
| tensor = shards[my_mesh_rank].contiguous() |
There was a problem hiding this comment.
The current implementation moves the entire tensor (or the EP-sharded version) to the NPU before performing the FSDP sharding. For large parameters (e.g., in MoE models), this can lead to unnecessary NPU memory spikes and potential OOM. It is more efficient to perform the sharding on the CPU first and only move the resulting local shard to the NPU. Additionally, it's safer to check if the tensor size is sufficient for the requested sharding to avoid an IndexError from torch.chunk if the dimension size is smaller than the mesh size.
for mesh_dim, p in enumerate(placements):
if p.__class__.__name__ == "Shard":
shard_dim = p.dim
my_mesh_rank = device_mesh.get_coordinate()[mesh_dim]
world_size = device_mesh.size(mesh_dim)
if tensor.size(shard_dim) < world_size:
raise ValueError(f"Tensor size {tensor.size(shard_dim)} on dim {shard_dim} is too small for world_size {world_size} on mesh dimension {mesh_dim} for parameter {full_param_name}")
shards = tensor.chunk(world_size, dim=shard_dim)
tensor = shards[my_mesh_rank]
tensor = tensor.to(device=orig_tensor.device, dtype=orig_tensor.dtype).contiguous()
target_local = orig_tensor.to_local()…arameter dispatch Bypass implicit redistribute collective sync in DTensor copy by replacing dtensor_factory with manual chunking and local physical copy. This resolves critical Hang issues on Ascend NPUs when dispatching replicated or sharded weights like MoE gates.
9a476f0 to
3ff6669
Compare
|
Thanks for you PR, can you help show the deadlock case (like model / size / training args), in fsdp2 we use rank0 load and broadcast to other ranks to avoid OOM. In this way OOM will happen only one tensor is larger than the device max memory . And I notice that the fixed code is fsdp1 but the title is fsdp2. |
|
Thanks for pointing that out! You brought up excellent points. Let me clarify the context and the root cause:
Model: Qwen3-VL-30B-A3B-Instruct Hardware: 16-card Ascend NPU cluster (e.g., 910C with HCCL) Why the hang happens when EP > 1:
Thanks again for the rigorous review! Please let me know if any further adjustments are needed. |
|
|
||
| else: | ||
| # Default execution path for GPUs or non-EP scenarios | ||
| tensor = |
|
Sry, it's my mistake only fsdp2 use |
|
My apologies! 😅 My code editor completely messed up the copy-paste during my last commit, which accidentally truncated the GPU execution path and broke the syntax. I just pulled the branch back and pushed a clean fix to restore the logic. Thanks for catching that! Thanks for double-checking and confirming that it is indeed the FSDP2 routing. I have updated the PR title back to FSDP2 to keep it accurate. Regarding PR #648, I have reviewed its code and logic carefully. It does NOT fix my issue, and this PR (#642) is still absolutely necessary. Here is why: Different Execution Paths: PR #648 explicitly targets and fixes the load_model_weights (all-ranks-read) path. As the author of #648 noted in their PR description: "The rank0_load_and_broadcast_weights path keeps src_data_rank=0 — it legitimately needs scatter since only rank 0 reads." My Trigger Condition: In my training setup (loading HF format initially), the framework falls back to the broadcast path. My training logs explicitly show: >> Loading model weights from disk on rank0 then broadcasting to other ranks... The Deadlock Remains: Because PR #648 leaves the rank0_load_and_broadcast_weights path unchanged, my setup still executes the old logic. It still passes the tensor to dtensor_factory, which triggers the implicit Redistribute collective via PyTorch's torch_dispatch. This uncoordinated implicit collective is exactly what destroys the HCCL streams and causes the permanent hang on Ascend NPUs. Conclusion: PR #648 is a fantastic fix for the direct-read scenario, but this PR (#642) acts as the essential safety net (physical bypass) for the broadcast scenario on NPUs. They complement each other perfectly! Let me know if you need any more logs or tests from my side! |
Describe the bug
When training VLM/MoE models on certain hardware backends (especially Ascend NPUs with HCCL), the program frequently hangs indefinitely during the weight loading phase.
Root Cause
The original
_dispatch_parameterusesdtensor_factoryand then calls.copy_()on a flat parameter. When assigning a DTensor view to a flat parameter, PyTorch's__torch_dispatch__implicitly triggers aRedistributecollective communication. This uncoordinated implicit network communication disrupts the backend streams on NPU, causing permanent deadlocks.Proposed Solution
This PR adds a defensive feature toggle for NPU environments to bypass implicit DTensor network communications:
.to_local().placementsto perform pure mathematical.chunk(), assigning the correct shard based on the local mesh coordinate..copy_().This "physical bypass" effectively decouples the weight assignment from the distributed communication graph, completely resolving the deadlock while leaving the default GPU execution path untouched.