fix(neighbors): delegate batched method selection to ops#106
fix(neighbors): delegate batched method selection to ops#106nikitafedik wants to merge 1 commit into
Conversation
Greptile SummaryFixes a correctness bug where
Important Files Changed
Reviews (3): Last reviewed commit: "fix(neighbors): delegate batched method ..." | Re-trigger Greptile |
| def test_compute_neighbors_multi_graph_isolation(self, device: str): | ||
| """compute_neighbors must not build neighbors across Batch graph boundaries.""" | ||
| from nvalchemi.neighbors import compute_neighbors | ||
|
|
||
| batch = _line_batch(device, n_graphs=4) | ||
| compute_neighbors(batch, cutoff=_CUTOFF, max_neighbors=16) | ||
|
|
||
| _assert_no_cross_graph_neighbors(batch) | ||
|
|
||
| def test_compute_neighbors_passes_explicit_batched_method( | ||
| self, device: str, monkeypatch: pytest.MonkeyPatch | ||
| ): | ||
| """Toolkit should not rely on implicit Toolkit-Ops method selection.""" | ||
| from nvalchemi import neighbors as neighbors_mod | ||
| from nvalchemi.neighbors import compute_neighbors | ||
|
|
||
| methods: list[str | None] = [] | ||
|
|
||
| def fake_neighbor_list(**kwargs): | ||
| methods.append(kwargs.get("method")) | ||
| kwargs["num_neighbors"].zero_() | ||
|
|
||
| monkeypatch.setattr(neighbors_mod, "neighbor_list", fake_neighbor_list) | ||
|
|
||
| batch = _line_batch(device, n_graphs=4) | ||
| compute_neighbors(batch, cutoff=_CUTOFF, max_neighbors=16) | ||
|
|
||
| assert methods == ["batch_naive"] |
There was a problem hiding this comment.
New compute_neighbors tests placed in wrong class
test_compute_neighbors_multi_graph_isolation and test_compute_neighbors_passes_explicit_batched_method are appended to TestAdaptiveK, which is focused on neighbor-count overflow and shrinkage behaviour. Both tests cover graph-boundary isolation and explicit method dispatch — neither exercises the adaptive-K machinery. These would be easier to discover in a dedicated class, e.g. TestComputeNeighbors.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| cutoff=cutoff, | ||
| cell=cell, | ||
| pbc=pbc, | ||
| method=_select_batched_neighbor_list_method(N, batch.num_graphs), |
There was a problem hiding this comment.
Method string recomputed on every overflow retry
_select_batched_neighbor_list_method(N, batch.num_graphs) is evaluated on each loop iteration, even though N and batch.num_graphs are invariant inside the while True block. Computing it once before the loop would make the intent clearer. Not a correctness issue.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Reproduced against Ops main checkout: I put the ops-main source checkout first on Output: Script: #!/usr/bin/env python3
"""Reproduce cross-system neighbor edges with forced unbatched NL methods.
Run against Toolkit-Ops main by putting the source checkout first on PYTHONPATH,
for example:
WARP_CACHE_PATH=/tmp/warp-cache-batching-bug \
PYTHONPATH=/path/to/nvalchemi-toolkit-ops-main \
python repro_ops_main_neighbor_boundaries.py
"""
from __future__ import annotations
import os
os.environ.setdefault("WARP_CACHE_PATH", "/tmp/warp-cache-batching-bug")
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/torch-cache-batching-bug")
os.environ.setdefault("TRITON_CACHE_DIR", "/tmp/triton-cache-batching-bug")
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", "/tmp/torchinductor-cache-batching-bug")
for cache_dir in (
os.environ["WARP_CACHE_PATH"],
os.environ["XDG_CACHE_HOME"],
os.environ["TRITON_CACHE_DIR"],
os.environ["TORCHINDUCTOR_CACHE_DIR"],
):
os.makedirs(cache_dir, exist_ok=True)
import torch
import nvalchemiops
from nvalchemiops.torch.neighbors import neighbor_list
def build_batched_waters(
n_systems: int = 4,
device: torch.device | str = "cpu",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[str]]:
"""Return identical water-like systems packed into one coordinate array."""
one_water = torch.tensor(
[
[0.0, 0.0, 0.0],
[0.968565, 0.0, 0.0],
[-0.242, 0.928, 0.0],
],
dtype=torch.float32,
device=device,
)
positions = one_water.repeat(n_systems, 1)
batch_idx = torch.repeat_interleave(
torch.arange(n_systems, dtype=torch.int32, device=device), 3
)
batch_ptr = torch.arange(
0, 3 * n_systems + 1, 3, dtype=torch.int32, device=device
)
atom_names = ["O0", "H1", "H2"] * n_systems
return positions, batch_idx, batch_ptr, atom_names
def count_cross_edges(
neighbor_matrix: torch.Tensor,
num_neighbors: torch.Tensor,
batch_idx: torch.Tensor,
atom_names: list[str],
) -> tuple[int, int, list[str]]:
total_edges = 0
cross_edges = 0
examples: list[str] = []
for src in range(neighbor_matrix.shape[0]):
src_system = int(batch_idx[src].item())
for dst in neighbor_matrix[src, : int(num_neighbors[src].item())].tolist():
dst = int(dst)
total_edges += 1
dst_system = int(batch_idx[dst].item())
if src_system != dst_system:
cross_edges += 1
if len(examples) < 5:
examples.append(
f"{src}->{dst} "
f"(system {src_system}->{dst_system}, "
f"{atom_names[src]}->{atom_names[dst]})"
)
return total_edges, cross_edges, examples
def run_case(method: str | None) -> tuple[str, int, int, list[str]]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
positions, batch_idx, batch_ptr, atom_names = build_batched_waters(device=device)
n_atoms = positions.shape[0]
neighbor_matrix = torch.full(
(n_atoms, 32), n_atoms, dtype=torch.int32, device=device
)
num_neighbors = torch.zeros(n_atoms, dtype=torch.int32, device=device)
neighbor_list(
positions=positions,
cutoff=1.2,
batch_idx=batch_idx,
batch_ptr=batch_ptr,
max_neighbors=32,
neighbor_matrix=neighbor_matrix,
num_neighbors=num_neighbors,
method=method,
)
total_edges, cross_edges, examples = count_cross_edges(
neighbor_matrix.cpu(),
num_neighbors.cpu(),
batch_idx.cpu(),
atom_names,
)
return str(method), total_edges, cross_edges, examples
def main() -> None:
print(f"nvalchemiops: {nvalchemiops.__file__}")
print(f"torch device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print()
print(f"{'method':16s} {'total_edges':>11s} {'cross_edges':>11s} examples")
print("-" * 80)
for method in ("batch_naive", "naive", "batch_cell_list", "cell_list", None):
label, total_edges, cross_edges, examples = run_case(method)
print(f"{label:16s} {total_edges:11d} {cross_edges:11d} {examples}")
if __name__ == "__main__":
main() |
24e312b to
af576e7
Compare
Signed-off-by: Nikita Fedik <nfedik@nvidia.com>
af576e7 to
d3480a3
Compare
|
Updated after bot review: removed the unused _alloc_nl_kwargs arguments and the stale batch_ptr/plumbing comment in _alloc_staging_buffers. Re-ran ruff, diff-check, focused allocation tests, and the full test/hooks/test_neighbor_list_hook.py file: 123 passed. |
ALCHEMI Toolkit Pull Request
Description
Align Toolkit neighbor-list construction with current Toolkit-Ops dispatch:
compute_neighborsandNeighborListHooknow pass batched metadata(
batch_idx/batch_ptr) without forcing an explicit neighbor-list method.That lets Toolkit-Ops choose the correct batched strategy via
method=None.This pairs with an upstream Toolkit-Ops guard that rejects explicit
single-system methods such as
method="naive"ormethod="cell_list"whenbatched metadata is provided.
The observed failure mode is that forcing
method="naive"can connect atomsfrom different batched systems as neighbors. The model then treats those
cross-system edges like real neighbor interactions/messages.
Type of Change
Related Issues
Relates to neighbor-list batching reports where explicit unbatched Toolkit-Ops
methods can create cross-graph neighbor edges.
Changes Made
compute_neighborsandNeighborListHookto leavemethod=Nonewhen calling Toolkit-Ops with batched metadata.NeighborListHook; Toolkit-Ops chooses among geometry-dependent strategies at dispatch time.compute_neighborsand strengthenedNeighborListHookboundary assertions.CHANGELOG.md.Testing
make pytest)make lint)Ran locally:
Result:
123 passed, 16 warnings.Result:
All checks passed!Coverage added in
test/hooks/test_neighbor_list_hook.py:test_compute_neighbors_multi_graph_isolationverifies one-shot neighborconstruction does not create cross-graph neighbor entries.
test_multi_graph_isolationnow checks every validNeighborListHookneighbor entry for
src_system == dst_system.test_compute_neighbors_delegates_method_selectionverifiescompute_neighborsleavesmethod=None.TestAllocNlKwargsverifiesNeighborListHookdoes not forward stalealgorithm-specific scratch kwargs while Toolkit-Ops owns method selection.
Here "hooks" refers to Toolkit runtime hooks, not Git or CI hooks. CI exercises
them by running pytest; the tests instantiate and call
NeighborListHookdirectly.
Live H2O boundary probe against the patched Toolkit branch:
Checklist
Additional Notes
This PR intentionally does not change the Toolkit public API. It updates
Toolkit's own neighbor-list callers to use Toolkit-Ops' official batched
auto-dispatch path.
The paired upstream Toolkit-Ops PR should land first or alongside this one so
direct explicit misuse fails loudly instead of silently treating a concatenated
batch as one system.
The current
CONTRIBUTING.mdsays public direct contributions are not acceptedduring the initial public beta, and signed-off commits are required.