Skip to content

value_dist_injection with Filter#4253

Open
irobert0126 wants to merge 1 commit into
meta-pytorch:mainfrom
irobert0126:export-D104766359
Open

value_dist_injection with Filter#4253
irobert0126 wants to merge 1 commit into
meta-pytorch:mainfrom
irobert0126:export-D104766359

Conversation

@irobert0126

Copy link
Copy Markdown

Summary:
Make _register_param_grad_hook robust to sharded / unhookable params, and add a diagnostic print so we can see which param the hook actually landed on.

param.register_hook only fires through AccumulateGrad. Sharded embedding params (ShardedTensor / DTensor) bypass AccumulateGrad via FBGEMM TBE fused backward, so a hook attached to one silently never fires — the value-dist diagnostic prints nothing and the injection point looks broken.

What changes

  1. Index over all named params, not just hookable ones. hook_position now maps to a stable index regardless of how many params happen to be hookable (which depends on shard topology). Previously the same hook_position=0.5 could land on different params across runs.
  2. Outward fallback. If the param at the target index is unhookable, walk outward (forward first at each ring, then backward) to the nearest hookable neighbor.
  3. Assert if no param under the FQN is hookable at all (was ValueError over filtered params; now covers the all-sharded case too).
  4. Diagnostic print of requested_idx, chosen_idx, fqn, type, and is_leaf so it is obvious in logs which param the hook was attached to.

How hook_position resolves to a param

named_parameters under site.fqn  (declaration order, n = 5)

  idx        0       1       2       3       4
  param    [emb]   [w_a]   [w_b]   [emb]   [bias]
  hookable   ✗       ✓       ✓       ✗       ✓
             │                       │
             └─ ShardedTensor ───────┘
                bypasses AccumulateGrad via FBGEMM TBE fused backward
                → register_hook silently never fires

position → target_idx = int(position * n)  → outward walk → chosen_idx

  0.00   →    0    →   0 ✗ → +1 ✓                       →   1
  0.30   →    1    →   1 ✓                              →   1
  0.50   →    2    →   2 ✓                              →   2
  0.60   →    3    →   3 ✗ → +1 ✓                       →   4
  1.00   →    4    →   4 ✓                              →   4

Outward walk at offset k: try target+k first, then target-k. So forward neighbors win at each ring, but the search expands symmetrically until it finds a hookable param or exhausts the module.

Differential Revision: D104766359

Summary:
Make `_register_param_grad_hook` robust to sharded / unhookable params, and add a diagnostic print so we can see *which* param the hook actually landed on.

`param.register_hook` only fires through `AccumulateGrad`. Sharded embedding params (`ShardedTensor` / `DTensor`) bypass `AccumulateGrad` via FBGEMM TBE fused backward, so a hook attached to one silently never fires — the value-dist diagnostic prints nothing and the injection point looks broken.

## What changes

1. **Index over *all* named params, not just hookable ones.** `hook_position` now maps to a stable index regardless of how many params happen to be hookable (which depends on shard topology). Previously the same `hook_position=0.5` could land on different params across runs.
2. **Outward fallback.** If the param at the target index is unhookable, walk outward (forward first at each ring, then backward) to the nearest hookable neighbor.
3. **Assert** if no param under the FQN is hookable at all (was `ValueError` over filtered params; now covers the all-sharded case too).
4. **Diagnostic print** of `requested_idx`, `chosen_idx`, fqn, type, and `is_leaf` so it is obvious in logs which param the hook was attached to.

## How `hook_position` resolves to a param

```
named_parameters under site.fqn  (declaration order, n = 5)

  idx        0       1       2       3       4
  param    [emb]   [w_a]   [w_b]   [emb]   [bias]
  hookable   ✗       ✓       ✓       ✗       ✓
             │                       │
             └─ ShardedTensor ───────┘
                bypasses AccumulateGrad via FBGEMM TBE fused backward
                → register_hook silently never fires

position → target_idx = int(position * n)  → outward walk → chosen_idx

  0.00   →    0    →   0 ✗ → +1 ✓                       →   1
  0.30   →    1    →   1 ✓                              →   1
  0.50   →    2    →   2 ✓                              →   2
  0.60   →    3    →   3 ✗ → +1 ✓                       →   4
  1.00   →    4    →   4 ✓                              →   4
```

Outward walk at offset *k*: try `target+k` first, then `target-k`. So forward neighbors win at each ring, but the search expands symmetrically until it finds a hookable param or exhausts the module.

Differential Revision: D104766359
@meta-codesync

meta-codesync Bot commented May 13, 2026

Copy link
Copy Markdown
Contributor

@irobert0126 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104766359.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant