Skip to content

Refactor: split CheferInterpretable into AttentionInterpretable and GradientInterpretable #1166

Description

@jhnwu3

Background

CheferInterpretable (in pyhealth/interpret/api.py) currently bundles two distinct concerns into a single interface:

  1. Forward attention captureset_attention_hooks stores attention weight tensors during the forward pass so they can be retrieved via get_attention_layers.
  2. Backward gradient captureset_attention_hooks also registers backward hooks so that attention gradients are stored after .backward().

This bundling made sense when CheferRelevance was the only attention-based interpreter (it needs both). But PR #1158 adds AttentionRollout (Abnar & Zuidema 2020), which is gradient-free: it only needs the forward attention maps and explicitly discards attn_grad with for attn_map, _ in layers:. Requiring backward hook registration for a gradient-free method is semantically wrong, even if it works in practice.

Proposed refactor

Split the interface into two levels:

AttentionInterpretable          ← forward attention capture only
    set_attention_hooks(enabled)    (captures attn maps, no backward hooks)
    get_attention_layers()          (returns attn maps; attn_grad may be None)
    get_relevance_tensor(R, **data)

GradientInterpretable(AttentionInterpretable)   ← extends with backward hooks
    set_attention_hooks(enabled)    (also registers backward hooks)
    get_attention_layers()          (attn_grad is always populated)
  • AttentionRollout checks isinstance(model, AttentionInterpretable)
  • CheferRelevance checks isinstance(model, GradientInterpretable)
  • Keep CheferInterpretable = GradientInterpretable as a backwards-compatible alias

Impact

  • Models (Transformer, StageAttentionNet) would need to support a "forward-only hooks" mode in set_attention_hooks, so attention maps are captured without registering backward hooks. This would also remove the torch.no_grad() incompatibility currently documented for AttentionRollout.
  • No breaking changes to existing callers if the alias is kept.

Context

Surfaced during review of PR #1158. The duck-typing AttentionRollout uses today is a reasonable workaround until this refactor lands.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions