diff --git a/fla/ops/common/chunk_delta_h.py b/fla/ops/common/chunk_delta_h.py index 63de5d54fc..deeaa87e34 100644 --- a/fla/ops/common/chunk_delta_h.py +++ b/fla/ops/common/chunk_delta_h.py @@ -13,7 +13,7 @@ from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.ops.utils.cache import fla_cache_autotune from fla.ops.utils.op import exp2 -from fla.utils import IS_NVIDIA_HOPPER, USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem +from fla.utils import IS_NVIDIA_BLACKWELL, IS_NVIDIA_HOPPER, USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem NUM_WARPS = [2, 4] if IS_NVIDIA_HOPPER else [2, 4, 8, 16] @@ -30,7 +30,7 @@ configs=[ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4] - for num_stages in ([2, 3, 4] if check_shared_mem('ampere') else [2, 1]) + for num_stages in ([4] if IS_NVIDIA_BLACKWELL else [2, 3, 4] if check_shared_mem('ampere') else [2, 1]) for BV in ([32, 64] if check_shared_mem('ada') else [32]) ], key=['H', 'HV', 'K', 'V', 'BT', 'STATE_V_FIRST'], @@ -305,7 +305,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( configs=[ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4] - for num_stages in ([2, 3, 4] if check_shared_mem('ampere') else [1]) + for num_stages in ([4] if IS_NVIDIA_BLACKWELL else [2, 3, 4] if check_shared_mem('ampere') else [1]) for BV in ([32, 64] if check_shared_mem('ada') else [32]) ], key=['H', 'HV', 'K', 'V', 'BT', 'BV', 'USE_G', 'STATE_V_FIRST'],