diff --git a/python/triton_dist/function/nvidia/ep_moe_fused.py b/python/triton_dist/function/nvidia/ep_moe_fused.py index 87c214a02..42ccf6b5a 100644 --- a/python/triton_dist/function/nvidia/ep_moe_fused.py +++ b/python/triton_dist/function/nvidia/ep_moe_fused.py @@ -204,7 +204,7 @@ def backward(ctx, dy): triton_dist_ep_ctx = ctx.triton_dist_ep_ctx ep_a2a_layout_desc = triton_dist_ep_ctx.ep_a2a_layout_desc - assert triton_dist_ep_ctx.ep_group.size() == 8 # only for intra-node + assert triton_dist_ep_ctx.ep_group.size() <= 8 # only for intra-node optim_config = get_moe_optim_config(use_mega=True, is_forward=False) profile_config = get_triton_dist_moe_profile_enabled()