Skip to content

Fix backward regression in PR #5712#5825

Open
avbokovoy wants to merge 4 commits into
pytorch:mainfrom
ROCm:abokovoi/pr_5712_backward_regression
Open

Fix backward regression in PR #5712#5825
avbokovoy wants to merge 4 commits into
pytorch:mainfrom
ROCm:abokovoi/pr_5712_backward_regression

Conversation

@avbokovoy

@avbokovoy avbokovoy commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Guarded short segments cases to skip optimized kernel launch. New warp_per_row kernel dispatch logic (FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL=1):

if  (vbe || mixed_D) && num_unique_prev > 0 && total_L <= 2*num_unique_prev
  → launch hip_mixed_d_warp (introduced in the PR)
else !mixed_D && D in [64, 128, 160, 192, 256, 320] && weights_dtype == output_dtype && weights_on_HBM
  → launch hip_warp (introduced previously - no change for this PR)
else 
  → regular split_warp kernel

@meta-cla meta-cla Bot added the cla signed label Jun 3, 2026
@spcyppt

spcyppt commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

could you help clarify what cases will use this hip_mixed_d kernel? Please also put it in the PR summary.

From previous discussion with Bernard, this is my understanding:

if vbe
→ launch hip_mixed_d_warp kernel (introduced in the PR)
else
if mixed D
→  launch hip_mixed_d_warp (introduced in the PR)
else
if D in [64, 128, 160, 192, 256, 320] && weights_dtype == output_dtype && weights_on_HBM
→ launch hip_warp kernel (introduced previously - no change for this PR)
else
→ regular split_warp kernel

could you modify the above based on the new push?

cc: @avbokovoy @liligwu

@avbokovoy

avbokovoy commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

@spcyppt The dispatch logic after this PR should look like this:

if  (vbe || mixed_D) && num_unique_prev > 0 && total_L <= 2*num_unique_prev
  → launch hip_mixed_d_warp (introduced in the PR)
else !mixed_D && D in [64, 128, 160, 192, 256, 320] && weights_dtype == output_dtype && weights_on_HBM
  → launch hip_warp (introduced previously - no change for this PR)
else 
  → regular split_warp kernel

Let me know if you have other questions

cc: @liligwu

@spcyppt

spcyppt commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

@avbokovoy could you rebase this? There's a conflict and I could not import the PR. I don't have permission to resolve conflicts.

Bernard-Liu and others added 4 commits June 9, 2026 09:22
Summary:
CTA kernel optimizations from D86108064 split out:
- Enable subwarp shuffle for CTA kernel on ROCm
- Adaptive work group sizing based on total_L/total_B ratio
- Small-D template parameter override (max_D <= 128) with kFixedMaxVecsPerThread=1, kThreadGroupSize=32, kUseVecBlocking=false for tighter compiler codegen
- PROCESS_BLOCK macro for grad accumulation loop unrolling in compute_grad_sum (shared by CTA and warp kernels)
- subwarp_reduce_add for ROCm GROUP_REDUCE_ALL_SUM (shared by CTA and warp kernels)

This diff should not impact any performance on NVIDIA GPUs.

Differential Revision: D102946299

Pulled By: spcyppt
Summary:
Warp kernel optimizations from D86108064 split out:
- New hip_mixed_d_warp kernel for mixed-D and VBE backward on ROCm
- Batch preloading of per-row metadata via warp shuffle
- Momentum value preloading (split_precomputation_preload) to eliminate separate global memory reads
- Overloaded table_update_kernel accepting pre-resolved placement/offset/optimizer state
- AMD __builtin_amdgcn_readlane for efficient broadcasting
- Small-D template override (max_D <= 128) for warp kernel

This diff should not impact any performance on NVIDIA GPUs.


Differential Revision: D102946325

Pulled By: spcyppt
@avbokovoy avbokovoy force-pushed the abokovoi/pr_5712_backward_regression branch from 1445a35 to 5ff55c7 Compare June 9, 2026 10:25
@avbokovoy

Copy link
Copy Markdown
Contributor Author

@avbokovoy could you rebase this? There's a conflict and I could not import the PR. I don't have permission to resolve conflicts.

@spcyppt Rebase is done. Apologies for the delay

@meta-codesync

meta-codesync Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

@spcyppt has imported this pull request. If you are a Meta employee, you can view this in D107407381.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants