Fix backward regression in PR #5712#5825
Conversation
|
could you help clarify what cases will use this hip_mixed_d kernel? Please also put it in the PR summary. From previous discussion with Bernard, this is my understanding: could you modify the above based on the new push? cc: @avbokovoy @liligwu |
|
@spcyppt The dispatch logic after this PR should look like this: Let me know if you have other questions cc: @liligwu |
|
@avbokovoy could you rebase this? There's a conflict and I could not import the PR. I don't have permission to resolve conflicts. |
Summary: CTA kernel optimizations from D86108064 split out: - Enable subwarp shuffle for CTA kernel on ROCm - Adaptive work group sizing based on total_L/total_B ratio - Small-D template parameter override (max_D <= 128) with kFixedMaxVecsPerThread=1, kThreadGroupSize=32, kUseVecBlocking=false for tighter compiler codegen - PROCESS_BLOCK macro for grad accumulation loop unrolling in compute_grad_sum (shared by CTA and warp kernels) - subwarp_reduce_add for ROCm GROUP_REDUCE_ALL_SUM (shared by CTA and warp kernels) This diff should not impact any performance on NVIDIA GPUs. Differential Revision: D102946299 Pulled By: spcyppt
Summary: Warp kernel optimizations from D86108064 split out: - New hip_mixed_d_warp kernel for mixed-D and VBE backward on ROCm - Batch preloading of per-row metadata via warp shuffle - Momentum value preloading (split_precomputation_preload) to eliminate separate global memory reads - Overloaded table_update_kernel accepting pre-resolved placement/offset/optimizer state - AMD __builtin_amdgcn_readlane for efficient broadcasting - Small-D template override (max_D <= 128) for warp kernel This diff should not impact any performance on NVIDIA GPUs. Differential Revision: D102946325 Pulled By: spcyppt
1445a35 to
5ff55c7
Compare
@spcyppt Rebase is done. Apologies for the delay |
|
@spcyppt has imported this pull request. If you are a Meta employee, you can view this in D107407381. |
Guarded short segments cases to skip optimized kernel launch. New warp_per_row kernel dispatch logic (
FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL=1):