Skip to content

Resolve ewald cache shape issue#110

Merged
laserkelvin merged 4 commits into
NVIDIA:mainfrom
ys-teh:fix/ewald-pme-cache
Jun 15, 2026
Merged

Resolve ewald cache shape issue#110
laserkelvin merged 4 commits into
NVIDIA:mainfrom
ys-teh:fix/ewald-pme-cache

Conversation

@ys-teh

@ys-teh ys-teh commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

ALCHEMI Toolkit Pull Request

Description

The Ewald/PME wrappers used torch.allclose(cell, self._cached_cell) to detect cell changes for cache invalidation. This assumes the current and cached cell tensors have identical shape, dtype, and device. When the same wrapper instance sees a different batch size, for example between training and validation batches, cell.shape changes from (B1, 3, 3) to (B2, 3, 3). In that case torch.allclose raises instead of returning False, so the cache cannot be safely invalidated and model training is affected.

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Performance improvement
  • Documentation update
  • Refactoring (no functional changes)
  • CI/CD or infrastructure change

Related Issues

Changes Made

  • Implements a core cache invalidation fix for Ewald and PME cell caches.
    • Added _cell_cache_needs_update() to both EwaldModelWrapper and PMEModelWrapper.
    • The helper now treats missing cache, shape mismatch, device mismatch, dtype mismatch, or changed cell values as stale cache conditions.
    • Updated the forward paths to call this helper before recomputing Ewald/PME cache state.
    • Adds regression tests for both wrappers covering: missing cached cell, identical cached cell reuse, train/validation batch-size shape mismatch, same-shape changed cell values, dtype mismatch

Testing

  • Unit tests pass locally (make pytest)
  • Linting passes (make lint)
  • New tests added for new functionality meets coverage expectations?

Checklist

  • I have read and understand the Contributing Guidelines
  • I have updated the CHANGELOG.md
  • I have performed a self-review of my code
  • I have added docstrings to new functions/classes
  • I have updated the documentation (if applicable)

Additional Notes

Tip

This repository uses Greptile, an AI code review service, to help conduct
pull request reviews. We encourage contributors to read and consider suggestions
made by Greptile, but note that human maintainers will provide the necessary
reviews for merging: Greptile's comments are not a qualitative judgement
of your code, nor is it an indication that the PR will be accepted/rejected.
We encourage the use of emoji reactions to Greptile comments, depending on
their usefulness and accuracy.

@copy-pr-bot

copy-pr-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ys-teh ys-teh marked this pull request as ready for review June 9, 2026 21:17
@greptile-apps

greptile-apps Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a crash in the Ewald and PME wrappers where torch.allclose was called on cell tensors with mismatched shapes (e.g. different batch sizes between training and validation), causing an exception instead of cache invalidation. The fix extracts the comparison into a shared cell_cache_needs_update helper in _utils.py that explicitly guards shape, device, and dtype before delegating to torch.allclose.

  • Core fix (_utils.py): New cell_cache_needs_update handles all mismatch conditions robustly, but its default tolerances (rtol=1e-5, atol≈1e-6) are 10–1000× looser than the originals (rtol=1e-6, atol=1e-9), which may cause stale-cache usage in fine-grained NPT simulations.
  • Wrapper updates (ewald.py, pme.py): Both wrappers now delegate to the shared helper and expose rtol/atol as constructor parameters for user control.
  • Tests (test_base.py): Five regression tests cover the main fix; the device-mismatch branch is not yet tested.

Important Files Changed

Filename Overview
nvalchemi/models/_utils.py Adds cell_cache_needs_update helper that guards against shape/device/dtype mismatches before calling torch.allclose; default tolerances are 10–1000× looser than the originals (rtol=1e-6, atol=1e-9), which could cause stale-cache misses in NPT simulations.
nvalchemi/models/ewald.py Replaces inline torch.allclose call with cell_cache_needs_update; adds rtol/atol constructor parameters with correct wiring into forward path.
nvalchemi/models/pme.py Mirrors the same cache-invalidation fix as ewald.py; changes are symmetric and correct.
test/models/test_base.py Adds five unit tests for cell_cache_needs_update covering the key regression cases; device-mismatch path is not tested.

Reviews (4): Last reviewed commit: "update tolerances and break up OR statem..." | Re-trigger Greptile

Comment thread nvalchemi/models/ewald.py Outdated
@ys-teh ys-teh requested a review from dallasfoster June 9, 2026 22:07

@laserkelvin laserkelvin left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general looks good to me, just have some minor things to discuss

Comment thread nvalchemi/models/_utils.py Outdated


def cell_cache_needs_update(
cell: torch.Tensor,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding the appropriate jaxtyping shape annotations?

You might need to have them as separate hints to denote that the cached and the incoming cells can be different shapes

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks for the reminder.

Comment thread nvalchemi/models/_utils.py Outdated
Comment thread nvalchemi/models/_utils.py Outdated
ys-teh added 3 commits June 15, 2026 21:09
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
@ys-teh ys-teh force-pushed the fix/ewald-pme-cache branch from 62734a6 to 347d601 Compare June 15, 2026 21:46
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
@ys-teh

ys-teh commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator Author

Thanks, I addressed all the comments including allowing users to change the tolerances when initializing the wrappers. All related tests passed.

@laserkelvin laserkelvin left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@laserkelvin

Copy link
Copy Markdown
Collaborator

/ok to test b0af7b4

@laserkelvin laserkelvin added this pull request to the merge queue Jun 15, 2026
Merged via the queue into NVIDIA:main with commit 01c99d5 Jun 15, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants