Background
CheferInterpretable (in pyhealth/interpret/api.py) currently bundles two distinct concerns into a single interface:
- Forward attention capture —
set_attention_hooks stores attention weight tensors during the forward pass so they can be retrieved via get_attention_layers.
- Backward gradient capture —
set_attention_hooks also registers backward hooks so that attention gradients are stored after .backward().
This bundling made sense when CheferRelevance was the only attention-based interpreter (it needs both). But PR #1158 adds AttentionRollout (Abnar & Zuidema 2020), which is gradient-free: it only needs the forward attention maps and explicitly discards attn_grad with for attn_map, _ in layers:. Requiring backward hook registration for a gradient-free method is semantically wrong, even if it works in practice.
Proposed refactor
Split the interface into two levels:
AttentionInterpretable ← forward attention capture only
set_attention_hooks(enabled) (captures attn maps, no backward hooks)
get_attention_layers() (returns attn maps; attn_grad may be None)
get_relevance_tensor(R, **data)
GradientInterpretable(AttentionInterpretable) ← extends with backward hooks
set_attention_hooks(enabled) (also registers backward hooks)
get_attention_layers() (attn_grad is always populated)
AttentionRollout checks isinstance(model, AttentionInterpretable)
CheferRelevance checks isinstance(model, GradientInterpretable)
- Keep
CheferInterpretable = GradientInterpretable as a backwards-compatible alias
Impact
- Models (
Transformer, StageAttentionNet) would need to support a "forward-only hooks" mode in set_attention_hooks, so attention maps are captured without registering backward hooks. This would also remove the torch.no_grad() incompatibility currently documented for AttentionRollout.
- No breaking changes to existing callers if the alias is kept.
Context
Surfaced during review of PR #1158. The duck-typing AttentionRollout uses today is a reasonable workaround until this refactor lands.
Background
CheferInterpretable(inpyhealth/interpret/api.py) currently bundles two distinct concerns into a single interface:set_attention_hooksstores attention weight tensors during the forward pass so they can be retrieved viaget_attention_layers.set_attention_hooksalso registers backward hooks so that attention gradients are stored after.backward().This bundling made sense when
CheferRelevancewas the only attention-based interpreter (it needs both). But PR #1158 addsAttentionRollout(Abnar & Zuidema 2020), which is gradient-free: it only needs the forward attention maps and explicitly discardsattn_gradwithfor attn_map, _ in layers:. Requiring backward hook registration for a gradient-free method is semantically wrong, even if it works in practice.Proposed refactor
Split the interface into two levels:
AttentionRolloutchecksisinstance(model, AttentionInterpretable)CheferRelevancechecksisinstance(model, GradientInterpretable)CheferInterpretable = GradientInterpretableas a backwards-compatible aliasImpact
Transformer,StageAttentionNet) would need to support a "forward-only hooks" mode inset_attention_hooks, so attention maps are captured without registering backward hooks. This would also remove thetorch.no_grad()incompatibility currently documented forAttentionRollout.Context
Surfaced during review of PR #1158. The duck-typing
AttentionRolloutuses today is a reasonable workaround until this refactor lands.