Skip to content

perf/accuracy: Flash Attention, torch-native SO(3), cosine schedule, DDIM, analytical g(t), acos fix#454

Open
mooreneural wants to merge 1 commit into
RosettaCommons:mainfrom
mooreneural:main
Open

perf/accuracy: Flash Attention, torch-native SO(3), cosine schedule, DDIM, analytical g(t), acos fix#454
mooreneural wants to merge 1 commit into
RosettaCommons:mainfrom
mooreneural:main

Conversation

@mooreneural
Copy link
Copy Markdown

Scientific and performance improvements across the diffusion pipeline. No breaking changes — all new features are opt-in via existing config flags or new parameters with safe defaults.

Changes

Flash Attention (Attention_module.py)

  • Replaced hand-rolled einsum attention in Attention, AttentionWithBias, and MSAColAttention with F.scaled_dot_product_attention
  • On CUDA + PyTorch ≥ 2.0, this automatically dispatches to Flash Attention — O(1) memory vs. materializing the full L×L attention matrix
  • Estimated 20–40% speedup on the attention-heavy MSA and pair tracks
  • AttentionWithBias passes its pairwise bias as attn_mask, fusing it into the kernel rather than a separate add

Torch-native SO(3) ops (igso3.py, diffusion.py, inference/utils.py)

  • Added hat_batch(), Log_torch(), Exp_torch() — Rodrigues-formula rotation ops that stay on-device
  • Eliminated all scipy_R / .cpu().numpy() roundtrips that fired at every denoising step in reverse_sample_vectorized(), diffuse_frames(), and get_next_frames()
  • Round-trip accuracy <1e-6 for angles in [0, π] (matches scipy's output for the same domain)

Cosine noise schedule (diffusion.py)

  • Added schedule_type="cosine" (Nichol & Dhariwal, 2021 — Improved DDPM)
  • Enable via diffuser.schedule_type=cosine; b0/bT are ignored for this mode
  • Better SNR curve at small t, tends to produce more diverse samples than linear

Analytical g(t) (diffusion.py)

  • For the linear σ schedule, replaced a per-step torch.autograd.grad call with the closed-form derivative: g(t) = sqrt(2·σ(t)·(min_b + t·(max_b − min_b)))
  • Falls back to autograd for the exponential schedule

IGSO3 module-level cache (diffusion.py)

  • Added _igso3_cache dict at module level
  • Prevents repeated disk deserialization of the precomputed CDF table when multiple Diffuser instances are created in the same process (e.g., batch inference scripts)

DDIM deterministic sampling (inference/utils.py)

  • Added get_mu_xt_x0_ddim() implementing the DDIM update rule (Song et al., 2021)
  • Wired ddim=True flag through Denoise.__init__()get_next_pose()get_next_ca()
  • Deterministic, lower-variance trajectories; enables fewer-step inference at equivalent quality
  • Default is False — no change to existing behavior

acos NaN clamp (kinematics.py)

  • torch.acos(vw)torch.acos(torch.clamp(vw, -1.0, 1.0))
  • Prevents silent NaN propagation when float rounding pushes a normalized dot product just outside [−1, 1]

Notes for reviewers

  • The Flash Attention change requires PyTorch ≥ 2.0 for the fused kernel; it degrades gracefully to the standard implementation on older versions
  • DDIM is wired at the Denoise class level only — rotation (SO3) denoising still uses the stochastic IGSO3 reverse SDE; a full DDIM-on-SO3 implementation would be a follow-on
  • The legacy Log() and Exp() functions in igso3.py are kept for backward compatibility (they're used in offline precomputation, not inference)

…IM, analytical g(t)

Attention (Attention_module.py):
- Replace hand-rolled einsum attention with F.scaled_dot_product_attention in
  Attention, AttentionWithBias, and MSAColAttention. Uses Flash Attention
  automatically when available on CUDA (20-40% speedup, O(1) memory).
- AttentionWithBias passes the pairwise bias as attn_mask so it is folded into
  the fused kernel rather than materializing a separate attention matrix.

SO3 diffusion (igso3.py, diffusion.py, inference/utils.py):
- Add hat_batch(), Log_torch(), Exp_torch() -- on-device rotation ops using
  the Rodrigues formula. Eliminates all scipy CPU round-trips during inference.
- Replace scipy_R calls in reverse_sample_vectorized() and diffuse_frames() with
  the new torch-native equivalents (stay on GPU, no .cpu()/.numpy() transfers).
- Remove redundant scipy rotation normalization in get_next_frames(); rotation
  matrices from rigid_from_3_points are already orthogonal.

Noise schedule (diffusion.py):
- Add cosine schedule (Nichol & Dhariwal, 2021). Enabled via
  schedule_type="cosine"; b0/bT are ignored for this mode.
- Analytical g(t) for linear schedule: eliminates a per-step autograd call.
  Formula: g(t) = sqrt(2 * sigma(t) * (min_b + t*(max_b - min_b))).

IGSO3 cache (diffusion.py):
- Add module-level _igso3_cache dict. Avoids repeated disk deserialization when
  multiple Diffuser objects are created in the same process (batch inference).

DDIM sampling (inference/utils.py):
- Add get_mu_xt_x0_ddim() implementing the deterministic DDIM update rule.
- Wire ddim=True flag through Denoise.__init__() -> get_next_pose() -> get_next_ca().
  Setting ddim=True produces deterministic, lower-variance trajectories and
  enables fewer-step inference at equivalent quality.

Numerical stability (kinematics.py):
- Clamp input to acos in get_ang() to [-1, 1] to prevent NaN from float
  rounding at exactly +/-1.
@mooreneural
Copy link
Copy Markdown
Author

mooreneural commented May 19, 2026

Thanks for running the CI check. This failure is not introduced by my PR.

The test environment resolves dgl==2.1.0 + torchdata==0.11.0, which are incompatible. DGL 2.1.0 internally imports torchdata.datapipes, but that submodule was removed from torchdata starting in version 0.7.0. The crash originates in util_module.py → import dgl, which is completely unrelated to the five files changed in this PR (Attention_module.py, diffusion.py, igso3.py, inference/utils.py, kinematics.py).

The identical failure would occur on the unmodified upstream main branch under these package versions.

To fix the CI environment (outside the scope of this PR):

Option A: Pin torchdata to the last release that still ships datapipes

pip install "torchdata<0.7"

Option B: Upgrade DGL to a version that no longer imports datapipes

pip install "dgl>=2.4"

Happy to help investigate further if useful.

@rclune rclune requested a review from woodsh17 May 20, 2026 08:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant