diff --git a/.claude/skills/README.md b/.claude/skills/README.md index f39a6399..1d820b0a 100644 --- a/.claude/skills/README.md +++ b/.claude/skills/README.md @@ -6,7 +6,8 @@ concise instructions on how to use the `nvalchemi` API for elementary tasks. - `nvalchemi-data-structures`: how to use individual atomic systems as well as batches. -- `nvalchemi-data-storage`: how to write and read atomic data. +- `nvalchemi-data-storage`: how to write, read, compose, and load atomic data. +- `nvalchemi-zarr-perf`: how to tune Zarr-backed Dataset/DataLoader throughput. - `nvalchemi-model-wrapping`: how to wrap MLIPs to use arbitrary models within `nvalchemi`. - `nvalchemi-dynamics-implementation`: how to implement a simple dynamics class. - `nvalchemi-dynamics-hooks`: how to implement and use `Hook`s in dynamics. diff --git a/.claude/skills/nvalchemi-data-storage/SKILL.md b/.claude/skills/nvalchemi-data-storage/SKILL.md index abbc3762..e1059281 100644 --- a/.claude/skills/nvalchemi-data-storage/SKILL.md +++ b/.claude/skills/nvalchemi-data-storage/SKILL.md @@ -1,6 +1,9 @@ --- name: nvalchemi-data-storage -description: How to write, read, and load atomic data using nvalchemi's composable Zarr-backed storage pipeline (Writer, Reader, Dataset, DataLoader). +description: >- + How to write, read, compose, and load atomic data using nvalchemi's + composable Zarr-backed storage pipeline (Writer, Reader, Dataset, + MultiDataset, DataLoader). --- # nvalchemi Data Storage @@ -9,15 +12,17 @@ description: How to write, read, and load atomic data using nvalchemi's composab `nvalchemi` provides a composable pipeline for persisting and loading atomic data: -``` +```text Writer Reader (AtomicData/Batch -> Zarr) (Zarr -> dict[str, Tensor]) | Dataset - (dict -> AtomicData, device transfer, prefetch) + (dict -> AtomicData, load_batches, prefetch) + | + optional MultiDataset composition | DataLoader - (AtomicData -> Batch, batching, iteration) + (Batch iteration) ``` ```python @@ -25,7 +30,9 @@ from nvalchemi.data.datapipes import ( AtomicDataZarrWriter, AtomicDataZarrReader, Dataset, + MultiDataset, DataLoader, + MultiDatasetBatchSampler, ) ``` @@ -33,7 +40,8 @@ from nvalchemi.data.datapipes import ( ## Writing Data -`AtomicDataZarrWriter` serializes `AtomicData`, `list[AtomicData]`, or `Batch` into a Zarr store. +`AtomicDataZarrWriter` serializes `AtomicData`, `list[AtomicData]`, or +`Batch` into a Zarr store. ```python from nvalchemi.data import AtomicData, Batch @@ -82,7 +90,7 @@ writer.defragment() # rebuild store without deleted samples ### Zarr store layout -``` +```text dataset.zarr/ ├── meta/ │ ├── atoms_ptr # int64 [N+1] — cumulative node counts @@ -144,6 +152,10 @@ atomic_data, metadata = ds[0] # AtomicData on target device # Lightweight metadata (no full construction) num_atoms, num_edges = ds.get_metadata(0) +# Explicit batch loading. This is the canonical synchronous batch API. +batches = ds.load_batches([[0, 3, 2], [4, 1, 5]]) +batch0 = batches[0] + len(ds) # number of samples ds.close() @@ -178,8 +190,8 @@ Iterates over a `Dataset` in batches, producing `Batch` objects. ```python from nvalchemi.data.datapipes import AtomicDataZarrReader, Dataset, DataLoader -reader = AtomicDataZarrReader("dataset.zarr") -ds = Dataset(reader, device="cuda", num_workers=4) +reader = AtomicDataZarrReader("dataset.zarr", pin_memory=True) +ds = Dataset(reader, device="cuda", num_workers=1) loader = DataLoader( ds, @@ -187,11 +199,14 @@ loader = DataLoader( shuffle=True, drop_last=False, sampler=None, # optional torch Sampler - prefetch_factor=2, # batches to prefetch ahead - num_streams=4, # CUDA streams for prefetching + prefetch_factor=16, # fuse 16 batches per read_many call + num_streams=2, # CUDA streams for prefetching use_streams=True, # enable stream prefetching ) +# For throughput tuning (skip_validation, prefetch_factor, chunk/shard +# sizing), load the nvalchemi-zarr-perf agent skill. + for batch in loader: # batch is a Batch with concatenated tensors on target device print(batch.num_graphs, batch.num_nodes) @@ -200,6 +215,45 @@ len(loader) # number of batches loader.set_epoch(epoch) # for distributed sampler ``` +Use `prefetch_factor=0` to disable async fused prefetch while still reading each +emitted batch through `Dataset.load_batches([indices])`. For explicit/manual +batch reads, use `load_batches(...)`. + +### Composing multiple datasets + +Use `MultiDataset` to concatenate multiple `Dataset` instances behind one global +index space while keeping the same `load_batches(...)` fast path: + +```python +from nvalchemi.data.datapipes import ( + AtomicDataZarrReader, + DataLoader, + Dataset, + MultiDataset, + MultiDatasetBatchSampler, +) + +ds_a = Dataset(AtomicDataZarrReader("dataset_a.zarr"), device="cuda") +ds_b = Dataset(AtomicDataZarrReader("dataset_b.zarr"), device="cuda") +dataset = MultiDataset(ds_a, ds_b, output_strict=True) + +batch_sampler = MultiDatasetBatchSampler.balanced( + dataset, + batch_size=64, + epoch_policy="max_size", # oversample smaller datasets when replacement=True + replacement=True, +) + +loader = DataLoader(dataset, batch_sampler=batch_sampler, prefetch_factor=16) +``` + +Sampler notes: + +- `samples_per_dataset` accepts integer counts or float ratios. +- `epoch_policy="min_size"` stops at the smallest contributing dataset. +- `epoch_policy="max_size"` covers the largest dataset and oversamples smaller + datasets when `replacement=True`. + --- ## Custom Readers @@ -218,6 +272,10 @@ class MyReader(Reader): """Load raw tensor dict for a single sample.""" ... + def _load_many_samples(self, indices) -> list[dict[str, torch.Tensor]]: + """Optional fast path for coalesced batch reads.""" + ... + def __len__(self) -> int: """Total number of samples.""" ... diff --git a/.claude/skills/nvalchemi-data-structures/SKILL.md b/.claude/skills/nvalchemi-data-structures/SKILL.md index f8c0e298..720ecf87 100644 --- a/.claude/skills/nvalchemi-data-structures/SKILL.md +++ b/.claude/skills/nvalchemi-data-structures/SKILL.md @@ -1,6 +1,8 @@ --- name: nvalchemi-data-structures -description: How to use AtomicData and Batch — the core graph-based data structures for representing atomic systems and batching them for GPU computation. +description: >- + How to use AtomicData and Batch, the core graph-based data structures for + representing atomic systems and batching them for GPU computation. --- # nvalchemi Data Structures @@ -10,7 +12,8 @@ description: How to use AtomicData and Batch — the core graph-based data struc `nvalchemi` represents atomic systems as graphs using two core classes: - **`AtomicData`** — a single atomic system (molecule, crystal, etc.) -- **`Batch`** — an efficient container of multiple `AtomicData` objects stored as concatenated tensors +- **`Batch`** — an efficient container of multiple `AtomicData` objects + stored as concatenated tensors Both are Pydantic `BaseModel` subclasses with `DataMixin` for device/dtype operations. @@ -274,7 +277,10 @@ batch.model_dump_json() # JSON string ### Distributed communication -`Batch` supports point-to-point distributed communication via `torch.distributed`. Data is sent in three phases: a metadata header (`num_graphs`, `num_nodes`, `num_edges`), per-group segment lengths, and bulk tensor data. +`Batch` supports point-to-point distributed communication via +`torch.distributed`. Data is sent in three phases: a metadata header +(`num_graphs`, `num_nodes`, `num_edges`), per-group segment lengths, +and bulk tensor data. **Blocking send/recv:** @@ -304,10 +310,14 @@ received = handle.wait() # block until data arrives, returns Batch **Key details:** -- `template` is required on the receiver to know the attribute keys, dtypes, and group structure (atoms/edges/system). Cache it across calls. -- A 0-graph (sentinel) batch can be sent/received — only the metadata header is transmitted. -- `tag` is a base tag; it is incremented internally per group. Use distinct base tags for concurrent send/recv pairs. -- `empty_like(batch)` creates a 0-graph batch with the same schema — useful for sentinel signals. +- `template` is required on the receiver to know the attribute keys, + dtypes, and group structure (atoms/edges/system). Cache it across calls. +- A 0-graph sentinel batch can be sent or received. Only the metadata + header is transmitted. +- `tag` is a base tag incremented internally per group. Use distinct + base tags for concurrent send/recv pairs. +- `empty_like(batch)` creates a 0-graph batch with the same schema, which + is useful for sentinel signals. ```python sentinel = Batch.empty_like(batch, device="cuda") # 0-graph, same schema diff --git a/.claude/skills/nvalchemi-dynamics-api/SKILL.md b/.claude/skills/nvalchemi-dynamics-api/SKILL.md index 07799e46..94db22b0 100644 --- a/.claude/skills/nvalchemi-dynamics-api/SKILL.md +++ b/.claude/skills/nvalchemi-dynamics-api/SKILL.md @@ -215,11 +215,17 @@ stage = DemoDynamics( ) ``` -The default `comm_mode` is `"async_recv"`. The three modes differ in when blocking occurs: - -- `"sync"`: `irecv` completes inline in `_prestep_sync_buffers` — simplest, good for debugging -- `"async_recv"`: `irecv` is posted in `_prestep_sync_buffers` but `wait()` deferred to `_complete_pending_recv` — allows compute/communication overlap -- `"fully_async"`: both send and receive are deferred — maximum overlap, highest throughput; pending sends from the previous step are drained at the start of the next `_prestep_sync_buffers` +The default `comm_mode` is `"async_recv"`. The three modes differ in when +blocking occurs: + +- `"sync"`: `irecv` completes inline in `_prestep_sync_buffers`; simplest + and good for debugging. +- `"async_recv"`: `irecv` is posted in `_prestep_sync_buffers`, but + `wait()` is deferred to `_complete_pending_recv` for communication + overlap. +- `"fully_async"`: send and receive are both deferred for maximum + overlap. Pending sends from the prior step are drained at the start of + the next `_prestep_sync_buffers`. ### Pre-allocated buffers @@ -240,9 +246,12 @@ stage = DemoDynamics( ) ``` -Buffers are **lazily initialized** on the first step using the first concrete batch as a template for attribute keys, dtypes, and shapes. This means the first step has slightly more overhead. +Buffers are **lazily initialized** on the first step using the first +concrete batch as a template for attribute keys, dtypes, and shapes. +This means the first step has slightly more overhead. -Adjacent stages must use identical `BufferConfig` values — this is validated in `DistributedPipeline.setup()`. +Adjacent stages must use identical `BufferConfig` values. This is +validated in `DistributedPipeline.setup()`. --- @@ -262,20 +271,27 @@ The dynamics framework manages data flow through three layers: Each pipeline step follows a four-phase protocol: -1. `_prestep_sync_buffers()` — zeros send buffer, posts `irecv` from prior rank -2. `_complete_pending_recv()` — waits on deferred recv, routes into active batch, drains overflow sinks -3. `step()` — dynamics integration -4. `_poststep_sync_buffers(converged_indices)` — extracts converged into send buffer, sends to next rank +1. `_prestep_sync_buffers()` zeros the send buffer and posts `irecv` + from the prior rank. +2. `_complete_pending_recv()` waits on deferred receive, routes into + the active batch, and drains overflow sinks. +3. `step()` runs dynamics integration. +4. `_poststep_sync_buffers(converged_indices)` extracts converged + samples into the send buffer and sends them to the next rank. -**Deadlock prevention:** when no samples converge, an empty send buffer is still sent so the downstream `irecv` completes. +**Deadlock prevention:** when no samples converge, an empty send buffer +is still sent so the downstream `irecv` completes. ### Back-pressure When `send_buffer` has limited capacity (via `BufferConfig`): - Only `min(converged_count, remaining_capacity)` samples are extracted -- Excess converged samples remain in the active batch as **no-ops** — their positions/velocities are saved before the integrator and restored after -- Without `BufferConfig`, all converged samples are sent without constraints (backward compat) +- Excess converged samples remain in the active batch as **no-ops**. + Their positions and velocities are saved before the integrator and + restored after it runs. +- Without `BufferConfig`, all converged samples are sent without + constraints (backward compatible). ### Buffer lifecycle: put/defrag/zero @@ -294,7 +310,9 @@ src_batch.defrag() buffer.zero() ``` -**Important:** `Batch.put()` uses Warp GPU kernels that only handle float32 attributes. Adjacent pipeline stages must have identical `BufferConfig` values. +**Important:** `Batch.put()` uses Warp GPU kernels that only handle +float32 attributes. Adjacent pipeline stages must have identical +`BufferConfig` values. ### Data routing methods @@ -348,7 +366,8 @@ When `refill_frequency` triggers (every N steps), `_refill_check()`: 5. Appends replacements via `Batch.append` 6. Rebuilds `status` (replacements get `0`) and `fmax` (replacements get `inf`) tensors -This produces a **new** `Batch` object (not in-place mutation). Returns `None` when the sampler is exhausted and no active samples remain. +This produces a **new** `Batch` object, not an in-place mutation. It +returns `None` when the sampler is exhausted and no active samples remain. ### With FusedStage diff --git a/.claude/skills/nvalchemi-dynamics-implementation/SKILL.md b/.claude/skills/nvalchemi-dynamics-implementation/SKILL.md index 788a959e..1cc91875 100644 --- a/.claude/skills/nvalchemi-dynamics-implementation/SKILL.md +++ b/.claude/skills/nvalchemi-dynamics-implementation/SKILL.md @@ -22,7 +22,7 @@ from nvalchemi.data import Batch Each call to `step(batch)` executes: -``` +```text 1. BEFORE_STEP hooks 2. BEFORE_PRE_UPDATE hooks → pre_update(batch) → AFTER_PRE_UPDATE hooks 3. BEFORE_COMPUTE hooks → compute(batch) → AFTER_COMPUTE hooks diff --git a/.claude/skills/nvalchemi-loss-api/SKILL.md b/.claude/skills/nvalchemi-loss-api/SKILL.md new file mode 100644 index 00000000..7329c574 --- /dev/null +++ b/.claude/skills/nvalchemi-loss-api/SKILL.md @@ -0,0 +1,221 @@ +--- +name: nvalchemi-loss-api +description: How to use built-in loss functions and implement custom losses using the BaseLossFunction template-method pattern — residual types, per-atom normalization, masking, and graph-balanced reductions. +--- + +# nvalchemi Loss API + +## Overview + +Loss functions are `torch.nn.Module` subclasses rooted at `BaseLossFunction`. +Each leaf consumes `(pred, target, **kwargs)` and returns a scalar. +`ComposedLossFunction` routes keyed prediction/target mappings to leaves, +applies per-component weights (float or `LossWeightSchedule`), and returns +a `ComposedLossOutput` TypedDict. + +```python +from nvalchemi.training import ( + BaseLossFunction, + ComposedLossFunction, + ReductionContext, + EnergyMSELoss, + EnergyMAELoss, + ForceMSELoss, + ForceL2NormLoss, + StressMSELoss, +) +``` + +--- + +## Built-in losses + +| Class | Target shape | Residual | Key defaults | Extra knobs | +|---|---|---|---|---| +| `EnergyMSELoss` | `(B, 1)` | squared | `energy` / `predicted_energy` | `per_atom`, `ignore_nonfinite` | +| `EnergyMAELoss` | `(B, 1)` or `(B,)` | absolute | `energy` / `predicted_energy` | `per_atom`, `ignore_nonfinite` | +| `ForceMSELoss` | `(V, 3)` or `(B, V_max, 3)` | squared component | `forces` / `predicted_forces` | `normalize_by_atom_count`, `ignore_nonfinite` | +| `ForceL2NormLoss` | `(V, 3)` or `(B, V_max, 3)` | vector L2 norm | `forces` / `predicted_forces` | `normalize_by_atom_count`, `ignore_nonfinite` | +| `StressMSELoss` | `(B, 3, 3)` | squared (Frobenius) | `stress` / `predicted_stress` | `ignore_nonfinite` | + +**Composition sugar:** + +```python +loss_fn = 1.0 * EnergyMSELoss() + 10.0 * ForceMSELoss() + 0.1 * StressMSELoss() +out = loss_fn(predictions, targets, step=step, epoch=epoch, batch=batch) +out["total_loss"].backward() +``` + +**Graph metadata:** losses that need graph structure (`per_atom=True`, +`normalize_by_atom_count=True`, or padded layouts) accept `batch=` +(pulls `batch_idx`, `num_graphs`, `num_nodes_per_graph` automatically) +or explicit kwargs. + +--- + +## Template-method pattern + +`BaseLossFunction.forward` orchestrates five hooks: + +```text +forward(pred, target, **kwargs) + 1. validate(pred, target) # shape checks + 2. pred, target, ctx = normalize(pred, target, **kwargs) # pre-processing + 3. valid = mask(pred, target, ctx, **kwargs) # boolean validity mask + 4. residual = compute_residual(pred, target, valid) # ABSTRACT — must override + 5. scalar = reduce(residual, valid, ctx, **kwargs) # collapse to scalar +``` + +**Minimum implementation:** override `compute_residual` only. Defaults +handle shape validation, all-True masking, and validity-weighted mean reduction. + +--- + +## Writing a custom loss + +### Minimal: compute_residual only + +```python +class HuberEnergyLoss(BaseLossFunction): + def __init__(self, *, target_key="energy", prediction_key="predicted_energy", delta=1.0): + super().__init__() + self.target_key = target_key + self.prediction_key = prediction_key + self.delta = delta + + def compute_residual(self, pred, target, valid): + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + abs_r = residual.abs() + return torch.where( + abs_r < self.delta, + 0.5 * residual.pow(2), + self.delta * (abs_r - 0.5 * self.delta), + ) +``` + +### Per-atom normalization (normalize override) + +Override `normalize` to divide by atom counts and pass weights via +`ReductionContext["weights"]`. The base `reduce` picks up weights +automatically. + +```python +class PerAtomEnergyMSE(BaseLossFunction): + target_key = "energy" + prediction_key = "predicted_energy" + + def normalize(self, pred, target, **kwargs): + ctx = ReductionContext() + counts = kwargs["num_nodes_per_graph"].to(dtype=pred.dtype).unsqueeze(-1).clamp_min(1.0) + ctx["weights"] = counts # base reduce uses this for atom-count-weighted mean + return pred / counts, target / counts, ctx + + def compute_residual(self, pred, target, valid): + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return residual.pow(2) +``` + +### Custom masking (mask override) + +Override `mask` to exclude non-finite targets, padding, or other invalid entries. +Return a boolean tensor broadcast-compatible with `pred`/`target`. + +```python +def mask(self, pred, target, ctx, **kwargs): + if self.ignore_nonfinite: + return torch.isfinite(target) + return torch.ones_like(target, dtype=torch.bool) +``` + +For padded force layouts `(B, V_max, 3)`, combine a node mask with nonfinite check: + +```python +def mask(self, pred, target, ctx, **kwargs): + num_nodes_per_graph = kwargs.get("num_nodes_per_graph") + node_mask = _padded_node_mask(num_nodes_per_graph, pred, pred.shape[1]) + valid = node_mask.unsqueeze(-1).expand_as(pred) + if self.ignore_nonfinite: + valid = valid & torch.isfinite(target) + return valid +``` + +The `valid` tensor flows into `compute_residual` as the third argument. +Zero invalid entries with `torch.where(valid, ..., torch.zeros_like(...))`. + +### Custom reduction (reduce override) + +Override `reduce` for graph-balanced or other non-mean reductions. +Populate `self.per_sample_loss` with a detached `(B,)` tensor for diagnostics. + +```python +from nvalchemi.training.losses.reductions import per_graph_sum + +def reduce(self, residual, valid, ctx, **kwargs): + batch_idx = kwargs["batch_idx"] + num_graphs = kwargs["num_graphs"] + valid_f = valid.to(dtype=residual.dtype) + # Per-atom SE summed over xyz, then per-graph mean, then mean over graphs + per_atom_se = residual.sum(dim=-1) + per_atom_valid = valid_f.sum(dim=-1) + per_graph_num = per_graph_sum(per_atom_se, batch_idx, num_graphs) + per_graph_den = per_graph_sum(per_atom_valid, batch_idx, num_graphs) + per_sample = per_graph_num / per_graph_den.clamp_min(1.0) + self.per_sample_loss = per_sample.detach() + return per_sample.mean() +``` + +### Layout dispatch with plum (dense vs padded forces) + +`ForceMSELoss` and `ForceL2NormLoss` use `plum-dispatch` to handle both +dense `(V, 3)` and padded `(B, V_max, 3)` layouts without `if/else` on +`ndim`. Their `mask` and `reduce` hooks delegate to `@overload`/`@dispatch` +helper methods — one overload per layout. See these implementations in +`nvalchemi/training/losses/terms.py` as the reference pattern for +multi-layout losses. + +```python +from plum import dispatch, overload + +@overload +def _my_helper(self, pred: Forces, target: Forces, ...): + """Dense (V, 3) path.""" + ... + +@overload +def _my_helper(self, pred: _PaddedForces, target: _PaddedForces, ...): + """Padded (B, V_max, 3) path.""" + ... + +@dispatch +def _my_helper(self, pred, target, ...): + pass # plum routes to matching overload at runtime +``` + +--- + +## Conventions + +1. **Define `target_key` and `prediction_key`** on any loss that participates + in `ComposedLossFunction` — these route tensors from the prediction/target + mappings. +2. **Accept `**kwargs`** in hooks that receive them — `ComposedLossFunction` + forwards metadata kwargs to every component. +3. **`compute_residual` must zero invalid entries** using the `valid` mask + argument — the base `reduce` handles weighting but not masking. +4. **`ReductionContext`** is a `dict` subclass (not TypedDict) for + `torch.compile` compatibility. Conventional key: `"weights"` for + atom-count weights consumed by the base `reduce`. + +--- + +## Key files + +| File | Contents | +|---|---| +| `nvalchemi/training/losses/composition.py` | `BaseLossFunction`, `ComposedLossFunction`, `ReductionContext` | +| `nvalchemi/training/losses/terms.py` | All 5 built-in leaf losses | +| `nvalchemi/training/losses/reductions.py` | `per_graph_sum`, `per_graph_mean`, `frobenius_mse` | +| `nvalchemi/training/losses/schedules.py` | `ConstantWeight`, `LinearWeight`, `CosineWeight`, `PiecewiseWeight` | +| `nvalchemi/training/losses/base.py` | `LossWeightSchedule` protocol, re-exports | +| `test/training/test_losses.py` | Comprehensive tests for all loss terms | +| `docs/userguide/losses.md` | Full user guide with examples | diff --git a/.claude/skills/nvalchemi-model-wrapping/SKILL.md b/.claude/skills/nvalchemi-model-wrapping/SKILL.md index cc8a4501..60012e61 100644 --- a/.claude/skills/nvalchemi-model-wrapping/SKILL.md +++ b/.claude/skills/nvalchemi-model-wrapping/SKILL.md @@ -22,7 +22,7 @@ from nvalchemi.data import AtomicData, Batch A wrapped model uses **multiple inheritance**: your PyTorch model class + `BaseModelMixin`. -``` +```text ┌──────────────────────┐ ┌──────────────────┐ │ YourModel(nn.Module)│ │ BaseModelMixin │ │ - forward() │ │ - model_card │ diff --git a/.claude/skills/nvalchemi-zarr-perf/SKILL.md b/.claude/skills/nvalchemi-zarr-perf/SKILL.md new file mode 100644 index 00000000..37975020 --- /dev/null +++ b/.claude/skills/nvalchemi-zarr-perf/SKILL.md @@ -0,0 +1,280 @@ +--- +name: nvalchemi-zarr-perf +description: > + Performance tuning for nvalchemi's Zarr-backed Reader, Dataset, and + DataLoader pipeline. Use when configuring AtomicDataZarrReader, Dataset, + DataLoader, ZarrWriteConfig, or nvalchemi-io-test for training/inference + throughput, especially shuffled access, graph-like random access, fused + prefetch, pinned memory, validation overhead, or Zarr chunk/shard choices. +--- + +# Zarr DataLoader Performance Tuning + +Use this skill when optimizing nvalchemi Zarr reads or writing stores that will +later be read through the nvalchemi DataLoader. + +## Current API model + +The pipeline has clean ownership boundaries: + +- `Reader`: storage I/O only. Returns raw CPU tensor dictionaries plus metadata. +- `Dataset`: validation, optional validation skipping, device transfer, and async + prefetch orchestration. Its canonical explicit batch API is + `load_batches(batch_index_lists)`. +- `DataLoader`: sampler/batch iteration, fused prefetch, stream usage, and batch + construction. +- `MultiDataset`: global index composition over multiple Datasets while routing + `load_batches` requests to child datasets. +- `Sampler` / `batch_sampler`: semantic sample order and batch membership. Do not + rely on sampler windows to optimize storage I/O. + +Reader public methods: + +- `reader.read(index)`: one sample. +- `reader.read_many(indices)`: many samples, returned in the request order. + +Reader backend hooks: + +- `_load_sample(index)`: implement for simple single-sample formats. +- `_load_many_samples(indices)`: implement for batch-optimized formats. +- `__len__()`: total logical samples. + +The base `Reader` owns metadata finalization and optional pinned memory. Index +validity is the concrete reader's responsibility. `AtomicDataZarrReader` supports +negative logical indices, maps through the active sample mask, and implements +`_load_many_samples` as the fast path. + +## Recommended DataLoader setup + +```python +from nvalchemi.data.datapipes import ( + AtomicDataZarrReader, + Dataset, + DataLoader, +) + +reader = AtomicDataZarrReader("store.zarr") + +dataset = Dataset( + reader, + device="cuda", + num_workers=1, # 1 is enough; concurrent Zarr reads contend + skip_validation=True, # safe when store was written by the toolkit +) + +loader = DataLoader( + dataset, + batch_size=64, + shuffle=True, + prefetch_factor=16, # up to 64 * 16 = 1024 indices per backend read + num_streams=2, + use_streams=True, + pin_memory=True, # request pinned CPU tensors from the reader +) +``` + +Use `pin_memory=True` on `AtomicDataZarrReader(...)` directly only for manual +reader usage. For normal training, prefer `DataLoader(..., pin_memory=True)` so +the loader owns the transfer optimization. + +## Key knobs + +### `prefetch_factor` (DataLoader) + +Controls how many emitted batches are fused into one backend read: + +```text +effective_read_window = batch_size * prefetch_factor +``` + +For `batch_size=64, prefetch_factor=16`, the model still receives batches of 64 +graphs, but the Zarr reader sees up to 1024 logical indices per `read_many`. + +| Access pattern | Recommended `prefetch_factor` | +|----------------|------------------------------:| +| Sequential | 2-4 | +| Shuffled | 16-64 | +| Block-shuffle | 2-8 | + +Use `prefetch_factor=0` to disable fused prefetch and issue one backend read per +emitted batch through `Dataset.load_batches([indices])`. This is useful for +debugging or for stores where larger windows do not help. Positive +`prefetch_factor` values use the async +`prefetch_fused_batches(...)` / `get_fused_batches()` path. + +Manual batch reads should use: + +```python +batches = dataset.load_batches([[0, 4, 2], [8, 1, 3]]) +``` + +### `skip_validation` (Dataset) + +Bypasses per-sample `AtomicData` Pydantic validation (~4 ms/sample). +Constructs `Batch` directly from raw tensor dicts via +`Batch.from_raw_dicts()`. + +**Use when:** the store was written by `AtomicDataZarrWriter` or has been +validated externally. +**Do not use when:** the store contents are untrusted or from a third party. + +### `num_workers` (Dataset) + +Thread pool size for background Dataset prefetch work. Start with **1**. +Increase only if profiling shows CPU-side validation or device transfer is +underlapping and storage reads are not contending. + +### `pin_memory` (DataLoader or Reader) + +Pinned CPU tensors make async CPU-to-GPU transfer possible. Use with CUDA targets +and `use_streams=True`. + +Normal path: + +```python +loader = DataLoader(dataset, batch_size=64, pin_memory=True) +``` + +Manual reader path: + +```python +reader = AtomicDataZarrReader("store.zarr", pin_memory=True) +data, metadata = reader.read(0) +``` + +## Writing stores for fast random reads + +For shuffled training reads, avoid extremely large chunks unless reads are mostly +sequential. A practical starting point: + +```python +from zarr.codecs import ZstdCodec + +from nvalchemi.data.datapipes import ( + AtomicDataZarrWriter, + ZarrWriteConfig, + ZarrArrayConfig, +) + +config = ZarrWriteConfig( + core=ZarrArrayConfig( + compressors=(ZstdCodec(level=3),), + chunk_size=10_000, + shard_size=500_000, + ), +) +writer = AtomicDataZarrWriter("store.zarr", config=config) +``` + +Guidance: + +- `chunk_size` is rows along dimension 0, not number of structures. Atom fields + are stored on the total atom axis; edge fields on the total edge axis. +- Smaller chunks reduce single-sample read amplification but increase metadata + and codec overhead. +- Sharding groups many chunks into fewer storage objects and is useful when small + chunks would create too many files. +- Use `edge_chunk_size` / `edge_shard_size` in `nvalchemi-io-test` when edge + arrays need different tuning from atom/system arrays. +- Zstd level 3 is a good default ratio/speed tradeoff. LZ4 is useful when write + and decompression speed matter more than compression ratio. + +## How the reader optimises random access + +`AtomicDataZarrReader._load_many_samples(indices)` is the optimized path behind +public `reader.read_many(indices)`. + +It currently: + +1. Resolves logical indices through the active sample mask. +2. Sorts requests by physical sample index. +3. Groups physical positions by Zarr chunk locality. +4. Uses coalesced range reads when a small number of chunk-local runs exists. +5. Falls back to orthogonal selection for highly fragmented requests. +6. Restores the caller's original request order. + +This is transparent to Dataset, DataLoader, and Samplers. Larger fused read +windows give the Zarr backend more indices to coalesce, which is why +`prefetch_factor` matters most for shuffled reads. + +For multidataset training, use `MultiDatasetBatchSampler` or +`MultiDatasetBatchSampler.balanced(...)` to define semantic dataset mixing +rates. +`samples_per_dataset` may be integer counts or float ratios. Use +`epoch_policy="max_size", replacement=True` when smaller datasets should be +oversampled so the largest dataset does not dominate an epoch. + +## Benchmark workflow + +Use the current CLI subcommands: + +```bash +# Self-contained write + read benchmark. +env COLUMNS=240 uv run nvalchemi-io-test roundtrip \ + -n 10000 \ + --read-mode batch \ + --read-order shuffle \ + --batch-size 64 \ + --prefetch-factor 16 \ + --pin-memory + +# Sweep prefetch factors on the same access pattern. +for pf in 8 16 32 64 128; do + env COLUMNS=240 uv run nvalchemi-io-test roundtrip \ + -n 10000 \ + --read-mode batch \ + --read-order shuffle \ + --batch-size 64 \ + --prefetch-factor "$pf" \ + --pin-memory +done + +# Benchmark an existing store without rewriting it. +env COLUMNS=240 uv run nvalchemi-io-test read /path/to/store.zarr \ + --read-order shuffle \ + --batch-size 64 \ + --prefetch-factor 32 \ + --pin-memory + +# Compare DataLoader fused reads against one-sample-at-a-time reads. +env COLUMNS=240 uv run nvalchemi-io-test read /path/to/store.zarr \ + --read-mode both \ + --read-order shuffle \ + --batch-size 64 \ + --prefetch-factor 32 +``` + +Important benchmark semantics: + +- `read-mode=batch` uses the public DataLoader path with fused prefetch. +- Benchmark batch mode uses `Dataset(skip_validation=True)` to focus on storage + and batching throughput. +- `read-mode=single` calls `reader.read(index)` once per sample and is only a + baseline for one-sample-at-a-time access. +- `batch_size` is the model-facing batch size. +- `prefetch_factor` controls the backend read window. +- Use `read-order=shuffle` to model fully shuffled training reads. +- Use `read-order=block-shuffle` to test partial locality. + +## Diagnosing bottlenecks + +1. Run `nvalchemi-io-test read` on an existing representative store. +2. Sweep `prefetch_factor` at the target `batch_size`. +3. Compare `read-mode=batch` against `read-mode=single`. +4. If batch mode is fast but training is slow, inspect validation, batching, and + device-transfer overhead. Try `skip_validation=True`, `pin_memory=True`, and + CUDA streams. +5. If batch mode is slow, inspect chunk/shard configuration, compression codec, + filesystem metadata pressure, and read order. + +## Quick checklist + +- [ ] Use `Dataset(skip_validation=True)` for trusted toolkit-written stores. +- [ ] Use `DataLoader(pin_memory=True)` for CUDA training. +- [ ] Start with `batch_size=64`. +- [ ] Start with `prefetch_factor=16` or `32` for shuffled reads. +- [ ] Sweep `prefetch_factor=8,16,32,64,128` with `nvalchemi-io-test`. +- [ ] Keep sampler semantics independent from storage locality. +- [ ] Use `load_batches(...)` for explicit batch reads. +- [ ] Tune chunk/shard sizes on a representative store and filesystem. +- [ ] Use `read-mode=single` only as a baseline, not as the training path. diff --git a/CHANGELOG.md b/CHANGELOG.md index ec3b5840..131d4a23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,43 @@ ## Unreleased +### Added + +- `EMAHook._build_averaged_model` override seam, so a caller that owns + model sharding can supply a pre-built `AveragedModel` instead of the + default deepcopy — enabling EMA on `fully_shard` (FSDP2) / DTensor + models. Default behaviour unchanged. +- Checkpointable training hooks. Hooks such as EMA can now save restart + state with strategy checkpoints, so resumed training keeps averaged + weights instead of starting them over. +- Training strategy checkpoint restart support, including a periodic + checkpoint hook for step- or epoch-based saves and restart loading with + models, optimizers, schedulers, runtime counters, and restart-safe device + placement. +- PhysicsNeMo-compatible atomic datapipes with `MultiDataset` composition, + multidataset-aware sampling policies, and fused batch loading that preserves + the Zarr reader's coalesced I/O path. +- First-class validation on `TrainingStrategy`. Set a `ValidationConfig` + on `strategy.validation_config` and validation runs automatically at the + configured step or epoch cadence, plus one final pass at end-of-training; + the latest summary is stored on `strategy.last_validation`. Mechanics live + in a public, context-managed `ValidationLoop` that can also be run + standalone outside training. An `inference_model` slot lets EMA (or SWA / + a distillation teacher) publish averaged weights for validation to read. + A new `AFTER_VALIDATION` hook stage fires immediately after each pass so + loggers can read the live summary. For per-batch logging, pass a + `batch_callback` (any object matching the `BatchValidationCallback` + protocol) on the config; it is invoked once per validation batch with the + batch, predictions, and per-batch loss. +- Metric-driven learning-rate schedulers. `ReduceLROnPlateau` is now + supported via `OptimizerConfig.scheduler_metric_adapter` (a summary-dict + key string or a callable). Time-based schedulers step every optimizer + step as before; metric-driven schedulers step only at validation + checkpoints, where the validation summary supplies the metric. + ### Core Data Layer -- **User-specified transforms** — `Dataset` accepts a `transforms=` kwarg +- **User-specified transforms** - `Dataset` accepts a `transforms=` kwarg (per-sample `(AtomicData, metadata) -> (AtomicData, metadata)`) and `DataLoader` accepts a `batch_transforms=` kwarg (per-batch `Batch -> Batch`). Both default to `None` (backward compatible). New `nvalchemi.data.transforms` @@ -17,6 +51,14 @@ ### Fixed +- **Zarr dataloader custom fields** — validated `Dataset` batch paths now + preserve reader field-level metadata so custom atom-, edge-, and + system-level tensors survive batching like the `skip_validation` path. +- EMA checkpointing now restores averaged tensors to the corresponding live + model tensor devices, publishes restored EMA weights during SETUP before validation, + and supports callable reconstruction specs for model wrappers that must + rebuild from factory methods, including MACE checkpoints with + cuEquivariance enabled. - **MTK NPT barostat runaway** (#89, #90) — four bugs in `nvalchemi/dynamics/integrators/npt.py` (with matching fixes in `nph.py`) that combined to drive unbounded cell-volume drift in long @@ -39,6 +81,11 @@ ### Breaking Changes +- Dataset-level explicit batch reads now use `load_batches(...)`. The raw + `read_many(...)` API remains on readers, where storage backends can optimize + ordered I/O, but `Dataset.read_many(...)` and `Dataset.get_batch(...)` have + been removed to keep the public Dataset API focused on sample access, + batch materialization, and prefetching. - Split hook context state into `HookContext`, `DynamicsContext`, and `TrainContext` so each workflow exposes only the fields it owns. Dynamics-specific state such as `step_count`, `converged_mask`, and @@ -48,6 +95,27 @@ - Standardized public `stress` outputs on tensile-positive Cauchy stress (`sigma = -W / V`) while keeping low-level virials defined as negative strain derivatives. +- Removed `EvaluateHook` in favor of first-class validation on + `TrainingStrategy`. Validation is no longer a registered hook. Migrate by + moving the hook's arguments onto a `ValidationConfig`: + + ```python + # Before + strategy.register_hook( + EvaluateHook(validation_data=val_data, every_n_epochs=1) + ) + + # After + strategy.validation_config = ValidationConfig( + validation_data=val_data, every_n_epochs=1 + ) + ``` + + Validation then runs automatically during `strategy.run(...)` at the + configured cadence and once at end-of-training. The `EvaluationSink` / + `EvaluationZarrSink` output classes were removed; replace summary logging + with an `AFTER_VALIDATION` hook and per-batch logging with a + `ValidationConfig(batch_callback=...)`. ## 0.1.0 — 2026-04-16 diff --git a/SECURITY.md b/SECURITY.md index 9d1a7116..fe471bb0 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,24 +1,44 @@ -## Security +# Security -NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. +NVIDIA is dedicated to the security and trust of our software products and +services, including all source code repositories managed through our +organization. -If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.** If a potential security issue is inadvertently reported via a public issue or pull request, NVIDIA maintainers may limit public discussion and redirect the reporter to the appropriate private disclosure channels. +If you need to report a security issue, use the appropriate contact points +outlined below. **Do not report security vulnerabilities through GitHub.** If +a potential security issue is inadvertently reported through a public issue or +pull request, NVIDIA maintainers may limit public discussion and redirect the +reporter to the appropriate private disclosure channels. ## Reporting Potential Security Vulnerability in an NVIDIA Product To report a potential security vulnerability in any NVIDIA product: -- Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) -- E-Mail: psirt@nvidia.com - - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) - - Please include the following information: - - Product/Driver name and version/branch that contains the vulnerability - - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) - - Instructions to reproduce the vulnerability - - Proof-of-concept or exploit code - - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability - -While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. + +- Web: [Security Vulnerability Submission Form][vulnerability-form] +- Email: [psirt@nvidia.com][psirt-email] + - We encourage you to use the [NVIDIA public PGP key][pgp-key] for + secure email communication. + - Please include the following information: + - Product or driver name and version or branch containing the + vulnerability + - Type of vulnerability, such as code execution, denial of service, + or buffer overflow + - Instructions to reproduce the vulnerability + - Proof-of-concept or exploit code + - Potential impact, including how an attacker could exploit the + vulnerability + +NVIDIA currently does not have a bug bounty program, but we do offer +acknowledgement when an externally reported security issue is addressed under +our coordinated vulnerability disclosure policy. Visit the [PSIRT policies +page][psirt-policies] for more information. ## NVIDIA Product Security -For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security +For all security-related concerns, visit NVIDIA's [Product Security portal][product-security]. + +[vulnerability-form]:https://www.nvidia.com/object/submit-security-vulnerability.html +[psirt-email]:mailto:psirt@nvidia.com +[pgp-key]:https://www.nvidia.com/en-us/security/pgp-key +[psirt-policies]:https://www.nvidia.com/en-us/security/psirt-policies/ +[product-security]:https://www.nvidia.com/en-us/security diff --git a/docs/modules/data.rst b/docs/modules/data.rst index fac84dd3..16e787c0 100644 --- a/docs/modules/data.rst +++ b/docs/modules/data.rst @@ -31,6 +31,20 @@ I/O and pipelines DataLoader Reader +Dataset composition and sampling +-------------------------------- + +.. currentmodule:: nvalchemi.data.datapipes + +.. autosummary:: + :toctree: generated + :template: class.rst + :nosignatures: + + MultiDataset + MultiDatasetSampler + MultiDatasetBatchSampler + Write configuration ------------------- diff --git a/docs/modules/dynamics/api.rst b/docs/modules/dynamics/api.rst index b63f7356..cc85b70e 100644 --- a/docs/modules/dynamics/api.rst +++ b/docs/modules/dynamics/api.rst @@ -54,8 +54,9 @@ Hooks LoggingHook MaxForceClampHook NaNDetectorHook - ProfilerHook + TorchProfilerHook SnapshotHook + StageTimingHook General-purpose hooks (:class:`~nvalchemi.hooks.NeighborListHook`, :class:`~nvalchemi.hooks.BiasedPotentialHook`, diff --git a/docs/modules/dynamics/hooks.rst b/docs/modules/dynamics/hooks.rst index ffcf0046..ac092f02 100644 --- a/docs/modules/dynamics/hooks.rst +++ b/docs/modules/dynamics/hooks.rst @@ -145,11 +145,13 @@ monitor simulation state. * - :class:`~nvalchemi.dynamics.hooks.EnergyDriftMonitorHook` - Track cumulative energy drift in NVE runs; warn or halt on excessive drift. - * - :class:`~nvalchemi.dynamics.hooks.ProfilerHook` - - Instrument steps with NVTX ranges and wall-clock timing for - Nsight Systems profiling. Fires at multiple stages via - ``_runs_on_stage`` and uses ``plum`` dispatch to support - dynamics and custom workflows. + * - :class:`~nvalchemi.dynamics.hooks.StageTimingHook` + - Measure elapsed time between dynamics stages, with optional NVTX ranges, + CSV output, and console summaries. + * - :class:`~nvalchemi.dynamics.hooks.TorchProfilerHook` + - Capture PyTorch profiler Chrome traces through PhysicsNeMo. Starts at + ``BEFORE_STEP``, advances at ``AFTER_STEP``, and writes rank-specific + trace directories. Post-compute hooks (modify batch, fire at ``AFTER_COMPUTE``) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -252,17 +254,39 @@ Enhanced sampling with a bias potential hook = BiasedPotentialHook(bias_fn=harmonic_restraint, stage=DynamicsStage.AFTER_COMPUTE) dynamics = DemoDynamics(model=model, dt=0.5, hooks=[hook]) -Profiling with Nsight Systems -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Timing dynamics stages +~~~~~~~~~~~~~~~~~~~~~~ + +Use :class:`~nvalchemi.dynamics.hooks.StageTimingHook` for lightweight stage +timing and optional NVTX ranges. .. code-block:: python - from nvalchemi.dynamics.hooks import ProfilerHook + from nvalchemi.dynamics.hooks import StageTimingHook - hook = ProfilerHook(enable_nvtx=True, enable_timer=True, frequency=10) + hook = StageTimingHook("step", frequency=10, log_path="stage_timing.csv") dynamics = DemoDynamics(model=model, n_steps=1_000, dt=0.5, hooks=[hook]) + dynamics.run(batch) + +Capturing PyTorch Chrome traces +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Run under: nsys profile python my_script.py +Use :class:`~nvalchemi.dynamics.hooks.TorchProfilerHook` to capture PyTorch +operator traces through PhysicsNeMo. The hook starts at ``BEFORE_STEP`` and +advances the profiler schedule at ``AFTER_STEP``. + +.. code-block:: python + + from torch.profiler import ProfilerActivity, schedule + + from nvalchemi.dynamics.hooks import TorchProfilerHook + + hook = TorchProfilerHook( + output_dir="profiles/dynamics-run", + activities=(ProfilerActivity.CPU, ProfilerActivity.CUDA), + schedule=schedule(wait=2, warmup=2, active=5, repeat=1), + ) + dynamics = DemoDynamics(model=model, n_steps=100, dt=0.5, hooks=[hook]) dynamics.run(batch) NVE energy drift monitoring diff --git a/docs/modules/hooks.rst b/docs/modules/hooks.rst index 0f17b0f7..17800b6c 100644 --- a/docs/modules/hooks.rst +++ b/docs/modules/hooks.rst @@ -18,6 +18,8 @@ hooks that are useful regardless of the specific engine type. patterns. - **Dynamics hooks**: :ref:`dynamics-hooks` — hooks and stages specific to dynamics simulations. + - **Training update hooks**: :ref:`training-update-hooks` — update-stage + ownership, veto semantics, and constraints for training hooks. The Hook protocol @@ -55,6 +57,11 @@ use it as a type hint and check membership with ``isinstance``: class---or even a frozen ``dataclass``---that provides ``frequency``, ``stage``, and ``__call__`` works as a hook. +:class:`~nvalchemi.hooks.CheckpointableHook` is a second, optional protocol for +hooks that own restart-critical runtime state. It requires ``state_dict()`` and +``load_state_dict()`` and is used by training checkpoints to persist only hooks +that explicitly opt in. + Context dataclasses ------------------- @@ -108,6 +115,12 @@ only for one workflow category. * - ``step_count`` - ``int`` - Current optimizer step. + * - ``batch_count`` + - ``int`` + - Number of training batches consumed, including skipped optimizer steps. + * - ``epoch_step_count`` + - ``int`` + - Number of batches consumed within the current training epoch. * - ``epoch`` - ``int`` - Current training epoch. @@ -203,6 +216,59 @@ that uses the hook system, not just dynamics. - Wrap atomic positions back into the unit cell under PBC. Fires at ``AFTER_POST_UPDATE``, respects per-system ``batch.pbc`` flags. + * - :class:`~nvalchemi.hooks.StageTimingHook` + - Measure elapsed time between hook stages, with optional NVTX ranges, CSV + output, and console summaries. + * - :class:`~nvalchemi.hooks.TorchProfilerHook` + - Capture PyTorch profiler Chrome traces for training and dynamics through + PhysicsNeMo's profiler wrapper, with rank-specific output directories. + + +Stage timing +------------ + +:class:`~nvalchemi.hooks.StageTimingHook` records elapsed time between selected +hook stages. It is useful for lightweight per-stage timing and NVTX annotation; +use :class:`~nvalchemi.hooks.TorchProfilerHook` when you need PyTorch operator +Chrome traces. + +.. code-block:: python + + from nvalchemi.dynamics.base import DynamicsStage + from nvalchemi.hooks import StageTimingHook + + timing_hook = StageTimingHook( + {DynamicsStage.BEFORE_STEP, DynamicsStage.AFTER_STEP}, + log_path="stage_timing.csv", + ) + + +PyTorch profiler traces +----------------------- + +:class:`~nvalchemi.hooks.TorchProfilerHook` captures PyTorch profiler traces +for both :class:`~nvalchemi.training.strategy.TrainingStrategy` and dynamics +workflows. It starts lazily on the first training or dynamics stage, advances +``torch.profiler`` on each training batch or dynamics step, and finalizes when +the workflow context exits. + +.. code-block:: python + + from torch.profiler import ProfilerActivity, schedule + + from nvalchemi.hooks import TorchProfilerHook + + profile_hook = TorchProfilerHook( + output_dir="profiles/run-001", + activities=(ProfilerActivity.CPU, ProfilerActivity.CUDA), + schedule=schedule(wait=2, warmup=2, active=5, repeat=1), + record_shapes=True, + profile_memory=True, + with_flops=True, + ) + +Outputs are written under ``rank_/torch/`` unless PhysicsNeMo's +own distributed manager is active and already owns rank suffixing. API Reference @@ -218,6 +284,7 @@ Protocol :nosignatures: Hook + CheckpointableHook HookContext DynamicsContext TrainContext @@ -232,4 +299,22 @@ General-purpose hooks BiasedPotentialHook NeighborListHook + StageTimingHook + TorchProfilerHook WrapPeriodicHook + +Reporting +~~~~~~~~~ + +.. autosummary:: + :toctree: generated + :nosignatures: + + ReportingOrchestrator + ReportingState + TensorBoardReporter + RichReporter + RichLayout + BaseRichLayout + TrainingRichLayout + DynamicsRichLayout diff --git a/docs/modules/index.md b/docs/modules/index.md index 4a11a997..c5bcf507 100644 --- a/docs/modules/index.md +++ b/docs/modules/index.md @@ -12,5 +12,6 @@ data hooks dynamics/index models +training/index typing ``` diff --git a/docs/modules/training/checkpoints.rst b/docs/modules/training/checkpoints.rst new file mode 100644 index 00000000..9fafea65 --- /dev/null +++ b/docs/modules/training/checkpoints.rst @@ -0,0 +1,341 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +.. _training-checkpoints: + +Training checkpoints +==================== + +Training checkpoints capture enough state to stop and restart a +:class:`~nvalchemi.training.TrainingStrategy`: model weights, optimizer state, +learning-rate scheduler state, strategy runtime counters, checkpointable hook +state, and the serializable strategy recipe. They are intended for training +restarts, not just inference weight export. + +Manual save and restart +----------------------- + +Use :meth:`~nvalchemi.training.TrainingStrategy.save_checkpoint` when a script +wants to take a one-off checkpoint at a known point: + +.. code-block:: python + + from nvalchemi.training import TrainingStrategy + + strategy = TrainingStrategy(...) + strategy.run(train_loader) + + checkpoint_index = strategy.save_checkpoint("runs/example/checkpoints") + +Reload with :meth:`~nvalchemi.training.TrainingStrategy.load_checkpoint` when +the checkpoint was written from a strategy: + +.. code-block:: python + + from nvalchemi.training import TrainingStrategy + + strategy = TrainingStrategy.load_checkpoint( + "runs/example/checkpoints", + map_location="cpu", + training_fn=training_fn, + ) + + strategy.num_steps = 20_000 + strategy.run(train_loader) + strategy.save_checkpoint("runs/example/checkpoints") + +``checkpoint_index=-1`` loads the latest checkpoint recorded in +``manifest.json``. Pass an explicit index to restart from an older point: + +.. code-block:: python + + strategy = TrainingStrategy.load_checkpoint( + "runs/example/checkpoints", + checkpoint_index=3, + ) + +Training functions +------------------ + +Checkpoint metadata stores the training function only when it can be expressed +as an importable dotted path. If the original strategy used a local function, a +closure, or another non-importable callable, pass ``training_fn=...`` when +loading. Importable functions do not need to be passed again. + +Hooks are runtime objects and are intentionally supplied at load time: + +.. code-block:: python + + from nvalchemi.training import CheckpointHook, TrainingStrategy + + strategy = TrainingStrategy.load_checkpoint( + "runs/example/checkpoints", + hooks=[ + CheckpointHook("runs/example/checkpoints", step_interval=1000), + ], + ) + +Restartable hook state +---------------------- + +Hooks are still runtime objects and must be supplied when loading a strategy. +However, hooks that implement :class:`~nvalchemi.hooks.CheckpointableHook` have +their runtime state stored in strategy checkpoints and restored into the +matching hook supplied at load time. This is intended for hooks whose state +changes training semantics, such as :class:`~nvalchemi.training.hooks.EMAHook` +and its averaged weights. + +.. code-block:: python + + from nvalchemi.training import CheckpointHook, EMAHook, TrainingStrategy + + checkpoint_dir = "runs/example/checkpoints" + + ema = EMAHook(model_key="main", decay=0.999) + strategy = TrainingStrategy( + ..., + hooks=[ + ema, + CheckpointHook(checkpoint_dir, step_interval=1000), + ], + ) + strategy.run(train_loader) + + restored_ema = EMAHook(model_key="main", decay=0.999) + restored = TrainingStrategy.load_checkpoint( + checkpoint_dir, + hooks=[ + restored_ema, + CheckpointHook(checkpoint_dir, step_interval=1000), + ], + ) + +When a script already constructs the strategy and its runtime hooks, use +:meth:`~nvalchemi.training.TrainingStrategy.restore_checkpoint` to hydrate those +live objects in place instead of reconstructing the strategy from metadata: + +.. code-block:: python + + restored = TrainingStrategy( + ..., + hooks=[ + restored_ema, + CheckpointHook(checkpoint_dir, step_interval=1000), + ], + ) + restored.restore_checkpoint(checkpoint_dir) + +Checkpointable hooks are matched by class occurrence in the runtime hook list, +so load-time hooks should be registered in the same relative order as the hooks +that wrote the checkpoint. Non-checkpointable hook state remains the user's +responsibility. Prefer deriving transient state from restored strategy counters +or rebuilding caches at setup time when possible. + +Periodic checkpoint hook +------------------------ + +Use :class:`~nvalchemi.training.hooks.CheckpointHook` for long-running jobs that +should save without custom logic in the training loop: + +.. code-block:: python + + from nvalchemi.training import CheckpointHook, TrainingStrategy + + strategy = TrainingStrategy( + ..., + hooks=[ + CheckpointHook("runs/example/checkpoints", step_interval=1000), + ], + ) + strategy.run(train_loader) + +A checkpoint hook owns one cadence policy. Use ``step_interval`` to save every +N completed optimizer steps, or ``epoch_interval`` to save every N completed +epochs. Register separate hooks only when a job intentionally needs separate +checkpoint roots or policies. + +By default, ``CheckpointHook`` captures a CPU snapshot on the training thread +and writes that snapshot on a background thread. This avoids racing live model +and optimizer tensors while moving filesystem writes off the main training +path. Pending async writes are flushed when the strategy exits its hook +context. + +Model reconstruction specs +-------------------------- + +Strategy checkpoints store model weights separately from a small JSON model +spec. A model spec records an importable callable plus JSON-serializable +keyword arguments. For ordinary modules, this callable is usually the class +constructor:: + + create_model_spec(torch.nn.Linear, in_features=16, out_features=1) + +For models that are created by a factory, adapter, monkey patch, or optimized +conversion pass, the spec can point at that factory instead:: + + create_model_spec( + MACEWrapper.from_checkpoint, + checkpoint_path="small-0b", + dtype=torch.float32, + enable_cueq=True, + ) + +During load, the checkpoint layer rebuilds the model from the spec and then +loads the saved training weights. If the factory accepts ``device``, the loader +passes ``map_location`` into the factory so device-sensitive conversions, such +as MACE cuEquivariance conversion, happen directly on the target device. + +Models may provide their own reconstruction spec by implementing +``checkpoint_spec()`` and returning a :class:`~nvalchemi.training._spec.BaseSpec` +or ``None``. Returning ``None`` keeps the default constructor-introspection +fallback. This is useful for wrappers whose live module cannot be reconstructed +from its transformed ``__init__`` arguments. + +MACE checkpoints and cuEquivariance +----------------------------------- + +When training starts from an existing MACE checkpoint, construct the wrapper +with :meth:`~nvalchemi.models.mace.MACEWrapper.from_checkpoint` and then let +:class:`TrainingStrategy` save and reload the full restart checkpoint:: + + import torch + + from nvalchemi.models.mace import MACEWrapper + from nvalchemi.training import EMAHook, TrainingStrategy + + model = MACEWrapper.from_checkpoint( + "small-0b", + device=torch.device("cuda"), + dtype=torch.float32, + enable_cueq=True, + ) + + ema = EMAHook(model_key="main", decay=0.999) + strategy = TrainingStrategy( + models=model, + ..., + hooks=[ema], + ) + strategy.run(train_loader) + strategy.save_checkpoint(checkpoint_dir) + +On restart, reload the strategy checkpoint rather than saving or loading the +EMA hook in isolation:: + + restored_ema = EMAHook(model_key="main", decay=0.999) + restored = TrainingStrategy.load_checkpoint( + checkpoint_dir, + map_location=torch.device("cuda"), + hooks=[restored_ema], + training_fn=training_fn, + ) + +The saved model spec calls ``MACEWrapper.from_checkpoint`` again with the +recorded MACE checkpoint and options, then the strategy loader restores model +weights, optimizer state, counters, and checkpointable hook state such as EMA +averages. + +Distributed training +-------------------- + +Distributed checkpointing follows the same file layout as single-process +checkpointing, but only one process should write the shared checkpoint. The +default ``CheckpointHook(rank_zero_only=True)`` uses the +:class:`~nvalchemi.hooks.TrainContext` global rank and saves only on rank 0. +Other ranks continue training and do not write duplicate manifests or state +files. + +The usual end-to-end pattern is: + +.. code-block:: python + + from nvalchemi.training import CheckpointHook, TrainingStrategy + + checkpoint_dir = "runs/example/checkpoints" + + strategy = TrainingStrategy( + ..., + hooks=[ + CheckpointHook(checkpoint_dir, step_interval=1000), + ], + ) + strategy.run(train_loader) + +On restart, launch the distributed job again and have each process load the +same checkpoint path: + +.. code-block:: python + + from nvalchemi.training import CheckpointHook, TrainingStrategy + + checkpoint_dir = "runs/example/checkpoints" + + strategy = TrainingStrategy.load_checkpoint( + checkpoint_dir, + map_location=local_device, + training_fn=training_fn, + hooks=[ + CheckpointHook(checkpoint_dir, step_interval=1000), + ], + ) + strategy.num_steps = 20_000 + strategy.run(train_loader) + +``load_checkpoint`` is not rank-zero-only: every process reconstructs its local +strategy, model, optimizer, scheduler, and counters from the shared checkpoint +files. Pass ``map_location`` when the restored process should load onto a +rank-local device instead of the device recorded in the checkpoint metadata. + +The checkpoint directory must be visible to every rank before restart. For +periodic hook saves, the async writer is flushed when the strategy exits. For +manual save workflows, users should coordinate their distributed script so only +one rank calls :meth:`~nvalchemi.training.TrainingStrategy.save_checkpoint`, +then ensure all ranks wait until the checkpoint is complete before any rank +tries to reload it. + +Current checkpoints store replicated strategy and optimizer state. They are +intended for the training strategies used by this package and do not provide a +separate sharded checkpoint format for distributed optimizers or model shards. +Workflows that shard model or optimizer state outside the strategy checkpoint +must save and restore those sharded states separately. + +``DistributedDataParallel`` wrappers are unwrapped before model specs and model +weights are written, so native checkpoints store the underlying model state +without ``module.`` key prefixes. FSDP and FSDP2 require PyTorch Distributed +Checkpoint (DCP) so that each rank can save its shard and reload under a +possibly different topology. Native strategy checkpoints currently reject +FSDP/FSDP2-wrapped models instead of writing incomplete rank-local state. See +the `PyTorch Distributed Checkpoint recipe `_ +for the DCP workflow. + +Lower-level loader +------------------ + +The module-level :func:`~nvalchemi.training.save_checkpoint` and +:func:`~nvalchemi.training.load_checkpoint` functions remain available when +callers need the full manifest, component dictionaries, validators, model +subsets, or adapter loads. ``TrainingStrategy.load_checkpoint`` deliberately +returns only the restored strategy and rejects component-only checkpoints. + +API reference +------------- + +.. currentmodule:: nvalchemi.training + +.. autosummary:: + :toctree: generated + :nosignatures: + + TrainingStrategy.save_checkpoint + TrainingStrategy.restore_checkpoint + TrainingStrategy.load_checkpoint + save_checkpoint + load_checkpoint + +.. currentmodule:: nvalchemi.training.hooks + +.. autosummary:: + :toctree: generated + :nosignatures: + + CheckpointHook diff --git a/docs/modules/training/hooks.rst b/docs/modules/training/hooks.rst new file mode 100644 index 00000000..5861f70f --- /dev/null +++ b/docs/modules/training/hooks.rst @@ -0,0 +1,393 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +.. _training-hooks-api: +.. _training-hooks: +.. _training-update-hooks: + +Training update hooks +===================== + +Training update hooks are for policies that need to participate in the +weight-update portion of a training batch. They are intentionally narrower than +general :class:`~nvalchemi.hooks.Hook` objects: a +:class:`~nvalchemi.training.hooks.TrainingUpdateHook` only runs on the stages +owned by :class:`~nvalchemi.training.hooks.TrainingUpdateOrchestrator`, and the +orchestrator performs the actual ``backward()``, optimizer step, scheduler step, +and gradient zeroing calls. + +Use this hook family when multiple update policies need to coordinate around the +same batch update. Typical examples include gradient accumulation, mixed +precision, gradient clipping, spike skipping, and post-step model averaging. +Use a standard training hook for read-only observation or lifecycle logic that +does not need to own backward or optimizer-step behavior. + +``ctx.step_count`` tracks completed optimizer/scheduler steps. If an update hook +vetoes ``DO_OPTIMIZER_STEP`` for gradient accumulation or spike skipping, the +batch still advances ``ctx.batch_count`` and ``ctx.epoch_step_count`` but does +not advance ``ctx.step_count``. + +Distributed data parallel +------------------------- + +:class:`~nvalchemi.training.hooks.DDPHook` wraps optimized models in +``torch.nn.parallel.DistributedDataParallel`` during +``TrainingStage.SETUP``. This setup stage runs after distributed rank/device +resolution and before optimizer construction, so optimizers are built from the +DDP-wrapped model parameters. +See :ref:`distributed_manager_guide` for the workflow-level +``DistributedManager`` guide. + +.. code-block:: python + + from nvalchemi.distributed import DistributedManager + from nvalchemi.training.hooks import DDPHook, MixedPrecisionHook + from nvalchemi.training.strategy import TrainingStrategy + + DistributedManager.initialize() + manager = DistributedManager() + + strategy = TrainingStrategy( + ..., + distributed_manager=manager, + hooks=[ + DDPHook(find_unused_parameters=False), + MixedPrecisionHook(precision="bf16"), + ], + ) + +Launch single-node distributed training with ``torchrun``: + +.. code-block:: bash + + torchrun --nproc_per_node=2 train.py + +``DDPHook`` can also use ``TrainingStrategy.distributed_manager`` when a caller +provides a manager object. The recommended manager is +:class:`nvalchemi.distributed.DistributedManager`, which re-exports +``physicsnemo.distributed.DistributedManager``. Users should call +``DistributedManager.initialize()`` before constructing the manager. The hook +uses the manager's rank, world-size, local-rank, device, process group, and DDP +defaults such as ``broadcast_buffers`` and ``find_unused_parameters``. Without a +manager, the hook falls back to ``torch.distributed`` and torchrun environment +variables. + +Sampler handling is automatic for supported dataloaders. For +``torch.utils.data.DataLoader``, the hook returns a replacement loader with a +configured sampler when one is not already present. The default sampler is +``torch.utils.data.DistributedSampler``; pass ``sampler_kwargs`` to override +its inferred ``rank``, ``num_replicas``, ``shuffle``, ``seed``, or +``drop_last`` arguments, or pass ``sampler_cls`` with ``sampler_kwargs`` to use +a custom distributed sampler. For +``nvalchemi.data.datapipes.DataLoader``, it mutates ``loader.sampler`` in place. +Custom ``batch_sampler`` instances must already be distributed-aware. +The strategy's epoch handling calls ``sampler.set_epoch(...)`` when available. + +``DDPHook`` is not a training-update hook, so it does not participate in +``DO_BACKWARD`` or ``DO_OPTIMIZER_STEP``. Register it alongside +``MixedPrecisionHook`` normally; DDP wrapping happens before AMP opens its +per-batch autocast/update path. + +PyTorch profiler traces +----------------------- + +:class:`~nvalchemi.training.hooks.TorchProfilerHook` captures PyTorch profiler +Chrome traces through PhysicsNeMo's profiler wrapper. In training workflows it +starts at ``TrainingStage.BEFORE_TRAINING``, advances the profiler schedule at +``TrainingStage.AFTER_BATCH``, and finalizes at ``TrainingStage.AFTER_TRAINING`` +or when the strategy context exits. Standalone ``train_batch()`` calls start +lazily at ``TrainingStage.BEFORE_BATCH`` and still finalize when the context +closes. + +.. code-block:: python + + from torch.profiler import ProfilerActivity, schedule + + from nvalchemi.training.hooks import TorchProfilerHook + from nvalchemi.training.strategy import TrainingStrategy + + profile_hook = TorchProfilerHook( + output_dir="profiles/train-run", + activities=(ProfilerActivity.CPU, ProfilerActivity.CUDA), + schedule=schedule(wait=2, warmup=2, active=5, repeat=1), + record_shapes=True, + profile_memory=True, + with_flops=True, + ) + + strategy = TrainingStrategy(..., hooks=[profile_hook]) + strategy.run(train_loader) + +Each process writes to ``profiles/train-run/rank_/torch/`` unless +PhysicsNeMo's distributed manager is active and already owns rank suffixing. + +Mixed precision +--------------- + +:class:`~nvalchemi.training.hooks.MixedPrecisionHook` enables +``torch.amp.autocast`` for the forward/loss portion of the batch and uses +``torch.amp.GradScaler`` when ``precision`` is ``torch.float16``. The +``precision`` argument is required so configs must choose one of the supported +policies explicitly: + +.. code-block:: python + + import torch + + from nvalchemi.training.hooks import MixedPrecisionHook + from nvalchemi.training.strategy import TrainingStrategy + + strategy = TrainingStrategy( + ..., + hooks=[MixedPrecisionHook(precision=torch.bfloat16)], + ) + +``precision`` accepts the dtype objects ``torch.float32``, ``torch.bfloat16``, +and ``torch.float16``, the canonical strings ``"float32"``, ``"bfloat16"``, +and ``"float16"``, or the shorthand aliases ``"fp32"``, ``"bf16"``, and +``"fp16"``. + +The policies are: + +* ``torch.float32``: no autocast context is created and no scaler is used. +* ``torch.bfloat16``: eligible ops run under bf16 autocast and no scaler is used. +* ``torch.float16``: eligible forward/loss ops run under fp16 autocast, the hook + scales the loss before backward, unscales gradients immediately before an + optimizer step proceeds, and lets the scaler skip steps with ``inf`` or + ``nan`` gradients. + +Register at most one ``MixedPrecisionHook`` per strategy. The strategy rejects +multiple mixed-precision hooks so that autocast, loss scaling, unscale, scaler +step, and scaler update cannot be applied twice in one batch update. + +Autocast scope +-------------- + +Autocast begins from the update-hook ``BEFORE_BATCH`` stage and is released +before ``backward()`` during ``DO_BACKWARD``. In normal strategy execution, that +covers the model forward and configured loss calculation while keeping backward +outside autocast. ``torch.float32`` is a no-op policy and does not create an +autocast context. Model wrappers or custom losses that need full precision for +a numerically sensitive subregion should open a local +``torch.amp.autocast(..., enabled=False)`` block or choose ``torch.float32`` / +``torch.bfloat16`` for the strategy. + +Gradient accumulation +--------------------- + +With fp16 gradient scaling, accumulated gradients stay scaled until the +effective batch is ready to step. A gradient-accumulation update hook should +veto ``TrainingStage.DO_OPTIMIZER_STEP`` on intermediate microbatches; that +suppresses AMP unscale, scaler step, and scaler update for those batches. When +the accumulation window is complete, the optimizer-step stage proceeds and +``MixedPrecisionHook`` unscales once per optimizer just before stepping. + +The scaler path has a small fast path when no schedulers are configured: +``GradScaler.step`` and ``GradScaler.update`` are sufficient. When schedulers are +present, the orchestrator checks whether each scaler step was skipped so it can +advance only schedulers whose paired optimizer actually stepped. + +Validation +---------- + +``MixedPrecisionHook`` is tied to the training update path owned by +``TrainingStrategy``. Validation is first-class on the strategy: +``TrainingStrategy.validate()`` (driven by a :class:`~nvalchemi.training.ValidationConfig`) +automatically honors a registered ``MixedPrecisionHook``'s inference autocast +according to the config's ``use_mixed_precision`` policy, so no separate +validation hook is required. The standalone +:class:`~nvalchemi.training.ValidationLoop` is hook-agnostic and instead takes +an explicit ``autocast`` callable. See :doc:`validation`. + +Stage constraints +----------------- + +Training update hooks always receive ``(ctx, stage, will_skip)`` and return +``(proceed, loss)``. The meaning of those values depends on the stage: + +.. list-table:: Training update hook stage contract + :widths: 18 22 22 38 + :header-rows: 1 + + * - Stage + - Hook responsibility + - Return contract + - Restrictions and expectations + * - ``BEFORE_BATCH`` + - Decide whether the orchestrator should call + :func:`~nvalchemi.training.optimizers.zero_gradients`. + - ``proceed`` must be a strict ``bool``. Any ``False`` vetoes gradient + zeroing. ``loss`` is ignored. + - Do not call ``backward()``, ``optimizer.step()``, or + ``scheduler.step()``. Use this stage for zero-grad policy, per-batch + update bookkeeping, or resetting state that is safe before the forward + pass. + * - ``DO_BACKWARD`` + - Transform or replace ``ctx.loss`` before the orchestrator calls + ``backward()`` once. + - ``loss`` must be a :class:`torch.Tensor`. ``proceed`` is ignored. + - Do not call ``backward()`` directly. Return the loss tensor the next + update hook should see. This is the stage for loss scaling and other + loss-space transforms. + * - ``DO_OPTIMIZER_STEP`` + - Decide whether the orchestrator should call + :func:`~nvalchemi.training.optimizers.step_optimizers` and + :func:`~nvalchemi.training.optimizers.step_lr_schedulers`. + - ``proceed`` must be a strict ``bool``. Any ``False`` vetoes both the + optimizer and scheduler step. ``loss`` is ignored. + - Do not call ``backward()``. Avoid side effects that assume a step will + run when ``will_skip`` is ``True``. This is the stage for pre-step logic + such as gradient clipping, scaler updates, and accumulation/spike-skip + decisions. + * - ``AFTER_OPTIMIZER_STEP`` + - Observe the final step decision and run post-step bookkeeping. + - ``proceed`` and ``loss`` are ignored. ``will_skip`` tells the hook + whether the optimizer/scheduler step was vetoed. + - Do not call ``backward()`` or perform another optimizer/scheduler step. + Use this stage for work that should happen after the step path, such as + EMA updates, diagnostics, and state cleanup. + +Composition rules +----------------- + +All update hooks for a strategy are composed into one orchestrator. Lower +``priority`` values run first, and registration order breaks ties. The +orchestrator keeps calling later hooks after a veto so they can observe +``will_skip=True`` and update their own state consistently. + +Only one object may own ``DO_BACKWARD`` or ``DO_OPTIMIZER_STEP`` in a +:class:`~nvalchemi.training.strategy.TrainingStrategy`. For convenience, the +strategy auto-wraps bare :class:`~nvalchemi.training.hooks.TrainingUpdateHook` +instances into one :class:`~nvalchemi.training.hooks.TrainingUpdateOrchestrator`. +Passing ``stage=...`` while registering an update hook is not supported because +update hooks declare their stages through the orchestrator. + +EMA model averaging +------------------- + +:class:`~nvalchemi.training.hooks.EMAHook` maintains an +``AveragedModel`` for one model in ``ctx.models``. It runs during +``AFTER_OPTIMIZER_STEP`` and updates only after a successful optimizer step. If +an earlier update hook vetoes ``DO_OPTIMIZER_STEP``, the orchestrator passes +``will_skip=True`` and the EMA weights are left unchanged for that batch. + +.. code-block:: python + + from nvalchemi.training.hooks import EMAHook + from nvalchemi.training.strategy import TrainingStrategy + + ema = EMAHook(model_key="main", decay=0.999) + strategy = TrainingStrategy(..., hooks=[ema]) + +Restartable update hooks +------------------------ + +Most hooks should not write checkpoint state. If a training hook owns state +that changes resumed training behavior, such as EMA averaged weights or a +learned/adaptive update policy, make it satisfy +:class:`~nvalchemi.hooks.CheckpointableHook`. The strategy checkpoint loader +restores state only into runtime hooks that implement this specialized +protocol. + +For Pydantic hooks, keep constructor/configuration fields on the model and use +``model_dump()`` inside ``state_dict()`` before adding non-field runtime state: + +.. code-block:: python + + from collections.abc import Mapping + from typing import Any + + import torch + from pydantic import BaseModel, Field, PrivateAttr + + from nvalchemi.hooks import CheckpointableHook + from nvalchemi.training import TrainingStage + from nvalchemi.training.hooks import TrainingUpdateHook + + class RestartableMetricHook(BaseModel, TrainingUpdateHook): + update_every: int = Field(gt=0, default=1) + num_updates: int = 0 + + _metric_total: torch.Tensor | None = PrivateAttr(default=None) + + def __call__(self, ctx, stage, will_skip): + if ( + stage is TrainingStage.AFTER_OPTIMIZER_STEP + and not will_skip + and ctx.loss is not None + ): + value = ctx.loss.detach().to("cpu") + self._metric_total = ( + value + if self._metric_total is None + else self._metric_total + value + ) + self.num_updates += 1 + return True, ctx.loss + + def state_dict(self) -> dict[str, Any]: + state = self.model_dump() + if self._metric_total is not None: + state["metric_total"] = self._metric_total + return state + + def load_state_dict(self, state: Mapping[str, Any]) -> None: + if state.get("update_every", self.update_every) != self.update_every: + raise ValueError("RestartableMetricHook config mismatch") + self.num_updates = int(state.get("num_updates", self.num_updates)) + self._metric_total = state.get("metric_total") + + assert isinstance(RestartableMetricHook(), CheckpointableHook) + +Use ``model_dump_json()`` for JSON configuration records or diagnostics, not for +tensor-bearing checkpoint state. Tensor state should remain in ``state_dict()`` +so the checkpoint layer can save it with the rest of the training state. + +Example +------- + +.. code-block:: python + + import torch + + from nvalchemi.training import TrainingStage + from nvalchemi.training.hooks import TrainingUpdateHook + + class ClipGradients(TrainingUpdateHook): + priority = 30 + + def __init__(self, max_norm: float) -> None: + self.max_norm = max_norm + + def __call__(self, ctx, stage, will_skip): + match stage: + case TrainingStage.DO_OPTIMIZER_STEP: + if not will_skip: + for optimizer in ctx.optimizers: + params = ( + param + for group in optimizer.param_groups + for param in group["params"] + ) + torch.nn.utils.clip_grad_norm_(params, self.max_norm) + case _: + pass + return True, ctx.loss + + +API reference +------------- + +.. currentmodule:: nvalchemi.training.hooks + +.. autosummary:: + :toctree: generated + :nosignatures: + + DDPHook + MixedPrecisionHook + TorchProfilerHook + TrainingUpdateHook + TrainingUpdateOrchestrator + EMAHook + CheckpointHook diff --git a/docs/modules/training/index.rst b/docs/modules/training/index.rst new file mode 100644 index 00000000..1d6588ca --- /dev/null +++ b/docs/modules/training/index.rst @@ -0,0 +1,13 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +Training module +=============== + +.. toctree:: + :maxdepth: 2 + + checkpoints + hooks + losses + validation diff --git a/docs/modules/training/losses.rst b/docs/modules/training/losses.rst new file mode 100644 index 00000000..fc9e10b9 --- /dev/null +++ b/docs/modules/training/losses.rst @@ -0,0 +1,84 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +.. _losses-api: + +======================= +Losses — Training Terms +======================= + +Composable, tensor-first loss functions for MLIP training. + +.. seealso:: + + - **User guide**: :ref:`losses_guide` — conceptual overview, usage + patterns, and how to write your own loss term. + + +Leaf and composition +-------------------- + +Leaf losses subclass :class:`~nvalchemi.training.BaseLossFunction`; +compositions use :class:`~nvalchemi.training.ComposedLossFunction` and +return a :class:`~nvalchemi.training.ComposedLossOutput`. + +.. currentmodule:: nvalchemi.training + +.. autosummary:: + :toctree: generated + :nosignatures: + + BaseLossFunction + ReductionContext + ComposedLossFunction + ComposedLossOutput + LossWeightSchedule + + +Concrete losses +--------------- + +Built-in leaf losses for common quantum-chemistry targets. + +.. autosummary:: + :toctree: generated + :nosignatures: + + EnergyMSELoss + EnergyMAELoss + ForceMSELoss + ForceL2NormLoss + StressMSELoss + + +Weight schedules +---------------- + +Pydantic ``frozen`` models satisfying :class:`~nvalchemi.training.LossWeightSchedule`. + +.. autosummary:: + :toctree: generated + :nosignatures: + + ConstantWeight + LinearWeight + CosineWeight + PiecewiseWeight + + +Reduction helpers +----------------- + +Per-graph reduction helpers — scatter reductions (``V ... → B ...``) +and matrix reductions (``B ... m n → B ...``) — importable for use in +custom losses. + +.. currentmodule:: nvalchemi.training.losses.reductions + +.. autosummary:: + :toctree: generated + :nosignatures: + + per_graph_sum + per_graph_mean + frobenius_mse diff --git a/docs/modules/training/validation.rst b/docs/modules/training/validation.rst new file mode 100644 index 00000000..0d96a3bd --- /dev/null +++ b/docs/modules/training/validation.rst @@ -0,0 +1,312 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +.. _validation-api: + +========== +Validation +========== + +Validation reuses the training loop's forward pass and loss machinery but runs +it under configurable inference conditions. By default a validation pass calls +the strategy's ``training_fn`` and evaluates the same +:class:`~nvalchemi.training.losses.ComposedLossFunction`, so it reports the +metrics you train against. Both are overridable: set ``validation_fn`` to run a +user-defined validation function, and set ``loss_fn`` to score against a +different metric (for example a plain MAE for monitoring while you train against +a weighted energy/force loss). + +What stays fixed is that there is **no backward pass and no optimizer step** — a +validation pass only runs the forward and the loss, then reduces the per-batch +results across ranks into a single summary. The remaining inference semantics are +configuration-driven, not automatic: modules are placed in eval mode when +``set_eval`` is true (the default) and restored afterward, and autograd is +governed by ``grad_mode`` (see :ref:`configuring-validation-gradients`). + +Because validation is a first-class part of +:class:`~nvalchemi.training.TrainingStrategy`, you do not register a validation +hook — you attach a :class:`~nvalchemi.training.ValidationConfig` to the +strategy and passes run automatically. The mechanics live in a reusable +:class:`~nvalchemi.training.ValidationLoop`, which you can also drive yourself +for standalone metric evaluation (see +:ref:`standalone-validation`). + +.. seealso:: + + - :doc:`hooks` — training lifecycle stages and update hooks, including + ``AFTER_VALIDATION``. + + +How validation differs from training +------------------------------------ + +Both loops move each batch to the device, call the forward function, and +evaluate the same composed loss. From there they diverge: + +.. list-table:: + :header-rows: 1 + :widths: 26 32 42 + + * - Aspect + - Training step + - Validation pass + * - Backward / optimizer step + - Yes + - No — forward + loss only + * - Module mode + - ``train()`` + - ``eval()`` by default (``set_eval``), restored afterward + * - Autograd + - Always on + - Driven by ``grad_mode`` (see below) + * - Weights + - Live training weights + - Live, or the EMA / inference slot + * - Per-batch output + - Loss for the update + - Accumulated into a reduced summary + * - Gradient buffers + - Updated in place + - Snapshotted, cleared, restored + +Validation snapshots, clears, and restores parameter ``.grad`` buffers around +the pass, so it never corrupts in-flight training gradients even when it runs +with autograd enabled. Module training modes are likewise snapshotted and +restored. + + +.. _configuring-validation-gradients: + +Configuring gradients +---------------------- + +Some validation losses need autograd at inference time. Force and stress losses, +for example, differentiate energy with respect to positions, so the forward pass +must build a graph even though no optimizer step follows. ``ValidationConfig`` +exposes this through ``grad_mode``: + +- ``"auto"`` (default) — enable gradients when any loss component reports + ``requires_eval_grad=True`` (e.g. force/stress terms) and disable them + otherwise. This usually does the right thing without configuration. +- ``"enabled"`` — always run under ``torch.enable_grad()``. +- ``"disabled"`` — always run under ``torch.no_grad()``. + +When gradients are enabled the loop runs each batch under +``torch.enable_grad()``; otherwise it uses ``torch.no_grad()``. Either way the +parameter gradient buffers are restored on exit. + + +The validation flow +-------------------- + +A single pass proceeds as: + +1. **Setup** — snapshot module training modes and set them to eval (when + ``set_eval=True``); snapshot and clear parameter gradients (when grad-enabled). +2. **Per batch** — move the batch to the device; clear gradients; run the + forward + loss under the resolved autograd and autocast contexts; accumulate + the per-component loss diagnostics; invoke the optional ``batch_callback``. +3. **Reduce** — all-reduce the accumulated totals across ranks and build the + summary dict (published on rank 0; ``None`` on other ranks). +4. **Teardown** — restore parameter gradients and module training modes, even + if the pass raised. + + +Strategy-owned validation +------------------------- + +Assign a :class:`~nvalchemi.training.ValidationConfig` to +``strategy.validation_config`` and validation runs automatically inside +``strategy.run(...)``: + +- at a **step cadence** (``every_n_steps``), after the completed optimizer + step so EMA weights are already current, or +- at an **epoch cadence** (``every_n_epochs``), at the epoch boundary, and +- once **unconditionally at end-of-training** whenever a config is present. + +Each pass stores its summary on ``strategy.last_validation`` and fires the +``AFTER_VALIDATION`` hook stage, so loggers can read the live summary before +any metric-driven learning-rate scheduler consumes it. + +.. code-block:: python + + from nvalchemi.training import TrainingStrategy, ValidationConfig + + strategy = TrainingStrategy(...) + strategy.validation_config = ValidationConfig( + validation_data=val_data, # a re-iterable container of Batch + every_n_epochs=1, + ) + strategy.run(train_loader) + +``validation_data`` must be a *re-iterable* container (a ``list``, +``DataLoader``, or ``Dataset``) — the strategy walks it afresh on every pass, so +one-shot generators are rejected at construction time. By default validation +reuses the strategy's ``training_fn`` and ``loss_fn``; set ``validation_fn`` or +``loss_fn`` on the config to override either. + + +Using regular hooks with validation +------------------------------------ + +Validation does not bypass the hook system. Validation passes execute inside +``strategy.run(...)``, so every hook you register on the strategy keeps firing +on its normal stages. The dedicated tap-off point is ``AFTER_VALIDATION``, +fired from inside ``TrainingStrategy.validate()`` the moment a summary is +produced — and before any metric-driven scheduler consumes it. Register an +ordinary hook on that stage to log aggregate metrics from ``ctx.validation``: + +.. code-block:: python + + from nvalchemi.training import TrainingStage + + class SummaryLogger: + stage = TrainingStage.AFTER_VALIDATION + frequency = 1 + + def __call__(self, ctx, stage): + summary = ctx.validation # None on non-publishing ranks + if summary is not None: + my_tracker.log(val_loss=float(summary["total_loss"])) + + strategy.register_hook(SummaryLogger()) + +This is also how metric-driven learning-rate scheduling is wired (see +:ref:`metric-driven-schedulers`): the summary is available to consumers on the +same iteration the pass runs. + + +Inference model slot +-------------------- + +``TrainingStrategy`` owns an ``inference_model`` slot. Validation reads it via +the config's ``use_ema`` policy; an :class:`~nvalchemi.training.EMAHook` +publishes its averaged module into the slot at ``AFTER_OPTIMIZER_STEP``. The +writer (EMA / SWA / a distillation teacher) and the reader (validation) never +inspect each other — both only know the strategy. An empty slot falls back to +the live training model(s). + + +.. _metric-driven-schedulers: + +Metric-driven schedulers +------------------------ + +``ReduceLROnPlateau`` and subclasses are metric-driven: they step only at +validation checkpoints, consuming a scalar extracted from the validation +summary via :attr:`OptimizerConfig.scheduler_metric_adapter +` (a summary-dict key string or a +callable). Time-based schedulers continue to step every optimizer step. + + +.. _tapping-off-validation-data: + +Tapping off per-batch data with ``batch_callback`` +-------------------------------------------------- + +The ``AFTER_VALIDATION`` hook above sees only the reduced *summary*. When you +need the individual batches — to stream predictions to a Zarr store, dump +per-sample diagnostics, or run a custom error analysis — configure a +``batch_callback``. The toolkit ships no output-sink machinery: you bring your +own sink and the loop simply hands you each batch as it goes. + +A ``batch_callback`` is any object matching the +:class:`~nvalchemi.training.BatchValidationCallback` protocol. It is invoked +once per validation batch from inside +:meth:`~nvalchemi.training.ValidationLoop.execute`, immediately after that +batch's predictions and loss are computed. The call is keyword-only — +``batch``, ``predictions``, ``loss``, ``batch_count``, ``step_count``, and +``epoch`` — and you own the sink, its buffering, and its I/O: + +.. code-block:: python + + from nvalchemi.training import ValidationConfig + + class ZarrBatchSink: + """Example escape-hatch sink — write predictions to a Zarr store.""" + + def __init__(self, store): + self._store = store + + def __call__( + self, *, batch, predictions, loss, batch_count, step_count, epoch + ): + group = self._store.require_group(f"step_{step_count}") + group[f"batch_{batch_count}"] = predictions["energy"].cpu().numpy() + + config = ValidationConfig( + validation_data=val_data, + batch_callback=ZarrBatchSink(my_zarr_store), + ) + +A plain function works too — any callable with the keyword-only signature +satisfies the protocol: + +.. code-block:: python + + def log_batch(*, batch, predictions, loss, batch_count, step_count, epoch): + ... # write predictions / per-batch loss to your store of choice + + config = ValidationConfig(validation_data=val_data, batch_callback=log_batch) + + +.. _standalone-validation: + +Standalone validation (metric evaluation) +----------------------------------------- + +The same :class:`~nvalchemi.training.ValidationLoop` that the strategy drives +can be run on its own— for example to evaluate a +trained checkpoint against a held-out set and read back the metrics. Standalone +construction takes the dependencies the strategy would otherwise supply: an +explicit ``model`` (or named ``models``), a ``validation_fn``, a loss (directly +or via ``config.loss_fn``), and optionally an ``autocast`` factory and an +explicit ``grad_enabled`` override. It is a context manager — ``execute()`` must +be called inside the ``with`` block — that snapshots and restores training modes +and gradients on exit, even on exception: + +.. code-block:: python + + from nvalchemi.training import ValidationConfig, ValidationLoop + + config = ValidationConfig(validation_data=val_data, loss_fn=loss_fn) + loop = ValidationLoop( + validation_data=val_data, + config=config, + device=device, + model=model, + validation_fn=validation_fn, + ) + with loop as active: + summary = active.execute() + + print(summary["total_loss"]) + +The returned ``summary`` is the same dictionary surfaced on +``ctx.validation`` / ``strategy.last_validation`` during integrated training. It +contains ``total_loss``, per-component totals/weights/samples, batch and sample +counts, ``model_source`` (``"ema"`` / ``"mixed"`` / ``"live"``), +``ema_model_keys``, ``precision``, and ``distributed_reduced``. Under +distributed execution it is published on rank 0 and ``None`` elsewhere. + +.. note:: + + If your loss differentiates the model output (force or stress losses), set + ``grad_mode="enabled"`` on the config or pass ``grad_enabled=True`` so the + standalone forward builds an autograd graph; ``grad_mode="auto"`` does this + for you when run through a strategy but standalone callers should be explicit + when the loss object cannot be introspected. + + +API reference +------------- + +.. currentmodule:: nvalchemi.training + +.. autosummary:: + :toctree: generated + :nosignatures: + + ValidationConfig + ValidationLoop + BatchValidationCallback diff --git a/docs/userguide/agent_skills.md b/docs/userguide/agent_skills.md index 9a81ed61..ee72c430 100644 --- a/docs/userguide/agent_skills.md +++ b/docs/userguide/agent_skills.md @@ -29,7 +29,8 @@ available. | Skill | Description | Related user guide | |-------|-------------|--------------------| | `nvalchemi-data-structures` | How to use {py:class}`~nvalchemi.data.AtomicData` and {py:class}`~nvalchemi.data.Batch` for representing atomic systems and batching them for GPU computation. | {ref}`data_guide` | -| `nvalchemi-data-storage` | How to write, read, and load atomic data using the composable Zarr-backed storage pipeline (Writer, Reader, Dataset, DataLoader). | {ref}`datapipes_guide` | +| `nvalchemi-data-storage` | How to write, read, compose, and load atomic data using the composable Zarr-backed storage pipeline (Writer, Reader, Dataset, MultiDataset, DataLoader). | {ref}`datapipes_guide` | +| `nvalchemi-zarr-perf` | How to tune Zarr-backed Reader, Dataset, MultiDataset, and DataLoader throughput with fused reads, validation skipping, pinned memory, and benchmark sweeps. | {ref}`read_performance_tuning` | | `nvalchemi-model-wrapping` | How to wrap an arbitrary MLIP using the {py:class}`~nvalchemi.models.base.BaseModelMixin` interface to standardize inputs, outputs, and embeddings. | {ref}`models_guide` | | `nvalchemi-dynamics-api` | How to configure and run dynamics simulations, compose multi-stage pipelines ({py:class}`~nvalchemi.dynamics.FusedStage`, {py:class}`~nvalchemi.dynamics.DistributedPipeline`), use inflight batching, and manage data sinks. | {ref}`dynamics_guide` | | `nvalchemi-dynamics-implementation` | How to implement a dynamics integrator by subclassing {py:class}`~nvalchemi.dynamics.base.BaseDynamics` and overriding `pre_update()` and `post_update()`. | {ref}`dynamics_guide` | diff --git a/docs/userguide/datapipes.md b/docs/userguide/datapipes.md index f35d47d1..0798b58a 100644 --- a/docs/userguide/datapipes.md +++ b/docs/userguide/datapipes.md @@ -9,8 +9,9 @@ workloads. It is built from four composable pieces: a **Reader** that pulls raw tensors from storage, a **Dataset** that validates them into {py:class}`nvalchemi.data.AtomicData` objects, a **DataLoader** that batches them into {py:class}`nvalchemi.data.Batch` objects, and an optional **Sampler** that -controls batching strategy. Each layer adds exactly one concern, and you can swap -any of them independently. +controls batching strategy. A **MultiDataset** can also compose several datasets +behind one global index space. Each layer adds exactly one concern, and you can +swap any of them independently. ```{note} The ``datapipes`` abstraction is shared with ``physicsnemo``: there are some @@ -26,14 +27,42 @@ reading, and loading atomic data through the Zarr-backed storage pipeline. ## Reader: raw tensor I/O -A {py:class}`~nvalchemi.data.datapipes.backends.base.Reader` is the simplest -abstraction in the pipeline. It knows how to load a single sample from storage and -return it as a plain `dict[str, torch.Tensor]` --- no validation, no device -transfers, no threading. Readers are intentionally minimal so that adding a new -storage backend only requires implementing two methods: +A {py:class}`~nvalchemi.data.datapipes.backends.base.Reader` is the storage-facing +layer of the pipeline. It returns plain `dict[str, torch.Tensor]` objects on CPU: +no `AtomicData` validation, no device transfers, no batching policy, and no +threading. That separation keeps storage backends focused on I/O and lets +samplers decide *which* samples to request without also needing to know how the +store should be read efficiently. -- `_load_sample(index) -> dict[str, torch.Tensor]`: Reads one sample into CPU tensors. -- `__len__() -> int`: Returns the total number of available samples. +Readers expose two public loading methods: + +- `read(index)`: Load one sample and return `(raw_tensor_dict, metadata)`. +- `read_many(indices)`: Load several samples in the requested order and return one + `(raw_tensor_dict, metadata)` pair per requested index. + +Both public methods attach per-sample metadata and optionally pin CPU tensors when +`pin_memory=True`. Index validity is a backend concern: for example, the Zarr +reader supports negative logical indices and maps them through its active-sample +mask, while another backend may choose different index semantics. + +`read_many` has an ordered contract: results must align one-for-one with the +requested indices. Backends can reorder internally for physical I/O, but they must +restore the caller's requested order before returning. + +Backend authors implement one or both raw loading hooks: + +- `_load_sample(index) -> dict[str, torch.Tensor]`: Simple single-sample path. +- `_load_many_samples(indices) -> list[dict[str, torch.Tensor]]`: Batch-oriented + path for amortizing I/O across many requested samples. +- `__len__() -> int`: Total number of available logical samples. + +For simple formats, implementing `_load_sample` is enough; the base `Reader` +implements `read_many` by looping over `_load_sample`. Readers that only have an +efficient batch path can implement `_load_many_samples`; the base single-sample +hook can call it with a one-index request. For storage formats with high per-call +overhead or chunk locality, implement `_load_many_samples` so the backend can +sort, merge, cache, or otherwise coalesce physical reads before returning samples +in the caller's original order. The built-in reader is {py:class}`~nvalchemi.data.datapipes.backends.zarr.AtomicDataZarrReader`, which @@ -42,6 +71,13 @@ reads from the structured Zarr stores produced by the toolkit's layout uses separate groups for core fields, metadata, and custom attributes, and supports soft-deletes via a validity mask. +`AtomicDataZarrReader` implements `_load_many_samples` as the fast path. Given a +shuffled list of logical indices, it maps them to physical sample positions, sorts +by physical order, groups reads by Zarr chunk locality, loads each array in +coalesced ranges or orthogonal selections, and then restores the caller's requested +sample order. This is why downstream code should prefer `read_many` for batches +instead of looping over `read`. + ```{tip} The writer supports per-group compression and chunking via {py:class}`~nvalchemi.data.datapipes.ZarrWriteConfig`. See the @@ -50,9 +86,38 @@ recommendations and storage estimates. ``` If your data lives in a different format (HDF5, LMDB, a collection of files), you -can subclass `Reader` and implement `_load_sample` and `__len__`. Everything +can subclass `Reader` and implement the hook that matches the backend. Everything downstream --- Dataset, DataLoader, Sampler --- will work without changes. +```python +from collections.abc import Sequence + +import torch + +from nvalchemi.data.datapipes.backends.base import Reader + + +class MyReader(Reader): + def __len__(self) -> int: + return 10_000 + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + # Good enough for simple formats or true random-access stores. + return load_one_sample(index) + + +class MyBatchOptimizedReader(Reader): + def __len__(self) -> int: + return 10_000 + + def _load_many_samples( + self, indices: Sequence[int] + ) -> list[dict[str, torch.Tensor]]: + # Use backend-specific locality here, then return results in the same + # logical order as ``indices``. + return load_samples_with_coalesced_io(indices) +``` + ## Dataset: validation and prefetching {py:class}`~nvalchemi.data.datapipes.dataset.Dataset` wraps a Reader and adds two @@ -60,28 +125,62 @@ responsibilities: 1. **Validation**: Raw dictionaries are validated into {py:class}`nvalchemi.data.AtomicData` objects, catching schema issues early. + Pass `skip_validation=True` to bypass Pydantic validation when the backing + store is already known to be well-formed (see + [Read performance tuning](read_performance_tuning)). 2. **Async prefetching**: A background `ThreadPoolExecutor` loads and transfers - samples to the target device ahead of time, so the GPU is never starved. + samples to the target device ahead of time, reducing stalls while the model + consumes previous batches. + +The Dataset talks to readers through public `reader.read_many(...)`. This is true +even when a caller asks for one sample: single-sample Dataset access is +represented as a one-element read request, so batch-capable readers keep their +optimized path and Dataset does not need to know backend-specific private hooks. +Duck-typed readers can be used without inheriting from `Reader` if they implement +`read_many`, `__len__`, and `close`. ```python from nvalchemi.data.datapipes.dataset import Dataset from nvalchemi.data.datapipes.backends.zarr import AtomicDataZarrReader reader = AtomicDataZarrReader("/path/to/store.zarr") -dataset = Dataset(reader=reader, device="cuda:0", num_workers=4) +dataset = Dataset( + reader=reader, + device="cuda:0", + num_workers=1, + skip_validation=True, # use for trusted stores written by the toolkit +) # Fetch a single sample (AtomicData on GPU) data, metadata = dataset[0] ``` -### CUDA stream prefetching +### Batch loading and CUDA stream prefetching -When called by a DataLoader, the Dataset uses -{py:meth}`~nvalchemi.data.datapipes.dataset.Dataset.prefetch` to overlap -host-to-device data transfers with compute. The DataLoader issues prefetch calls on -non-default CUDA streams; the Dataset records the transfer and synchronises the -stream before returning the data. This means the next batch is already on the GPU -while the model is processing the current one. +When called by a DataLoader, the Dataset can overlap host-to-device transfers with +compute. The DataLoader issues prefetch calls on non-default CUDA streams; the +Dataset records the transfer and synchronises the stream before returning the data. +This means the next batch can already be on the GPU while the model is processing +the current one. + +The canonical synchronous batch API is +{py:meth}`~nvalchemi.data.datapipes.dataset.Dataset.load_batches`. It accepts a +sequence of batch-index lists and returns one {py:class}`nvalchemi.data.Batch` per +input list. Even for a single emitted batch, this path goes through one +`reader.read_many(...)` request so batch-capable readers can use the same +coalesced I/O implementation everywhere: + +```python +batches = dataset.load_batches([[0, 4, 2], [8, 1, 3]]) +batch0, batch1 = batches +``` + +For asynchronous loader iteration, the important path is fused prefetch: +{py:meth}`~nvalchemi.data.datapipes.dataset.Dataset.prefetch_fused_batches` +accepts several upcoming DataLoader batches, flattens their indices into one +`reader.read_many(...)` request, and then splits the loaded samples back into the +original batch boundaries. This improves I/O throughput without requiring the +sampler to choose storage-friendly windows. ### Lightweight metadata access @@ -103,9 +202,11 @@ from nvalchemi.data.datapipes.dataloader import DataLoader loader = DataLoader( dataset=dataset, - batch_size=32, - prefetch_factor=2, + batch_size=64, + prefetch_factor=16, num_streams=2, + use_streams=True, + pin_memory=True, ) for batch in loader: @@ -118,14 +219,126 @@ Key parameters: | Parameter | Purpose | |---------------------|--------------------------------------------------------------| | `batch_size` | Number of graphs per batch | -| `prefetch_factor` | How many **batches** to load ahead of the current one | +| `prefetch_factor` | How many **batches** to fuse into each background read ([tuning guide](read_performance_tuning)) | | `num_streams` | Number of CUDA streams used for overlapping transfers | +| `use_streams` | Whether to enable CUDA-stream prefetching when CUDA is available | +| `pin_memory` | Request page-locked CPU tensors from readers that support pinned memory | | `sampler` | Controls index ordering (defaults to sequential or random) | +| `batch_sampler` | Supplies complete batches of indices and overrides `batch_size`, `shuffle`, and `sampler` | Unlike PyTorch's `torch.utils.data.DataLoader`, this implementation returns {py:class}`nvalchemi.data.Batch` objects (disjoint graphs with proper node-index offsets) rather than generic collated tensors. +### Batch throughput and fused prefetch + +`batch_size` controls the number of samples emitted to the training loop. +`prefetch_factor` controls how many emitted batches are fused into one background +backend read. For positive `prefetch_factor`, together they define the effective +read window: + +```text +effective_read_window = batch_size * prefetch_factor +``` + +For example, `batch_size=64` and `prefetch_factor=16` produces batches of 64 +graphs for the model, but the reader sees read requests of up to 1,024 logical +indices. The model-facing batch size stays unchanged; only the storage access +window grows. + +This distinction is useful for graph-like data with shuffled access: + +- Samplers remain semantic: they decide ordering and batch membership based on + training needs, size limits, or distributed partitioning. +- Readers remain physical: they can exploit chunk locality, sort by physical + position, merge adjacent ranges, and amortize per-call overhead. +- Dataset and DataLoader connect the two by converting several upcoming batches + into one larger `read_many` request, then yielding the original batch sequence. + +Use `prefetch_factor=0` to disable fused prefetch and issue one backend read per +emitted batch. This is useful for debugging or for stores where large read windows +do not help. For shuffled Zarr training reads, start with `prefetch_factor=16` or +`32`, then benchmark with `nvalchemi-io-test` on a representative store. Enable +`pin_memory=True` for CUDA training so the DataLoader requests page-locked CPU +tensors before asynchronous transfer. See +[Read performance tuning](read_performance_tuning) and the +[I/O benchmark tool](io_benchmark_section) for concrete commands. + +## MultiDataset: composing datasets + +{py:class}`~nvalchemi.data.datapipes.multidataset.MultiDataset` concatenates +multiple {py:class}`~nvalchemi.data.datapipes.dataset.Dataset` instances behind +one global index space. It follows the PhysicsNeMo multidataset indexing contract +while preserving the nvalchemi batch fast path: `load_batches(...)` routes each +global batch to the relevant child datasets and recombines mixed-child batches in +the requested sample order. + +```python +from nvalchemi.data.datapipes import ( + AtomicDataZarrReader, + DataLoader, + Dataset, + MultiDataset, + MultiDatasetBatchSampler, +) + +dataset_a = Dataset(AtomicDataZarrReader("dataset_a.zarr"), device="cuda:0") +dataset_b = Dataset(AtomicDataZarrReader("dataset_b.zarr"), device="cuda:0") +dataset = MultiDataset(dataset_a, dataset_b, output_strict=True) + +batch_sampler = MultiDatasetBatchSampler.balanced( + dataset, + batch_size=64, + epoch_policy="max_size", + replacement=True, +) + +loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + prefetch_factor=16, + pin_memory=True, +) +``` + +By default, `output_strict=True` requires all non-empty child datasets to expose +the same field names. Empty children are skipped when choosing the reference +field set. Use `output_strict=False` only when downstream code can handle +source-specific fields. + +### Multidataset sampler policies + +The multidataset samplers operate on global indices but allocate samples at the +child-dataset level: + +| Sampler | Use case | +|---------------------|--------------------------------------------------------------| +| {py:class}`~nvalchemi.data.datapipes.samplers.MultiDatasetSampler` | Draw individual samples from child datasets at custom rates | +| {py:class}`~nvalchemi.data.datapipes.samplers.MultiDatasetBatchSampler` | Build batches with explicit or weighted per-dataset allocations | +| {py:meth}`~nvalchemi.data.datapipes.samplers.MultiDatasetBatchSampler.balanced` | Build dataset-balanced batches | + +`samples_per_dataset` accepts integer counts or floating-point relative ratios. +For example, `[1.0, 3.0]` allocates roughly one quarter of each batch to the first +dataset and three quarters to the second dataset. + +When `num_batches` is omitted, `epoch_policy` controls the default epoch length: + +| `epoch_policy` | Behavior | +|----------------|----------| +| `"dataset_size"` | Preserve the historical default based on total dataset size | +| `"min_size"` | Stop when the smallest contributing dataset would be exhausted | +| `"max_size"` | Run until the largest contributing dataset is covered, oversampling smaller datasets when `replacement=True` | + +Use `"max_size"` for balanced training over datasets of different sizes when you +want smaller datasets to be oversampled instead of dominated by the largest +dataset. Without replacement, `"max_size"` raises if oversampling would be +required. + +For data-parallel training, the multidataset samplers can shard sample or batch +orders across ranks. See {ref}`distributed_manager_guide` for examples ranging +from the default `DDPHook` sampler injection to distributed +`MultiDatasetBatchSampler` composition. + ## Transforms: per-sample and per-batch hooks Both {py:class}`~nvalchemi.data.datapipes.dataset.Dataset` and @@ -233,15 +446,22 @@ from nvalchemi.data.datapipes.dataloader import DataLoader from nvalchemi.dynamics.sampler import SizeAwareSampler reader = AtomicDataZarrReader("/path/to/store.zarr") -dataset = Dataset(reader=reader, device="cuda:0", num_workers=4) +dataset = Dataset( + reader=reader, + device="cuda:0", + num_workers=1, + skip_validation=True, +) sampler = SizeAwareSampler(dataset=dataset, max_atoms=4096) loader = DataLoader( dataset=dataset, batch_size=64, # upper bound; sampler may produce smaller batches sampler=sampler, - prefetch_factor=2, + prefetch_factor=16, num_streams=2, + use_streams=True, + pin_memory=True, ) for batch in loader: @@ -259,5 +479,5 @@ for batch in loader: - **Compression**: The [Zarr Compression Tuning Guide](zarr_compression_guide) covers how to configure compression and chunking when writing Zarr stores. - **I/O benchmark**: The [I/O benchmark tool](io_benchmark_section) lets you - measure write throughput and compression ratios on synthetic data before - choosing a configuration. + measure write throughput, readback throughput, and compression ratios on + synthetic data before choosing a configuration. diff --git a/docs/userguide/distributed_training.md b/docs/userguide/distributed_training.md new file mode 100644 index 00000000..8a2e1abb --- /dev/null +++ b/docs/userguide/distributed_training.md @@ -0,0 +1,223 @@ + + +(distributed_manager_guide)= + +# Distributed Training + +`DistributedManager` is the recommended entry point for distributed runtime +state in ALCHEMI training workflows. ALCHEMI re-exports PhysicsNeMo's manager as +`nvalchemi.distributed.DistributedManager` so training code can use one object +for process rank, local rank, world size, device selection, process groups, and +DistributedDataParallel defaults. + +You can still manage `torch.distributed` directly in advanced workflows. Passing +a `DistributedManager` to {py:class}`~nvalchemi.training.TrainingStrategy` gives +ALCHEMI hooks a shared view of the distributed runtime without each hook needing +to read environment variables or initialize communication on its own. + +## Basic pattern + +Initialize the manager before constructing it, then pass the instance into the +strategy. {py:class}`~nvalchemi.training.hooks.DDPHook` uses the manager during +setup to choose the rank-local device, wrap optimized models in +`torch.nn.parallel.DistributedDataParallel`, and install a distributed sampler +for supported dataloaders. + +```python +from nvalchemi.distributed import DistributedManager +from nvalchemi.training import TrainingStrategy +from nvalchemi.training.hooks import DDPHook + +DistributedManager.initialize() +manager = DistributedManager() + +strategy = TrainingStrategy( + ..., + distributed_manager=manager, + hooks=[ + DDPHook(), + ], +) + +strategy.run(train_loader) +``` + +Launch the script with the process launcher for your environment. For a simple +single-node PyTorch launch: + +```bash +$ torchrun --nproc_per_node=4 train.py +``` + +`DistributedManager.initialize()` also supports single-process execution. In +that case `DDPHook` is a no-op because the world size is one, so the same script +can run locally and under a distributed launcher. + +For a complete single-node dummy training script, see +{doc}`/examples/intermediate/06_ddp_mlp_training`. It can be launched with: + +```bash +$ uv run --extra cu12 torchrun --standalone --nproc_per_node=2 \ + examples/intermediate/06_ddp_mlp_training.py --backend auto +``` + +## Data loaders and samplers + +Each data-parallel rank must see a different slice of the training data. The +right composition depends on whether you use the default sampler, a custom +sampler, or a sampler that already emits complete batches. + +### Simple case: let DDPHook install the sampler + +For standard dataloaders with a `dataset` and mutable `sampler`, use an ordinary +loader and let {py:class}`~nvalchemi.training.hooks.DDPHook` install +`torch.utils.data.DistributedSampler` during strategy setup. The hook infers +`num_replicas`, `rank`, `shuffle`, and `drop_last` from the distributed manager +and dataloader, and uses `seed=0` unless overridden. + +```python +from nvalchemi.data.datapipes import DataLoader, Dataset +from nvalchemi.distributed import DistributedManager +from nvalchemi.training import TrainingStrategy +from nvalchemi.training.hooks import DDPHook + +DistributedManager.initialize() +manager = DistributedManager() + +dataset = Dataset(reader, device=manager.device) +train_loader = DataLoader( + dataset, + batch_size=64, + shuffle=True, + pin_memory=True, +) + +strategy = TrainingStrategy( + ..., + distributed_manager=manager, + hooks=[DDPHook()], +) +strategy.run(train_loader) +``` + +This is the preferred starting point for a single dataset. The loader stays +single-process friendly: when `manager.world_size == 1`, `DDPHook` leaves it +unchanged. + +Use `sampler_kwargs` to override arguments passed to the default sampler: + +```python +DDPHook( + sampler_kwargs={ + "shuffle": False, + "seed": 1234, + }, +) +``` + +### Custom distributed sampler + +If a dataloader already has a distributed-aware sampler, `DDPHook` preserves it +instead of replacing it. A sampler is considered distributed-aware when it +satisfies {py:class}`~nvalchemi.data.datapipes.samplers.DistributedSamplerProtocol`: +it exposes `num_replicas`, `rank`, and `set_epoch(epoch)`. Native PyTorch +`DistributedSampler` satisfies this protocol. + +For a sampler class or factory that accepts PyTorch-style distributed sampler +arguments, pass it to `DDPHook`. The hook supplies `num_replicas`, `rank`, +`shuffle`, `seed`, and `drop_last` defaults before applying your +`sampler_kwargs`. + +```python +DDPHook( + sampler_cls=MyDistributedSampler, + sampler_kwargs={ + "seed": 1234, + }, +) +``` + +If your sampler uses different constructor names, pass those names explicitly in +`sampler_kwargs`. + +```python +DDPHook( + sampler_cls=MyDistributedSampler, + sampler_kwargs={ + "replicas": manager.world_size, + "worker_rank": manager.rank, + }, +) +``` + +### Multidataset batch sampling + +When a dataloader is constructed with `batch_sampler`, the sampler is already +responsible for emitting complete batches. In that case, `DDPHook` cannot safely +replace the sampler with a plain `DistributedSampler`; the batch sampler itself +must be distributed-aware. + +Use {py:class}`~nvalchemi.data.datapipes.samplers.MultiDatasetBatchSampler` when +you need per-dataset batch composition and distributed sharding together. Pass +the initialized manager to the sampler so each rank receives a different shard of +the batch sequence. + +```python +from nvalchemi.data.datapipes import ( + AtomicDataZarrReader, + DataLoader, + Dataset, + MultiDataset, + MultiDatasetBatchSampler, +) +from nvalchemi.distributed import DistributedManager +from nvalchemi.training import TrainingStrategy +from nvalchemi.training.hooks import DDPHook + +DistributedManager.initialize() +manager = DistributedManager() + +dataset = MultiDataset( + Dataset(AtomicDataZarrReader("dataset_a.zarr"), device=manager.device), + Dataset(AtomicDataZarrReader("dataset_b.zarr"), device=manager.device), +) + +batch_sampler = MultiDatasetBatchSampler.balanced( + dataset, + batch_size=64, + epoch_policy="max_size", + replacement=True, + distributed_manager=manager, + seed=1234, +) + +train_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + prefetch_factor=16, + pin_memory=True, +) + +strategy = TrainingStrategy( + ..., + distributed_manager=manager, + hooks=[DDPHook()], +) +strategy.run(train_loader) +``` + +`MultiDatasetBatchSampler` first builds the global batch order according to its +per-dataset allocation policy, then splits that batch order across data-parallel +ranks. With `drop_last=False`, it pads the batch order so each rank emits the +same number of batches, matching PyTorch `DistributedSampler` behavior. With +`drop_last=True`, it truncates the uneven tail instead. + +Use {py:meth}`~nvalchemi.data.datapipes.dataloader.DataLoader.set_epoch` or let +{py:class}`~nvalchemi.training.TrainingStrategy` call it during training so +distributed samplers reshuffle deterministically from epoch to epoch. + +## API details + +For the complete manager API, including process-group methods and distributed +configuration knobs, see the +[PhysicsNeMo DistributedManager API](https://docs.nvidia.com/physicsnemo/latest/physicsnemo/api/physicsnemo.distributed.html#physicsnemo.distributed.manager.DistributedManager). diff --git a/docs/userguide/dynamics.md b/docs/userguide/dynamics.md index fe3886b0..fc17693b 100644 --- a/docs/userguide/dynamics.md +++ b/docs/userguide/dynamics.md @@ -272,7 +272,13 @@ digraph buffer_sync { d_sinks -> d_batch [label="drain when\ncapacity available" style=dotted] } - u_send -> d_recv [label="isend / irecv\n(NCCL)" style=bold color="#c0392b" fontcolor="#c0392b" penwidth=2] + u_send -> d_recv [ + label="isend / irecv\n(NCCL)"; + style=bold; + color="#c0392b"; + fontcolor="#c0392b"; + penwidth=2; + ] } ``` diff --git a/docs/userguide/hooks.md b/docs/userguide/hooks.md index 3e7a38dd..0a0f5eed 100644 --- a/docs/userguide/hooks.md +++ b/docs/userguide/hooks.md @@ -102,6 +102,8 @@ Training loops pass {py:class}`~nvalchemi.hooks.TrainContext`, which adds: | Field | Type | Meaning | |-------|------|---------| | `step_count` | `int` | Current optimizer step | +| `batch_count` | `int` | Training batches consumed, including skipped optimizer steps | +| `epoch_step_count` | `int` | Batches consumed within the current epoch | | `epoch` | `int` | Current epoch | | `loss` | `torch.Tensor \| None` | Aggregate loss | | `losses` | `dict[str, torch.Tensor] \| None` | Named loss components | @@ -216,7 +218,7 @@ hook = ConvergenceHook( In a single-stage simulation (no status arguments), convergence simply causes those systems to stop being updated. -### LoggingHook +### Dynamics LoggingHook {py:class}`~nvalchemi.dynamics.hooks.LoggingHook` records scalar observables (energy, temperature, maximum force, etc.) at a configurable interval: @@ -228,6 +230,22 @@ hook = LoggingHook(backend="csv", log_path="hooks.csv", frequency=10) # log eve ``` The hook implements the context manager protocol to manage its logger lifecycle. +It is the current built-in dynamics logger, not the full logging abstraction for +all workflows. + +### Logging vs. reporting + +Use logging hooks when you want simple, direct records from a workflow: rows, +files, or lightweight backend writes that are easy to inspect later. Logging is +workflow-general; dynamics and training can each have loggers that understand +their own event model. For example, the built-in dynamics `LoggingHook` writes +per-graph dynamics observables to CSV, TensorBoard, or a custom sink without +imposing a higher-level analysis model. + +Use reporting when you want workflow-level summaries: scalar collection, +rank-aware reductions, serialized reporting snapshots, live dashboards, or +analysis-facing output across training and dynamics. The reporting abstractions +are described separately in the {doc}`reporting user guide `. ### SnapshotHook @@ -356,9 +374,10 @@ class UniversalLoggerHook: print(f"[custom] stage={stage.name}, graphs={ctx.batch.num_graphs}") ``` -The built-in {py:class}`~nvalchemi.dynamics.hooks.ProfilerHook` uses this -pattern to instrument dynamics and custom workflows with appropriate -NVTX domain annotations. +Cross-category hooks such as {py:class}`~nvalchemi.hooks.TorchProfilerHook` use +this pattern to claim the training and dynamics stages they support. +{py:class}`~nvalchemi.hooks.StageTimingHook` uses the same multi-stage hook +protocol for lightweight per-stage timing. ### Resource management with `__enter__` / `__exit__` @@ -390,6 +409,70 @@ class FileWriterHook: self._file.close() ``` +### Restartable hooks with `CheckpointableHook` + +Hooks are stateless by default. If a hook owns state that changes training +semantics after a restart (for example EMA weights, a dynamic schedule, or a +history buffer), make it satisfy +{py:class}`~nvalchemi.hooks.CheckpointableHook` by adding `state_dict()` and +`load_state_dict()`. Training checkpoints discover this protocol at runtime and +store only hooks that opt in. + +Pydantic-backed hooks should keep declarative configuration in model fields and +use `model_dump()` for the configuration part of `state_dict()`. Use +`model_dump_json()` when you need a JSON representation for logs or separate +configuration files. Runtime tensors or counters that are not Pydantic fields +can then be added explicitly. + +```python +from collections.abc import Mapping +from typing import Any + +import torch +from pydantic import BaseModel, Field, PrivateAttr + +from nvalchemi.hooks import CheckpointableHook +from nvalchemi.training import TrainingStage +from nvalchemi.training.hooks import TrainingUpdateHook + +class RunningLossHook(BaseModel, TrainingUpdateHook): + window: int = Field(gt=0, default=100) + num_updates: int = 0 + + _loss_sum: torch.Tensor | None = PrivateAttr(default=None) + + def __call__(self, ctx, stage, will_skip): + if ( + stage is TrainingStage.AFTER_OPTIMIZER_STEP + and not will_skip + and ctx.loss is not None + ): + value = ctx.loss.detach().to("cpu") + self._loss_sum = ( + value if self._loss_sum is None else self._loss_sum + value + ) + self.num_updates += 1 + return True, ctx.loss + + def state_dict(self) -> dict[str, Any]: + state = self.model_dump() + if self._loss_sum is not None: + state["loss_sum"] = self._loss_sum + return state + + def load_state_dict(self, state: Mapping[str, Any]) -> None: + if "window" in state and state["window"] != self.window: + raise ValueError("RunningLossHook checkpoint window does not match") + self.num_updates = int(state.get("num_updates", self.num_updates)) + self._loss_sum = state.get("loss_sum") + +assert isinstance(RunningLossHook(), CheckpointableHook) +``` + +Only implement this protocol for state that must survive restart. Temporary +resources, cached buffers that can be rebuilt, and bookkeeping derived from the +workflow counters should stay out of hook checkpoints. + ## Composing hooks Hooks are independent and composable. A typical production setup combines diff --git a/docs/userguide/index.md b/docs/userguide/index.md index f44c21d7..b1f14c3c 100644 --- a/docs/userguide/index.md +++ b/docs/userguide/index.md @@ -33,11 +33,14 @@ $ python -c "import nvalchemi; print(nvalchemi.__version__)" - [AtomicData and Batch](data) - [Data Loading Pipeline](datapipes) - {doc}`Models: Wrapping ML Interatomic Potentials ` +- {doc}`Losses: Composable Training Terms ` - {doc}`Hooks: Observe & Modify ` +- {doc}`Reporting: Summaries and Dashboards ` - [Dynamics: Optimization and MD](dynamics) ## Advanced Usage +- [Distributed Training](distributed_training) - [Zarr Compression Tuning](zarr_compression) - [Agent Skills](agent_skills) @@ -62,7 +65,9 @@ about/contributing data datapipes models +losses hooks +reporting dynamics ``` @@ -71,6 +76,7 @@ dynamics :maxdepth: 1 :hidden: +distributed_training zarr_compression agent_skills ``` diff --git a/docs/userguide/losses.md b/docs/userguide/losses.md new file mode 100644 index 00000000..bf154843 --- /dev/null +++ b/docs/userguide/losses.md @@ -0,0 +1,888 @@ + + +(losses_guide)= + +# Losses + +Loss functions in ALCHEMI are tensor-first, composable +{py:class}`torch.nn.Module` objects. A **leaf loss** consumes a +prediction tensor and a target tensor and returns a scalar; a +{py:class}`~nvalchemi.training.ComposedLossFunction` *routes* keyed +mappings of predictions and targets into each leaf, applies the +composition's per-component weights, and returns a structured +{py:class}`~nvalchemi.training.ComposedLossOutput` with a `total_loss` +plus per-component contributions. + +This page covers: + +- the built-in leaf losses and how to call them directly; +- {py:class}`~nvalchemi.training.ComposedLossFunction` for multi-task + training and where per-loss coefficients live; +- loss-weight scheduling via the + {py:class}`~nvalchemi.training.LossWeightSchedule` protocol, applied + at the composition level; +- how to write your own loss — first a pure tensor-to-tensor loss, + then a metadata-aware one. + +```{tip} +Leaves are tensor-first: they consume plain `(pred, target)` plus +optional `**kwargs`. For how graph metadata is threaded through, see +[Passing graph metadata](passing_graph_metadata). +``` + +## Built-in losses + +The built-in losses cover standard MLIP training targets and additional +MAE/L2 norm tensor reductions. Each is a {py:class}`torch.nn.Module` with +configurable `target_key` / `prediction_key` attributes used by +composition. The MSE-style losses expose an opt-in `ignore_nonfinite` flag; +the MAE/L2 norm losses expose `ignore_nonfinite` and mask target `NaN` +and `inf` values. + +| Class | Target | Key defaults | Extra knobs | +|-------|--------|--------------|-------------| +| {py:class}`~nvalchemi.training.EnergyMSELoss` | Per-graph energy `(B, 1)` | `"energy"` / `"predicted_energy"` | `per_atom` normalization, `ignore_nonfinite` | +| {py:class}`~nvalchemi.training.EnergyMAELoss` | Per-graph energy `(B, 1)` or `(B,)` | `"energy"` / `"predicted_energy"` | MAE reduction, `per_atom`, `ignore_nonfinite` | +| {py:class}`~nvalchemi.training.EnergyHuberLoss` | Per-graph energy `(B, 1)` | `"energy"` / `"predicted_energy"` | Huber residual, `per_atom`, `delta`, `ignore_nonfinite` | +| {py:class}`~nvalchemi.training.ForceMSELoss` | Per-atom forces, dense `(V, 3)` or padded `(B, V_max, 3)` | `"forces"` / `"predicted_forces"` | `normalize_by_atom_count`, `ignore_nonfinite` | +| {py:class}`~nvalchemi.training.ForceHuberLoss` | Per-atom forces, dense `(V, 3)` or padded `(B, V_max, 3)` | `"forces"` / `"predicted_forces"` | Huber residual, `normalize_by_atom_count`, `delta`, `ignore_nonfinite` | +| {py:class}`~nvalchemi.training.ForceL2NormLoss` | Per-atom forces, dense `(V, 3)` or padded `(B, V_max, 3)` | `"forces"` / `"predicted_forces"` | Vector-L2 reduction, `normalize_by_atom_count`, `ignore_nonfinite` | +| {py:class}`~nvalchemi.training.StressMSELoss` | Per-graph stress `(B, 3, 3)` | `"stress"` / `"predicted_stress"` | `ignore_nonfinite` | +| {py:class}`~nvalchemi.training.StressHuberLoss` | Per-graph stress `(B, 3, 3)` | `"stress"` / `"predicted_stress"` | Huber residual, `delta`, `ignore_nonfinite` | + +### Calling a leaf loss directly + +A leaf loss is a plain `nn.Module`. For losses that do not require +graph metadata — `EnergyMSELoss(per_atom=False)` (the default), dense +`ForceMSELoss(normalize_by_atom_count=False)`, +`ForceHuberLoss(normalize_by_atom_count=False)`, +`StressMSELoss`, `StressHuberLoss`, `EnergyMAELoss(per_atom=False)`, +and dense `ForceL2NormLoss(normalize_by_atom_count=False)` — call it +with `(pred, target)` and get a scalar back. Leaves carry no weight or +schedule of their own; a direct call returns the unweighted value: + +```python +import torch +from nvalchemi.training import EnergyMSELoss + +loss_fn = EnergyMSELoss() +pred = torch.randn(4, 1, requires_grad=True) +target = torch.randn(4, 1) + +loss = loss_fn(pred, target) # scalar Tensor +loss.backward() +``` + +`ForceMSELoss()` and `ForceL2NormLoss()` (default +`normalize_by_atom_count=True`), `EnergyHuberLoss()` (default +`per_atom=True`), and both energy losses with `per_atom=True` require +graph metadata and will raise `ValueError` on a bare `(pred, target)` +call. Either pass metadata kwargs (see +[Passing graph metadata](passing_graph_metadata)) or, for dense `(V, 3)` +forces, disable the per-graph normalization for a tensor-only call: + +```python +from nvalchemi.training import ForceL2NormLoss, ForceMSELoss + +force_fn = ForceMSELoss(normalize_by_atom_count=False) # plain MSE over (V, 3) +force_pred = torch.randn(10, 3, requires_grad=True) +force_target = torch.randn(10, 3) +loss = force_fn(force_pred, force_target) # no metadata needed + +l2_fn = ForceL2NormLoss(normalize_by_atom_count=False) +l2_loss = l2_fn(force_pred, force_target) # no metadata needed +``` + +Padded `(B, V_max, 3)` forces still require `num_nodes_per_graph` even +with `normalize_by_atom_count=False`, since padding rows must be +masked before reduction. + +#### Expected shape layouts + +Built-in leaves call `assert_same_shape(..., strict=True)`, so +prediction and target shapes must match exactly. The table below lists +the layouts these losses are designed for. + +| Loss | `pred` shape | `target` shape | +|------|--------------|----------------| +| `EnergyMSELoss` | `(B, 1)` | `(B, 1)` | +| `EnergyMAELoss` | `(B, 1)` or `(B,)` | exact same shape as `pred` | +| `EnergyHuberLoss` | `(B, 1)` | `(B, 1)` | +| `ForceMSELoss` (dense) | `(V, 3)` | `(V, 3)` | +| `ForceMSELoss` (padded) | `(B, V_max, 3)` | `(B, V_max, 3)` | +| `ForceHuberLoss` (dense) | `(V, 3)` | `(V, 3)` | +| `ForceHuberLoss` (padded) | `(B, V_max, 3)` | `(B, V_max, 3)` | +| `ForceL2NormLoss` (dense) | `(V, 3)` | `(V, 3)` | +| `ForceL2NormLoss` (padded) | `(B, V_max, 3)` | `(B, V_max, 3)` | +| `StressMSELoss` | `(B, 3, 3)` | `(B, 3, 3)` | +| `StressHuberLoss` | `(B, 3, 3)` | `(B, 3, 3)` | + +```{warning} +`(B, 1)` versus `(B,)` is broadcast-compatible but rejected by the +built-ins. Keep the explicit trailing `1` on per-graph tensors unless +both prediction and target intentionally use the `(B,)` layout supported +by `EnergyMAELoss`. +``` + +Leaf losses do not receive schedule counters. `step=` and `epoch=` +belong to {py:class}`~nvalchemi.training.ComposedLossFunction`, which +uses them to resolve schedule-driven weights before calling each leaf +(see [Composition weights and schedules](composition_weights)). + +(passing_graph_metadata)= + +### Passing graph metadata + +Concrete losses may require graph metadata as keyword arguments. For +example, `ForceMSELoss` with the default graph-balanced normalization +needs `batch_idx` and `num_graphs` for dense `(V, 3)` forces: + +```python +from nvalchemi.training import ForceMSELoss + +force_fn = ForceMSELoss() # normalize_by_atom_count=True + +pred = torch.randn(10, 3, requires_grad=True) +target = torch.randn(10, 3) +batch_idx = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2]) + +loss = force_fn(pred, target, batch_idx=batch_idx, num_graphs=3) +``` + +The same loss accepts a padded `(B, V_max, 3)` layout with per-graph +counts instead: + +```python +pred_padded = torch.randn(3, 4, 3, requires_grad=True) +target_padded = torch.randn(3, 4, 3) +counts = torch.tensor([3, 4, 3]) + +loss = force_fn(pred_padded, target_padded, num_nodes_per_graph=counts) +``` + +{py:class}`~nvalchemi.training.EnergyMSELoss`, +{py:class}`~nvalchemi.training.EnergyMAELoss`, +{py:class}`~nvalchemi.training.EnergyHuberLoss`, +{py:class}`~nvalchemi.training.ForceMSELoss`, +{py:class}`~nvalchemi.training.ForceHuberLoss`, and +{py:class}`~nvalchemi.training.ForceL2NormLoss` accept an optional +`batch=` keyword argument as a convenience source for metadata when the +selected reduction needs it. When `batch=` is provided, the loss pulls +`batch_idx`, `num_graphs`, and `num_nodes_per_graph` directly from it: + +```python +# Batch-derived metadata — shorter callsite +loss = force_fn(pred, target, batch=batch) + +# Equivalent explicit call — fine-grained control +loss = force_fn( + pred, target, + batch_idx=batch.batch_idx, + num_graphs=batch.num_graphs, +) +``` + +Explicit kwargs always win when both are provided — useful if you want +to override `num_graphs` for a sub-batch without rebuilding a `Batch`. +A duck-typed `batch` that's missing a required attribute still falls +through to the descriptive `ValueError` raised by the metadata +resolver, so you don't have to pre-validate it. + +### Ignoring missing labels + +`EnergyMSELoss`, `ForceMSELoss`, and `StressMSELoss` have an `ignore_nonfinite=False` +flag. When `True`, target entries equal to `NaN` contribute zero to both +the loss value and the gradient — a "nanmean"-style reduction +implemented with branch-free tensor ops so it stays `torch.compile`-safe: + +```python +energy_loss = EnergyMSELoss(ignore_nonfinite=True) + +target = torch.tensor([[1.0], [float("nan")], [3.0]]) +pred = torch.zeros_like(target, requires_grad=True) + +loss = energy_loss(pred, target) +loss.backward() + +assert torch.isfinite(loss) +assert pred.grad[1].item() == 0.0 # masked row has zero gradient +``` + +`NaN` targets contribute zero loss and zero gradient; a graph whose +target is entirely `NaN` contributes exactly `0.0` because the numerator +and denominator both go to zero and the denominator is clamp-min'd to +`1`. The default (`ignore_nonfinite=False`) lets `NaN` propagate, which is +usually what you want during development when a label *shouldn't* be +missing. + +```{warning} +For these MSE-style losses, only target `NaN`s are treated as missing +labels. Prediction `NaN`s still propagate whenever the corresponding +target is finite; if the target is `NaN`, that position contributes zero +loss and zero gradient. Do not rely on `ignore_nonfinite` to hide model +explosions. +``` + +### MAE and force-L2 reductions + +`EnergyMAELoss` and `ForceL2NormLoss` implement tensor reductions only. +They do not apply dataset normalization, target transforms, +element-reference corrections, or any other preprocessing; apply those +outside the loss before passing tensors in. + +`EnergyMAELoss` computes absolute energy residuals and defaults to +`per_atom=True`: prediction and target are divided by +`num_nodes_per_graph`, then reduced with atom-count weights so that +larger graphs contribute in proportion to their size — matching the +reduction semantics of `EnergyMSELoss(per_atom=True)`. + +`ForceL2NormLoss` computes a per-atom vector norm before reduction: + +```python +per_atom = torch.linalg.vector_norm(predicted_forces - forces, ord=2, dim=-1) +``` + +With `normalize_by_atom_count=True`, dense forces use `batch_idx` and +`num_graphs` to compute a valid-atom mean per graph, then mean over +graphs; padded forces use `num_nodes_per_graph` counts or a node mask to +exclude padding before the same per-graph reduction. With +`normalize_by_atom_count=False`, the scalar is a global mean over valid +atom L2 norms. + +Both MAE/L2 norm losses have `ignore_nonfinite=True` by default and use +`torch.isfinite(target)` (`.all(dim=-1)` for force vectors), excluding +target `NaN` and `inf` labels while preserving gradients through valid +prediction entries. + +(shape_validation)= + +### Shape and dtype validation + +Built-in leaves opt in to shape and dtype validation via the +{py:meth}`~nvalchemi.training.BaseLossFunction.validate` hook, which +calls {py:func}`nvalchemi.training.losses.assert_same_shape`: + +```python +from nvalchemi.training.losses import assert_same_shape + +assert_same_shape( + pred, target, + name="MyLoss", + prediction_key="predicted_energy", + target_key="energy", +) +``` + +`assert_same_shape` checks strict `dtype` equality first. With its +default `strict=False`, it then uses `torch.broadcast_shapes` to verify +shape compatibility — so `(B, 1)` vs. `(B,)` passes (broadcastable) but +mismatched dtypes do not. With `strict=True`, it requires exact shape +equality. The helper raises `ValueError` with the component `name` and +the prediction/target keys embedded in the message. + +Validation is opt-in because some legitimate losses (e.g. dipole +derived from per-atom charges) have `pred.shape != target.shape` by +design. When writing a custom loss, call `assert_same_shape` at the +top of your `forward` with `strict=True` if pred and target are supposed +to match exactly; use the default broadcast-compatible policy only when +that is intentional. Skip the call when they don't. Note that +`assert_same_shape` is exported from `nvalchemi.training.losses` only — +it is not re-exported from the top-level `nvalchemi.training`. + +## Composition + +Real training objectives typically combine several targets. The idiomatic way is +to add leaves together and use the resulting +{py:class}`~nvalchemi.training.ComposedLossFunction`: + +```python +from nvalchemi.training import EnergyMSELoss, ForceMSELoss, StressMSELoss + +loss_fn = EnergyMSELoss() + ForceMSELoss() + StressMSELoss() +``` + +`loss_fn` is an `nn.Module` whose components sit in an +`nn.ModuleList`, so `.to(device)`, `.state_dict()`, `.modules()`, and +the nested `__repr__` work the way you'd expect. Adding a +`ComposedLossFunction` to another loss flattens transparently: + +```python +loss_fn_a = EnergyMSELoss() + ForceMSELoss() +loss_fn_b = loss_fn_a + StressMSELoss() # still 3 flat components +``` + +### The call signature + +A composed loss takes **keyed mappings**, not tensors: + +```python +def loss_fn( + predictions: Mapping[str, torch.Tensor], + targets: Mapping[str, torch.Tensor], + *, + step: int = 0, + epoch: int | None = None, + **kwargs, +) -> ComposedLossOutput: ... +``` + +Each component reads its own `prediction_key` and `target_key` +attributes to pull tensors out of the two mappings. Any extra `**kwargs` +(graph metadata, for example) are forwarded unchanged to every leaf; +each leaf consumes the kwargs it needs and ignores the rest. + +```python +predictions = { + "predicted_energy": model_outputs["energy"], + "predicted_forces": model_outputs["forces"], + "predicted_stress": model_outputs["stress"], +} +targets = { + "energy": batch.energy, + "forces": batch.forces, + "stress": batch.stress, +} + +out = loss_fn( + predictions, targets, + step=global_step, epoch=epoch, + batch_idx=batch.batch_idx, + num_graphs=batch.num_graphs, + num_nodes_per_graph=batch.num_nodes_per_graph, +) + +out["total_loss"].backward() +``` + +Or equivalently `loss_fn(predictions, targets, step=..., epoch=..., +batch=batch)`; see [Passing graph metadata](passing_graph_metadata). + +### The return type + +`ComposedLossFunction.forward` returns a +{py:class}`~nvalchemi.training.ComposedLossOutput` — a +{py:class}`typing.TypedDict` with five fields: + +| Field | Type | Meaning | +|-------|------|---------| +| `total_loss` | `torch.Tensor` | Scalar sum of `effective_weight * component_loss` across components. `.backward()` on this. | +| `per_component_unweighted` | `dict[str, torch.Tensor]` | Raw per-component loss before applying the effective weight. Keyed by component class name with suffixes on duplicates. | +| `per_component_weight` | `dict[str, float]` | Effective (post-normalization) weights actually applied at this call. | +| `per_component_raw_weight` | `dict[str, float]` | Raw (pre-normalization) weights, equal to `per_component_weight` when `normalize_weights=False`. | +| `per_component_sample` | `dict[str, torch.Tensor]` | Weighted, detached `(B,)` tensors for components that populate `per_sample_loss`. Absent when the leaf stores `None`. See [Per-sample loss diagnostics](#per-sample-loss-diagnostics) below for details (including aggregation caveats). | + +```python +out = loss_fn(predictions, targets) +out["total_loss"].backward() + +for name, value in out["per_component_unweighted"].items(): + logger.log_scalar(f"loss/{name}", value.detach(), step=global_step) +for name, w in out["per_component_weight"].items(): + logger.log_scalar(f"loss_weight/{name}", w, step=global_step) +``` + +Duplicate class names get numeric suffixes (`StressMSELoss_0`, +`StressMSELoss_1`, …) so keys remain unique. + +### Per-sample loss diagnostics + +Every leaf carries an optional `per_sample_loss: torch.Tensor | None` attribute. +Concrete losses populate it as a side effect of `forward` with a detached +per-graph tensor of shape `(B,)`, cleared to `None` at the top of every call. +The scalar return still carries gradients — this attribute is for logging and +diagnostics only. + +| Loss | When populated | Aggregation caveat | +|------|----------------|--------------------| +| `EnergyMSELoss` | Recognizable `(B,)` or `(B, 1)` residuals | `per_atom=True` stores per-graph squared per-atom residuals; scalar applies atom-count weights. `ignore_nonfinite=True` uses a global valid-entry divisor. | +| `EnergyMAELoss` | Supported `(B,)` or `(B, 1)` layouts | `per_atom=True` stores per-graph absolute per-atom residuals; scalar applies atom-count weights. `ignore_nonfinite=True` stores masked entries as zero; scalar divides by valid atom-count-weighted sum. | +| `EnergyHuberLoss` | Recognizable `(B,)` or `(B, 1)` residuals | Same layout caveats as `EnergyMSELoss`; scalar is a graph-balanced mean over labeled structures when `per_atom=True`. | +| `StressMSELoss` | Always | None; per-graph Frobenius MSE is already the scalar mean input. | +| `StressHuberLoss` | Always | Same as `StressMSELoss`; per-graph component Huber mean, then mean over graphs. | +| `ForceMSELoss` | Graph-balanced paths and padded global path | Dense `normalize_by_atom_count=False` leaves it absent. Padded global path divides by total valid components. | +| `ForceHuberLoss` | Same paths as `ForceMSELoss` | Inherits `ForceMSELoss` reduction; default global component mean leaves `per_sample_loss` absent for dense inputs. | +| `ForceL2NormLoss` | Graph-balanced paths and padded global path | Dense `normalize_by_atom_count=False` leaves it absent. Padded global path divides by total valid atoms. | + +`ComposedLossOutput["per_component_sample"]` carries +`effective_weight * component.per_sample_loss` (detached) for each component +that populated the attribute. Components whose `per_sample_loss` was `None` +are **absent** from the dict: + +```python +out = loss(predictions, targets) +if "EnergyMSELoss" in out["per_component_sample"]: + per_graph_energy_loss = out["per_component_sample"]["EnergyMSELoss"] + # shape (B,), detached, weighted by the effective energy weight at this step +``` + +```{note} +For paths with an aggregation caveat, inspect individual components rather than +assuming `per_sample_loss.mean()` equals the scalar return. +``` + +### Routing errors + +`ComposedLossFunction` validates its inputs eagerly and fails with a +focused error when a contract is broken: + +- A missing `prediction_key` or `target_key` in the input mappings + raises `KeyError`. +- A mapping entry that is not a `torch.Tensor` raises `TypeError`. +- A component class without `prediction_key` / `target_key` + attributes (e.g. a bespoke loss you forgot to configure) raises + `AttributeError`. +- A non-finite or non-strictly-positive **sum** of resolved weights + (when `normalize_weights=True`) raises `ValueError` — see + [Weight normalization](weight_normalization) for details. + +(composition_weights)= + +## Composition weights and schedules + +Per-loss coefficients live on +{py:class}`~nvalchemi.training.ComposedLossFunction`, not on leaves. +Leaves have no `weight` argument. A composition stores a parallel +`weights` list — one entry per top-level component — of +`float | LossWeightSchedule | None`. `None` defaults to `1.0`. + +The idiomatic way to assemble a weighted composition is with operator +sugar: + +```python +from nvalchemi.training import EnergyMSELoss, ForceMSELoss, StressMSELoss + +loss_fn = 1.0 * EnergyMSELoss() + 10.0 * ForceMSELoss() + 0.1 * StressMSELoss() +``` + +`3.0 * EnergyMSELoss()` returns a one-component +`ComposedLossFunction([EnergyMSELoss()], weights=[3.0])`. Multiplying a +leaf attaches a weight; subsequent additions combine weights into a +single flat composition. + +For a direct construction with named arguments: + +```python +from nvalchemi.training import ComposedLossFunction, LinearWeight + +loss_fn = ComposedLossFunction( + [EnergyMSELoss(), ForceMSELoss(), StressMSELoss()], + weights=[1.0, LinearWeight(start=0.0, end=10.0, num_steps=1000), 0.1], + normalize_weights=True, +) +``` + +(weight_normalization)= + +### Weight normalization + +`ComposedLossFunction` normalizes its resolved weights to sum to `1.0` +at every call by default (`normalize_weights=True`). That keeps the +loss magnitude independent of how many terms you add and puts +scheduling in control of relative weighting rather than absolute +magnitude. + +Opt out when you want raw arithmetic sums (e.g. if you're reproducing +results from a paper that hard-codes coefficients): + +```python +loss_fn = ComposedLossFunction( + [EnergyMSELoss(), ForceMSELoss()], + weights=[1.0, 10.0], + normalize_weights=False, +) +``` + +For direct summed task losses, construct the composition +explicitly and set `normalize_weights=False` so coefficients are applied +as raw multipliers rather than renormalized relative weights: + +```python +from nvalchemi.training import ComposedLossFunction, EnergyMAELoss, ForceL2NormLoss + +loss_fn = ComposedLossFunction( + [EnergyMAELoss(), ForceL2NormLoss()], + weights=[1.0, 10.0], + normalize_weights=False, +) +``` + +When `normalize_weights=True`, the raw-weight sum must be finite and +strictly positive at every call; otherwise a `ValueError` fires before +any gradient can be computed. + +### Operator sugar and its constraints + +Common forms: `3.0 * EnergyMSELoss()` to attach a weight, +`schedule * EnergyMSELoss()` to attach a schedule, `a + b + c` and +`sum([a, b, c])` to compose. A handful of non-obvious constraints: + +- **`composition + composition`** requires both sides to share the + same `normalize_weights` flag. Mismatch raises `ValueError`; + construct the combined composition explicitly with + `ComposedLossFunction(..., normalize_weights=...)` to choose. +- **`schedule * composition`** is **rejected** with `TypeError`. + Scale each component individually (`schedule * EnergyMSELoss()` and + compose the results) or multiply the composition by a plain float. +- **`bool * loss`** is **rejected** to avoid `True` silently + coercing to `1.0`. Pass `1.0` explicitly. + +### Weight schedules + +Any entry in `weights` may be a +{py:class}`~nvalchemi.training.LossWeightSchedule` instead of a +float. The composition evaluates it at every call with the `(step, +epoch)` you pass to `forward`: + +```python +from nvalchemi.training import ( + ConstantWeight, + CosineWeight, + EnergyMSELoss, + ForceMSELoss, + LinearWeight, + PiecewiseWeight, + StressMSELoss, +) + +energy_sched = ConstantWeight(value=1.0) +force_sched = LinearWeight(start=0.0, end=1.0, num_steps=1000) +stress_sched = PiecewiseWeight( + boundaries=(0, 10, 20), + values=(0.0, 0.5, 1.0, 1.0), + per_epoch=True, +) + +loss_fn = ( + energy_sched * EnergyMSELoss() + + force_sched * ForceMSELoss() + + stress_sched * StressMSELoss() +) + +out = loss_fn(predictions, targets, step=500, epoch=7, batch=batch) +``` + +| Schedule | Shape | Typical use | +|----------|-------|-------------| +| {py:class}`~nvalchemi.training.ConstantWeight` | Flat | Static task weight | +| {py:class}`~nvalchemi.training.LinearWeight` | `start` → `end` over `num_steps`, clamped | Curriculum warm-up | +| {py:class}`~nvalchemi.training.CosineWeight` | Half-cosine `start` → `end`, clamped | Smooth curriculum | +| {py:class}`~nvalchemi.training.PiecewiseWeight` | Step function over boundaries | Phase changes | + +### Step vs. epoch + +Every schedule has a `per_epoch: bool` field. When `False` (the default) +the schedule advances by the `step` argument passed to the loss; when +`True`, it advances by `epoch`. Mixing the two lets most schedules +advance per batch while keeping others, such as a stress-weight +curriculum, aligned with learning-rate epochs. + +A `per_epoch=True` schedule called with `epoch=None` raises +`ValueError` — passing `epoch` is required whenever any attached +schedule opts in. + +### Bring your own schedule + +{py:class}`~nvalchemi.training.LossWeightSchedule` is a +`runtime_checkable` {py:class}`typing.Protocol`: any object with a +`per_epoch` attribute and a `__call__(step: int, epoch: int) -> float` +method qualifies. You don't need to subclass anything to use a custom +schedule in a composition; it just has to quack like one. + +```python +class CappedInverse: + """Return min(1.0, 1.0 / max(step, 1)) — reciprocal step decay.""" + + per_epoch = False + + def __call__(self, step: int, epoch: int) -> float: + return min(1.0, 1.0 / max(step, 1)) + +loss_fn = CappedInverse() * ForceMSELoss() + EnergyMSELoss() +``` + +Subclass the internal `_BaseWeightSchedule` (from +`nvalchemi.training.losses.base`) instead when you want Pydantic +validation and `create_model_spec` round-tripping for checkpoints. + +## Writing your own loss + +{py:class}`~nvalchemi.training.BaseLossFunction` uses a **template-method** +`forward` that orchestrates five hooks: + +1. {py:meth}`~nvalchemi.training.BaseLossFunction.validate` — shape/dtype + checks (default calls `assert_same_shape`). +2. {py:meth}`~nvalchemi.training.BaseLossFunction.normalize` — pre-process + `pred` and `target` (e.g. per-atom energy division) and return a + {py:class}`~nvalchemi.training.ReductionContext` for downstream hooks. +3. {py:meth}`~nvalchemi.training.BaseLossFunction.mask` — produce a boolean + validity tensor (e.g. `torch.isfinite`, padding masks). +4. {py:meth}`~nvalchemi.training.BaseLossFunction.compute_residual` — + **abstract**, the only method every leaf must implement. +5. {py:meth}`~nvalchemi.training.BaseLossFunction.reduce` — collapse the + residual + validity mask to a scalar (default: validity-weighted mean, + incorporating optional `ctx["weights"]`). + +Subclass `BaseLossFunction` and override `compute_residual` at a +minimum. The default hooks handle shape validation, all-valid masking, +and weighted-mean reduction out of the box. Override individual hooks +when you need domain-specific behaviour (per-atom normalization in +`normalize`, padding-aware masking in `mask`, graph-balanced reduction +in `reduce`). Weight scheduling lives on `ComposedLossFunction`, so +your hooks return unweighted values only. + +You may also override `forward` directly to bypass the template — useful +for losses with non-standard signatures — but you lose the composable +hook structure. + +Four conventions worth knowing: + +1. **Define `target_key` and `prediction_key`.** These attributes tell + `ComposedLossFunction` which slots in the prediction/target mappings + to wire into your loss. Without them, your loss works standalone but + cannot participate in a composition. +2. **Accept `**kwargs` in hooks that receive them.** `ComposedLossFunction` + forwards extra metadata kwargs to every component. Swallowing the ones + you don't use keeps your loss composable with any other loss in the mix. +3. **Keep hooks tensor-first.** See + [Passing graph metadata](passing_graph_metadata) for the kwarg + contract. +4. **Override `validate` for non-standard shapes** (skip or customize it + when `pred.shape != target.shape` by design). + +### Example 1: a metadata-aware per-atom energy loss (normalize + compute_residual) + +When your loss depends on graph structure, override `normalize` to +inject per-atom division and return atom-count weights via +{py:class}`~nvalchemi.training.ReductionContext`. The base `reduce` +picks up `ctx["weights"]` automatically. + +```python +from typing import Any + +import torch + +from nvalchemi.training import BaseLossFunction, ReductionContext + + +class PerAtomEnergyMSELoss(BaseLossFunction): + """Energy MSE normalized by atom count, with atom-count-weighted reduction.""" + + target_key = "energy" + prediction_key = "predicted_energy" + + def normalize( + self, + pred: torch.Tensor, + target: torch.Tensor, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, ReductionContext]: + ctx = ReductionContext() + counts = kwargs.get("num_nodes_per_graph") + if counts is None: + raise ValueError( + "PerAtomEnergyMSELoss requires num_nodes_per_graph=... metadata." + ) + counts = counts.to(dtype=pred.dtype).unsqueeze(-1).clamp_min(1.0) + ctx["weights"] = counts + return pred / counts, target / counts, ctx + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return residual.pow(2) +``` + +`target_key` and `prediction_key` are resolved by composition via +`getattr`, so class-level defaults are enough when a loss has no other +constructor state. If you want callers to override routing keys or +configure additional fields, expose those via `__init__` (for example +`delta` on {py:class}`~nvalchemi.training.EnergyHuberLoss`). + +### Example 2: custom masking (mask override) + +Override `mask` when your loss needs validity logic beyond the base +default (all-True). The mask is a boolean tensor broadcast-compatible +with `pred`/`target`; entries where `mask` is `False` are zeroed in +`compute_residual` and excluded from the reduction denominator. + +A common pattern is excluding non-finite targets so that missing labels +contribute zero loss and zero gradient. The built-in +`EnergyMSELoss.mask` is a one-liner: + +```python +def mask( + self, + pred: torch.Tensor, + target: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, +) -> torch.Tensor: + if self.ignore_nonfinite: + return torch.isfinite(target) + return torch.ones_like(target, dtype=torch.bool) +``` + +For padded tensor layouts, the mask must also exclude padding rows. The +built-in force losses combine a node-validity mask (derived from +`num_nodes_per_graph`) with an optional `isfinite` check: + +```python +def mask(self, pred, target, ctx, **kwargs): + num_nodes_per_graph = kwargs.get("num_nodes_per_graph") + # Build a (B, V_max) node mask from counts, expand to (B, V_max, 3) + node_mask = _padded_node_mask(num_nodes_per_graph, pred, pred.shape[1]) + valid = node_mask.unsqueeze(-1).expand_as(pred) + if self.ignore_nonfinite: + valid = valid & torch.isfinite(target) + return valid +``` + +The key contract: `mask` returns a boolean tensor, and `compute_residual` +receives it as the `valid` argument. Your `compute_residual` should use +`torch.where(valid, ..., torch.zeros_like(...))` to zero invalid +entries, and the base `reduce` weights the denominator by +`valid.to(dtype=residual.dtype)`. + +### Example 3: custom reduction (reduce override) + +Override `reduce` when the base validity-weighted mean is not the +reduction you need — for example, a graph-balanced reduction that +computes a per-graph mean first, then averages over graphs: + +```python +import torch + +from nvalchemi.training import BaseLossFunction, ReductionContext +from nvalchemi.training.losses.reductions import per_graph_mean, per_graph_sum + + +class GraphBalancedForceMSE(BaseLossFunction): + """Force MSE with graph-balanced reduction for dense (V, 3) forces.""" + + target_key = "forces" + prediction_key = "predicted_forces" + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return residual.pow(2) + + def reduce( + self, + residual: torch.Tensor, + valid: torch.Tensor, + ctx: ReductionContext, + **kwargs, + ) -> torch.Tensor: + batch_idx = kwargs["batch_idx"] + num_graphs = kwargs["num_graphs"] + valid_f = valid.to(dtype=residual.dtype) + # Per-atom squared error summed over xyz, then per-graph mean + per_atom_se = residual.sum(dim=-1) + per_atom_valid = valid_f.sum(dim=-1) + per_graph_num = per_graph_sum(per_atom_se, batch_idx, num_graphs) + per_graph_den = per_graph_sum(per_atom_valid, batch_idx, num_graphs) + per_sample = per_graph_num / per_graph_den.clamp_min(1.0) + self.per_sample_loss = per_sample.detach() + return per_sample.mean() +``` + +When overriding `reduce`, populate `self.per_sample_loss` with a +detached `(B,)` tensor for diagnostics, or leave it `None` when a +per-graph decomposition is not meaningful. + +### Layout dispatch with plum (advanced) + +The built-in force losses (`ForceMSELoss`, `ForceHuberLoss`, `ForceL2NormLoss`) +accept both dense `(V, 3)` and padded `(B, V_max, 3)` inputs. Rather than +branching on `pred.ndim` inside each hook, they use +[plum-dispatch](https://github.com/beartype/plum) to route to +type-annotated overloads. For example, `ForceMSELoss._valid_force_components` +has two `@overload` implementations — one for `Forces` (dense, 2-D) and +one for `_PaddedForces` (padded, 3-D) — plus a `@dispatch` fallback: + +```python +from plum import dispatch, overload + +class ForceMSELoss(BaseLossFunction): + # ... + + @overload + def _valid_force_components(self, pred: Forces, target: Forces, ...): + """Dense (V, 3) path — no padding mask needed.""" + ... + + @overload + def _valid_force_components(self, pred: _PaddedForces, target: _PaddedForces, ...): + """Padded (B, V_max, 3) path — build node mask from counts.""" + ... + + @dispatch + def _valid_force_components(self, pred, target, num_nodes_per_graph): + pass # plum routes to the matching overload at runtime +``` + +The `mask` and `reduce` hooks delegate to these dispatched helpers, +keeping each layout's logic in a focused, testable overload. If you are +writing a loss that handles multiple tensor layouts, the `ForceMSELoss` +and `ForceL2NormLoss` implementations in +`nvalchemi/training/losses/terms.py` are the reference patterns to +follow. + +### Populating `per_sample_loss` (optional) + +The base `reduce` populates `self.per_sample_loss` automatically for +residuals with a recognizable `(B,)` or `(B, 1)` shape. For custom +`reduce` overrides, set `self.per_sample_loss` to a detached `(B,)` tensor +to expose per-graph diagnostics through +`ComposedLossOutput["per_component_sample"]`. See +[Per-sample loss diagnostics](#per-sample-loss-diagnostics) for the full +contract; leave it `None` when a per-graph decomposition is unavailable. + +### Testing a custom loss + +Two checks usually suffice: + +1. A direct call returns a scalar of the expected dtype and gradient + flows back to `pred`. +2. If `ignore_nonfinite` semantics matter for your loss, assert that a + `NaN`-filled target row contributes zero to `pred.grad`. + +```python +import torch + +from nvalchemi.training import EnergyMSELoss + +loss_fn = EnergyMSELoss() +pred = torch.randn(4, 1, requires_grad=True) +target = torch.randn(4, 1) + +value = loss_fn(pred, target) +assert value.ndim == 0 +value.backward() +assert pred.grad is not None +``` + +For composed losses, assert `total_loss` equals the expected weighted +sum of per-component values on a tiny batch — inspect +`out["per_component_unweighted"]` and `out["per_component_weight"]` to see +exactly what the composition applied. + +## See also + +- **API**: {ref}`losses-api` for the full class and schedule reference. +- **Reductions**: the `nvalchemi.training.losses.reductions` module for + scatter-based per-graph helpers usable in custom losses. +- **Models**: the [models guide](models) covers the model-side of the + contract (how `predictions` mappings are produced). +- **Hooks**: the [hooks guide](hooks_guide) covers the + {py:class}`~nvalchemi.hooks.HookContext` fields a training loop + makes available, including `ctx.loss`. diff --git a/docs/userguide/models.md b/docs/userguide/models.md index 775399df..7cb76483 100644 --- a/docs/userguide/models.md +++ b/docs/userguide/models.md @@ -41,6 +41,49 @@ potentials: are lazily imported --- they only load when accessed, so missing dependencies will not break other imports. +### MACE checkpoints in training + +When starting from an existing MACE checkpoint, prefer +{py:meth}`~nvalchemi.models.mace.MACEWrapper.from_checkpoint` over manually +loading the underlying MACE module and wrapping it. The wrapper records a +factory-based reconstruction spec that strategy checkpoints can use later. +This matters for optimized variants such as cuEquivariance, where the live +transformed module is not reliably reconstructible from its Python constructor. + +```python +import torch + +from nvalchemi.models.mace import MACEWrapper +from nvalchemi.training import EMAHook, TrainingStrategy + +model = MACEWrapper.from_checkpoint( + "small-0b", + device=torch.device("cuda"), + dtype=torch.float32, + enable_cueq=True, +) + +ema = EMAHook(model_key="main", decay=0.999) +strategy = TrainingStrategy( + models=model, + ..., + hooks=[ema], +) +strategy.save_checkpoint(checkpoint_dir) + +restored_ema = EMAHook(model_key="main", decay=0.999) +restored = TrainingStrategy.load_checkpoint( + checkpoint_dir, + map_location=torch.device("cuda"), + hooks=[restored_ema], + training_fn=training_fn, +) +``` + +Avoid saving only `ema.state_dict()` for MACE training restarts. Strategy +checkpoints preserve the model reconstruction recipe, model weights, optimizer +state, runtime counters, and checkpointable hook state together. + ## Architecture overview A wrapped model uses **multiple inheritance**: your existing {py:class}`~torch.nn.Module` diff --git a/docs/userguide/reporting.md b/docs/userguide/reporting.md new file mode 100644 index 00000000..dd153ca9 --- /dev/null +++ b/docs/userguide/reporting.md @@ -0,0 +1,364 @@ + + +(reporting_guide)= + +# Reporting + +Reporting is the higher-level observability layer for hook-enabled workflows. +It collects scalar summaries from hook contexts, tracks reporting metadata, +optionally reduces values across ranks, and sends the resulting snapshots to +reporting sinks such as TensorBoard or live Rich dashboards. + +## Reporting vs. logging + +Logging and reporting have different intent: + +| Use this | When you want | +|----------|----------------| +| Logging | Workflow event records: rows, files, or backend writes that preserve a direct stream of events. | +| Reporting | Curated workflow summaries: scalar snapshots, rank-safe reductions, previews, dashboards, and analysis-facing output. | + +Logging is not inherently dynamics-specific. A training workflow can also have a +logger when it needs a direct record of training events, optimizer steps, +gradient statistics, or validation passes. The current built-in +{py:class}`~nvalchemi.dynamics.hooks.LoggingHook` is dynamics-focused because it +computes per-graph observables such as energy, `fmax`, and temperature and writes +one row per system. A future training logger should be a separate +training-specific implementation rather than overloading the dynamics hook with a +different event model. + +Reporters sit one level up. They receive the current hook context and shared +reporting state, collect scalar metrics, decide whether to reduce across ranks, +and then render or serialize a summary. A reporter may intentionally discard +low-level detail if the output is meant to be a compact dashboard or analysis +record. + +Backends do not define the layer. CSV, TensorBoard, W&B, and MLflow can be +used for logging or reporting depending on what is being written. In this +package, {py:class}`~nvalchemi.hooks.TensorBoardReporter` is a reporter because +it writes {py:class}`~nvalchemi.hooks.ScalarSnapshot` payloads collected by +{py:class}`~nvalchemi.hooks.ReportingOrchestrator`. By contrast, the dynamics +`LoggingHook` TensorBoard backend is logging because it writes the hook's raw +per-graph dynamics rows directly. + +## Basic usage + +{py:class}`~nvalchemi.hooks.ReportingOrchestrator` is the hook that fans events +out to reporters: + +```python +from nvalchemi.hooks import ReportingOrchestrator, RichReporter, TensorBoardReporter + +reporting = ReportingOrchestrator( + [ + TensorBoardReporter("runs/example"), + RichReporter(), + ], + stages={"AFTER_OPTIMIZER_STEP"}, + frequency=10, +) +``` + +`RichReporter()` defaults to automatic layout selection. It chooses the first +built-in layout that matches the first reported context and keeps that choice +for the workflow run. Pin a layout when you want a specific dashboard surface: + +```python +from nvalchemi.hooks import ReportingOrchestrator, RichReporter + +reporting = ReportingOrchestrator( + [RichReporter(layout="dynamics", refresh_per_second=2.0)], + stages={"AFTER_STEP"}, +) +``` + +You can preview a Rich layout without running a workflow: + +```python +from nvalchemi.hooks import RichReporter + +RichReporter.preview(layout="dynamics", title="dynamics preview") +``` + +For a live training dashboard demo without real training logic, run the +synthetic example: + +```bash +uv run python examples/intermediate/07_rich_training_reporting.py --steps 80 --delay 0.05 +``` + +## What happens under the hood + +The reporting path has two boundaries: workflow engines emit hook contexts, and +reporters decide how to turn those contexts into an output artifact. + +```{graphviz} +digraph reporting_orchestrator { + graph [rankdir=LR, bgcolor="transparent"]; + node [ + shape=box, + style="rounded,filled", + fillcolor="#F8F9FA", + color="#5C677D", + fontname="Helvetica" + ]; + edge [color="#5C677D", fontname="Helvetica"]; + + workflow [label="Training, dynamics,\nor custom workflow"]; + context [label="HookContext\n+ stage enum"]; + orchestrator [label="ReportingOrchestrator"]; + state [label="ReportingState\n event metadata"]; + reporter [label="Reporter\n(TensorBoard, Rich, ...)"]; + output [label="Output\nfile, run log, dashboard"]; + + workflow -> context [label="engine hook call"]; + context -> orchestrator [label="stage and frequency match"]; + orchestrator -> state [label="mark_event"]; + orchestrator -> reporter [label="report(ctx, stage, state)"]; + reporter -> output [label="write or render"]; +} +``` + +At each matching hook event, `ReportingOrchestrator`: + +1. Updates a shared {py:class}`~nvalchemi.hooks.ReportingState`. +2. Skips rank-zero-only reporters on nonzero ranks. +3. Calls each reporter with `(ctx, stage, state)`. +4. Applies the configured error policy if a reporter raises. + +Scalar reporters then call {py:func}`~nvalchemi.hooks.collect_scalars`. The +collector builds a {py:class}`~nvalchemi.hooks.ScalarSnapshot` containing: + +- `stage`, timestamp, elapsed time, event count, step count, rank, optional + training metadata, and recent reporter messages. +- A flat dictionary of scalar values, using slash-separated keys such as + `loss/total`, `optimizer/lr`, `scheduler/lr`, `converged_fraction`, or + `dynamics/graduated_count`. + +Reporters can also request rank reductions. When enabled, every rank must call +the reporter with the same scalar keys, and only rank zero writes or renders the +reduced result. + +```{graphviz} +digraph reporting_reduction { + graph [rankdir=LR, bgcolor="transparent"]; + node [ + shape=box, + style="rounded,filled", + fillcolor="#F8F9FA", + color="#5C677D", + fontname="Helvetica" + ]; + edge [color="#5C677D", fontname="Helvetica"]; + + rank0 [label="rank 0\ncollect_scalars"]; + rank1 [label="rank 1\ncollect_scalars"]; + rankn [label="rank n\ncollect_scalars"]; + reduce [label="reduce_scalar_snapshot\nmean, sum, min, or max"]; + write [label="rank 0\nwrites or renders"]; + skip [label="nonzero ranks\nreturn after reduction"]; + + rank0 -> reduce; + rank1 -> reduce; + rankn -> reduce; + reduce -> write; + reduce -> skip; +} +``` + +## Rich dashboards + +{py:class}`~nvalchemi.hooks.RichReporter` owns the terminal dashboard mechanics: + +- scalar collection and optional rank reduction, +- retained per-metric history, +- Rich `Live` lifecycle, +- automatic layout selection, +- static preview seeding, +- rank-zero-only rendering. + +The selected layout owns the visual policy. Built-in layouts live under +`nvalchemi.hooks.reporting.layouts`: + +```python +from nvalchemi.hooks.reporting.layouts.train import TrainingRichLayout +from nvalchemi.hooks.reporting.layouts.dynamics import DynamicsRichLayout +``` + +`layout="auto"` and `layout=None` defer layout selection until the first report. +`layout="training"` prioritizes loss curves, optimizer and scheduler learning +rates, step progress, throughput, ETA, and recent reporter messages. +`layout="dynamics"` prioritizes energy, `fmax`, temperature, convergence, +active/graduated counts, status counts, dynamics progress, throughput, ETA, and +recent reporter messages. The dynamics layout also requests default dynamics +scalar collection when it is selected. + +Progress and ETA scalars are collected for Rich dashboards only. Durable +reporters keep their scalar snapshots stable unless you add the same values with +custom scalar callbacks. + +## Custom Rich layouts + +Rich layouts are plain Python objects. `RichReporter` passes the layout: + +- the latest `ScalarSnapshot`, or `None` before the first report, +- retained scalar history as `dict[str, Sequence[tuple[int, float]]]`, +- display options such as title, precision, max rows, plot keys, and plot size. + +The layout returns a Rich renderable, usually a {py:class}`rich.layout.Layout`. +It does not collect scalars, perform rank reduction, or manage `Live`. +Use `snapshot.scalars` for current values, `history` for curves, and +`snapshot.messages` for recent reporter messages or warnings. RichReporter also +adds workflow progress scalars when the context exposes enough metadata, such as +`training/progress_fraction`, `training/eta_s`, `dynamics/progress_fraction`, +and `dynamics/eta_s`. + +### Subclass BaseRichLayout + +For most dashboards, subclass {py:class}`~nvalchemi.hooks.BaseRichLayout`. This +keeps the standard header, latest-metric table, and plot panel. You only choose +metric priority, panel titles, and preview curves: + +```python +from collections.abc import Mapping, Sequence + +from nvalchemi.hooks import BaseRichLayout, RichReporter + + +class ValidationRichLayout(BaseRichLayout): + def __init__(self) -> None: + super().__init__( + name="validation", + preferred_plot_keys=("validation/loss", "validation/mae"), + latest_title="Validation", + history_title="Curves", + ) + + def default_preview_history(self) -> Mapping[str, Sequence[float]]: + return { + "validation/loss": (0.8, 0.62, 0.51, 0.44), + "validation/mae": (0.31, 0.24, 0.19, 0.16), + } + + +reporter = RichReporter(layout=ValidationRichLayout()) +``` + +`BaseRichLayout` also provides preview metadata hooks. Override them when the +default training metadata is wrong for your workflow: + +```python +class ValidationRichLayout(BaseRichLayout): + ... + + def default_preview_stage(self) -> str: + return "AFTER_VALIDATION" + + def default_preview_epoch(self) -> int | None: + return None + + def default_preview_batch_count(self) -> int | None: + return None +``` + +### Implement render directly + +For a fully custom surface, implement {py:class}`~nvalchemi.hooks.RichLayout` +directly. This is useful when the dashboard is not a table plus plots. +Custom layouts compose normal Rich renderables, but they do so inside the +`RichReporter` lifecycle: the reporter owns the console, `Live`, rank filtering, +scalar collection, history retention, and refresh cadence. The layout should +remain a pure rendering policy that turns `snapshot`, `history`, and display +options into a renderable. + +Useful Rich components inside `render(...)` include: + +| Component | Use in a `RichReporter` layout | API | +|-----------|--------------------------------|-----| +| `Layout` | Split the terminal into named regions that can hold independent panels. | [Layout](https://rich.readthedocs.io/en/stable/layout.html) | +| `Panel` | Frame one region, table, plot, or status summary with a title. | [Panel](https://rich.readthedocs.io/en/stable/panel.html) | +| `Table` | Show latest scalar values, rank summaries, or status counts. | [Table](https://rich.readthedocs.io/en/stable/tables.html) | +| `Text` | Build styled labels, headers, and compact status lines. | [Text](https://rich.readthedocs.io/en/stable/text.html) | +| `Group` | Stack several renderables inside one layout region. | [Renderables](https://rich.readthedocs.io/en/stable/group.html) | +| `Columns` | Arrange small repeated panels, such as per-rank or per-status summaries. | [Columns](https://rich.readthedocs.io/en/stable/columns.html) | +| `Align` and `Padding` | Position or pad a renderable without creating another `Layout` region. | [Padding](https://rich.readthedocs.io/en/stable/padding.html) | + +`Live` is intentionally absent from this list because `RichReporter` manages it. +Do not create or enter a nested `Live` display inside `render(...)`. If you want +standard line plots from retained metric history, subclass `BaseRichLayout`; it +already converts `history` into plotext-backed Rich renderables. + +```python +from collections.abc import Mapping, Sequence + +from rich import box +from rich.console import Group +from rich.layout import Layout +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from nvalchemi.hooks import RichLayout, RichReporter, ScalarSnapshot +from nvalchemi.hooks.reporting.layouts import RichMetricHistory, RichPreviewHistory + + +class CompactRichLayout: + include_dynamics_scalars = False + + def default_preview_history(self) -> RichPreviewHistory: + return {"metric": (1.0, 0.8, 0.6)} + + def default_preview_stage(self) -> str: + return "AFTER_STEP" + + def default_preview_epoch(self) -> None: + return None + + def default_preview_batch_count(self) -> None: + return None + + def render( + self, + snapshot: ScalarSnapshot | None, + history: RichMetricHistory, + *, + title: str, + precision: int, + max_scalars: int | None, + plot_keys: Sequence[str] | None, + max_plots: int, + plot_height: int, + ) -> Layout: + layout = Layout(name="root") + layout.split_column(Layout(name="header", size=3), Layout(name="body")) + subtitle = Text("waiting for metrics" if snapshot is None else snapshot.stage) + layout["header"].update(Panel(Group(Text(title), subtitle), box=box.SIMPLE)) + + table = Table(box=box.SIMPLE_HEAD, expand=True) + table.add_column("Metric") + table.add_column("Latest", justify="right") + if snapshot is None: + table.add_row("(waiting)", "") + else: + for key, value in sorted(snapshot.scalars.items()): + table.add_row(key, f"{value:.{precision}g}") + + layout["body"].update(Panel(table, title="Summary")) + return layout + +layout: RichLayout = CompactRichLayout() +reporter = RichReporter(layout=layout) +``` + +The `render(...)` parameters are intentionally the same values that +`RichReporter` already manages: + +- `snapshot` is the latest scalar payload. +- `history` contains retained `(step, value)` points for each metric. +- `plot_keys`, `max_plots`, and `plot_height` are user display preferences. +- `max_scalars` is the row limit for latest-value tables. + +If your layout wants default dynamics observables, set +`include_dynamics_scalars = True`. The reporter will then include available +dynamics metrics such as energy, `fmax`, temperature, convergence fraction, and +active fraction before calling `render(...)`. diff --git a/docs/userguide/zarr_compression.md b/docs/userguide/zarr_compression.md index 1cb6b047..5d2a4dbe 100644 --- a/docs/userguide/zarr_compression.md +++ b/docs/userguide/zarr_compression.md @@ -5,10 +5,10 @@ # Zarr Compression Tuning Zarr stores are the primary persistence format for atomic simulation data in the -toolkit. Configuring compression and chunking correctly can reduce disk usage by -2–4× and significantly improve I/O throughput for data pipelines. This -guide covers the configuration options, codec trade-offs, and practical recipes -for common workloads. +toolkit. Configuring compression, chunking, sharding, and read windows correctly +can reduce disk usage and improve training-time I/O throughput. This guide covers +the configuration options, codec trade-offs, and practical recipes for common +workloads. ## Quick start @@ -88,14 +88,14 @@ arrays are read sequentially. ## Codec comparison Zarr v3 supports pluggable codecs via the `zarr.abc.codec.Codec` interface. The -toolkit has been tested with the following: +toolkit writer accepts any codec supported by Zarr; the benchmark CLI exposes the +common choices `zstd`, `lz4`, and `blosc-zstd`. | Codec | Class | Strengths | Weaknesses | Typical use | |-------|-------|-----------|------------|-------------| | Zstd | `zarr.codecs.ZstdCodec` | Good ratio, fast decompress | Moderate compress speed | General purpose, sequential data | | Blosc/LZ4 | `zarr.codecs.BloscCodec(cname="lz4")` | Very fast compress+decompress | Lower ratio | Trajectories, real-time I/O | -| Blosc/Zstd | `zarr.codecs.BloscCodec(cname="zstd")` | Blosc multithreading + Zstd ratio | Slightly more complex | Large arrays, parallel writes | -| Gzip | `zarr.codecs.GzipCodec` | Universal compatibility | Slow | Archival, interop | +| Blosc/Zstd | `zarr.codecs.BloscCodec(cname="zstd")` | Blosc blocking + Zstd ratio | Slightly more complex | Large arrays, balanced ratio/speed | ```{note} Compression level controls the ratio/speed trade-off. Higher levels yield better @@ -104,16 +104,16 @@ improves ratio modestly at the cost of write throughput. For LZ4, the level parameter has minimal effect---speed is consistently high. ``` -### Blosc multithreading +### Blosc options -`BloscCodec` can use multiple threads internally, which helps when compressing -large chunks. By default it uses a single thread; pass `nthreads=4` (or similar) -if your workload benefits from parallel compression: +`BloscCodec` exposes codec name, compression level, shuffle, and blocksize through +its constructor. Keep these settings explicit in `ZarrWriteConfig` when you want +reproducible stores: ```python from zarr.codecs import BloscCodec -compressor = BloscCodec(cname="zstd", clevel=5, nthreads=4) +compressor = BloscCodec(cname="zstd", clevel=5) ``` ## Chunk size tuning @@ -172,7 +172,7 @@ The following table gives concrete values for common arrays: | neighbor_list `[E, 2]` | 2 | int64 | 16 | 62,500 | 250,000 | | shifts `[E, 3]` | 3 | float32 | 12 | 83,333 | 333,333 | -**Example: positions (float32, shape [V, 3]), 1 MB target** +#### Positions Example $$ \text{bytes\_per\_row} = 3 \times 4 = 12 \text{ bytes} @@ -181,7 +181,7 @@ $$ \text{chunk\_size} = \left\lfloor \frac{1{,}000{,}000}{12} \right\rfloor = 83{,}333 $$ -**Example: energy (float64, shape [B]), 1 MB target** +#### Energy Example $$ \text{bytes\_per\_row} = 1 \times 8 = 8 \text{ bytes} @@ -193,8 +193,9 @@ $$ ### Read amplification When reading a single structure by index, the reader fetches the slice -`positions[atoms_ptr[i]:atoms_ptr[i+1], :]` — typically ~50 rows (600 bytes). -With large chunks, most of the decompressed data is discarded: +`positions[atoms_ptr[i]:atoms_ptr[i+1], :]` --- typically about 50 rows +(600 bytes for float32 positions). With large chunks, most of the decompressed +data is discarded: | chunk_size | Chunk bytes (positions) | Amplification (50-atom read) | |------------|------------------------|------------------------------| @@ -202,9 +203,12 @@ With large chunks, most of the decompressed data is discarded: | 83,333 | 1 MB | 1,667× | | 10,000 | 120 KB | 200× | -For purely sequential workloads (sequential DataLoader) amplification does not -matter — every row is consumed. For random-access workloads, prefer smaller -chunks or consider field overrides for frequently accessed arrays. +For purely sequential workloads, amplification does not matter because every row +is consumed. For shuffled training, amplification depends on the effective read +window: the DataLoader fuses `prefetch_factor` batches into one `read_many` call, +and the Zarr reader can group indices that share chunk locality. Larger chunks can +still hurt fully random single-sample access, so prefer smaller chunks or field +overrides when interactive lookup or visualization is a primary workload. ```{warning} Atom-level fields (positions, forces, atomic_numbers) are stored as @@ -423,32 +427,65 @@ versus only 1,000 shard files with ``shard_size=500,000``. ## I/O benchmark tool -The toolkit ships a command-line benchmark for measuring Zarr write throughput -and compression ratios on synthetic data. Use it to validate configuration -choices before committing to a production workflow. +The toolkit ships a command-line benchmark for measuring Zarr write throughput, +readback throughput, and compression ratios on synthetic data. Use it to +validate storage configuration and readback strategy before committing to a +production workflow. -### Running the benchmark +The CLI has two subcommands: + +- **`roundtrip`** — generate synthetic data, write it to a temporary Zarr + store, then read it back and report timing. +- **`read`** — benchmark read throughput against a pre-existing Zarr store, + without writing anything. + +Run `nvalchemi-io-test --help` to see the available subcommands. Use +`roundtrip` when you want the benchmark to create a temporary store, and use +`read` when you already have a representative store on the target filesystem. + +### Running the roundtrip benchmark ```bash # Install (if not already) $ uv sync # Basic: compare codec overhead across dataset sizes -$ nvalchemi-io-test -n 1000 -n 10000 --codec zstd --level 3 --chunk-size 83333 +$ nvalchemi-io-test roundtrip -n 1000 -n 10000 \ + --codec zstd --level 3 \ + --chunk-size 83333 --edge-chunk-size 62500 + +# Compare fast batch readback against one-sample-at-a-time +$ nvalchemi-io-test roundtrip -n 1000 -n 10000 \ + --read-mode both --batch-size 64 --prefetch-factor 8 + +# Model shuffled training reads against compressed stores +$ nvalchemi-io-test roundtrip -n 1000 -n 10000 \ + --read-order shuffle --batch-size 64 --prefetch-factor 16 +$ nvalchemi-io-test roundtrip -n 1000 -n 10000 \ + --read-order block-shuffle --read-order-block-size 8192 \ + --batch-size 64 --prefetch-factor 16 # Fast codec with smaller chunks for trajectory-style workloads -$ nvalchemi-io-test -n 1000 -n 10000 --codec lz4 --chunk-size 10000 +$ nvalchemi-io-test roundtrip -n 1000 -n 10000 --codec lz4 \ + --chunk-size 10000 --edge-chunk-size 10000 # Larger molecules with edge-specific chunking -$ nvalchemi-io-test -n 1000 -n 10000 --min-atoms 100 --max-atoms 500 \ +$ nvalchemi-io-test roundtrip -n 1000 -n 10000 \ + --min-atoms 100 --max-atoms 500 \ --codec zstd --chunk-size 83333 --edge-chunk-size 62500 # With sharding enabled -$ nvalchemi-io-test -n 1000 -n 10000 --codec zstd \ - --chunk-size 1000 --shard-size 10000 +$ nvalchemi-io-test roundtrip -n 1000 -n 10000 \ + --chunk-size 10000 --shard-size 500000 \ + --edge-chunk-size 10000 --edge-shard-size 500000 + +# Write a store to a specific directory for later read benchmarking +$ nvalchemi-io-test roundtrip -n 10000 --codec zstd \ + --chunk-size 1024 --shard-size 4096 \ + --output-dir /scratch/benchmark_stores/ ``` -Key options: +Roundtrip options: | Option | Default | Description | |--------|---------|-------------| @@ -461,94 +498,315 @@ Key options: | `--shard-size` | — | Shard size for node/system arrays | | `--edge-chunk-size` | — | Chunk size for edge arrays (neighbor_list, shifts) | | `--edge-shard-size` | — | Shard size for edge arrays | +| `--read-mode` | `batch` | Readback path to time: `batch`, `single`, or `both` | +| `--batch-size` | 64 | Number of samples per emitted DataLoader batch in `batch` mode | +| `--prefetch-factor` | 16 | Number of emitted batches to fuse into each backend read in `batch` mode | +| `--read-order` | `sequential` | Logical read order: `sequential`, `shuffle`, or `block-shuffle` | +| `--read-seed` | 0 | Random seed for shuffled read orders | +| `--read-order-block-size` | 8192 | Contiguous block size for `block-shuffle` read order | +| `--pin-memory` | `False` | Request pinned CPU tensors in batch read mode | +| `--output-dir` | — | Persist the written store(s) here instead of a temp directory | + +### Read-only benchmark + +Use `nvalchemi-io-test read` to benchmark against an existing Zarr store. +This isolates read performance from generation and write overhead, and lets +you test multiple read configurations against the same store without +rewriting it each time. -### Example output +```bash +# Sequential read baseline +$ nvalchemi-io-test read /path/to/store.zarr + +# Shuffled access at different read windows +$ nvalchemi-io-test read /path/to/store.zarr \ + --read-order shuffle --batch-size 64 --prefetch-factor 8 +$ nvalchemi-io-test read /path/to/store.zarr \ + --read-order shuffle --batch-size 64 --prefetch-factor 64 + +# Compare batch vs. single-sample under shuffle +$ nvalchemi-io-test read /path/to/store.zarr \ + --read-mode both --read-order shuffle +``` -**Small molecules (10–100 atoms), Zstd level 3, 1 MB chunks:** +Read options: -```text -nvalchemi Zarr I/O benchmark atoms=10-100 config=zstd L3, chunk=83,333, - edge_chunk=62,500 -Pre-computed: 100,000 systems, 5,504,449 total atoms (avg 55.0), - 11,062,584 total edges (avg 110.6) -Estimated uncompressed: 484.9 MB +| Option | Default | Description | +|--------|---------|-------------| +| `PATH` | — | Path to an existing Zarr store (directory) | +| `--read-mode` | `batch` | `batch`, `single`, or `both` | +| `--batch-size` | 64 | Number of samples per emitted DataLoader batch | +| `--prefetch-factor` | 16 | Number of emitted batches to fuse into each backend read | +| `--read-order` | `sequential` | `sequential`, `shuffle`, or `block-shuffle` | +| `--read-seed` | 0 | Random seed for shuffled orders | +| `--read-order-block-size` | 8192 | Block size for `block-shuffle` | +| `--pin-memory` | `False` | Request pinned CPU tensors in batch read mode | - Zarr I/O Benchmark — zstd L3, chunk=83,333, edge_chunk=62,500 +```{tip} +The ``read`` subcommand measures the public DataLoader read path by default: +``batch_size`` controls emitted batches, and ``prefetch_factor`` controls how +many emitted batches are fused into one backend read. Use ``single`` mode only +as a one-sample-at-a-time baseline. +``` - Avg Avg Raw Disk Write - Systems atoms edges size size Ratio Files time Systems/s - ───────────────────────────────────────────────────────────────────────────── - 1,000 56 115 4.8 MB 2.8 MB 1.74x 36 0.14s 7,282 - 10,000 55 112 47.1 MB 27.0 MB 1.75x 96 0.48s 20,736 - 100,000 55 111 467.5 MB 267.7 MB 1.75x 691 4.66s 21,471 +```{note} +Benchmark `batch` mode uses `Dataset(skip_validation=True)` to focus on storage +and batching throughput for stores that are already trusted. If your training +pipeline keeps validation enabled, expect lower end-to-end throughput. ``` -**Small molecules, LZ4, 120 KB chunks (trajectory-optimised):** +### Readback mode: batch vs. single sample -```text -nvalchemi Zarr I/O benchmark atoms=10-100 config=lz4 L3, chunk=10,000, - edge_chunk=10,000 +The benchmark reports write time plus a full-store readback. Readback uses the +batch path by default: - Zarr I/O Benchmark — lz4 L3, chunk=10,000, edge_chunk=10,000 - - Avg Avg Raw Disk Write - Systems atoms edges size size Ratio Files time Systems/s - ───────────────────────────────────────────────────────────────────────────── - 1,000 56 115 4.8 MB 3.0 MB 1.61x 76 0.12s 8,207 - 10,000 55 112 47.1 MB 28.9 MB 1.63x 480 0.80s 12,446 - 100,000 55 111 467.5 MB 287.5 MB 1.63x 4,509 8.10s 12,341 +```bash +$ nvalchemi-io-test roundtrip -n 10000 --codec zstd \ + --chunk-size 83333 ``` -**Small molecules, sharded (chunk=10,000 inside shard=500,000):** +In `batch` mode the benchmark uses the toolkit +{py:class}`~nvalchemi.data.datapipes.dataloader.DataLoader` with fused prefetch. +The emitted batch size is controlled by `--batch-size`; the backend read window is +controlled by `--batch-size * --prefetch-factor`. The Zarr reader then receives +large `read_many(...)` requests and can coalesce physical I/O across the requested +indices. -```text -nvalchemi Zarr I/O benchmark atoms=10-100 config=chunk=10,000, - shard=500,000, edge_chunk=10,000, edge_shard=500,000 +Use `single` mode to time a one-sample-at-a-time access pattern: - Zarr I/O Benchmark — chunk=10,000, shard=500,000, - edge_chunk=10,000, edge_shard=500,000 +```bash +$ nvalchemi-io-test roundtrip -n 10000 --read-mode single +``` + +Use `both` to emit one row per read path from the same written store: - Avg Avg Raw Disk Write - Systems atoms edges size size Ratio Files time Systems/s - ───────────────────────────────────────────────────────────────────────────── - 1,000 56 115 4.8 MB 2.8 MB 1.73x 34 0.14s 6,998 - 10,000 55 112 47.1 MB 27.0 MB 1.74x 46 0.63s 15,930 - 100,000 55 111 467.5 MB 268.2 MB 1.74x 158 6.46s 15,471 +```bash +$ nvalchemi-io-test roundtrip -n 10000 \ + --read-mode both --batch-size 64 --prefetch-factor 8 ``` -Note the dramatic file count reduction with sharding: **4,509 → 158** at 100k -systems with the same chunk size, while compression ratio and disk size remain -essentially unchanged. +`batch` mode should be faster for DataLoader-style workloads because it amortises +Python dispatch, Zarr array indexing, chunk lookup, decompression setup, and +filesystem metadata access over many samples. `single` mode remains useful as a +baseline for debugging and for estimating the penalty paid by code that reads one +structure at a time. -**Larger molecules (100–500 atoms), Zstd with edge-specific chunks:** +### Read order: sequential vs. shuffled training access -```text -nvalchemi Zarr I/O benchmark atoms=100-500 config=zstd L3, chunk=83,333, - edge_chunk=62,500 -Pre-computed: 10,000 systems, 3,016,657 total atoms (avg 301.7), - 6,073,861 total edges (avg 607.4) -Estimated uncompressed: 263.5 MB +For compressed Zarr stores, the logical index order can dominate throughput. +Sequential readback gives the Zarr reader mostly contiguous physical positions. +Fully shuffled readback models `DataLoader(shuffle=True)`: each emitted batch can +contain unrelated samples, but fused prefetch still gives the reader a larger +window of indices to sort and group by chunk locality. + +Use `--read-order shuffle` to benchmark that worst-case training pattern: + +```bash +$ nvalchemi-io-test roundtrip -n 10000 --codec zstd \ + --chunk-size 83333 --edge-chunk-size 62500 \ + --read-order shuffle +``` - Zarr I/O Benchmark — zstd L3, chunk=83,333, edge_chunk=62,500 +Use `--read-order block-shuffle` to model one locality-preserving training +order: - Avg Avg Raw Disk Write - Systems atoms edges size size Ratio Files time Systems/s - ───────────────────────────────────────────────────────────────────────────── - 1,000 303 615 25.7 MB 15.4 MB 1.67x 66 0.21s 4,737 - 10,000 302 607 254.7 MB 152.9 MB 1.67x 394 1.23s 8,138 +```bash +$ nvalchemi-io-test roundtrip -n 10000 --codec zstd \ + --chunk-size 83333 --edge-chunk-size 62500 \ + --read-order block-shuffle --read-order-block-size 8192 ``` +`block-shuffle` splits the index range into contiguous blocks of +`--read-order-block-size` samples, shuffles the *blocks*, and leaves the +indices inside each block in sequential order. For example, with 10,000 +samples and a block size of 2,000 the reader sees five blocks in random +order, but within each block it reads indices 0–1,999, 2,000–3,999, etc. +sequentially. + +This benchmark mode does **not** correspond to a specific DataLoader API; +it is a synthetic access pattern that helps you measure how much throughput +you recover when read locality is partially preserved. Compare +`block-shuffle` against `shuffle` to quantify the cost of fully random +access. In practice, a +{py:class}`~nvalchemi.dynamics.sampler.SizeAwareSampler` with bin-packing +can produce similar locality as a side-effect of grouping similarly-sized +systems. + ```{note} -Zarr v3 defaults to ``ZstdCodec(level=0)`` when no compressor is specified. -The "Raw size" column reflects the data as written by the toolkit (including -Zarr metadata overhead), so even runs without an explicit ``--codec`` flag -will show some compression. +When `--read-mode both` is used, the two read paths run back-to-back against the +same freshly written store. This is useful for relative comparisons, but the +second mode may benefit from filesystem cache. For strict cold-cache numbers, +run `batch` and `single` in separate invocations with the same benchmark +configuration. +``` + +The following output illustrates the expected shape of the result table. Treat +numbers as machine- and store-specific; use the CLI on the target filesystem for +decisions. + +```text +Zarr I/O Roundtrip Benchmark — no compression + + Systems Read path Read order Batch Prefetch Read window Write Read I/O/s + ────────────────────────────────────────────────────────────────────────────────────────── + 10,000 batch shuffle 64 32 2,048 0.54s 3.17s 2,695 +``` + +(read_performance_tuning)= + +## Read performance tuning + +The benchmark commands above measure the public read paths: `batch` mode uses +the toolkit DataLoader with fused prefetch and `single` mode calls +`reader.read(...)` once per sample. In production, validation, batching, and +device-transfer overhead can dominate the end-to-end pipeline. This section +covers the knobs that matter most for read throughput, especially under shuffled +access patterns. + +```{graphviz} +:caption: End-to-end read pipeline. + +digraph read_pipeline { + rankdir=LR + compound=true + fontname="Helvetica" + node [fontname="Helvetica" fontsize=11 shape=box style="filled,rounded"] + edge [fontname="Helvetica" fontsize=10] + + subgraph cluster_dataloader { + label="DataLoader" + style=rounded + color="#4a90d9" + fontcolor="#4a90d9" + + sampler [label="Sampler\n(indices)" fillcolor="#dce6f1"] + fuse [label="Fuse\nprefetch_factor\nbatches" fillcolor="#f9e2ae"] + sampler -> fuse [label="batch of\nindices"] + } + + subgraph cluster_dataset { + label="Dataset (background thread)" + style=rounded + color="#5bb35b" + fontcolor="#5bb35b" + + read_many [label="reader.read_many()\ncoalesced backend read" fillcolor="#dce6f1"] + validate [label="AtomicData\nvalidation\n(Pydantic)" fillcolor="#fddede"] + raw [label="raw tensor\ndicts" fillcolor="#d5f5d5"] + batch_val [label="Batch.from_data_list()" fillcolor="#e8daef"] + batch_raw [label="Batch.from_raw_dicts()" fillcolor="#e8daef"] + + read_many -> validate [label="skip_validation\n= False"] + read_many -> raw [label="skip_validation\n= True"] + validate -> batch_val + raw -> batch_raw + } + + subgraph cluster_consumer { + label="Consumer" + style=rounded + color="#c0392b" + fontcolor="#c0392b" + + device [label=".to(device)" fillcolor="#f9e2ae"] + model [label="Model" fillcolor="#dce6f1"] + device -> model + } + + fuse -> read_many [label="N indices\n(N = pf \u00d7 bs)" lhead=cluster_dataset style=bold] + batch_val -> device [ltail=cluster_dataset lhead=cluster_consumer style=bold] + batch_raw -> device [ltail=cluster_dataset lhead=cluster_consumer style=bold] +} ``` +### The read window: `prefetch_factor` + +{py:class}`~nvalchemi.data.datapipes.dataloader.DataLoader` groups +`prefetch_factor` consecutive batches into a single +{py:meth}`~nvalchemi.data.datapipes.dataset.Dataset.prefetch_fused_batches` call. +The reader sees one large `read_many(...)` request containing up to +`prefetch_factor * batch_size` indices instead of many small calls, which lets +the Zarr backend coalesce random indices into larger physical reads. + +The synchronous counterpart is +{py:meth}`~nvalchemi.data.datapipes.dataset.Dataset.load_batches`, which accepts +one or more batch-index lists and returns one +{py:class}`~nvalchemi.data.Batch` per list. `DataLoader` uses this same +batch-construction path when `prefetch_factor=0`; only the async double-buffered +prefetch is disabled. New code should prefer `load_batches(...)` for explicit +batch reads rather than calling older one-batch helpers directly. + +Larger windows amortise per-call Zarr overhead across more samples. For +shuffled training, a `prefetch_factor` of 16–32 is a good starting point, but +the best value depends on store size, chunking, compression, filesystem, and +whether pinned memory is enabled. Use the benchmark tool below on a +representative store before treating any value as a default for production. + ```{tip} -Run with ``--min-atoms`` and ``--max-atoms`` matching your actual dataset to get -realistic estimates. The benchmark uses uniform random atom counts; real-world -distributions may be skewed toward smaller or larger structures. +For sequential access the reader already detects contiguous runs, so +``prefetch_factor=2`` is enough. Increase it primarily when +``read_order=shuffle`` or ``read_order=block-shuffle``. +``` + +### Skipping validation: `skip_validation` + +By default the {py:class}`~nvalchemi.data.datapipes.dataset.Dataset` +validates every loaded sample through +{py:class}`~nvalchemi.data.AtomicData` (Pydantic), which adds CPU overhead. +When the backing store is known to +contain well-formed data --- for example, stores written by the toolkit's +own writer --- you can bypass this: + +```python +dataset = Dataset(reader=reader, device="cuda:0", skip_validation=True) +``` + +With `skip_validation=True` the Dataset constructs +{py:class}`~nvalchemi.data.Batch` objects directly from raw tensor +dictionaries via +{py:meth}`~nvalchemi.data.Batch.from_raw_dicts`, avoiding per-sample +Pydantic overhead entirely. + +```{warning} +``skip_validation`` trusts the store contents. Use it only with stores +produced by +{py:class}`~nvalchemi.data.datapipes.backends.zarr.AtomicDataZarrWriter` +or stores whose schema you have already validated independently. +``` + +### How the Zarr reader coalesces random indices + +The public `read_many` method delegates raw loading to the Zarr reader's +batch-oriented `_load_many_samples` hook. That hook applies several +backend-specific optimisations automatically: + +1. **Resolve logical indices**: requested logical indices are mapped through the + active-sample mask, so soft-deleted samples are skipped consistently. +2. **Sort by physical position**: requests are ordered by physical sample index + so the underlying storage sees monotonic offsets where possible. +3. **Group by chunk locality**: samples that share Zarr chunks are grouped into + range reads, with an amplification cap to avoid pathological over-reads when + indices are very sparse. +4. **Fallback for fragmentation**: highly fragmented requests use orthogonal + selections instead of many tiny range reads. + +These optimisations are transparent: `read_many` still returns results in the +caller's original request order. + +### Starting configurations + +| Access pattern | `prefetch_factor` | `skip_validation` | Notes | +|----------------|------------------:|:-----------------:|-------| +| Sequential training | 2–4 | `False` or `True` | Small windows are usually enough because samples are already contiguous. | +| Shuffled training (trusted store) | 16–64 | `True` | Larger windows give the Zarr reader more indices to coalesce. | +| Shuffled training (untrusted store) | 16–64 | `False` | Keeps validation enabled, but validation can dominate end-to-end time. | +| Block-shuffle (block ≥ chunk) | 2–8 | `True` | Preserves some locality while still mixing batches. | + +```{note} +Treat these as starting points, not throughput guarantees. Benchmark with +``nvalchemi-io-test read`` or ``nvalchemi-io-test roundtrip`` using the same +read order, batch size, prefetch factor, compression, and storage backend you +expect in training. ``` ## See also diff --git a/examples/intermediate/06_ddp_mlp_training.py b/examples/intermediate/06_ddp_mlp_training.py new file mode 100644 index 00000000..1f83c156 --- /dev/null +++ b/examples/intermediate/06_ddp_mlp_training.py @@ -0,0 +1,513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Distributed Training: DDPHook with a Dummy MLP +============================================== + +This example trains a small MLP on synthetic per-system energy labels and uses +:class:`~nvalchemi.training.hooks.DDPHook` to configure +``torch.nn.parallel.DistributedDataParallel``. The dataset is intentionally +small and generated on the fly so the example focuses on the distributed +training wiring rather than model quality. + +Run on a single node with ``torchrun`` through ``uv``: + +.. code-block:: bash + + uv run --extra cu12 torchrun --standalone --nproc_per_node=2 \ + examples/intermediate/06_ddp_mlp_training.py --backend auto + +The ``--backend`` option accepts: + +* ``auto``: choose ``nccl`` when the requested local ranks fit on visible GPUs, + otherwise choose ``gloo``. +* ``gloo``: run on CPU with the Gloo process group. +* ``nccl``: require one visible CUDA device per requested local rank. + +The backend selection is intentionally single-node oriented: ``auto`` treats the +torchrun world size as the requested local rank count. +""" + +from __future__ import annotations + +import argparse +from collections.abc import Sequence +from typing import Any + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, Dataset, DistributedSampler + +from nvalchemi.data import AtomicData, Batch +from nvalchemi.distributed import DistributedManager +from nvalchemi.models.base import BaseModelMixin, ModelConfig +from nvalchemi.training import ( + DDPHook, + EnergyMSELoss, + OptimizerConfig, + TrainingStage, + TrainingStrategy, + default_training_fn, +) + +# --------------------------------------------------------------------------- +# Launch and backend setup +# --------------------------------------------------------------------------- +# This block is the only part of the example that deals with process launch +# mechanics. The training code below only consumes a DistributedManager and a +# resolved backend string. + + +def _is_torchrun() -> bool: + """Return whether this process appears to be launched by torchrun.""" + return dist.is_torchelastic_launched() + + +def resolve_backend(requested: str, *, requested_ranks: int) -> str: + """Resolve ``auto``/``gloo``/``nccl`` into a concrete process backend.""" + if requested == "gloo": + return "gloo" + + cuda_count = torch.cuda.device_count() + nccl_ready = torch.cuda.is_available() and dist.is_nccl_available() + ranks_fit_on_gpus = cuda_count >= requested_ranks + + if requested == "auto": + # Keep the example single-node and self-contained: prefer NCCL only + # when each requested local rank can own a visible CUDA device. + return "nccl" if nccl_ready and ranks_fit_on_gpus else "gloo" + + if not nccl_ready: + raise RuntimeError( + "--backend nccl requested, but CUDA or the NCCL process group is " + "not available in this environment." + ) + if not ranks_fit_on_gpus: + raise RuntimeError( + "--backend nccl requested, but visible CUDA devices " + f"({cuda_count}) are fewer than requested local ranks " + f"({requested_ranks})." + ) + return "nccl" + + +def setup_distributed_runtime(requested_backend: str) -> tuple[DistributedManager, str]: + """Initialize process communication from torchrun and return the manager.""" + # DistributedManager.initialize_env() is the public entry point we want users + # to see. It discovers rank metadata first, then calls setup(), so this + # example briefly wraps setup() to inject the backend chosen by the CLI. + original_setup = DistributedManager.setup + original_cuda_is_available = torch.cuda.is_available + original_init_process_group = dist.init_process_group + resolved_backend: dict[str, str] = {} + + def init_process_group_without_cpu_device_id(*args: Any, **kwargs: Any) -> Any: + device_id = kwargs.get("device_id") + if device_id is not None and torch.device(device_id).type == "cpu": + kwargs = dict(kwargs) + kwargs.pop("device_id") + return original_init_process_group(*args, **kwargs) + + def setup( + *, + rank: int = 0, + world_size: int = 1, + local_rank: int | None = None, + addr: str = "localhost", + port: str = "12355", + backend: str = "nccl", + method: str = "env", + ) -> None: + # initialize_env() already read rank/world-size/local-rank from the + # torchrun environment. We only decide which backend should be handed + # to PhysicsNeMo's setup() call. + selected = resolve_backend(requested_backend, requested_ranks=world_size) + resolved_backend["value"] = selected + if selected == "gloo": + # PhysicsNeMo normally chooses a CUDA device when CUDA is visible. + # For an explicit Gloo run, keep the example CPU-only so it can be + # used as a portable debug path even on GPU machines. + torch.cuda.is_available = lambda: False + dist.init_process_group = init_process_group_without_cpu_device_id + original_setup( + rank=rank, + world_size=world_size, + local_rank=local_rank, + addr=addr, + port=port, + backend=selected, + method=method, + ) + + DistributedManager.setup = staticmethod(setup) + try: + # This is the recommended PhysicsNeMo entry point for torchrun-launched + # processes. It initializes torch.distributed and populates the manager + # singleton with rank, world-size, local-rank, and device metadata. + DistributedManager.initialize_env() + return DistributedManager(), resolved_backend["value"] + except Exception: + DistributedManager._shared_state = {} + raise + finally: + DistributedManager.setup = staticmethod(original_setup) + torch.cuda.is_available = original_cuda_is_available + dist.init_process_group = original_init_process_group + + +def cleanup_distributed_runtime(manager: DistributedManager) -> None: + """Destroy the process group created by this example.""" + DistributedManager.cleanup() + + +def training_device(manager: DistributedManager) -> torch.device: + """Return the training device implied by the selected backend.""" + return torch.device(manager.device) + + +# --------------------------------------------------------------------------- +# Dummy data +# --------------------------------------------------------------------------- +# The training stack expects ALCHEMI AtomicData/Batch objects. This toy dataset +# creates fixed-size systems so the MLP can flatten positions without padding. + + +class DummyEnergyDataset(Dataset[AtomicData]): + """Deterministic synthetic systems with per-system energy labels.""" + + def __init__(self, *, num_samples: int, num_atoms: int, seed: int) -> None: + self.num_samples = num_samples + self.num_atoms = num_atoms + self.seed = seed + + def __len__(self) -> int: + """Return the number of synthetic samples.""" + return self.num_samples + + def __getitem__(self, index: int) -> AtomicData: + """Generate one deterministic synthetic atomic system.""" + generator = torch.Generator().manual_seed(self.seed + index) + positions = torch.randn(self.num_atoms, 3, generator=generator) + atomic_numbers = torch.ones(self.num_atoms, dtype=torch.long) + # A deliberately learnable target: the model only has to regress a + # smooth function of positions, not a real atomistic potential. + energy = positions.square().sum().view(1, 1) + return AtomicData( + positions=positions, + atomic_numbers=atomic_numbers, + atomic_masses=torch.ones(self.num_atoms), + energy=energy, + forces=torch.zeros(self.num_atoms, 3), + ) + + +def collate_atomic_data(samples: Sequence[AtomicData]) -> Batch: + """Collate synthetic systems into an ALCHEMI batch.""" + return Batch.from_data_list(list(samples)) + + +# --------------------------------------------------------------------------- +# Model wrapper +# --------------------------------------------------------------------------- +# TrainingStrategy works with BaseModelMixin wrappers. The wrapper advertises +# that the model produces "energy"; default_training_fn will therefore expose it +# to the loss as "predicted_energy". + + +class SimpleEnergyMLP(torch.nn.Module, BaseModelMixin): + """Small MLP that predicts one total energy per fixed-size system.""" + + def __init__(self, *, num_atoms: int, hidden_dim: int) -> None: + super().__init__() + self.num_atoms = num_atoms + self.network = torch.nn.Sequential( + torch.nn.Linear(num_atoms * 3, hidden_dim), + torch.nn.SiLU(), + torch.nn.Linear(hidden_dim, hidden_dim), + torch.nn.SiLU(), + torch.nn.Linear(hidden_dim, 1), + ) + self.model_config = ModelConfig( + outputs=frozenset({"energy"}), + autograd_outputs=frozenset(), + autograd_inputs=frozenset(), + required_inputs=frozenset({"positions"}), + optional_inputs=frozenset(), + supports_pbc=False, + needs_pbc=False, + neighbor_config=None, + ) + + @property + def embedding_shapes(self) -> dict[str, tuple[int, ...]]: + """Return no named embeddings for this toy model.""" + return {} + + def compute_embeddings( + self, data: AtomicData | Batch, **kwargs: Any + ) -> AtomicData | Batch: + """Return ``data`` unchanged because the toy MLP has no embeddings.""" + return data + + def forward( + self, data: AtomicData | Batch, **kwargs: Any + ) -> dict[str, torch.Tensor]: + """Predict per-graph energies from flattened atomic positions.""" + num_graphs = data.batch_size if isinstance(data, Batch) else 1 + # The dataset uses a fixed atom count, so every graph has the same + # feature width. Production MLIPs usually avoid this flattening pattern. + features = data.positions.reshape(num_graphs, self.num_atoms * 3) + return {"energy": self.network(features)} + + +# --------------------------------------------------------------------------- +# Rank-zero reporting hooks +# --------------------------------------------------------------------------- +# Hooks keep the example output tied to the actual TrainingStrategy lifecycle: +# SETUP runs after DDPHook has prepared the model/dataloader, and AFTER_BATCH +# runs after each optimizer step. + + +def sampler_description(sampler: Any) -> str: + """Return a compact human-readable dataloader sampler summary.""" + if sampler is None: + return "None" + fields = [ + f"{name}={getattr(sampler, name)}" + for name in ("num_replicas", "rank", "shuffle") + if hasattr(sampler, name) + ] + suffix = f" ({', '.join(fields)})" if fields else "" + return f"{type(sampler).__name__}{suffix}" + + +class RankZeroSetupLogger: + """Explain the distributed training setup once DDPHook has run.""" + + stage = TrainingStage.SETUP + frequency = 1 + + def __init__( + self, + *, + requested_backend: str, + resolved_backend: str, + manager: DistributedManager, + num_samples: int, + num_atoms: int, + batch_size: int, + hidden_dim: int, + lr: float, + ) -> None: + self.requested_backend = requested_backend + self.resolved_backend = resolved_backend + self.manager = manager + self.num_samples = num_samples + self.num_atoms = num_atoms + self.batch_size = batch_size + self.hidden_dim = hidden_dim + self.lr = lr + + def __call__(self, ctx: Any, stage: TrainingStage) -> None: + """Print a rank-zero summary of the setup-stage side effects.""" + if ctx.global_rank != 0: + return + strategy = ctx.workflow + # DDPHook stores the active dataloader on the strategy workflow. Looking + # here lets the log report whether the hook replaced the sampler. + sampler = getattr(getattr(strategy, "active_dataloader", None), "sampler", None) + sampler_status = ( + "DDPHook installed a DistributedSampler" + if isinstance(sampler, DistributedSampler) + else "DDPHook left the dataloader sampler unchanged" + ) + print( + "\nDDP MLP training example\n" + "------------------------\n" + f"requested backend: {self.requested_backend}\n" + f"resolved backend: {self.resolved_backend}\n" + f"world size: {self.manager.world_size}\n" + f"rank-0 device: {self.manager.device}\n" + f"dataset: {self.num_samples} synthetic systems, " + f"{self.num_atoms} atoms each\n" + "target: energy = sum(positions ** 2) per system\n" + f"model: SimpleEnergyMLP(hidden_dim={self.hidden_dim})\n" + f"optimizer: Adam(lr={self.lr})\n" + f"batch size: {self.batch_size} systems per rank\n" + f"sampler after DDP: {sampler_description(sampler)}\n" + f"sampler status: {sampler_status}\n" + "progress log: rank-0 local mini-batch loss after each " + "optimizer step\n", + flush=True, + ) + + +class RankZeroLossLogger: + """Record local losses and print progress on rank zero.""" + + stage = TrainingStage.AFTER_BATCH + frequency = 1 + + def __init__(self, *, every: int) -> None: + self.every = every + self.last_loss: float | None = None + + def __call__(self, ctx: Any, stage: TrainingStage) -> None: + """Record the latest scalar loss and print occasional progress.""" + if ctx.loss is None: + return + self.last_loss = float(ctx.loss.detach().cpu()) + if ctx.global_rank == 0 and ctx.step_count % self.every == 0: + print( + "progress: " + f"optimizer_step={ctx.step_count:03d} " + f"epoch={ctx.epoch:02d} " + f"rank0_local_loss={self.last_loss:.6f}", + flush=True, + ) + + +def mean_across_ranks(value: float, device: torch.device) -> float: + """Return the distributed mean of a scalar value.""" + tensor = torch.tensor(value, device=device) + if dist.is_available() and dist.is_initialized(): + # The progress lines show rank-0 local loss. This final value averages + # the last local loss from every rank so users see one global summary. + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor /= dist.get_world_size() + return float(tensor.cpu()) + + +# --------------------------------------------------------------------------- +# CLI and training assembly +# --------------------------------------------------------------------------- + + +def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: + """Parse command-line arguments for the DDP MLP example.""" + parser = argparse.ArgumentParser( + description="Train a simple MLP with nvalchemi DDPHook on dummy data.", + ) + parser.add_argument( + "--backend", + choices=("auto", "gloo", "nccl"), + default="auto", + help=( + "Distributed backend. auto uses nccl when requested local ranks fit " + "on visible GPUs, otherwise gloo." + ), + ) + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--num-samples", type=int, default=64) + parser.add_argument("--num-atoms", type=int, default=4) + parser.add_argument("--hidden-dim", type=int, default=32) + parser.add_argument("--lr", type=float, default=5e-3) + parser.add_argument("--seed", type=int, default=123) + parser.add_argument("--log-every", type=int, default=2) + return parser.parse_args(argv) + + +def main(argv: Sequence[str] | None = None) -> int: + """Run the DDP MLP training example.""" + args = parse_args(argv) + if not _is_torchrun(): + print( + "This example is intended to run under torchrun. Try:\n" + "uv run --extra cu12 torchrun --standalone --nproc_per_node=2 " + "examples/intermediate/06_ddp_mlp_training.py --backend auto", + flush=True, + ) + return 0 + + manager: DistributedManager | None = None + + try: + # 1. Initialize distributed runtime from torchrun, then let the manager + # tell the rest of the script which rank/device it owns. + manager, backend = setup_distributed_runtime(args.backend) + device = training_device(manager) + torch.manual_seed(args.seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(args.seed) + + # 2. Build ordinary PyTorch data/model pieces. DDPHook will make the + # dataloader distributed-aware during TrainingStage.SETUP. + dataset = DummyEnergyDataset( + num_samples=args.num_samples, + num_atoms=args.num_atoms, + seed=args.seed, + ) + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=True, + collate_fn=collate_atomic_data, + num_workers=0, + ) + logger = RankZeroLossLogger(every=args.log_every) + setup_logger = RankZeroSetupLogger( + requested_backend=args.backend, + resolved_backend=backend, + manager=manager, + num_samples=len(dataset), + num_atoms=args.num_atoms, + batch_size=args.batch_size, + hidden_dim=args.hidden_dim, + lr=args.lr, + ) + # 3. Hand the manager and DDPHook to TrainingStrategy. DDPHook runs + # before optimizer construction, so Adam sees DDP-wrapped parameters. + strategy = TrainingStrategy( + models=SimpleEnergyMLP( + num_atoms=args.num_atoms, + hidden_dim=args.hidden_dim, + ), + optimizer_configs=OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": args.lr}, + ), + num_epochs=args.epochs, + training_fn=default_training_fn, + loss_fn=EnergyMSELoss(), + devices=[device], + distributed_manager=manager, + hooks=[ + DDPHook(backend=backend), + setup_logger, + logger, + ], + ) + + # 4. Run training. The setup logger explains the resolved distributed + # configuration before the first batch, then the loss logger reports + # rank-zero progress after optimizer steps. + strategy.run(dataloader) + if logger.last_loss is not None: + final_loss = mean_across_ranks(logger.last_loss, device) + if manager.rank == 0: + print( + f"summary: mean_final_loss_across_ranks={final_loss:.6f}", + flush=True, + ) + finally: + if manager is not None: + cleanup_distributed_runtime(manager) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/intermediate/07_rich_training_reporting.py b/examples/intermediate/07_rich_training_reporting.py new file mode 100644 index 00000000..3de8975e --- /dev/null +++ b/examples/intermediate/07_rich_training_reporting.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Rich Training Reporting +======================= + +This example drives the Rich reporting dashboard with synthetic training +metrics. The scalar values are deterministic and intentionally lightweight; the +goal is to demonstrate the live terminal UI without requiring a real model, +dataset, or training strategy. + +Run it directly from the repository root to watch the dashboard refresh: + +.. code-block:: bash + + uv run python examples/intermediate/07_rich_training_reporting.py --steps 80 --delay 0.05 +""" + +from __future__ import annotations + +import argparse +import math +import time +from collections.abc import Sequence +from enum import Enum, auto +from types import SimpleNamespace + +import torch + +from nvalchemi.hooks import ReportingOrchestrator, RichReporter, TrainContext + + +class SyntheticTrainingStage(Enum): + """Minimal training-like hook stage enum for this reporting demo.""" + + AFTER_OPTIMIZER_STEP = auto() + + +def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: + """Parse command-line arguments for the Rich reporting demo.""" + parser = argparse.ArgumentParser( + description="Preview the Rich training reporter with synthetic metrics.", + ) + parser.add_argument( + "--steps", + type=int, + default=24, + help="Number of synthetic reporting steps to emit.", + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + help="Number of synthetic epochs represented in the progress panel.", + ) + parser.add_argument( + "--delay", + type=float, + default=0.03, + help="Seconds to sleep between dashboard refreshes.", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0e-3, + help="Initial optimizer learning rate shown in the dashboard.", + ) + parser.add_argument( + "--refresh-per-second", + type=float, + default=8.0, + help="Rich Live refresh rate.", + ) + parser.add_argument( + "--final-delay", + type=float, + default=0.0, + help="Seconds to keep the final dashboard visible before exit.", + ) + return parser.parse_args(argv) + + +def synthetic_losses(step: int, total_steps: int) -> dict[str, float]: + """Return deterministic loss values for one synthetic training step.""" + progress = step / max(total_steps, 1) + energy = 0.70 * math.exp(-3.0 * progress) + 0.04 + forces = 1.10 * math.exp(-2.1 * progress) + 0.08 + ripple = 0.015 * math.sin(step / 2.5) + validation = 0.55 * math.exp(-2.4 * progress) + 0.06 + abs(ripple) + total = 0.25 * energy + 0.75 * forces + ripple + return { + "total": max(total, 0.0), + "energy": max(energy, 0.0), + "forces": max(forces, 0.0), + "validation": max(validation, 0.0), + } + + +def build_context( + *, + step: int, + total_steps: int, + epochs: int, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + workflow: SimpleNamespace, +) -> TrainContext: + """Build a training hook context populated with synthetic metrics.""" + losses = synthetic_losses(step, total_steps) + steps_per_epoch = max(math.ceil(total_steps / max(epochs, 1)), 1) + epoch = min((step - 1) // steps_per_epoch, max(epochs - 1, 0)) + epoch_step = step - epoch * steps_per_epoch + + return TrainContext( + batch=None, + global_rank=0, + workflow=workflow, + step_count=step, + batch_count=step, + epoch_step_count=epoch_step, + epoch=epoch, + loss=torch.tensor(losses["total"]), + losses={ + "total_loss": torch.tensor(losses["total"]), + "validation": torch.tensor(losses["validation"]), + "per_component_unweighted": { + "energy": torch.tensor(losses["energy"]), + "forces": torch.tensor(losses["forces"]), + }, + "per_component_weight": { + "energy": torch.tensor(0.25), + "forces": torch.tensor(0.75), + }, + "per_component_raw_weight": { + "energy": torch.tensor(1.0), + "forces": torch.tensor(3.0), + }, + }, + optimizers=[optimizer], + lr_schedulers=[scheduler], + ) + + +def main(argv: Sequence[str] | None = None) -> int: + """Run the synthetic Rich reporting demo.""" + args = parse_args(argv) + if args.steps < 1: + raise ValueError("--steps must be at least 1.") + if args.epochs < 1: + raise ValueError("--epochs must be at least 1.") + if args.delay < 0: + raise ValueError("--delay must be non-negative.") + if args.final_delay < 0: + raise ValueError("--final-delay must be non-negative.") + + parameter = torch.nn.Parameter(torch.tensor(0.0)) + optimizer = torch.optim.AdamW([parameter], lr=args.lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=max(args.steps, 1), + eta_min=args.lr * 0.08, + ) + workflow = SimpleNamespace(num_steps=args.steps, num_epochs=args.epochs) + stage = SyntheticTrainingStage.AFTER_OPTIMIZER_STEP + reporter = RichReporter( + title="nvalchemi synthetic training", + layout="training", + max_scalars=12, + max_plots=4, + plot_height=6, + plot_keys=( + "loss/total", + "loss/validation", + "loss/energy/unweighted", + "loss/forces/unweighted", + "scheduler/lr", + ), + refresh_per_second=args.refresh_per_second, + transient=False, + ) + reporting = ReportingOrchestrator([reporter], stages={stage}, rank_zero_only=True) + + with reporting: + for step in range(1, args.steps + 1): + losses = synthetic_losses(step, args.steps) + parameter.grad = torch.tensor(losses["total"]) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + scheduler.step() + + ctx = build_context( + step=step, + total_steps=args.steps, + epochs=args.epochs, + optimizer=optimizer, + scheduler=scheduler, + workflow=workflow, + ) + if step == 1: + reporting.state.add_message( + "info", + "synthetic warmup finished", + ctx=ctx, + stage=stage, + ) + elif step == math.ceil(args.steps * 0.55): + reporting.state.add_message( + "info", + "validation curve refreshed", + ctx=ctx, + stage=stage, + ) + reporting(ctx, stage) + time.sleep(args.delay) + + if args.final_delay: + time.sleep(args.final_delay) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/intermediate/README.rst b/examples/intermediate/README.rst index bb636e45..ad5da106 100644 --- a/examples/intermediate/README.rst +++ b/examples/intermediate/README.rst @@ -19,3 +19,11 @@ system_id tracking, ConvergedSnapshotHook collecting results. **05 — Safety and Monitoring**: NaNDetectorHook, MaxForceClampHook, EnergyDriftMonitorHook, ProfilerHook — defensive MD patterns. + +**06 — DDP MLP Training**: DDPHook with a simple MLP, dummy AtomicData, +single-node ``torchrun`` launch, and ``auto``/``gloo``/``nccl`` backend +selection. + +**07 — Rich Training Reporting**: Live Rich dashboard driven by synthetic +training losses, validation metrics, progress counters, and learning-rate +scheduler values. diff --git a/nvalchemi/_optional.py b/nvalchemi/_optional.py index 52b308ca..797a790d 100644 --- a/nvalchemi/_optional.py +++ b/nvalchemi/_optional.py @@ -93,6 +93,7 @@ def needs_pymatgen(): PYMATGEN = ("pymatgen", "nvalchemi-toolkit[pymatgen]") MACE = ("mace", "nvalchemi-toolkit[mace]") AIMNET = ("aimnet", "nvalchemi-toolkit[aimnet]") + TENSORBOARD = ("tensorboard", "nvalchemi-toolkit[tensorboard]") def __init__(self, import_name: str, install_target: str) -> None: self.import_name = import_name diff --git a/nvalchemi/_serialization.py b/nvalchemi/_serialization.py new file mode 100644 index 00000000..c8850552 --- /dev/null +++ b/nvalchemi/_serialization.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared no-pickle serialization helpers.""" + +from __future__ import annotations + +import importlib +import inspect +from collections.abc import Callable +from functools import lru_cache +from types import NoneType, UnionType +from typing import Annotated, Any, Union, get_args, get_origin + +import torch +from pydantic import BeforeValidator, PlainSerializer + +_TYPE_SERIALIZERS: dict[type, tuple[Callable[[Any], Any], Callable[[Any], Any]]] = {} +"""Registry mapping a type to its ``(serialize, deserialize)`` callable pair.""" + + +def register_type_serializer( + type_: type, + serialize: Callable[[Any], Any], + deserialize: Callable[[Any], Any], +) -> None: + """Register JSON (de)serializers for a custom type. + + Parameters + ---------- + type_ + The Python type to register, for example :class:`torch.dtype`. + serialize + Callable converting a ``type_`` instance to a JSON-safe value. + deserialize + Callable converting the JSON-safe value back into a ``type_`` instance. + """ + _TYPE_SERIALIZERS[type_] = (serialize, deserialize) + + +def _wrap_custom_type(t: type) -> Any: + """Wrap a registered type in an ``Annotated[...]`` with Pydantic hooks.""" + ser, deser = _TYPE_SERIALIZERS[t] + + def _before(v: Any) -> Any: + return v if isinstance(v, t) else deser(v) + + return Annotated[t, BeforeValidator(_before), PlainSerializer(ser)] + + +def _dtype_deserialize(s: Any) -> torch.dtype: + """Rehydrate a :class:`torch.dtype` from its string form with a type guard.""" + if isinstance(s, torch.dtype): + return s + if not isinstance(s, str): + raise TypeError( + f"torch.dtype deserializer expected str, got {type(s).__name__}" + ) + result = getattr(torch, s.removeprefix("torch."), None) + if not isinstance(result, torch.dtype): + raise ValueError( + f"{s!r} does not resolve to a torch.dtype " + "(defense-in-depth against attacker-controlled JSON smuggling " + "non-dtype torch.* attributes)." + ) + return result + + +def _tensor_serialize(t: torch.Tensor) -> dict[str, Any]: + """Serialize a :class:`torch.Tensor` as ``{data, dtype, shape}``.""" + return { + "data": t.detach().cpu().tolist(), + "dtype": str(t.dtype), + "shape": list(t.shape), + } + + +def _tensor_deserialize(v: Any) -> torch.Tensor: + """Rehydrate a :class:`torch.Tensor` from its ``{data, dtype, shape}`` dict.""" + if isinstance(v, torch.Tensor): + return v + if not isinstance(v, dict): + raise TypeError(f"Cannot deserialize torch.Tensor from {type(v).__name__}") + dtype = _dtype_deserialize(v["dtype"]) + out = torch.tensor(v["data"], dtype=dtype) + expected_shape = tuple(v["shape"]) + if tuple(out.shape) != expected_shape: + out = out.reshape(expected_shape) + return out + + +register_type_serializer( + torch.dtype, + serialize=str, + deserialize=_dtype_deserialize, +) +register_type_serializer( + torch.device, + serialize=str, + deserialize=lambda s: s if isinstance(s, torch.device) else torch.device(s), +) +register_type_serializer(torch.Tensor, _tensor_serialize, _tensor_deserialize) + + +@lru_cache(maxsize=None) +def _import_object(path: str) -> Any: + """Import an object identified by a dotted module/attribute path.""" + parts = path.split(".") + module: Any = None + module_depth = 0 + for i in range(1, len(parts)): + try: + module = importlib.import_module(".".join(parts[:i])) + except ModuleNotFoundError: + break + module_depth = i + if module is None: + raise ModuleNotFoundError( + f"Could not import any module prefix of {path!r}. " + "Expected a dotted path like 'pkg.mod.Object' or " + "'pkg.mod.Outer.method'." + ) + obj: Any = module + for part in parts[module_depth:]: + obj = getattr(obj, part) + return obj + + +@lru_cache(maxsize=None) +def _import_cls(cls_path: str) -> type: + """Import the class identified by a dotted path.""" + obj = _import_object(cls_path) + if not isinstance(obj, type): + raise TypeError(f"{cls_path!r} resolved to non-class {obj!r}") + return obj + + +@lru_cache(maxsize=None) +def _import_callable(target_path: str) -> Callable[..., Any]: + """Import the callable identified by a dotted path.""" + obj = _import_object(target_path) + if not callable(obj): + raise TypeError(f"{target_path!r} resolved to non-callable {obj!r}") + return obj + + +def _callable_path_of(target: Callable[..., Any]) -> str: + """Return the canonical dotted path (``module.QualName``) for ``target``.""" + module = getattr(target, "__module__", None) + qualname = getattr(target, "__qualname__", None) + if not module or not qualname or "" in qualname or "" in qualname: + raise TypeError( + f"{target!r} is not an importable callable. Specs require a " + "module-level class, function, staticmethod, or classmethod." + ) + return f"{module}.{qualname}" + + +def _cls_path_of(cls_: type) -> str: + """Return the canonical dotted path (``module.QualName``) for ``cls_``.""" + return _callable_path_of(cls_) + + +@lru_cache(maxsize=None) +def _callable_signature(target: Callable[..., Any]) -> inspect.Signature: + """Return the string-annotation-resolved signature for ``target``.""" + return inspect.signature(target, eval_str=True) + + +@lru_cache(maxsize=None) +def _constructor_signature(cls_: type) -> inspect.Signature: + """Return the string-annotation-resolved constructor signature for ``cls_``.""" + return _callable_signature(cls_) + + +def _extract_init_kwargs_from_attrs(instance: Any) -> dict[str, Any]: + """Extract constructor kwargs from matching attributes on ``instance``.""" + sig = _constructor_signature(type(instance)) + kwargs: dict[str, Any] = {} + for name, param in sig.parameters.items(): + if name == "self" or param.kind in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + }: + continue + try: + kwargs[name] = getattr(instance, name) + except AttributeError: + continue + return kwargs + + +def _serialize_type(value: type | None) -> str | None: + """Serialize a class to its dotted path; pass ``None`` through.""" + if value is None: + return None + return _cls_path_of(value) + + +def _validate_type(value: Any) -> Any: + """Accept a ``type`` or dotted-path string; convert strings to classes.""" + if value is None or isinstance(value, type): + return value + if _is_tagged_type(value): + return _deserialize_tagged_type(value) + if isinstance(value, str): + try: + return _import_cls(value) + except (ImportError, AttributeError, TypeError) as exc: + raise ValueError(f"{value!r} must resolve to an importable class.") from exc + return value + + +def _is_class_type_annotation(annotation: Any) -> bool: + """Return whether ``annotation`` accepts a class object.""" + if annotation is type: + return True + return get_origin(annotation) is type + + +def _is_optional_class_type_annotation(annotation: Any) -> bool: + """Return whether ``annotation`` accepts a class object or ``None``.""" + origin = get_origin(annotation) + if origin not in {Union, UnionType}: + return False + args = get_args(annotation) + non_none_args = [arg for arg in args if arg is not NoneType] + return len(non_none_args) == 1 and _is_class_type_annotation(non_none_args[0]) + + +def _wrap_class_type_annotation(annotation: Any) -> Any: + """Wrap class-object annotations with dotted-path Pydantic hooks.""" + return Annotated[ + annotation, + BeforeValidator(_validate_type), + PlainSerializer(_serialize_type), + ] + + +def _serialize_tagged_type(value: type) -> dict[str, str]: + """Serialize an inferred class value with an explicit type tag.""" + return {"__type__": _cls_path_of(value)} + + +def _is_tagged_type(value: Any) -> bool: + """Return whether ``value`` is a tagged class serialization payload.""" + return isinstance(value, dict) and set(value) == {"__type__"} + + +def _deserialize_tagged_type(value: Any) -> type: + """Deserialize a tagged class serialization payload.""" + if isinstance(value, type): + return value + if not _is_tagged_type(value): + raise TypeError( + f"tagged type deserializer expected {{'__type__': str}}, " + f"got {type(value).__name__}" + ) + cls_path = value["__type__"] + if not isinstance(cls_path, str): + raise TypeError(f"tagged type path must be str, got {type(cls_path).__name__}") + return _deserialize_type(cls_path) + + +SerializableTaggedClass = Annotated[ + type, + BeforeValidator(_deserialize_tagged_type), + PlainSerializer(_serialize_tagged_type), +] +"""``type`` annotation for inferred class fields using tagged JSON.""" + + +def _is_serializable_class_annotation(annotation: Any) -> bool: + """Return whether ``annotation`` should use class dotted-path hooks.""" + return _is_class_type_annotation(annotation) or _is_optional_class_type_annotation( + annotation + ) + + +def _deserialize_type(value: Any) -> type: + """Deserialize a class object from a dotted path for the type registry.""" + if isinstance(value, type): + return value + if not isinstance(value, str): + raise TypeError( + f"type deserializer expected str or type, got {type(value).__name__}" + ) + try: + return _import_cls(value) + except (ImportError, AttributeError, TypeError) as exc: + raise ValueError( + f"{value!r} is not a dotted path resolving to a class." + ) from exc + + +register_type_serializer( + type, + serialize=_serialize_type, + deserialize=_deserialize_type, +) + + +SerializableClass = Annotated[ + type, + BeforeValidator(_validate_type), + PlainSerializer(_serialize_type), +] +"""``type`` field annotation that round-trips via dotted-path strings.""" + +SerializableOptionalClass = Annotated[ + type | None, + BeforeValidator(_validate_type), + PlainSerializer(_serialize_type), +] +"""``type | None`` field annotation that round-trips via dotted-path strings.""" diff --git a/nvalchemi/data/batch.py b/nvalchemi/data/batch.py index bb33afc9..2140a26b 100644 --- a/nvalchemi/data/batch.py +++ b/nvalchemi/data/batch.py @@ -64,6 +64,121 @@ _OWN_ATTRS = frozenset({"device", "keys", "_storage", "_data_class"}) +def _build_batch_storage( + samples: Iterator[tuple[Iterator[tuple[str, Tensor]], int, int]], + *, + node_keys: frozenset[str] | set[str], + edge_keys: frozenset[str] | set[str], + system_keys: frozenset[str] | set[str], + device: torch.device, + validate: bool, + attr_map: LevelSchema, + field_levels: dict[str, str] | None = None, + fallback_level: str | None = None, +) -> tuple[MultiLevelStorage, dict[str, set[str]]]: + """Shared batch-construction pipeline for from_data_list / from_raw_dicts. + + Parameters + ---------- + samples : Iterator + Yields ``(key_value_pairs, num_nodes, num_edges)`` per sample. + ``key_value_pairs`` is an iterator of ``(key, tensor)`` pairs. + node_keys, edge_keys, system_keys : set-like + Key sets for level classification. + device : torch.device + Target device for tensors. + validate : bool + Whether to validate storage shapes. + attr_map : LevelSchema + Attribute registry. + field_levels : dict[str, str] or None, default=None + Explicit per-field level overrides (``"atom"`` / ``"edge"`` / + ``"system"``), typically from reader metadata. Checked for keys + not found in the static key sets. + fallback_level : str or None, default=None + Level to assign keys not in any key set and not in + *field_levels*. ``"system"`` for raw-dict paths. + ``None`` to silently drop unclassified keys. + + Returns + ------- + tuple[MultiLevelStorage, dict[str, set[str]]] + The constructed storage and tracked key sets. + """ + node_tensors: dict[str, list[Tensor]] = defaultdict(list) + edge_tensors: dict[str, list[Tensor]] = defaultdict(list) + system_tensors: dict[str, list[Tensor]] = defaultdict(list) + node_counts: list[int] = [] + edge_counts: list[int] = [] + + def _classify(key: str) -> str | None: + if key in node_keys: + return "atom" + if key in edge_keys: + return "edge" + if key in system_keys: + return "system" + if field_levels is not None and key in field_levels: + return field_levels[key] + return fallback_level + + node_offset = 0 + for key_value_pairs, n_nodes, n_edges in samples: + node_counts.append(n_nodes) + edge_counts.append(n_edges) + for key, value in key_value_pairs: + level = _classify(key) + if level is None: + continue + value = value.to(device, non_blocking=True) + if level == "atom": + node_tensors[key].append(value) + elif level == "edge": + if key in _INDEX_KEYS: + value = value + node_offset + edge_tensors[key].append(value) + else: + system_tensors[key].append(value) + node_offset += n_nodes + + atoms_data = {k: torch.cat(v, dim=0) for k, v in node_tensors.items()} + edges_data = {k: torch.cat(v, dim=0) for k, v in edge_tensors.items()} + system_data = {k: torch.cat(v, dim=0) for k, v in system_tensors.items()} + + groups: dict[str, UniformLevelStorage | SegmentedLevelStorage] = {} + if atoms_data: + groups["atoms"] = SegmentedLevelStorage( + data=atoms_data, + device=device, + segment_lengths=node_counts, + validate=validate, + attr_map=attr_map, + ) + if edges_data: + groups["edges"] = SegmentedLevelStorage( + data=edges_data, + device=device, + segment_lengths=edge_counts, + validate=validate, + attr_map=attr_map, + ) + if system_data: + groups["system"] = UniformLevelStorage( + data=system_data, + device=device, + validate=validate, + attr_map=attr_map, + ) + + storage = MultiLevelStorage(groups=groups, attr_map=attr_map, validate=validate) + tracked_keys = { + "node": set(node_tensors.keys()), + "edge": set(edge_tensors.keys()), + "system": set(system_tensors.keys()), + } + return storage, tracked_keys + + class Batch(DataMixin): """Graph-aware batch built on :class:`MultiLevelStorage`. @@ -277,6 +392,7 @@ def from_data_list( skip_validation: bool = False, attr_map: LevelSchema | None = None, exclude_keys: list[str] | None = None, + field_levels: dict[str, str] | None = None, ) -> Batch: """Construct a batch from a list of :class:`AtomicData` objects. @@ -292,6 +408,10 @@ def from_data_list( Attribute registry. Defaults to ``LevelSchema()``. exclude_keys : list[str], optional Keys to exclude from batching. + field_levels : dict[str, str], optional + Explicit per-field level map (``"atom"`` / ``"edge"`` / + ``"system"``), typically from :attr:`Reader.field_levels`. + Used to classify custom keys not in the data class key sets. Returns ------- @@ -309,82 +429,32 @@ def from_data_list( representative = data_list[0] data_cls = representative.__class__ - node_keys = representative.__node_keys__ - edge_keys = representative.__edge_keys__ - system_keys = representative.__system_keys__ + node_key_set = representative.__node_keys__ + edge_key_set = representative.__edge_keys__ + system_key_set = representative.__system_keys__ excluded = _EXCLUDED_KEYS | set(exclude_keys or []) actual_keys = set(data_list[0].model_dump(exclude_none=True).keys()) - excluded - node_tensors: dict[str, list[Tensor]] = defaultdict(list) - edge_tensors: dict[str, list[Tensor]] = defaultdict(list) - system_tensors: dict[str, list[Tensor]] = defaultdict(list) - node_counts: list[int] = [] - edge_counts: list[int] = [] - - node_offset = 0 - for data in data_list: - n_nodes = data.num_nodes - n_edges = data.num_edges - node_counts.append(n_nodes) - edge_counts.append(n_edges) - - for key in actual_keys: - value = getattr(data, key, None) - if not isinstance(value, Tensor): - continue - value = value.to(device) - - if key in node_keys: - node_tensors[key].append(value) - elif key in edge_keys: - if key in _INDEX_KEYS: - value = value + node_offset - edge_tensors[key].append(value) - elif key in system_keys: - system_tensors[key].append(value) - - node_offset += n_nodes - - atoms_data = {k: torch.cat(v, dim=0) for k, v in node_tensors.items()} - edges_data: dict[str, Tensor] = {} - for k, v in edge_tensors.items(): - edges_data[k] = torch.cat(v, dim=0) - system_data = {k: torch.cat(v, dim=0) for k, v in system_tensors.items()} - - validate = not skip_validation - groups: dict[str, UniformLevelStorage | SegmentedLevelStorage] = {} - if atoms_data: - groups["atoms"] = SegmentedLevelStorage( - data=atoms_data, - device=device, - segment_lengths=node_counts, - validate=validate, - attr_map=attr_map, - ) - if edges_data: - groups["edges"] = SegmentedLevelStorage( - data=edges_data, - device=device, - segment_lengths=edge_counts, - validate=validate, - attr_map=attr_map, - ) - if system_data: - groups["system"] = UniformLevelStorage( - data=system_data, - device=device, - validate=validate, - attr_map=attr_map, - ) - - storage = MultiLevelStorage(groups=groups, attr_map=attr_map, validate=validate) + def _iter_samples() -> Iterator[tuple[Iterator[tuple[str, Tensor]], int, int]]: + for data in data_list: + pairs = ( + (key, value) + for key in actual_keys + if isinstance((value := getattr(data, key, None)), Tensor) + ) + yield pairs, data.num_nodes, data.num_edges - tracked_keys = { - "node": set(node_tensors.keys()), - "edge": set(edge_tensors.keys()), - "system": set(system_tensors.keys()), - } + storage, tracked_keys = _build_batch_storage( + _iter_samples(), + node_keys=node_key_set, + edge_keys=edge_key_set, + system_keys=system_key_set, + device=device, + validate=not skip_validation, + attr_map=attr_map, + field_levels=field_levels, + ) batch = cls._construct( device=device, keys=tracked_keys, @@ -393,6 +463,100 @@ def from_data_list( ) return batch._make_contiguous() + @classmethod + def from_raw_dicts( + cls, + data_list: list[dict[str, Tensor]], + device: torch.device | str | None = None, + attr_map: LevelSchema | None = None, + exclude_keys: list[str] | None = None, + field_levels: dict[str, str] | None = None, + ) -> Batch: + """Construct a batch directly from raw tensor dictionaries. + + Bypasses :class:`AtomicData` construction and Pydantic validation + entirely, using ``AtomicData._default_*_keys`` for level + classification. Keys not found in the default key sets are + classified using *field_levels* (e.g. from + :attr:`Reader.field_levels`). This is significantly faster when + the data is already known to be well-formed (e.g. read from a + validated Zarr store). + + Parameters + ---------- + data_list : list[dict[str, Tensor]] + Per-sample tensor dictionaries. + device : torch.device | str, optional + Target device. Inferred from first dict if ``None``. + attr_map : LevelSchema, optional + Attribute registry. Defaults to ``LevelSchema()``. + exclude_keys : list[str], optional + Keys to exclude from batching. + field_levels : dict[str, str], optional + Explicit per-field level map (``"atom"`` / ``"edge"`` / + ``"system"``), typically from :attr:`Reader.field_levels`. + Used to classify custom keys not in the default key sets. + + Returns + ------- + Batch + """ + if not data_list: + raise ValueError("Cannot create batch from empty data list") + + first = data_list[0] + if device is None: + for v in first.values(): + if isinstance(v, Tensor): + device = v.device + break + else: + device = torch.device("cpu") + device = torch.device(device) if isinstance(device, str) else device + + if attr_map is None: + attr_map = LevelSchema() + + node_key_set = AtomicData._default_node_keys + edge_key_set = AtomicData._default_edge_keys + system_key_set = AtomicData._default_system_keys + + excluded = _EXCLUDED_KEYS | set(exclude_keys or []) + actual_keys = [ + k for k in first if k not in excluded and isinstance(first[k], Tensor) + ] + + def _iter_samples() -> Iterator[tuple[Iterator[tuple[str, Tensor]], int, int]]: + for data in data_list: + n_nodes = data["atomic_numbers"].shape[0] + nl = data.get("neighbor_list") + n_edges = nl.shape[0] if isinstance(nl, Tensor) else 0 + pairs = ( + (key, value) + for key in actual_keys + if isinstance((value := data.get(key)), Tensor) + ) + yield pairs, n_nodes, n_edges + + storage, tracked_keys = _build_batch_storage( + _iter_samples(), + node_keys=node_key_set, + edge_keys=edge_key_set, + system_keys=system_key_set, + device=device, + validate=False, + attr_map=attr_map, + field_levels=field_levels, + fallback_level="system", + ) + batch = cls._construct( + device=device, + keys=tracked_keys, + storage=storage, + data_class=AtomicData, + ) + return batch._make_contiguous() + @classmethod def empty( cls, @@ -1114,14 +1278,14 @@ def to( dtype : torch.dtype, optional Ignored (present for API compatibility). non_blocking : bool - Ignored (present for API compatibility). + Whether tensor copies may be asynchronous when supported. Returns ------- Batch """ new = self.clone() - new._storage.to_device(device) + new._storage.to_device(device, non_blocking=non_blocking) new.device = torch.device(device) if isinstance(device, str) else device return new diff --git a/nvalchemi/data/datapipes/__init__.py b/nvalchemi/data/datapipes/__init__.py index c5257f55..70c0aa6e 100644 --- a/nvalchemi/data/datapipes/__init__.py +++ b/nvalchemi/data/datapipes/__init__.py @@ -31,7 +31,11 @@ [optional] per-sample transforms, prefetch) | DataLoader - (collate Batch, [optional] per-batch transforms, iterate) + (Dataset.load_batches -> Batch, + [optional] per-batch transforms, iteration) + + MultiDataset can wrap several Dataset instances behind one global + index space while preserving the same batch-loading contract. **Writer** (:class:`AtomicDataZarrWriter`) serializes ``AtomicData`` or ``Batch`` objects into a structured Zarr store with CSR-style pointer @@ -45,14 +49,21 @@ handling device transfers and optional CUDA-stream prefetching. It also applies optional per-sample transforms after device transfer; see :class:`~nvalchemi.data.transforms.Compose`, passed via the -``transforms=`` kwarg. +``transforms=`` kwarg. Its canonical explicit batch API is +:meth:`~nvalchemi.data.datapipes.dataset.Dataset.load_batches`, which +uses fused ``read_many`` requests and returns one ``Batch`` per requested +batch-index list. **DataLoader** iterates over a Dataset in batches, collating -``AtomicData`` samples into ``Batch`` objects via -:meth:`~nvalchemi.data.batch.Batch.from_data_list`. Optional per-batch -transforms run on the collated batch; see -:class:`~nvalchemi.data.transforms.Compose`, passed via the -``batch_transforms=`` kwarg. +``AtomicData`` samples into ``Batch`` objects through the Dataset batch +loader. Positive ``prefetch_factor`` values fuse several emitted batches +into one background read window. Optional per-batch transforms run on the +collated batch; see :class:`~nvalchemi.data.transforms.Compose`, passed +via the ``batch_transforms=`` kwarg. + +**MultiDataset** composes multiple Dataset instances and routes +``load_batches`` requests to the owning child datasets before restoring +the requested global sample order. """ from __future__ import annotations @@ -66,6 +77,12 @@ ) from nvalchemi.data.datapipes.dataloader import DataLoader from nvalchemi.data.datapipes.dataset import Dataset +from nvalchemi.data.datapipes.multidataset import MultiDataset +from nvalchemi.data.datapipes.samplers import ( + DistributedSamplerProtocol, + MultiDatasetBatchSampler, + MultiDatasetSampler, +) __all__ = [ # Backends @@ -76,5 +93,9 @@ "ZarrWriteConfig", # Pipeline "Dataset", + "MultiDataset", + "DistributedSamplerProtocol", + "MultiDatasetSampler", + "MultiDatasetBatchSampler", "DataLoader", ] diff --git a/nvalchemi/data/datapipes/backends/base.py b/nvalchemi/data/datapipes/backends/base.py index 51c24090..57b2381c 100644 --- a/nvalchemi/data/datapipes/backends/base.py +++ b/nvalchemi/data/datapipes/backends/base.py @@ -18,15 +18,16 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from typing import Any import torch +from physicsnemo.datapipes.readers.base import Reader as PhysicsNeMoReader logger = logging.getLogger(__name__) -class Reader(ABC): +class Reader(PhysicsNeMoReader, ABC): """Abstract base class for data readers. Readers are intentionally simple and transactional: @@ -35,7 +36,10 @@ class Reader(ABC): - Return ``(dict[str, torch.Tensor], metadata_dict)`` tuples with CPU tensors - No threading, no prefetching, no device transfers - Subclasses must implement :meth:`_load_sample` and :meth:`__len__`. + Subclasses must implement :meth:`__len__` and at least one loading hook: + :meth:`_load_sample` for simple single-sample readers, or + :meth:`_load_many_samples` for readers that can amortize I/O across a + group of samples. Parameters ---------- @@ -61,6 +65,7 @@ def __init__( *, pin_memory: bool = False, include_index_in_metadata: bool = True, + coordinated_subsampling: dict[str, Any] | None = None, ) -> None: """Initialize the Reader base class. @@ -71,11 +76,15 @@ def __init__( async CPU→GPU transfers. include_index_in_metadata : bool, default=True If True, automatically add ``"index"`` to each sample's metadata dict. + coordinated_subsampling : dict[str, Any] | None, optional + PhysicsNeMo-compatible coordinated subsampling configuration. """ - self.pin_memory = pin_memory - self.include_index_in_metadata = include_index_in_metadata + super().__init__( + pin_memory=pin_memory, + include_index_in_metadata=include_index_in_metadata, + coordinated_subsampling=coordinated_subsampling, + ) - @abstractmethod def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load raw tensor data for a single sample. @@ -89,7 +98,44 @@ def _load_sample(self, index: int) -> dict[str, torch.Tensor]: dict[str, torch.Tensor] Mapping of field names to CPU tensors. """ - raise NotImplementedError + if type(self)._load_many_samples is not Reader._load_many_samples: + data_dicts = self._load_many_samples([index]) + if len(data_dicts) != 1: + raise RuntimeError( + f"{type(self).__name__}._load_many_samples returned " + f"{len(data_dicts)} samples for one index" + ) + return data_dicts[0] + raise NotImplementedError( + f"{type(self).__name__} must implement _load_sample() or " + "_load_many_samples()." + ) + + def _load_many_samples( + self, indices: Sequence[int] + ) -> list[dict[str, torch.Tensor]]: + """Load raw tensor data for multiple samples. + + The default implementation loops over :meth:`_load_sample`. Backends + can override this hook to coalesce physical I/O while keeping + metadata and optional pinned memory in the base class. + + Parameters + ---------- + indices : Sequence[int] + Sample indices to load. + + Returns + ------- + list[dict[str, torch.Tensor]] + Raw tensor dictionaries in requested order. + """ + if type(self)._load_sample is Reader._load_sample: + raise NotImplementedError( + f"{type(self).__name__} must implement _load_sample() or " + "_load_many_samples()." + ) + return [self._load_sample(index) for index in indices] @abstractmethod def __len__(self) -> int: @@ -102,6 +148,22 @@ def __len__(self) -> int: """ raise NotImplementedError + @property + def field_levels(self) -> dict[str, str]: + """Per-field level classification: ``"atom"``, ``"edge"``, or ``"system"``. + + Override in subclasses that store explicit level metadata (e.g. + Zarr stores). The default returns an empty dict, which causes + downstream consumers to fall back to + :data:`AtomicData._default_*_keys` for classification. + + Returns + ------- + dict[str, str] + Mapping of field name to level string. + """ + return {} + def _get_field_names(self) -> list[str]: """Return field names by inspecting the first sample. @@ -112,7 +174,7 @@ def _get_field_names(self) -> list[str]: """ if len(self) == 0: return [] - data = self._load_sample(0) + data, _metadata = self.read(0) return list(data.keys()) def _get_sample_metadata(self, index: int) -> dict[str, Any]: @@ -144,16 +206,29 @@ def field_names(self) -> list[str]: """ return self._get_field_names() - def __getitem__(self, index: int) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + def _finalize_sample( + self, index: int, data_dict: dict[str, torch.Tensor] + ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + """Attach metadata and optional pinned memory to loaded sample data.""" + metadata = self._get_sample_metadata(index) + if self.include_index_in_metadata: + metadata.setdefault("index", index) + + if self.pin_memory: + data_dict = {k: v.pin_memory() for k, v in data_dict.items()} + + return data_dict, metadata + + def read(self, index: int) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: """Load a sample and its metadata by index. - Handles negative indexing, bounds checking, optional pin-memory, - and automatic index injection into metadata. + Handles optional pin-memory and automatic index injection into + metadata. Index validity is determined by the concrete reader. Parameters ---------- index : int - Sample index. Negative values are supported. + Sample index. Concrete readers determine supported values. Returns ------- @@ -163,24 +238,67 @@ def __getitem__(self, index: int) -> tuple[dict[str, torch.Tensor], dict[str, An Raises ------ IndexError - If *index* is out of range. + If the concrete reader considers *index* out of range. """ - if index < 0: - index = len(self) + index - if index < 0 or index >= len(self): - raise IndexError( - f"Index {index} out of range for reader with {len(self)} samples" + data_dict = self._load_sample(index) + return self._finalize_sample(index, data_dict) + + def read_many( + self, indices: Sequence[int] + ) -> list[tuple[dict[str, torch.Tensor], dict[str, Any]]]: + """Load multiple samples and their metadata. + + The default implementation delegates raw tensor loading to + :meth:`_load_many_samples`, and then attaches metadata and optional + pinned memory. Backend implementations should override + :meth:`_load_many_samples` instead of this method. Index validity is + determined by the concrete reader. + + Parameters + ---------- + indices : Sequence[int] + Sample indices to load. Concrete readers determine supported values. + + Returns + ------- + list[tuple[dict[str, torch.Tensor], dict[str, Any]]] + Ordered ``(data_dict, metadata)`` pairs with CPU tensors. + + Raises + ------ + IndexError + If the concrete reader considers any requested index out of range. + """ + data_dicts = self._load_many_samples(indices) + if len(data_dicts) != len(indices): + raise RuntimeError( + f"{type(self).__name__}._load_many_samples returned " + f"{len(data_dicts)} samples for {len(indices)} indices" ) + return [ + self._finalize_sample(index, data_dict) + for index, data_dict in zip(indices, data_dicts, strict=True) + ] - data_dict = self._load_sample(index) - metadata = self._get_sample_metadata(index) - if self.include_index_in_metadata: - metadata["index"] = index + def __getitem__(self, index: int) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + """Load a sample and its metadata by index. - if self.pin_memory: - data_dict = {k: v.pin_memory() for k, v in data_dict.items()} + Parameters + ---------- + index : int + Sample index. Concrete readers determine supported values. - return data_dict, metadata + Returns + ------- + tuple[dict[str, torch.Tensor], dict[str, Any]] + ``(data_dict, metadata)`` pair with CPU tensors. + + Raises + ------ + IndexError + If the concrete reader considers *index* out of range. + """ + return self.read(index) def __iter__(self) -> Iterator[tuple[dict[str, torch.Tensor], dict[str, Any]]]: """Iterate over all samples sequentially. diff --git a/nvalchemi/data/datapipes/backends/zarr.py b/nvalchemi/data/datapipes/backends/zarr.py index 913e3ea9..8b1eb120 100644 --- a/nvalchemi/data/datapipes/backends/zarr.py +++ b/nvalchemi/data/datapipes/backends/zarr.py @@ -31,7 +31,7 @@ from __future__ import annotations import re -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from pathlib import Path from typing import Annotated, Any, Literal, TypeAlias @@ -52,8 +52,6 @@ # Type alias for zarr store-like objects StoreLike: TypeAlias = Store | StorePath | Path | str | dict[str, Any] -# TODO: make classes inherit from PNM when stable - class ZarrArrayConfig(BaseModel): """Configuration for Zarr array compression, chunking, and sharding. @@ -210,6 +208,165 @@ def _get_field_level(key: str) -> str: return "atom" +# --------------------------------------------------------------------------- +# Gap-merge run construction +# --------------------------------------------------------------------------- +# +# Policy: merge adjacent sorted physical indices into contiguous ranges when +# the gap between them is <= *gap_threshold* (defaults to the batch size). +# This reduces the number of Zarr codec-pipeline / shard-index round trips +# — the dominant cost for random access — at the expense of reading some +# unrequested rows ("read amplification"). +# +# To keep amplification bounded, each merged range is capped so that +# span / requested_count <= max_amplification +# where *span* is `last_physical - first_physical + 1` and +# *requested_count* is the number of positions in the run. Default +# cap is 8x, meaning we never decompress more than 8x the data we +# actually need within a single range. +_DEFAULT_MAX_AMPLIFICATION: int = 8 + + +def _leading_storage_size(arr: Any) -> int | None: + """Return the leading Zarr storage-object length when available.""" + metadata = getattr(arr, "metadata", None) + chunk_grid = getattr(metadata, "chunk_grid", None) + chunk_shape = getattr(chunk_grid, "chunk_shape", None) + if chunk_shape is not None and len(chunk_shape) > 0: + return int(chunk_shape[0]) + + shards = getattr(arr, "shards", None) + if shards is not None and len(shards) > 0 and shards[0] is not None: + return int(shards[0]) + + chunks = getattr(arr, "chunks", None) + if chunks is not None and len(chunks) > 0 and chunks[0] is not None: + return int(chunks[0]) + + return None + + +def _chunk_span_for_slice( + start: int, end: int, chunk_size: int +) -> tuple[int, int] | None: + """Return inclusive leading-axis chunk span for a half-open row slice.""" + if end <= start: + return None + return start // chunk_size, (end - 1) // chunk_size + + +def _sample_chunk_spans( + physical_idx: int, + fields: Sequence[tuple[str, str, Any]], + atoms_ptr: torch.Tensor, + edges_ptr: torch.Tensor, +) -> list[tuple[int, int, int]]: + """Return per-field chunk spans touched by one physical sample.""" + spans: list[tuple[int, int, int]] = [] + + atom_start = int(atoms_ptr[physical_idx].item()) + atom_end = int(atoms_ptr[physical_idx + 1].item()) + edge_start = int(edges_ptr[physical_idx].item()) + edge_end = int(edges_ptr[physical_idx + 1].item()) + + for field_idx, (_key, level, arr) in enumerate(fields): + if level == "atom": + start, end = atom_start, atom_end + elif level == "edge": + start, end = edge_start, edge_end + else: + continue + + chunk_size = _leading_storage_size(arr) + if chunk_size is None or chunk_size <= 0: + continue + chunk_span = _chunk_span_for_slice(start, end, chunk_size) + if chunk_span is not None: + spans.append((field_idx, *chunk_span)) + + return spans + + +def _spans_overlap( + run_spans: Mapping[int, tuple[int, int]], + sample_spans: Sequence[tuple[int, int, int]], +) -> bool: + """Return True when a sample touches a chunk already covered by a run.""" + for field_idx, first, last in sample_spans: + if field_idx not in run_spans: + continue + run_first, run_last = run_spans[field_idx] + if first <= run_last and last >= run_first: + return True + return False + + +def _merge_chunk_spans( + run_spans: dict[int, tuple[int, int]], + sample_spans: Sequence[tuple[int, int, int]], +) -> None: + """Extend run chunk spans in-place with spans from another sample.""" + for field_idx, first, last in sample_spans: + if field_idx not in run_spans: + run_spans[field_idx] = (first, last) + continue + run_first, run_last = run_spans[field_idx] + run_spans[field_idx] = (min(run_first, first), max(run_last, last)) + + +def _merge_physical_runs_by_chunks( + sorted_physical: Sequence[int], + fields: Sequence[tuple[str, str, Any]], + atoms_ptr: torch.Tensor, + edges_ptr: torch.Tensor, + *, + max_amplification: int = _DEFAULT_MAX_AMPLIFICATION, +) -> list[list[int]]: + """Group physical indices while preserving Zarr chunk locality.""" + if not sorted_physical: + return [] + + gap_threshold = max(len(sorted_physical), 1) + runs: list[list[int]] = [[0]] + run_first_physical = sorted_physical[0] + sample_spans = [ + _sample_chunk_spans(physical_idx, fields, atoms_ptr, edges_ptr) + for physical_idx in sorted_physical + ] + run_spans: dict[int, tuple[int, int]] = {} + _merge_chunk_spans(run_spans, sample_spans[0]) + + for position in range(1, len(sorted_physical)): + gap = sorted_physical[position] - sorted_physical[position - 1] + span = sorted_physical[position] - run_first_physical + 1 + count = len(runs[-1]) + 1 + within_gap_policy = gap <= gap_threshold and span <= count * max_amplification + overlaps_existing_chunk = _spans_overlap(run_spans, sample_spans[position]) + + if overlaps_existing_chunk or within_gap_policy: + runs[-1].append(position) + _merge_chunk_spans(run_spans, sample_spans[position]) + else: + runs.append([position]) + run_first_physical = sorted_physical[position] + run_spans = {} + _merge_chunk_spans(run_spans, sample_spans[position]) + + return runs + + +def _row_indices_for_ranges(starts: Sequence[int], ends: Sequence[int]) -> np.ndarray: + """Return concatenated row indices for a sequence of half-open ranges.""" + ranges = [ + np.arange(start, end, dtype=np.int64) + for start, end in zip(starts, ends, strict=True) + if end > start + ] + if not ranges: + return np.empty(0, dtype=np.int64) + return np.concatenate(ranges) + + # NOTE: the generic *index*/*face* regex fallback returning -1 is local to # the Zarr backend. No current AtomicData edge field reaches it, and the Zarr # read paths (_slice_edge_array) reject cat_dim != 0 with a RuntimeError. @@ -1316,8 +1473,32 @@ def refresh(self) -> None: self._root.attrs.get("fields", {"core": {}, "custom": {}}) ) + @property + def field_levels(self) -> dict[str, str]: + """Per-field level classification from store metadata. + + Returns + ------- + dict[str, str] + Mapping of field name to ``"atom"``, ``"edge"``, or ``"system"``. + """ + flat: dict[str, str] = {} + for fields in self._fields_metadata.values(): + flat.update(fields) + return flat + + def _resolve_logical_index(self, index: int) -> int: + """Resolve a logical index according to this store's active sample mask.""" + if index < 0: + index = len(self) + index + if index < 0 or index >= len(self): + raise IndexError( + f"Index {index} out of range for reader with {len(self)} samples" + ) + return index + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: - """Load raw data for a single sample. + """Load raw data for a single sample through the batch read path. Parameters ---------- @@ -1334,59 +1515,200 @@ def _load_sample(self, index: int) -> dict[str, torch.Tensor]: IndexError If index is out of range. """ - # Map logical index to physical index - physical_idx = int(self._active_indices[index].item()) + return self._load_many_samples([self._resolve_logical_index(index)])[0] + + def _read_many_orthogonal( + self, + normalized_indices: Sequence[int], + sorted_order: Sequence[int], + sorted_physical: Sequence[int], + fields: Sequence[tuple[str, str, Any]], + ) -> list[dict[str, torch.Tensor]]: + """Load fragmented samples using one orthogonal selection per field.""" + data_by_sorted: list[dict[str, torch.Tensor]] = [{} for _ in sorted_order] + + atom_starts = [] + atom_ends = [] + edge_starts = [] + edge_ends = [] + for physical_idx in sorted_physical: + atom_starts.append(int(self._atoms_ptr[physical_idx].item())) + atom_ends.append(int(self._atoms_ptr[physical_idx + 1].item())) + edge_starts.append(int(self._edges_ptr[physical_idx].item())) + edge_ends.append(int(self._edges_ptr[physical_idx + 1].item())) + + for key, level, arr in fields: + if level == "atom": + rows = _row_indices_for_ranges(atom_starts, atom_ends) + block = torch.from_numpy(arr.oindex[rows] if len(rows) else arr[:0]) + + offset = 0 + for i, (start, end) in enumerate( + zip(atom_starts, atom_ends, strict=True) + ): + count = end - start + data_by_sorted[i][key] = block[offset : offset + count] + offset += count + elif level == "edge": + rows = _row_indices_for_ranges(edge_starts, edge_ends) + block = torch.from_numpy( + arr.oindex[rows] if len(rows) else _slice_edge_array(arr, key, 0, 0) + ) + + offset = 0 + for i, (start, end) in enumerate( + zip(edge_starts, edge_ends, strict=True) + ): + count = end - start + tensor = block[offset : offset + count] + if key == "neighbor_list": + tensor = tensor - atom_starts[i] + data_by_sorted[i][key] = tensor + offset += count + else: + rows = np.asarray(sorted_physical, dtype=np.int64) + block = torch.from_numpy(arr.oindex[rows]) + for i in range(len(sorted_physical)): + data_by_sorted[i][key] = block[i : i + 1] + + inverse = [0] * len(sorted_order) + for new_pos, old_pos in enumerate(sorted_order): + inverse[old_pos] = new_pos + + return [data_by_sorted[inverse[i]] for i in range(len(normalized_indices))] + + def _load_many_samples( + self, indices: Sequence[int] + ) -> list[dict[str, torch.Tensor]]: + """Load raw data for multiple samples in requested order. + + Contiguous physical samples are read as ranges so each Zarr array is + opened once and sliced once per range. The range tensors are then + split back into per-sample dictionaries. The base ``Reader`` attaches + metadata and optional pinned memory. + + Parameters + ---------- + indices : Sequence[int] + Logical sample indices to load. Negative values are supported. + + Returns + ------- + list[dict[str, torch.Tensor]] + Ordered raw tensor dictionaries with CPU tensors. + + Raises + ------ + RuntimeError + If the reader has been closed. + IndexError + If any requested index is out of range. + """ + if self._root is None: + raise RuntimeError("Cannot read from a closed reader.") + + normalized_indices = [self._resolve_logical_index(index) for index in indices] + if not normalized_indices: + return [] - # Get slice ranges from pointer arrays - atom_start = int(self._atoms_ptr[physical_idx].item()) - atom_end = int(self._atoms_ptr[physical_idx + 1].item()) - edge_start = int(self._edges_ptr[physical_idx].item()) - edge_end = int(self._edges_ptr[physical_idx + 1].item()) + physical_indices = [ + int(self._active_indices[index].item()) for index in normalized_indices + ] + + # Sort by physical index to maximise contiguous runs and avoid + # decompressing the same Zarr chunk more than once. We keep a + # permutation so the output order matches the caller's request. + sorted_order = sorted( + range(len(physical_indices)), key=physical_indices.__getitem__ + ) + sorted_physical = [physical_indices[i] for i in sorted_order] - data: dict[str, torch.Tensor] = {} + data_by_sorted: list[dict[str, torch.Tensor]] = [{} for _ in sorted_order] - # Load core fields + fields: list[tuple[str, str, Any]] = [] core_group = self._root["core"] for key in core_group.array_keys(): level = self._fields_metadata.get("core", {}).get( key, _get_field_level(key) ) - arr = core_group[key] + fields.append((key, level, core_group[key])) - if level == "atom": - data[key] = torch.from_numpy(arr[atom_start:atom_end]) - elif level == "edge": - tensor = torch.from_numpy( - _slice_edge_array(arr, key, edge_start, edge_end) - ) - - # neighbor_list needs to be converted from global to local indices - # by subtracting the atom offset for this sample - if key == "neighbor_list": - tensor = tensor - atom_start - - data[key] = tensor - else: # system level - # Keep batch dim for system-level fields - data[key] = torch.from_numpy(arr[physical_idx : physical_idx + 1]) - - # Load custom fields if present if "custom" in self._root: custom_group = self._root["custom"] for key in custom_group.array_keys(): level = self._fields_metadata.get("custom", {}).get(key, "system") - arr = custom_group[key] + fields.append((key, level, custom_group[key])) + + run_positions = _merge_physical_runs_by_chunks( + sorted_physical, + fields, + self._atoms_ptr, + self._edges_ptr, + ) + if len(run_positions) > 4: + return self._read_many_orthogonal( + normalized_indices, + sorted_order, + sorted_physical, + fields, + ) + for positions in run_positions: + first_physical = sorted_physical[positions[0]] + last_physical = sorted_physical[positions[-1]] + + atom_range_start = int(self._atoms_ptr[first_physical].item()) + atom_range_end = int(self._atoms_ptr[last_physical + 1].item()) + edge_range_start = int(self._edges_ptr[first_physical].item()) + edge_range_end = int(self._edges_ptr[last_physical + 1].item()) + + # Precompute per-position pointer offsets once, shared across + # all fields. Avoids O(B*F) redundant int(..item()) calls. + pos_atom_starts = [] + pos_atom_ends = [] + pos_edge_starts = [] + pos_edge_ends = [] + for position in positions: + pidx = sorted_physical[position] + pos_atom_starts.append(int(self._atoms_ptr[pidx].item())) + pos_atom_ends.append(int(self._atoms_ptr[pidx + 1].item())) + pos_edge_starts.append(int(self._edges_ptr[pidx].item())) + pos_edge_ends.append(int(self._edges_ptr[pidx + 1].item())) + + for key, level, arr in fields: if level == "atom": - data[key] = torch.from_numpy(arr[atom_start:atom_end]) + block = torch.from_numpy(arr[atom_range_start:atom_range_end]) elif level == "edge": - data[key] = torch.from_numpy( - _slice_edge_array(arr, key, edge_start, edge_end) + block = torch.from_numpy( + _slice_edge_array(arr, key, edge_range_start, edge_range_end) ) - else: # system level - data[key] = torch.from_numpy(arr[physical_idx : physical_idx + 1]) + else: + block = torch.from_numpy(arr[first_physical : last_physical + 1]) + + for i, position in enumerate(positions): + data = data_by_sorted[position] - return data + if level == "atom": + rel_start = pos_atom_starts[i] - atom_range_start + rel_end = pos_atom_ends[i] - atom_range_start + data[key] = block[rel_start:rel_end] + elif level == "edge": + rel_start = pos_edge_starts[i] - edge_range_start + rel_end = pos_edge_ends[i] - edge_range_start + tensor = block[rel_start:rel_end] + if key == "neighbor_list": + tensor = tensor - pos_atom_starts[i] + data[key] = tensor + else: + system_offset = sorted_physical[position] - first_physical + data[key] = block[system_offset : system_offset + 1] + + # Map sorted results back to caller's request order. + inverse = [0] * len(sorted_order) + for new_pos, old_pos in enumerate(sorted_order): + inverse[old_pos] = new_pos + + return [data_by_sorted[inverse[i]] for i in range(len(normalized_indices))] def __len__(self) -> int: """Return the number of active (non-deleted) samples. @@ -1398,7 +1720,7 @@ def __len__(self) -> int: """ return len(self._active_indices) - def _get_sample_metadata(self, index: int) -> dict[str, str]: + def _get_sample_metadata(self, index: int) -> dict[str, str | int]: """Return metadata for a sample. Parameters @@ -1411,12 +1733,37 @@ def _get_sample_metadata(self, index: int) -> dict[str, str]: dict[str, str] Dictionary containing source file information. """ + index = self._resolve_logical_index(index) physical_idx = int(self._active_indices[index].item()) return { + "index": index, "source_file": str(self._store), "physical_index": str(physical_idx), } + def get_metadata(self, index: int) -> tuple[int, int]: + """Return atom and edge counts from cached pointer arrays. + + Parameters + ---------- + index : int + Logical sample index. Negative values are supported. + + Returns + ------- + tuple[int, int] + ``(num_atoms, num_edges)`` for the sample. + """ + index = self._resolve_logical_index(index) + physical_idx = int(self._active_indices[index].item()) + num_atoms = int( + (self._atoms_ptr[physical_idx + 1] - self._atoms_ptr[physical_idx]).item() + ) + num_edges = int( + (self._edges_ptr[physical_idx + 1] - self._edges_ptr[physical_idx]).item() + ) + return num_atoms, num_edges + def close(self) -> None: """Release the Zarr store reference and clean up resources.""" self._root = None diff --git a/nvalchemi/data/datapipes/dataloader.py b/nvalchemi/data/datapipes/dataloader.py index c9a048d0..5fbe90c2 100644 --- a/nvalchemi/data/datapipes/dataloader.py +++ b/nvalchemi/data/datapipes/dataloader.py @@ -12,27 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""AtomicData-native DataLoader with CUDA-stream prefetching. +"""AtomicData-native DataLoader with amortized prefetching. The ``DataLoader`` class is designed to be a drop-in replacement for ``torch.data.DataLoader``, specializing for ``nvalchemi`` and atomistic systems by emitting ``Batch`` data. -Additionally, the ``DataLoader`` provides two mechanisms for -performant data loading: an asynchronous prefetching mechanism, -as well as the use of CUDA streams; both of which can be used -to develop highly performant data loading and preprocessing -workflows. An optional ``batch_transforms`` hook applies -user-supplied callables to each collated :class:`Batch` on the -consumer thread. +Additionally, the ``DataLoader`` can fuse several emitted batches into one +backend read. ``prefetch_factor`` controls that read window, while optional +CUDA streams can overlap device transfers when available. An optional +``batch_transforms`` hook applies user-supplied callables to each collated +:class:`Batch` on the consumer thread. """ from __future__ import annotations -from collections import deque from collections.abc import Iterator, Sequence +from math import ceil import torch +from physicsnemo.datapipes.dataloader import DataLoader as PhysicsNeMoDataLoader from torch.utils.data import RandomSampler, Sampler, SequentialSampler from nvalchemi._typing import BatchTransform @@ -41,12 +40,13 @@ from nvalchemi.data.transforms import Compose -class DataLoader: +class DataLoader(PhysicsNeMoDataLoader): """Batch-iterating data loader that yields :class:`~nvalchemi.data.batch.Batch`. Wraps a :class:`Dataset` and yields ``Batch`` objects - built via :meth:`Batch.from_data_list`. CUDA-stream prefetching is - supported for overlapping I/O with computation. + built via :meth:`Batch.from_data_list`. Fused prefetching is used by + default to amortize I/O across multiple emitted batches; CUDA streams are + supported for overlapping device transfers when available. Parameters ---------- @@ -60,12 +60,19 @@ class DataLoader: Drop the last incomplete batch. sampler : torch.utils.data.Sampler | None, default=None Custom sampler (overrides ``shuffle``). + batch_sampler : torch.utils.data.Sampler | None, default=None + Custom sampler that yields batches of sample indices. prefetch_factor : int, default=2 - How many batches to prefetch ahead. + Number of emitted batches to fuse into each backend read. The effective + read window is ``batch_size * prefetch_factor``. Set to 0 to disable + fused prefetching and read one emitted batch at a time. num_streams : int, default=4 Number of CUDA streams for prefetching. use_streams : bool, default=True Enable CUDA-stream prefetching. + pin_memory : bool, default=False + If True, request page-locked CPU tensors from readers that support + pinned-memory reads. batch_transforms : Sequence[BatchTransform] | None, default=None Optional per-batch transforms applied to each yielded :class:`~nvalchemi.data.batch.Batch` after collation. ``None`` @@ -94,11 +101,14 @@ class DataLoader: Whether stream-based prefetching is actually enabled. Stored as ``use_streams and torch.cuda.is_available()``; reflects runtime availability, not the raw argument. + pin_memory : bool + Whether page-locked CPU tensors are requested from compatible readers. Raises ------ ValueError - Raised at construction if ``batch_size < 1``. + Raised at construction if ``batch_size < 1`` or + ``prefetch_factor < 0``. TypeError Raised at construction if ``batch_transforms`` is not a :class:`~collections.abc.Sequence` (e.g. a single callable or a @@ -111,7 +121,7 @@ class DataLoader: Notes ----- Batch transforms run on the consumer (main) thread after - collation, not on the prefetch workers — the fully-assembled + collation, not on the prefetch workers; the fully assembled ``Batch`` does not exist until the main thread constructs it. Transforms are applied in order via :class:`~nvalchemi.data.transforms.Compose` and execute on the @@ -139,14 +149,22 @@ def __init__( shuffle: bool = False, drop_last: bool = False, sampler: Sampler | None = None, + batch_sampler: Sampler[Sequence[int]] | None = None, prefetch_factor: int = 2, num_streams: int = 4, use_streams: bool = True, + pin_memory: bool = False, batch_transforms: Sequence[BatchTransform] | None = None, ) -> None: """Initialize the AtomicData-native DataLoader.""" if batch_size < 1: raise ValueError(f"batch_size must be >= 1, got {batch_size}") + if prefetch_factor < 0: + raise ValueError(f"prefetch_factor must be >= 0, got {prefetch_factor}") + if batch_sampler is not None and (sampler is not None or shuffle): + raise ValueError( + "batch_sampler is mutually exclusive with sampler and shuffle" + ) if batch_transforms is not None and not isinstance(batch_transforms, Sequence): raise TypeError( @@ -156,21 +174,31 @@ def __init__( self.dataset = dataset self.batch_size = batch_size + self.shuffle = shuffle self.drop_last = drop_last self.prefetch_factor = prefetch_factor self.num_streams = num_streams self.use_streams = use_streams and torch.cuda.is_available() + self.batch_sampler = batch_sampler + self.pin_memory = pin_memory + + if pin_memory: + self._set_pin_memory(self.dataset, True) + self._batch_transform: Compose | None = ( Compose(batch_transforms) if batch_transforms else None ) # Handle sampler - if sampler is not None: - self.sampler = sampler - elif shuffle: - self.sampler = RandomSampler(dataset) + if self.batch_sampler is None: + if sampler is not None: + self.sampler = sampler + elif shuffle: + self.sampler = RandomSampler(dataset) + else: + self.sampler = SequentialSampler(dataset) else: - self.sampler = SequentialSampler(dataset) + self.sampler = None self._streams: list[torch.cuda.Stream] = ( [torch.cuda.Stream() for _ in range(num_streams)] @@ -178,6 +206,17 @@ def __init__( else [] ) + @staticmethod + def _set_pin_memory(dataset: object, enabled: bool) -> None: + """Request pinned-memory reads from a single dataset when supported.""" + if hasattr(dataset, "pin_memory"): + setattr(dataset, "pin_memory", enabled) + + @property + def effective_read_window(self) -> int: + """Return the maximum sample count in one fused backend read.""" + return self.batch_size * max(self.prefetch_factor, 1) + def __len__(self) -> int: """Return the number of batches. @@ -186,23 +225,26 @@ def __len__(self) -> int: int Number of batches in the dataloader. """ - n_samples = len(self.dataset) + if self.batch_sampler is not None: + return len(self.batch_sampler) # type: ignore[arg-type] + + n_samples = len(self.sampler) if self.sampler is not None else len(self.dataset) if self.drop_last: return n_samples // self.batch_size - return (n_samples + self.batch_size - 1) // self.batch_size + return ceil(n_samples / self.batch_size) def __iter__(self) -> Iterator[Batch]: """Iterate over batches. - Uses stream-based prefetching when enabled to overlap IO, - GPU transfers, and computation. + Uses fused prefetching when ``prefetch_factor`` is positive, with + CUDA streams added when enabled and available. Yields ------ Batch Batched AtomicData as a disjoint graph. """ - if self.prefetch_factor > 0 and self.use_streams: + if self.prefetch_factor > 0: yield from self._iter_prefetch() else: yield from self._iter_simple() @@ -215,7 +257,14 @@ def _generate_batches(self) -> Iterator[list[int]]: list[int] List of sample indices for each batch. """ + if self.batch_sampler is not None: + for batch_indices in self.batch_sampler: + yield list(batch_indices) + return + batch: list[int] = [] + if self.sampler is None: + return for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: @@ -235,30 +284,31 @@ def _iter_simple(self) -> Iterator[Batch]: """ transform = self._batch_transform for batch_indices in self._generate_batches(): - samples = [self.dataset[idx] for idx in batch_indices] - # Extract AtomicData from (AtomicData, metadata) tuples - data_list = [atomic_data for atomic_data, _ in samples] - batch = Batch.from_data_list(data_list, skip_validation=True) + batch = self.dataset.load_batches([batch_indices])[0] if transform is not None: batch = transform(batch) yield batch def _iter_prefetch(self) -> Iterator[Batch]: - """Iteration with stream-based prefetching. - - Uses a lazy sliding window of size ``prefetch_factor`` over the - batch-index generator so that the full epoch plan is never - materialised in memory. - - Strategy: - - 1. Fill a window of up to ``prefetch_factor`` batches, submitting - each for async prefetch. - 2. Pop the front batch, yield it, then pull one more batch from - the generator and prefetch it (keeping the window full). - 3. Cleanup runs in a ``finally`` block so that - ``cancel_prefetch()`` fires on normal exhaustion, early break, - and exceptions. + """Iteration with fused prefetching. + + Fuses ``prefetch_factor`` consecutive batches into a single + ``read_many`` call so that Zarr reader optimisations can coalesce + scattered indices into fewer large reads. + + Strategy (true double-buffered): + + 1. Collect and submit two chunks upfront so that one Zarr + read is always in flight while the other is being consumed. + 2. Consume the oldest completed chunk, submit a fresh chunk + into the now-free queue slot, then yield the consumed + batches. The next Zarr read runs in the background while + the caller processes each yielded batch. + 3. Drain the remaining queued chunk after the sampler is + exhausted. + 4. Cleanup runs in a ``finally`` block so that + ``cancel_prefetch()`` fires on normal exhaustion, early + break, and exceptions. Yields ------ @@ -266,39 +316,59 @@ def _iter_prefetch(self) -> Iterator[Batch]: Collated batch of AtomicData. """ stream_idx = 0 - transform = self._batch_transform - - def _prefetch_batch(batch_indices: list[int]) -> None: - nonlocal stream_idx - for sample_idx in batch_indices: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 - batch_iter = self._generate_batches() - window: deque[list[int]] = deque() + transform = self._batch_transform - try: + def _collect_chunk() -> list[list[int]]: + """Collect up to prefetch_factor batch-index lists.""" + chunk: list[list[int]] = [] for _ in range(self.prefetch_factor): batch_indices = next(batch_iter, None) if batch_indices is None: break - window.append(batch_indices) - _prefetch_batch(batch_indices) - - while window: - batch_indices = window.popleft() - samples = [self.dataset[idx] for idx in batch_indices] - data_list = [atomic_data for atomic_data, _ in samples] - batch = Batch.from_data_list(data_list, skip_validation=True) - if transform is not None: - batch = transform(batch) - yield batch + chunk.append(batch_indices) + return chunk - next_batch = next(batch_iter, None) - if next_batch is not None: - window.append(next_batch) - _prefetch_batch(next_batch) + def _submit_chunk(chunk: list[list[int]]) -> None: + nonlocal stream_idx + stream = ( + self._streams[stream_idx % self.num_streams] if self._streams else None + ) + self.dataset.prefetch_fused_batches(chunk, stream=stream) + stream_idx += 1 + + try: + # Prime: fill both queue slots so one read is always in + # flight while the other is consumed. + chunk_a = _collect_chunk() + if not chunk_a: + return + _submit_chunk(chunk_a) + + chunk_b = _collect_chunk() + if chunk_b: + _submit_chunk(chunk_b) + + while True: + # Consume oldest completed read. + completed_batches = list(self.dataset.get_fused_batches()) + + # Refill: collect and submit next chunk into the freed + # queue slot so the background thread starts reading + # immediately -- *before* we yield any batches. + next_chunk = _collect_chunk() + if next_chunk: + _submit_chunk(next_chunk) + + for batch in completed_batches: + if transform is not None: + batch = transform(batch) + yield batch + + # Stop when both the sampler is exhausted and the + # queue has been drained. + if not next_chunk and not self.dataset.has_pending_fused_batches(): + break finally: self.dataset.cancel_prefetch() @@ -310,5 +380,6 @@ def set_epoch(self, epoch: int) -> None: epoch : int Current epoch number. """ - if hasattr(self.sampler, "set_epoch"): - self.sampler.set_epoch(epoch) + sampler = self.batch_sampler if self.batch_sampler is not None else self.sampler + if hasattr(sampler, "set_epoch"): + sampler.set_epoch(epoch) diff --git a/nvalchemi/data/datapipes/dataset.py b/nvalchemi/data/datapipes/dataset.py index 6e4bfdc4..146e6773 100644 --- a/nvalchemi/data/datapipes/dataset.py +++ b/nvalchemi/data/datapipes/dataset.py @@ -31,14 +31,18 @@ from __future__ import annotations import logging +from collections import deque from collections.abc import Iterator, Sequence from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable import torch +from physicsnemo.datapipes.dataset import Dataset as PhysicsNeMoDataset +from physicsnemo.datapipes.readers.base import Reader as PhysicsNeMoReader from nvalchemi.data.atomic_data import AtomicData +from nvalchemi.data.batch import Batch from nvalchemi.data.datapipes.backends.base import Reader from nvalchemi.data.transforms import Compose @@ -47,8 +51,6 @@ logger = logging.getLogger(__name__) -# TODO: refactor to subclass PNM when stable - @runtime_checkable class ReaderProtocol(Protocol): @@ -59,12 +61,10 @@ class ReaderProtocol(Protocol): :class:`~nvalchemi.data.datapipes.backends.base.Reader` ABC. """ - def _load_sample(self, index: int) -> dict[str, torch.Tensor]: - """Load raw tensor data for a single sample.""" - ... - - def _get_sample_metadata(self, index: int) -> dict[str, Any]: - """Return additional metadata for a sample.""" + def read_many( + self, indices: Sequence[int] + ) -> list[tuple[dict[str, torch.Tensor], dict[str, Any]]]: + """Load raw tensor data and metadata for multiple samples.""" ... def __len__(self) -> int: @@ -101,7 +101,49 @@ class _PrefetchResult: event: torch.cuda.Event | None = None -class Dataset: +@dataclass +class _FusedBatchPrefetchResult: + """Container for fused multi-batch prefetch results. + + Used for both validated (AtomicData) and raw (dict) fused-prefetch + paths. When ``raw`` is ``True``, ``data`` holds raw tensor dicts + and ``metadata`` is ``None``. + + Attributes + ---------- + batch_splits : list[int] + Number of samples in each sub-batch, used to split + the flat result list back into per-batch groups. + raw : bool + Whether the data contains raw tensor dicts (True) or + AtomicData objects (False). + data : list[Any] | None + Loaded samples in request order, or None on error. + metadata : list[dict[str, Any]] | None + Per-sample metadata (validated path only), or None. + error : Exception | None + Exception if loading failed, or None. + event : torch.cuda.Event | None + CUDA event for stream synchronization, or None. + """ + + batch_splits: list[int] + raw: bool = False + data: list[Any] | None = None + metadata: list[dict[str, Any]] | None = None + error: Exception | None = None + event: torch.cuda.Event | None = None + + +@dataclass +class _PendingFusedBatch: + """Queued fused batch request and its submitted future.""" + + batch_index_lists: tuple[tuple[int, ...], ...] + future: Future[_FusedBatchPrefetchResult] + + +class Dataset(PhysicsNeMoDataset): """AtomicData-native dataset that bypasses TensorDict conversion. Wraps a :class:`~nvalchemi.data.datapipes.backends.base.Reader` and returns @@ -152,6 +194,7 @@ def __init__( *, device: str | torch.device | None = None, num_workers: int = 2, + skip_validation: bool = False, transforms: Sequence[SampleTransform] | None = None, ) -> None: """Initialize the AtomicData-native dataset. @@ -164,6 +207,13 @@ def __init__( Target device. ``"auto"`` picks CUDA if available, otherwise CPU. num_workers : int, default=2 Thread pool size for async prefetch. + skip_validation : bool, default=False + If ``True``, bypass ``AtomicData`` construction and Pydantic + validation in the fused batch prefetch path, building batches + directly from raw tensor dicts via + :meth:`~nvalchemi.data.batch.Batch.from_raw_dicts`. This + is safe when the backing store is already validated (e.g. + data written by :class:`AtomicDataZarrWriter`). transforms : Sequence[SampleTransform] | None, default=None Optional per-sample transforms applied after device transfer. ``None`` or an empty sequence disables transform application @@ -190,8 +240,13 @@ def __init__( overriding ``stream=`` inside transforms, as these would serialize the prefetch worker with the main stream. """ - # Validate reader implements the required protocol - if not isinstance(reader, (Reader, ReaderProtocol)): + has_batch_reader = hasattr(reader, "read_many") + has_sample_reader = hasattr(reader, "_load_sample") and hasattr( + reader, "_get_sample_metadata" + ) + if not isinstance(reader, (PhysicsNeMoReader, Reader)) and not ( + has_batch_reader or has_sample_reader + ): raise TypeError( f"reader must implement Reader interface, got {type(reader).__name__}" ) @@ -203,29 +258,26 @@ def __init__( "callable or generator. Pass [fn] instead of fn." ) - self.reader = reader - self.num_workers = num_workers - - # Resolve device - if device is not None: - if isinstance(device, str): - device = torch.device(device) - if not isinstance(device, torch.device): - raise TypeError( - "Device expected to be a string or instance of `torch.device`." - f" Got {device}." - ) - self.target_device = device + target_device = self._resolve_target_device(device) + if isinstance(reader, PhysicsNeMoReader): + super().__init__( + reader, + transforms=None, + device=target_device, + num_workers=num_workers, + ) else: - # fallback - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - self.target_device = torch.device(device) + self.reader = reader + self.num_workers = num_workers + self.target_device = target_device + self.transforms = None + + self.skip_validation = skip_validation + self._field_levels: dict[str, str] = getattr(reader, "field_levels", {}) or {} # Prefetch state self._prefetch_futures: dict[int, Future[_PrefetchResult]] = {} + self._fused_batch_prefetch_queue: deque[_PendingFusedBatch] = deque() self._executor: ThreadPoolExecutor | None = None # Per-sample transform pipeline (None when no transforms configured so @@ -234,6 +286,37 @@ def __init__( Compose(transforms) if transforms else None ) + @staticmethod + def _resolve_target_device( + device: str | torch.device | None, + ) -> torch.device: + """Resolve the target device while preserving nvalchemi defaults. + + Parameters + ---------- + device : str | torch.device | None + Requested device. ``None`` and ``"auto"`` select CUDA when + available, otherwise CPU. + + Returns + ------- + torch.device + Resolved target device. + + Raises + ------ + TypeError + If *device* is not a string, ``torch.device``, or ``None``. + """ + if device is None or device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + elif not isinstance(device, (str, torch.device)): + raise TypeError( + "Device expected to be a string or instance of `torch.device`." + f" Got {device}." + ) + return torch.device(device) + def _ensure_executor(self) -> ThreadPoolExecutor: """Lazily create the thread pool executor. @@ -249,34 +332,52 @@ def _ensure_executor(self) -> ThreadPoolExecutor: ) return self._executor - def _finalize_on_device( - self, data: AtomicData, metadata: dict[str, Any] - ) -> tuple[AtomicData, dict[str, Any]]: - """Move ``data`` to ``target_device`` and apply the transform pipeline. - - Shared by the prefetch worker path (both stream and non-stream - branches) and the synchronous ``__getitem__`` fallback. When - ``self._sample_transform`` is ``None`` the transform step is - skipped, making the no-transforms hot path a single - ``is None`` check past the device transfer. + def _read_raw_samples( + self, indices: Sequence[int] + ) -> list[tuple[dict[str, torch.Tensor], dict[str, Any]]]: + """Read raw samples from the underlying reader.""" + if hasattr(self.reader, "read_many"): + return self.reader.read_many(indices) # type: ignore[attr-defined] + return [ + ( + self.reader._load_sample(index), # type: ignore[attr-defined] + self.reader._get_sample_metadata(index), # type: ignore[attr-defined] + ) + for index in indices + ] - Parameters - ---------- - data : AtomicData - Freshly constructed sample on the reader's (CPU) device. - metadata : dict[str, Any] - Per-sample metadata dict. + def _to_atomic_samples( + self, + raw_samples: Sequence[tuple[dict[str, torch.Tensor], dict[str, Any]]], + stream: torch.cuda.Stream | None = None, + ) -> tuple[list[tuple[AtomicData, dict[str, Any]]], torch.cuda.Event | None]: + """Validate raw samples and transfer them to the target device.""" + samples = [ + (AtomicData.model_validate(data_dict), metadata) + for data_dict, metadata in raw_samples + ] + + event: torch.cuda.Event | None = None + if stream is not None: + with torch.cuda.stream(stream): + if self.target_device is not None: + samples = [ + (data.to(self.target_device, non_blocking=True), metadata) + for data, metadata in samples + ] + if self._sample_transform is not None: + samples = [ + self._sample_transform(data, metadata) + for data, metadata in samples + ] + event = torch.cuda.Event() + event.record(stream) + else: + samples = [ + self._finalize_on_device(data, metadata) for data, metadata in samples + ] - Returns - ------- - tuple[AtomicData, dict[str, Any]] - The (possibly transformed) pair, ready to return to the caller. - """ - if self.target_device is not None: - data = data.to(self.target_device, non_blocking=True) - if self._sample_transform is not None: - data, metadata = self._sample_transform(data, metadata) - return data, metadata + return samples, event def _load_and_transform( self, @@ -302,28 +403,12 @@ def _load_and_transform( result = _PrefetchResult(index=index) try: - # Load raw dict from reader (CPU, potentially slow IO) - data_dict = self.reader._load_sample(index) - metadata = self.reader._get_sample_metadata(index) - - # Construct AtomicData directly from dict - data = AtomicData.model_validate(data_dict) - - # Device transfer + transform pipeline. On the stream branch the - # helper call stays inside the ``with torch.cuda.stream(stream):`` - # block so any CUDA ops launched by transforms enqueue on the - # prefetch stream. ``event.record(stream)`` fires after the helper - # so the consumer's ``event.synchronize()`` waits for both. - if stream is not None: - with torch.cuda.stream(stream): - data, metadata = self._finalize_on_device(data, metadata) - result.event = torch.cuda.Event() - result.event.record(stream) - else: - data, metadata = self._finalize_on_device(data, metadata) - - result.data = data - result.metadata = metadata + samples, event = self._to_atomic_samples( + self._read_raw_samples([index]), stream + ) + result.data = samples[0][0] + result.metadata = samples[0][1] + result.event = event except Exception as e: result.error = e @@ -366,6 +451,197 @@ def prefetch_batch( stream = streams[i % len(streams)] if streams else None self.prefetch(idx, stream=stream) + def prefetch_many( + self, indices: Sequence[int], stream: torch.cuda.Stream | None = None + ) -> None: + """Submit one batch of sample indices as a fused async prefetch. + + Parameters + ---------- + indices : Sequence[int] + Sample indices to prefetch as one batch. + stream : torch.cuda.Stream | None, default=None + CUDA stream for GPU operations. + """ + self.prefetch_fused_batches([indices], stream=stream) + + def _load_fused_batches( + self, + batch_index_lists: Sequence[Sequence[int]], + stream: torch.cuda.Stream | None = None, + ) -> _FusedBatchPrefetchResult: + """Load multiple batches in one fused read_many call. + + When ``self.skip_validation`` is ``True``, returns raw tensor + dicts (no ``AtomicData`` construction). Otherwise validates + each sample through ``AtomicData.model_validate``. + + Parameters + ---------- + batch_index_lists : Sequence[Sequence[int]] + Per-batch index lists to concatenate and read together. + stream : torch.cuda.Stream | None, default=None + Optional CUDA stream for GPU operations. + + Returns + ------- + _FusedBatchPrefetchResult + Combined result with batch split metadata. + """ + batch_splits = [len(b) for b in batch_index_lists] + raw = self.skip_validation + result = _FusedBatchPrefetchResult(batch_splits=batch_splits, raw=raw) + + try: + all_indices: list[int] = [] + for batch_indices in batch_index_lists: + all_indices.extend(batch_indices) + + raw_samples = self._read_raw_samples(all_indices) + + if raw: + raw_dicts = [tensor_dict for tensor_dict, _ in raw_samples] + result.data = raw_dicts + result.event = None + else: + samples, event = self._to_atomic_samples(raw_samples, stream) + result.data = [atomic_data for atomic_data, _ in samples] + result.metadata = [metadata for _, metadata in samples] + result.event = event + except Exception as e: + result.error = e + + return result + + def prefetch_fused_batches( + self, + batch_index_lists: Sequence[Sequence[int]], + stream: torch.cuda.Stream | None = None, + ) -> None: + """Submit multiple batches as one fused async read. + + All indices across the provided batch lists are concatenated + into a single ``read_many`` call, amortizing Zarr I/O overhead. + Use :meth:`get_fused_batches` to consume the results. + + Parameters + ---------- + batch_index_lists : Sequence[Sequence[int]] + Per-batch index lists. + stream : torch.cuda.Stream | None, default=None + CUDA stream for GPU operations. + """ + if len(self._fused_batch_prefetch_queue) >= 2: + raise RuntimeError( + "Fused batch prefetch queue is full; consume a pending chunk first." + ) + frozen_batch_index_lists = tuple( + tuple(indices) for indices in batch_index_lists + ) + executor = self._ensure_executor() + self._fused_batch_prefetch_queue.append( + _PendingFusedBatch( + batch_index_lists=frozen_batch_index_lists, + future=executor.submit( + self._load_fused_batches, frozen_batch_index_lists, stream + ), + ) + ) + + def _fused_result_to_batches( + self, result: _FusedBatchPrefetchResult + ) -> list[Batch]: + """Convert a fused prefetch result into per-batch objects.""" + if result.error is not None: + raise result.error + if result.event is not None: + result.event.synchronize() + if result.data is None: + raise RuntimeError("Fused batch prefetch returned None data without error") + + batches: list[Batch] = [] + offset = 0 + for size in result.batch_splits: + batch_slice = result.data[offset : offset + size] + offset += size + if result.raw: + batches.append( + Batch.from_raw_dicts( + batch_slice, + device=self.target_device, + field_levels=self._field_levels, + ) + ) + else: + batches.append( + Batch.from_data_list( + batch_slice, + skip_validation=True, + field_levels=self._field_levels, + ) + ) + return batches + + def load_batches( + self, + batch_index_lists: Sequence[Sequence[int]], + stream: torch.cuda.Stream | None = None, + ) -> list[Batch]: + """Load several batches immediately. + + This is the synchronous counterpart to + :meth:`prefetch_fused_batches`/:meth:`get_fused_batches`. The provided + batch index lists are read through one fused reader request so backends + can coalesce I/O while returning one :class:`Batch` per input list. + + Parameters + ---------- + batch_index_lists : Sequence[Sequence[int]] + Per-batch sample indices. + stream : torch.cuda.Stream | None, default=None + CUDA stream for device transfer when supported. + + Returns + ------- + list[Batch] + One :class:`Batch` per input batch-index list. + """ + return self._fused_result_to_batches( + self._load_fused_batches(batch_index_lists, stream) + ) + + def has_pending_fused_batches(self) -> bool: + """Return whether a fused prefetch chunk is waiting to be consumed.""" + return bool(self._fused_batch_prefetch_queue) + + def get_fused_batches(self) -> Iterator[Batch]: + """Consume the pending fused prefetch and yield per-batch results. + + Blocks until the fused read completes, then splits the flat + result list according to the original batch sizes and yields + one :class:`~nvalchemi.data.batch.Batch` per sub-batch. + + Yields + ------ + Batch + One batch per sub-batch from the fused read. + + Raises + ------ + RuntimeError + If no fused prefetch is pending. + Exception + If the background read failed, re-raises the original error. + """ + if not self._fused_batch_prefetch_queue: + raise RuntimeError( + "No fused batch prefetch pending; call prefetch_fused_batches() " + "before get_fused_batches()." + ) + pending = self._fused_batch_prefetch_queue.popleft() + + yield from self._fused_result_to_batches(pending.future.result()) + def cancel_prefetch(self, index: int | None = None) -> None: """Cancel pending prefetch operations. @@ -376,6 +652,7 @@ def cancel_prefetch(self, index: int | None = None) -> None: """ if index is None: self._prefetch_futures.clear() + self._fused_batch_prefetch_queue.clear() else: self._prefetch_futures.pop(index, None) @@ -427,13 +704,86 @@ def __getitem__(self, index: int) -> tuple[AtomicData, dict[str, Any]]: ) return result.data, result.metadata - # Not prefetched, load synchronously - data_dict = self.reader._load_sample(index) - metadata = self.reader._get_sample_metadata(index) + # Not prefetched, load synchronously through the reader batch path. + raw_samples = self._read_raw_samples([index]) + samples, _ = self._to_atomic_samples(raw_samples) + return samples[0] - # Construct AtomicData directly from dict, then transfer and transform. - data = AtomicData.model_validate(data_dict) - return self._finalize_on_device(data, metadata) + def read_many( + self, indices: Sequence[int] + ) -> list[tuple[AtomicData, dict[str, Any]]]: + """Read and validate multiple samples in one dataset request. + + Parameters + ---------- + indices : Sequence[int] + Sample indices to load in order. + + Returns + ------- + list[tuple[AtomicData, dict[str, Any]]] + Ordered ``(AtomicData, metadata)`` pairs. + """ + raw_samples = self._read_raw_samples(indices) + samples, _ = self._to_atomic_samples(raw_samples) + return samples + + def get_batch(self, indices: Sequence[int]) -> Batch: + """Read sample indices and return a validated :class:`Batch`. + + Parameters + ---------- + indices : Sequence[int] + Sample indices to batch in order. + + Returns + ------- + Batch + Batched AtomicData as a disjoint graph. + """ + key = (tuple(indices),) + if ( + self._fused_batch_prefetch_queue + and self._fused_batch_prefetch_queue[0].batch_index_lists == key + ): + pending = self._fused_batch_prefetch_queue.popleft() + batches = self._fused_result_to_batches(pending.future.result()) + if len(batches) != 1: + raise RuntimeError( + f"Prefetch for indices {key[0]} returned {len(batches)} batches" + ) + return batches[0] + + return self.load_batches([indices])[0] + + def _finalize_on_device( + self, data: AtomicData, metadata: dict[str, Any] + ) -> tuple[AtomicData, dict[str, Any]]: + """Move ``data`` to ``target_device`` and apply the transform pipeline. + + Shared by the prefetch worker path (both stream and non-stream + branches) and the synchronous ``__getitem__`` fallback. When + ``self._sample_transform`` is ``None`` the transform step is + skipped, making the no-transforms hot path a single + ``is None`` check past the device transfer. + + Parameters + ---------- + data : AtomicData + Freshly constructed sample on the reader's (CPU) device. + metadata : dict[str, Any] + Per-sample metadata dict. + + Returns + ------- + tuple[AtomicData, dict[str, Any]] + The (possibly transformed) pair, ready to return to the caller. + """ + if self.target_device is not None: + data = data.to(self.target_device, non_blocking=True) + if self._sample_transform is not None: + data, metadata = self._sample_transform(data, metadata) + return data, metadata def __len__(self) -> int: """Return the number of samples in the dataset. @@ -445,10 +795,60 @@ def __len__(self) -> int: """ return len(self.reader) + @property + def pin_memory(self) -> bool: + """Whether the underlying reader should return pinned CPU tensors.""" + return bool(getattr(self.reader, "pin_memory", False)) + + @pin_memory.setter + def pin_memory(self, enabled: bool) -> None: + """Request pinned-memory reads from the underlying reader. + + Parameters + ---------- + enabled : bool + Whether reader outputs should be page-locked. + """ + if hasattr(self.reader, "pin_memory"): + self.reader.pin_memory = enabled + + @property + def prefetch_count(self) -> int: + """Return the number of pending prefetch requests. + + Returns + ------- + int + Count of queued single-sample and fused-batch prefetches. + """ + return len(self._prefetch_futures) + len(self._fused_batch_prefetch_queue) + + @property + def field_names(self) -> list[str]: + """Return field names available in reader samples. + + Returns + ------- + list[str] + Field names exposed by the backing reader. + """ + field_names = getattr(self.reader, "field_names", None) + if field_names is not None: + return list(field_names) + + if len(self) == 0: + return [] + raw_samples = self._read_raw_samples([0]) + if not raw_samples: + return [] + data_dict, _metadata = raw_samples[0] + return list(data_dict) + def get_metadata(self, index: int) -> tuple[int, int]: """Return lightweight metadata for a sample without full construction. - Loads the raw tensor dictionary from the reader and extracts shape + Delegates to the reader when it provides lightweight metadata; + otherwise loads the raw tensor dictionary and extracts shape information for atom and edge counts, avoiding the overhead of full ``AtomicData`` construction and validation. @@ -469,7 +869,10 @@ def get_metadata(self, index: int) -> tuple[int, int]: KeyError If the sample dict does not contain ``"atomic_numbers"``. """ - data_dict = self.reader._load_sample(index) + if hasattr(self.reader, "get_metadata"): + return self.reader.get_metadata(index) # type: ignore[attr-defined] + + data_dict, _metadata = self._read_raw_samples([index])[0] num_atoms = len(data_dict["atomic_numbers"]) num_edges = 0 if "neighbor_list" in data_dict and data_dict["neighbor_list"] is not None: @@ -494,12 +897,17 @@ def close(self) -> None: executor, and closes the underlying reader. """ # Drain pending futures - for future in self._prefetch_futures.values(): + futures_to_drain: list[Future] = [ + *self._prefetch_futures.values(), + *[pending.future for pending in self._fused_batch_prefetch_queue], + ] + for future in futures_to_drain: try: future.result(timeout=1.0) except Exception: logger.debug("Ignoring error during prefetch future cleanup") self._prefetch_futures.clear() + self._fused_batch_prefetch_queue.clear() # Shutdown executor if self._executor is not None: diff --git a/nvalchemi/data/datapipes/multidataset.py b/nvalchemi/data/datapipes/multidataset.py new file mode 100644 index 00000000..a30305e4 --- /dev/null +++ b/nvalchemi/data/datapipes/multidataset.py @@ -0,0 +1,624 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compose multiple AtomicData-native datasets behind one index space.""" + +from __future__ import annotations + +import logging +from bisect import bisect_right +from collections import deque +from collections.abc import Iterator, Sequence +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any + +import torch +from physicsnemo.datapipes.multi_dataset import ( + DATASET_INDEX_METADATA_KEY, +) +from physicsnemo.datapipes.multi_dataset import ( + MultiDataset as PhysicsNeMoMultiDataset, +) + +from nvalchemi.data.atomic_data import AtomicData +from nvalchemi.data.batch import Batch +from nvalchemi.data.datapipes.dataset import Dataset + +logger = logging.getLogger(__name__) + + +@dataclass +class _FusedBatchResult: + """Container for async multidataset fused-batch results.""" + + batches: list[Batch] | None = None + error: Exception | None = None + + +@dataclass +class _DelegatedFusedBatch: + """Marker for fused reads delegated to one child dataset.""" + + dataset_index: int + + +@dataclass +class _ChildFusedBatchRequest: + """Per-child route for one mixed multidataset fused read.""" + + output_batch_indices: list[int] + local_batch_lists: list[list[int]] + output_positions: list[list[int]] + + +@dataclass +class _BatchRoute: + """Route for one child dataset within a global batch request.""" + + dataset_index: int + local_indices: list[int] + positions: list[int] + + +@dataclass +class _BatchRoutePlan: + """Child-dataset routes for one global sample request.""" + + routes: list[_BatchRoute] + size: int + + @property + def single_route(self) -> _BatchRoute | None: + """Return the only route when all samples belong to one child.""" + return self.routes[0] if len(self.routes) == 1 else None + + +PendingFusedBatch = Future[_FusedBatchResult] | _DelegatedFusedBatch + + +class MultiDataset(PhysicsNeMoMultiDataset): + """Compose multiple :class:`Dataset` instances behind one index space. + + The class follows PhysicsNeMo's ``MultiDataset`` contract for indexing and + prefetching, while adding nvalchemi-specific batch APIs used by + :class:`~nvalchemi.data.datapipes.dataloader.DataLoader`. + + Parameters + ---------- + *datasets : Dataset + One or more nvalchemi datasets. Order defines the global index mapping. + output_strict : bool, default=True + If True, require all datasets to expose identical field names. + num_workers : int, default=2 + Thread pool size for mixed-dataset fused prefetches. + """ + + def __init__( + self, + *datasets: Dataset, + output_strict: bool = True, + num_workers: int = 2, + ) -> None: + """Initialize the multidataset wrapper. + + Parameters + ---------- + *datasets : Dataset + Datasets to concatenate. + output_strict : bool, default=True + Require matching field names across datasets. + num_workers : int, default=2 + Worker count for mixed-dataset fused prefetches. + + Raises + ------ + TypeError + If any child is not a nvalchemi Dataset. + ValueError + If no datasets are provided or strict field names differ. + """ + if len(datasets) < 1: + raise ValueError( + f"MultiDataset requires at least one dataset, got {len(datasets)}" + ) + for i, dataset in enumerate(datasets): + if not isinstance(dataset, Dataset): + raise TypeError( + f"datasets[{i}] must be a Dataset instance, got {type(dataset).__name__}" + ) + + self._datasets = list(datasets) + self._output_strict = output_strict + self.num_workers = num_workers + + cumulative_lengths = [0] + for dataset in self._datasets: + cumulative_lengths.append(cumulative_lengths[-1] + len(dataset)) + self._cumul = cumulative_lengths + + self._field_names = self.validate_field_names(output_strict) + self._fused_batch_prefetch_queue: deque[PendingFusedBatch] = deque() + self._executor: ThreadPoolExecutor | None = None + + def validate_field_names(self, output_strict: bool | None = None) -> list[str]: + """Validate and return the field names exposed by this wrapper. + + Parameters + ---------- + output_strict : bool | None, default=None + Strictness mode to use for validation. ``None`` uses the mode passed + to :class:`MultiDataset` at construction time. + + Returns + ------- + list[str] + Field names this multidataset exposes. + + Raises + ------ + ValueError + If ``output_strict=True`` and non-empty child datasets expose + different field names. + + Notes + ----- + With ``output_strict=True``, all non-empty child datasets must expose + identical field names. Empty children are skipped, matching + PhysicsNeMo's ``MultiDataset`` strict-output behavior. + + With ``output_strict=False``, no cross-dataset validation is performed + and the first child dataset's field names are returned. Use this mode + for heterogeneous datasets where a custom training loop or collator + handles source-specific fields. + """ + if output_strict is None: + output_strict = self._output_strict + if not output_strict: + return list(self._datasets[0].field_names) + + reference: list[str] | None = None + reference_index: int | None = None + for i, dataset in enumerate(self._datasets): + if len(dataset) == 0: + continue + + current = list(dataset.field_names) + if reference is None: + reference = current + reference_index = i + continue + + reference_set = set(reference) + field_names = set(dataset.field_names) + if field_names != reference_set: + raise ValueError( + "output_strict=True requires identical field names across " + f"datasets: dataset {reference_index} has {sorted(reference_set)}, " + f"dataset {i} has {sorted(field_names)}" + ) + return ( + reference if reference is not None else list(self._datasets[0].field_names) + ) + + def _ensure_executor(self) -> ThreadPoolExecutor: + """Lazily create the thread pool executor.""" + if self._executor is None: + self._executor = ThreadPoolExecutor( + max_workers=self.num_workers, + thread_name_prefix="multidataset_prefetch", + ) + return self._executor + + def _index_to_dataset_and_local(self, index: int) -> tuple[int, int]: + """Map a global index to ``(dataset_index, local_index)``.""" + length = len(self) + original_index = index + if index < 0: + index += length + if index < 0 or index >= length: + raise IndexError( + f"Index {original_index} out of range for MultiDataset with {length} samples" + ) + + dataset_index = bisect_right(self._cumul, index) - 1 + return dataset_index, index - self._cumul[dataset_index] + + def _index_to_dataset_and_local_optional( + self, index: int + ) -> tuple[int, int] | None: + """Map a global index, returning None when it is out of range.""" + try: + return self._index_to_dataset_and_local(index) + except IndexError: + return None + + @staticmethod + def _with_dataset_metadata( + metadata: dict[str, Any], dataset_index: int + ) -> dict[str, Any]: + """Return metadata annotated with its source dataset index.""" + enriched = dict(metadata) + enriched[DATASET_INDEX_METADATA_KEY] = dataset_index + return enriched + + def _route_indices(self, indices: Sequence[int]) -> _BatchRoutePlan: + """Plan child-dataset reads for a global sample request.""" + grouped_indices: dict[int, list[int]] = {} + grouped_positions: dict[int, list[int]] = {} + for position, index in enumerate(indices): + dataset_index, local_index = self._index_to_dataset_and_local(index) + grouped_indices.setdefault(dataset_index, []).append(local_index) + grouped_positions.setdefault(dataset_index, []).append(position) + + return _BatchRoutePlan( + routes=[ + _BatchRoute( + dataset_index=dataset_index, + local_indices=local_indices, + positions=grouped_positions[dataset_index], + ) + for dataset_index, local_indices in grouped_indices.items() + ], + size=len(indices), + ) + + @staticmethod + def _combine_child_batches(parts: list[tuple[list[int], Batch]]) -> Batch: + """Append child batch parts and restore the original sample order.""" + if not parts: + raise ValueError("MultiDataset.load_batches() requires non-empty batches") + + combined_positions = list(parts[0][0]) + combined = parts[0][1] + if combined.num_graphs != len(combined_positions): + raise RuntimeError( + "Child dataset returned a batch with " + f"{combined.num_graphs} graphs for {len(combined_positions)} indices" + ) + + if len(parts) > 1: + combined = combined.clone() + for positions, child_batch in parts[1:]: + if child_batch.num_graphs != len(positions): + raise RuntimeError( + "Child dataset returned a batch with " + f"{child_batch.num_graphs} graphs for {len(positions)} indices" + ) + combined.append(child_batch) + combined_positions.extend(positions) + + restore_order = [ + combined_index + for combined_index, _position in sorted( + enumerate(combined_positions), key=lambda item: item[1] + ) + ] + if restore_order == list(range(len(restore_order))): + return combined + return combined.index_select(restore_order) + + def __len__(self) -> int: + """Return the total number of samples.""" + return self._cumul[-1] + + @property + def datasets(self) -> tuple[Dataset, ...]: + """Child datasets in global index order.""" + return tuple(self._datasets) + + @property + def offsets(self) -> tuple[int, ...]: + """Cumulative global index offsets for child datasets.""" + return tuple(self._cumul) + + def to_global_index(self, dataset_index: int, local_index: int) -> int: + """Map a child dataset index and local index to one global index.""" + if dataset_index < 0: + dataset_index += len(self._datasets) + if dataset_index < 0 or dataset_index >= len(self._datasets): + raise IndexError( + f"dataset_index {dataset_index} out of range for " + f"{len(self._datasets)} child datasets" + ) + + child_length = len(self._datasets[dataset_index]) + original_local_index = local_index + if local_index < 0: + local_index += child_length + if local_index < 0 or local_index >= child_length: + raise IndexError( + f"local_index {original_local_index} out of range for " + f"dataset {dataset_index} with {child_length} samples" + ) + return self._cumul[dataset_index] + local_index + + def to_local_index(self, index: int) -> tuple[int, int]: + """Map one global index to ``(dataset_index, local_index)``.""" + return self._index_to_dataset_and_local(index) + + def __getitem__(self, index: int) -> tuple[AtomicData, dict[str, Any]]: + """Return one sample by global index.""" + dataset_index, local_index = self._index_to_dataset_and_local(index) + data, metadata = self._datasets[dataset_index][local_index] + return data, self._with_dataset_metadata(metadata, dataset_index) + + def prefetch(self, index: int, stream: torch.cuda.Stream | None = None) -> None: + """Start prefetching one sample by global index.""" + dataset_index, local_index = self._index_to_dataset_and_local(index) + self._datasets[dataset_index].prefetch(local_index, stream=stream) + + def prefetch_batch( + self, + indices: Sequence[int], + streams: Sequence[torch.cuda.Stream] | None = None, + ) -> None: + """Start prefetching multiple samples by global index.""" + for i, index in enumerate(indices): + stream = streams[i % len(streams)] if streams else None + self.prefetch(index, stream=stream) + + def _local_batch_lists_if_single_dataset( + self, batch_index_lists: Sequence[Sequence[int]] + ) -> tuple[int, list[list[int]]] | None: + """Return local batch lists when a fused chunk belongs to one child.""" + dataset_index: int | None = None + local_batch_lists: list[list[int]] = [] + for batch_indices in batch_index_lists: + local_batch: list[int] = [] + for index in batch_indices: + current_dataset_index, local_index = self._index_to_dataset_and_local( + index + ) + if dataset_index is None: + dataset_index = current_dataset_index + elif current_dataset_index != dataset_index: + return None + local_batch.append(local_index) + local_batch_lists.append(local_batch) + + if dataset_index is None: + return None + return dataset_index, local_batch_lists + + def _child_fused_batch_requests( + self, batch_index_lists: Sequence[Sequence[int]] + ) -> dict[int, _ChildFusedBatchRequest]: + """Build per-child fused-batch routes for a mixed global chunk.""" + requests: dict[int, _ChildFusedBatchRequest] = {} + for output_batch_index, batch_indices in enumerate(batch_index_lists): + if not batch_indices: + raise ValueError("Fused batch prefetch does not support empty batches") + + route_plan = self._route_indices(batch_indices) + for route in route_plan.routes: + request = requests.setdefault( + route.dataset_index, + _ChildFusedBatchRequest( + output_batch_indices=[], + local_batch_lists=[], + output_positions=[], + ), + ) + request.output_batch_indices.append(output_batch_index) + request.local_batch_lists.append(route.local_indices) + request.output_positions.append(route.positions) + return requests + + def _load_fused_batches( + self, + batch_index_lists: Sequence[Sequence[int]], + stream: torch.cuda.Stream | None = None, + ) -> _FusedBatchResult: + """Load multiple global batches by grouping reads per child dataset.""" + try: + routed_requests = self._child_fused_batch_requests(batch_index_lists) + batch_parts: list[list[tuple[list[int], Batch]]] = [ + [] for _ in batch_index_lists + ] + + for dataset_index, request in routed_requests.items(): + child_batches = self._datasets[dataset_index].load_batches( + request.local_batch_lists, stream=stream + ) + if len(child_batches) != len(request.local_batch_lists): + raise RuntimeError( + f"Dataset {dataset_index} returned {len(child_batches)} " + f"batches for {len(request.local_batch_lists)} fused requests" + ) + for output_batch_index, positions, child_batch in zip( + request.output_batch_indices, + request.output_positions, + child_batches, + strict=True, + ): + batch_parts[output_batch_index].append((positions, child_batch)) + + batches = [self._combine_child_batches(parts) for parts in batch_parts] + return _FusedBatchResult(batches=batches) + except Exception as e: + return _FusedBatchResult(error=e) + + def prefetch_fused_batches( + self, + batch_index_lists: Sequence[Sequence[int]], + stream: torch.cuda.Stream | None = None, + ) -> None: + """Submit multiple global batches as one fused async read.""" + if len(self._fused_batch_prefetch_queue) >= 2: + raise RuntimeError( + "Fused batch prefetch queue is full; consume a pending chunk first." + ) + + local = self._local_batch_lists_if_single_dataset(batch_index_lists) + if local is not None: + dataset_index, local_batch_lists = local + self._datasets[dataset_index].prefetch_fused_batches( + local_batch_lists, stream=stream + ) + self._fused_batch_prefetch_queue.append( + _DelegatedFusedBatch(dataset_index=dataset_index) + ) + return + + executor = self._ensure_executor() + self._fused_batch_prefetch_queue.append( + executor.submit(self._load_fused_batches, batch_index_lists, stream) + ) + + def load_batches( + self, + batch_index_lists: Sequence[Sequence[int]], + stream: torch.cuda.Stream | None = None, + ) -> list[Batch]: + """Load several global batches immediately. + + This is the synchronous counterpart to + :meth:`prefetch_fused_batches`/:meth:`get_fused_batches`. Same-child + chunks are delegated directly to the owning child dataset, while mixed + chunks are routed per child and recombined in the requested batch order. + + Parameters + ---------- + batch_index_lists : Sequence[Sequence[int]] + Per-batch global sample indices. + stream : torch.cuda.Stream | None, default=None + CUDA stream for child dataset transfers when supported. + + Returns + ------- + list[Batch] + One :class:`Batch` per input batch-index list. + """ + local = self._local_batch_lists_if_single_dataset(batch_index_lists) + if local is not None: + dataset_index, local_batch_lists = local + return self._datasets[dataset_index].load_batches( + local_batch_lists, stream=stream + ) + + result = self._load_fused_batches(batch_index_lists, stream=stream) + if result.error is not None: + raise result.error + if result.batches is None: + raise RuntimeError( + "MultiDataset fused batch load returned None batches without error" + ) + return result.batches + + def has_pending_fused_batches(self) -> bool: + """Return whether a fused prefetch chunk is waiting to be consumed.""" + return bool(self._fused_batch_prefetch_queue) + + def get_fused_batches(self) -> Iterator[Batch]: + """Consume one pending fused prefetch chunk.""" + if not self._fused_batch_prefetch_queue: + raise RuntimeError( + "No fused batch prefetch pending; call prefetch_fused_batches() " + "before get_fused_batches()." + ) + + pending = self._fused_batch_prefetch_queue.popleft() + if isinstance(pending, _DelegatedFusedBatch): + yield from self._datasets[pending.dataset_index].get_fused_batches() + return + + result = pending.result() + if result.error is not None: + raise result.error + if result.batches is None: + raise RuntimeError( + "MultiDataset fused batch prefetch returned None batches without error" + ) + yield from result.batches + + def cancel_prefetch(self, index: int | None = None) -> None: + """Cancel prefetch for one global index or all child datasets.""" + if index is None: + self._fused_batch_prefetch_queue.clear() + for dataset in self._datasets: + dataset.cancel_prefetch() + return + + mapped = self._index_to_dataset_and_local_optional(index) + if mapped is None: + return + + dataset_index, local_index = mapped + self._datasets[dataset_index].cancel_prefetch(local_index) + + @property + def prefetch_count(self) -> int: + """Return queued prefetch count across this wrapper and children.""" + return len(self._fused_batch_prefetch_queue) + sum( + dataset.prefetch_count for dataset in self._datasets + ) + + @property + def field_names(self) -> list[str]: + """Return field names exposed by child datasets.""" + return list(self._field_names) + + def get_metadata(self, index: int) -> tuple[int, int]: + """Return lightweight metadata for a sample by global index.""" + dataset_index, local_index = self._index_to_dataset_and_local(index) + return self._datasets[dataset_index].get_metadata(local_index) + + def __iter__(self) -> Iterator[tuple[AtomicData, dict[str, Any]]]: + """Iterate over all samples in global index order.""" + for index in range(len(self)): + yield self[index] + + def close(self) -> None: + """Close all child datasets and release wrapper resources.""" + futures_to_drain: list[Future] = [ + *[ + pending + for pending in self._fused_batch_prefetch_queue + if not isinstance(pending, _DelegatedFusedBatch) + ], + ] + for future in futures_to_drain: + try: + future.result(timeout=1.0) + except Exception: + logger.debug("Ignoring error during multidataset prefetch cleanup") + + self._fused_batch_prefetch_queue.clear() + + if self._executor is not None: + self._executor.shutdown(wait=False) + self._executor = None + + for dataset in self._datasets: + dataset.close() + + def __enter__(self) -> MultiDataset: + """Enter context manager.""" + return self + + def __exit__( + self, exc_type: type | None, exc_val: BaseException | None, exc_tb: Any + ) -> None: + """Exit context manager.""" + self.close() + + def __repr__(self) -> str: + """Return a human-readable representation.""" + parts = [f" ({i}): {dataset}" for i, dataset in enumerate(self._datasets)] + return ( + f"{self.__class__.__name__}(\n" + f" output_strict={self._output_strict},\n" + f" datasets=[\n" + ",\n".join(parts) + "\n ]\n)" + ) diff --git a/nvalchemi/data/datapipes/samplers.py b/nvalchemi/data/datapipes/samplers.py new file mode 100644 index 00000000..ce262860 --- /dev/null +++ b/nvalchemi/data/datapipes/samplers.py @@ -0,0 +1,756 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Samplers for datasets composed with :class:`MultiDataset`.""" + +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from math import ceil +from numbers import Integral, Real +from typing import TYPE_CHECKING, Literal, Protocol, Self, TypeAlias, runtime_checkable + +import torch +from torch.utils.data import Sampler + +from nvalchemi.data.datapipes.multidataset import MultiDataset + +if TYPE_CHECKING: + from nvalchemi.distributed import DistributedManager + +EpochPolicy: TypeAlias = Literal["dataset_size", "min_size", "max_size"] + + +@runtime_checkable +class DistributedSamplerProtocol(Protocol): + """Protocol for samplers that partition work across distributed ranks. + + This intentionally matches the public surface provided by + :class:`torch.utils.data.DistributedSampler` so native PyTorch samplers + satisfy the protocol structurally. + + Attributes + ---------- + num_replicas : int + Number of distributed workers participating in sampling. + rank : int + Rank local to the sampler's process group. + """ + + num_replicas: int + rank: int + + def set_epoch(self, epoch: int) -> None: + """Set the current epoch for deterministic per-epoch shuffling.""" + + +def _generator_kwargs(generator: torch.Generator | None) -> dict[str, torch.Generator]: + """Return keyword arguments for torch random APIs.""" + return {"generator": generator} if generator is not None else {} + + +def _normalise_weights( + weights: Sequence[float] | None, lengths: Sequence[int] +) -> torch.Tensor: + """Return positive finite weights for each child dataset.""" + if weights is None: + weights = lengths + if len(weights) != len(lengths): + raise ValueError(f"Expected {len(lengths)} dataset weights, got {len(weights)}") + + tensor = torch.as_tensor(list(weights), dtype=torch.float64) + if not torch.isfinite(tensor).all(): + raise ValueError("Dataset weights must be finite") + if (tensor < 0).any(): + raise ValueError("Dataset weights must be non-negative") + if tensor.sum().item() <= 0: + raise ValueError("At least one dataset weight must be positive") + + for i, (weight, length) in enumerate(zip(tensor.tolist(), lengths, strict=True)): + if weight > 0 and length == 0: + raise ValueError(f"Dataset {i} has positive weight but no samples") + return tensor / tensor.sum() + + +def _counts_from_weights(weights: torch.Tensor, total: int) -> list[int]: + """Allocate an integer total according to fractional weights.""" + if total < 1: + raise ValueError(f"total must be >= 1, got {total}") + + raw_counts = weights * total + counts = torch.floor(raw_counts).to(torch.int64) + remaining = total - int(counts.sum().item()) + if remaining > 0: + fractions = raw_counts - counts + for index in torch.argsort(fractions, descending=True)[:remaining].tolist(): + counts[index] += 1 + return counts.tolist() + + +def _local_order( + length: int, *, shuffle: bool, generator: torch.Generator | None +) -> list[int]: + """Return one local index order for a child dataset.""" + if shuffle: + return torch.randperm(length, **_generator_kwargs(generator)).tolist() + return list(range(length)) + + +def _shuffle_indices( + indices: list[int], generator: torch.Generator | None +) -> list[int]: + """Return a shuffled copy of indices.""" + if len(indices) <= 1: + return indices + order = torch.randperm(len(indices), **_generator_kwargs(generator)).tolist() + return [indices[i] for i in order] + + +def _num_sharded_items(length: int, num_replicas: int, drop_last: bool) -> int: + """Return number of items emitted by one distributed rank.""" + if num_replicas == 1: + return length + if drop_last and length % num_replicas != 0: + return ceil((length - num_replicas) / num_replicas) + return ceil(length / num_replicas) + + +def _distributed_shard( + indices: list, + *, + num_replicas: int, + rank: int, + drop_last: bool, +) -> list: + """Return the subset of epoch items assigned to one distributed rank. + + Parameters + ---------- + indices : list + Sample indices in the order they would be retrieved for this epoch + before splitting the work across data-parallel ranks. In a + single-process run, this would be the sampler order. + num_replicas : int + Number of distributed ranks sharing the epoch. + rank : int + Rank whose local shard should be returned. + drop_last : bool + Whether to truncate the full epoch instead of padding it when the epoch + length is not evenly divisible by ``num_replicas``. + + Returns + ------- + list + Rank-local shard of ``indices``. + + Notes + ----- + To make strided sharding produce the same number of items on each rank, the + full epoch order is first resized to ``total_size``. ``num_samples`` is the + number of items one rank should emit, computed as + ``ceil(len(indices) / num_replicas)`` unless ``drop_last=True`` requires + truncating an uneven tail. ``total_size`` is the all-rank item count, + ``num_samples * num_replicas``. + + With ``drop_last=True``, the full list is truncated to ``total_size``. + Otherwise, items from the beginning of the epoch are repeated until the list + is evenly divisible across ranks, matching PyTorch + :class:`~torch.utils.data.DistributedSampler` behavior. After resizing, rank + ``r`` receives every ``num_replicas``-th item starting at offset ``r``: + ``indices[r:total_size:num_replicas]``. + """ + if num_replicas == 1: + return indices + + num_samples = _num_sharded_items(len(indices), num_replicas, drop_last) + total_size = num_samples * num_replicas + if drop_last: + indices = indices[:total_size] + elif len(indices) < total_size: + padding_size = total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * ceil(padding_size / len(indices)))[:padding_size] + return indices[rank:total_size:num_replicas] + + +def _contains_float(values: Sequence[int | float]) -> bool: + """Return whether any value should switch counts to ratio semantics.""" + return any( + isinstance(value, Real) and not isinstance(value, Integral) for value in values + ) + + +def _num_batches_from_policy( + *, + epoch_policy: EpochPolicy, + lengths: Sequence[int], + samples_per_dataset: Sequence[int], + batch_size: int, + total_length: int, + replacement: bool, +) -> int: + """Compute default epoch length from per-dataset batch allocations.""" + contributing = [ + (length, count) + for length, count in zip(lengths, samples_per_dataset, strict=True) + if count > 0 + ] + if not contributing: + raise ValueError("At least one dataset must contribute samples per batch") + + if replacement: + min_batches = min(ceil(length / count) for length, count in contributing) + max_batches = max(ceil(length / count) for length, count in contributing) + else: + min_batches = min(length // count for length, count in contributing) + max_batches = max(length // count for length, count in contributing) + + if epoch_policy == "dataset_size": + return ceil(total_length / batch_size) if replacement else min_batches + if epoch_policy == "min_size": + return min_batches + if epoch_policy == "max_size": + if not replacement and max_batches > min_batches: + raise ValueError( + "epoch_policy='max_size' requires replacement=True when smaller " + "datasets would need oversampling" + ) + return max_batches + raise ValueError( + "epoch_policy must be one of 'dataset_size', 'min_size', or 'max_size'" + ) + + +class MultiDatasetSampler(Sampler[int]): + """Sample global indices from a :class:`MultiDataset` at dataset-level rates. + + Parameters + ---------- + dataset : MultiDataset + Dataset wrapper that defines child dataset offsets. + weights : Sequence[float] | None, default=None + Per-child dataset sampling rates. ``None`` uses child lengths, matching + proportional sampling from the concatenated global index space. + num_samples : int | None, default=None + Number of global indices emitted per epoch. ``None`` emits + ``len(dataset)`` samples. + replacement : bool, default=True + Whether local samples may repeat within an epoch. + shuffle : bool, default=True + Randomize dataset choices and local sample order. + generator : torch.Generator | None, default=None + Optional random generator for reproducible sampling. + num_replicas : int | None, default=None + Number of distributed ranks. ``None`` uses initialized + ``distributed_manager.world_size`` or defaults to ``1``. + rank : int | None, default=None + Rank for this sampler. ``None`` uses initialized + ``distributed_manager.rank`` or defaults to ``0``. + distributed_manager : DistributedManager | None, default=None + Optional distributed manager used to infer rank and world size. + seed : int, default=0 + Base seed used for deterministic shuffling across epochs when + ``generator`` is ``None``. + drop_last : bool, default=False + Drop tail samples to make the epoch evenly divisible across ranks. + """ + + def __init__( + self, + dataset: MultiDataset, + *, + weights: Sequence[float] | None = None, + num_samples: int | None = None, + replacement: bool = True, + shuffle: bool = True, + generator: torch.Generator | None = None, + num_replicas: int | None = None, + rank: int | None = None, + distributed_manager: DistributedManager | None = None, + seed: int = 0, + drop_last: bool = False, + ) -> None: + """Initialize the sampler.""" + self.dataset = dataset + self.lengths = [len(child) for child in dataset.datasets] + self.weights = _normalise_weights(weights, self.lengths) + self.num_samples = len(dataset) if num_samples is None else num_samples + if self.num_samples < 1: + raise ValueError(f"num_samples must be >= 1, got {self.num_samples}") + self.replacement = replacement + self.shuffle = shuffle + self.generator = generator + if distributed_manager is not None and distributed_manager.is_initialized(): + num_replicas = distributed_manager.world_size + rank = distributed_manager.rank + if num_replicas is None: + num_replicas = 1 + if rank is None: + rank = 0 + if num_replicas < 1: + raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") + if rank < 0 or rank >= num_replicas: + raise ValueError( + f"rank must be in the range [0, {num_replicas}), got {rank}" + ) + self.num_replicas = num_replicas + self.rank = rank + self.seed = seed + self.drop_last = drop_last + self.epoch = 0 + + # if not sampling without replacement, we go through the datasets + # and make sure there are sufficient samples to meet the weights + if not replacement: + counts = _counts_from_weights(self.weights, self.num_samples) + for dataset_index, (count, length) in enumerate( + zip(counts, self.lengths, strict=True) + ): + if count > length: + raise ValueError( + "replacement=False cannot draw " + f"{count} samples from dataset {dataset_index} " + f"with only {length} samples" + ) + + def _epoch_generator(self) -> torch.Generator | None: + """Return the generator used for this epoch.""" + if self.generator is not None: + return self.generator + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch) + return generator + + def _global_indices(self) -> list[int]: + """Return the full unsharded epoch of global sample indices.""" + generator = self._epoch_generator() + if self.replacement and self.shuffle: + dataset_choices = torch.multinomial( + self.weights, + self.num_samples, + replacement=True, + **_generator_kwargs(generator), + ).tolist() + # if shuffling with replacement, there is a possiblity of + # encountering the same sample from a given dataset + indices = [] + for dataset_index in dataset_choices: + local_index = int( + torch.randint( + self.lengths[dataset_index], + (1,), + **_generator_kwargs(generator), + ).item() + ) + indices.append(self.dataset.to_global_index(dataset_index, local_index)) + return indices + + # case where we may be shuffling or replacing samples + counts = _counts_from_weights(self.weights, self.num_samples) + dataset_choices = [ + dataset_index + for dataset_index, count in enumerate(counts) + for _ in range(count) + ] + if self.shuffle: + dataset_choices = _shuffle_indices(dataset_choices, generator) + + local_orders = [ + _local_order(length, shuffle=self.shuffle, generator=generator) + for length in self.lengths + ] + cursors = [0] * len(self.lengths) + indices = [] + for dataset_index in dataset_choices: + cursor = cursors[dataset_index] + if self.replacement: + local_index = local_orders[dataset_index][ + cursor % self.lengths[dataset_index] + ] + else: + local_index = local_orders[dataset_index][cursor] + cursors[dataset_index] += 1 + indices.append(self.dataset.to_global_index(dataset_index, local_index)) + return indices + + def __iter__(self) -> Iterator[int]: + """Yield rank-local global sample indices.""" + yield from _distributed_shard( + self._global_indices(), + num_replicas=self.num_replicas, + rank=self.rank, + drop_last=self.drop_last, + ) + + def __len__(self) -> int: + """Return the number of rank-local emitted global indices.""" + return _num_sharded_items(self.num_samples, self.num_replicas, self.drop_last) + + def set_epoch(self, epoch: int) -> None: + """Set the epoch used for deterministic distributed shuffling. + + Parameters + ---------- + epoch : int + Epoch number added to ``seed`` when this sampler owns its generator. + """ + self.epoch = epoch + + +class MultiDatasetBatchSampler(Sampler[list[int]]): + """Sample full global-index batches from a :class:`MultiDataset`. + + Parameters + ---------- + dataset : MultiDataset + Dataset wrapper that defines child dataset offsets. + batch_size : int + Number of samples in each emitted batch. + weights : Sequence[float] | None, default=None + Per-child rates used to allocate ``batch_size`` slots. ``None`` uses + child lengths, matching proportional sampling from the global index + space. + samples_per_dataset : Sequence[int | float] | None, default=None + Per-child batch allocation. Integer entries are exact sample counts + per batch. If any entry is a float, the full sequence is interpreted + as relative per-dataset rates and allocated across ``batch_size``. + Mutually exclusive with ``weights``. + num_batches : int | None, default=None + Number of batches per epoch. For replacement sampling, the default is + ``ceil(len(dataset) / batch_size)``. Without replacement, the default is + the number of complete batches supported by the smallest requested child + allocation. + epoch_policy : {"dataset_size", "min_size", "max_size"}, default="dataset_size" + Policy used to compute ``num_batches`` when it is not provided. + ``"dataset_size"`` simply returns the combined dataset length divided + by the batch size when ``replacement=True``, otherwise ``min_size``. ``"min_size"`` + stops when the smallest contributing dataset would be exhausted. + ``"max_size"`` runs until the largest contributing dataset would be + exhausted, oversampling smaller datasets when ``replacement=True``. + replacement : bool, default=True + Whether local samples may repeat within an epoch. + shuffle : bool, default=True + Randomize local sample order and sample order within each batch. + generator : torch.Generator | None, default=None + Optional random generator for reproducible sampling. + num_replicas : int | None, default=None + Number of distributed ranks. ``None`` uses initialized + ``distributed_manager.world_size`` or defaults to ``1``. + rank : int | None, default=None + Rank for this sampler. ``None`` uses initialized + ``distributed_manager.rank`` or defaults to ``0``. + distributed_manager : DistributedManager | None, default=None + Optional distributed manager used to infer rank and world size. + seed : int, default=0 + Base seed used for deterministic shuffling across epochs when + ``generator`` is ``None``. + drop_last : bool, default=False + Drop tail batches to make the epoch evenly divisible across ranks. + """ + + def __init__( + self, + dataset: MultiDataset, + *, + batch_size: int, + weights: Sequence[float] | None = None, + samples_per_dataset: Sequence[int | float] | None = None, + num_batches: int | None = None, + epoch_policy: EpochPolicy = "dataset_size", + replacement: bool = True, + shuffle: bool = True, + generator: torch.Generator | None = None, + num_replicas: int | None = None, + rank: int | None = None, + distributed_manager: DistributedManager | None = None, + seed: int = 0, + drop_last: bool = False, + ) -> None: + """Initialize the batch sampler.""" + if batch_size < 1: + raise ValueError(f"batch_size must be >= 1, got {batch_size}") + if weights is not None and samples_per_dataset is not None: + raise ValueError("weights and samples_per_dataset are mutually exclusive") + + self.dataset = dataset + self.batch_size = batch_size + self.lengths = [len(child) for child in dataset.datasets] + self.replacement = replacement + self.shuffle = shuffle + self.generator = generator + self.epoch_policy = epoch_policy + if distributed_manager is not None and distributed_manager.is_initialized(): + num_replicas = distributed_manager.world_size + rank = distributed_manager.rank + if num_replicas is None: + num_replicas = 1 + if rank is None: + rank = 0 + if num_replicas < 1: + raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") + if rank < 0 or rank >= num_replicas: + raise ValueError( + f"rank must be in the range [0, {num_replicas}), got {rank}" + ) + self.num_replicas = num_replicas + self.rank = rank + self.seed = seed + self.drop_last = drop_last + self.epoch = 0 + + if samples_per_dataset is None: + normalised_weights = _normalise_weights(weights, self.lengths) + self.samples_per_dataset = _counts_from_weights( + normalised_weights, batch_size + ) + else: + if len(samples_per_dataset) != len(self.lengths): + raise ValueError( + f"Expected {len(self.lengths)} per-dataset counts, " + f"got {len(samples_per_dataset)}" + ) + # if floats are provided, we treat them as ratios + if _contains_float(samples_per_dataset): + normalised_weights = _normalise_weights( + samples_per_dataset, self.lengths + ) + self.samples_per_dataset = _counts_from_weights( + normalised_weights, batch_size + ) + else: + exact_counts: list[int] = [] + for count in samples_per_dataset: + if isinstance(count, bool) or not isinstance(count, Integral): + raise TypeError( + "Integer samples_per_dataset entries must be " + f"integral counts, got {count!r}" + ) + exact_counts.append(int(count)) + self.samples_per_dataset = exact_counts + + if any(count < 0 for count in self.samples_per_dataset): + raise ValueError("samples_per_dataset counts must be non-negative") + if sum(self.samples_per_dataset) != batch_size: + raise ValueError( + "samples_per_dataset counts must sum to batch_size: " + f"{sum(self.samples_per_dataset)} != {batch_size}" + ) + if all(count == 0 for count in self.samples_per_dataset): + raise ValueError("At least one dataset must contribute samples per batch") + + for dataset_index, (count, length) in enumerate( + zip(self.samples_per_dataset, self.lengths, strict=True) + ): + if count > 0 and length == 0: + raise ValueError( + f"Dataset {dataset_index} contributes {count} samples per " + "batch but has no samples" + ) + + if replacement: + self.num_batches = ( + _num_batches_from_policy( + epoch_policy=epoch_policy, + lengths=self.lengths, + samples_per_dataset=self.samples_per_dataset, + batch_size=batch_size, + total_length=len(dataset), + replacement=True, + ) + if num_batches is None + else num_batches + ) + else: + max_complete_batches = min( + length // count + for length, count in zip( + self.lengths, self.samples_per_dataset, strict=True + ) + if count > 0 + ) + self.num_batches = ( + _num_batches_from_policy( + epoch_policy=epoch_policy, + lengths=self.lengths, + samples_per_dataset=self.samples_per_dataset, + batch_size=batch_size, + total_length=len(dataset), + replacement=False, + ) + if num_batches is None + else num_batches + ) + if self.num_batches > max_complete_batches: + raise ValueError( + "replacement=False supports at most " + f"{max_complete_batches} complete batches for the requested " + "per-dataset counts" + ) + if self.num_batches < 1: + raise ValueError(f"num_batches must be >= 1, got {self.num_batches}") + + @classmethod + def balanced( + cls, + dataset: MultiDataset, + *, + batch_size: int, + num_batches: int | None = None, + epoch_policy: EpochPolicy = "dataset_size", + replacement: bool = True, + shuffle: bool = True, + generator: torch.Generator | None = None, + num_replicas: int | None = None, + rank: int | None = None, + distributed_manager: DistributedManager | None = None, + seed: int = 0, + drop_last: bool = False, + ) -> Self: + """Create a batch sampler with equal dataset-level sampling rates. + + Parameters + ---------- + dataset : MultiDataset + Dataset wrapper that defines child dataset offsets. + batch_size : int + Number of samples in each emitted batch. + num_batches : int | None, default=None + Number of batches per epoch. + epoch_policy : {"dataset_size", "min_size", "max_size"}, default="dataset_size" + Policy used to compute ``num_batches`` when it is not provided. + replacement : bool, default=True + Whether local samples may repeat within an epoch. + shuffle : bool, default=True + Randomize local sample order and sample order within each batch. + generator : torch.Generator | None, default=None + Optional random generator for reproducible sampling. + num_replicas : int | None, default=None + Number of distributed ranks. + rank : int | None, default=None + Rank for this sampler. + distributed_manager : DistributedManager | None, default=None + Optional distributed manager used to infer rank and world size. + seed : int, default=0 + Base seed used for deterministic shuffling across epochs. + drop_last : bool, default=False + Drop tail batches to make the epoch evenly divisible across ranks. + + Returns + ------- + Self + Batch sampler with one equal relative weight per child dataset. + """ + return cls( + dataset, + batch_size=batch_size, + weights=[1.0] * len(dataset.datasets), + num_batches=num_batches, + epoch_policy=epoch_policy, + replacement=replacement, + shuffle=shuffle, + generator=generator, + num_replicas=num_replicas, + rank=rank, + distributed_manager=distributed_manager, + seed=seed, + drop_last=drop_last, + ) + + def _epoch_generator(self) -> torch.Generator | None: + """Return the generator used for this epoch.""" + if self.generator is not None: + return self.generator + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch) + return generator + + def _global_batches(self) -> list[list[int]]: + """Return the full unsharded epoch of global-index batches.""" + generator = self._epoch_generator() + batches: list[list[int]] = [] + if self.replacement: + cursors = [0] * len(self.lengths) + for _ in range(self.num_batches): + batch: list[int] = [] + for dataset_index, count in enumerate(self.samples_per_dataset): + if count == 0: + continue + if self.shuffle: + local_indices = torch.randint( + self.lengths[dataset_index], + (count,), + **_generator_kwargs(generator), + ).tolist() + else: + cursor = cursors[dataset_index] + local_indices = [ + (cursor + i) % self.lengths[dataset_index] + for i in range(count) + ] + cursors[dataset_index] += count + batch.extend( + self.dataset.to_global_index(dataset_index, local_index) + for local_index in local_indices + ) + batches.append( + _shuffle_indices(batch, generator) if self.shuffle else batch + ) + return batches + + local_orders = [ + _local_order(length, shuffle=self.shuffle, generator=generator) + for length in self.lengths + ] + cursors = [0] * len(self.lengths) + for _ in range(self.num_batches): + batch = [] + for dataset_index, count in enumerate(self.samples_per_dataset): + if count == 0: + continue + cursor = cursors[dataset_index] + local_indices = local_orders[dataset_index][cursor : cursor + count] + cursors[dataset_index] += count + batch.extend( + self.dataset.to_global_index(dataset_index, local_index) + for local_index in local_indices + ) + batches.append( + _shuffle_indices(batch, generator) if self.shuffle else batch + ) + return batches + + def __iter__(self) -> Iterator[list[int]]: + """Yield rank-local batches of global sample indices.""" + yield from _distributed_shard( + self._global_batches(), + num_replicas=self.num_replicas, + rank=self.rank, + drop_last=self.drop_last, + ) + + def __len__(self) -> int: + """Return the number of rank-local emitted batches.""" + return _num_sharded_items(self.num_batches, self.num_replicas, self.drop_last) + + def set_epoch(self, epoch: int) -> None: + """Set the epoch used for deterministic distributed shuffling. + + Parameters + ---------- + epoch : int + Epoch number added to ``seed`` when this sampler owns its generator. + """ + self.epoch = epoch diff --git a/nvalchemi/data/io_test.py b/nvalchemi/data/io_test.py index f4c29d62..a9956b3b 100644 --- a/nvalchemi/data/io_test.py +++ b/nvalchemi/data/io_test.py @@ -12,13 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Quick Zarr I/O benchmark for measuring write throughput and compression. +"""Quick Zarr I/O benchmark for measuring write/read throughput and compression. Run with:: nvalchemi-io-test --help nvalchemi-io-test --num-systems 1000 5000 --codec zstd --chunk-size 10000 +Readback uses the dataloader fused-prefetch path by default. To compare +against one-sample-at-a-time reads:: + + nvalchemi-io-test -n 1000 --read-mode both --batch-size 64 --prefetch-factor 8 + nvalchemi-io-test -n 1000 --read-mode single + nvalchemi-io-test -n 1000 --read-order shuffle + nvalchemi-io-test -n 1000 --read-order block-shuffle --read-order-block-size 8192 + Edge-specific chunking (useful for large graphs):: nvalchemi-io-test -n 100 --codec zstd --chunk-size 10000 --edge-chunk-size 5000 @@ -32,8 +40,9 @@ import shutil import tempfile import time +from collections.abc import Iterator, Sequence from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, TypeAlias, cast import click import torch @@ -47,12 +56,19 @@ TimeElapsedColumn, ) from rich.table import Table +from torch.utils.data import Sampler if TYPE_CHECKING: from nvalchemi.data.atomic_data import AtomicData console = Console(stderr=True) +ReadMode: TypeAlias = Literal["batch", "single"] +ReadOrder: TypeAlias = Literal["sequential", "shuffle", "block-shuffle"] +DEFAULT_BATCH_SIZE = 64 +DEFAULT_PREFETCH_FACTOR = 16 +DEFAULT_READ_ORDER_BLOCK_SIZE = 8192 + def _make_atomic_data(num_atoms: int, num_edges: int) -> AtomicData: """Create a minimal AtomicData with random data. @@ -82,7 +98,8 @@ def _make_atomic_data(num_atoms: int, num_edges: int) -> AtomicData: [ torch.randint(0, max(num_atoms, 1), (num_edges,)), torch.randint(0, max(num_atoms, 1), (num_edges,)), - ] + ], + dim=1, ), shifts=torch.randn(num_edges, 3), ) @@ -324,6 +341,236 @@ def _fmt_bytes(n: int) -> str: return f"{n:.1f} TB" +def _tensor_bytes(data: AtomicData | dict[str, torch.Tensor]) -> int: + """Return the total tensor payload size in bytes. + + Parameters + ---------- + data : AtomicData | dict[str, torch.Tensor] + AtomicData object or raw tensor dictionary. + + Returns + ------- + int + Total tensor bytes. + """ + if isinstance(data, dict): + return sum(val.nelement() * val.element_size() for val in data.values()) + + total = 0 + for key in data.model_fields_set: + val = getattr(data, key, None) + if isinstance(val, torch.Tensor): + total += val.nelement() * val.element_size() + return total + + +def _expand_read_modes(read_mode_options: tuple[str, ...]) -> tuple[ReadMode, ...]: + """Expand CLI read-mode options into concrete benchmark modes. + + Parameters + ---------- + read_mode_options : tuple[str, ...] + CLI options. ``"both"`` expands to ``("batch", "single")``. + + Returns + ------- + tuple[ReadMode, ...] + Concrete readback modes in benchmark order. + + Raises + ------ + click.BadParameter + If an unknown mode is provided. + """ + modes: list[ReadMode] = [] + for option in read_mode_options: + normalized = option.lower() + if normalized == "both": + modes.extend(("batch", "single")) + elif normalized in ("batch", "single"): + modes.append(cast(ReadMode, normalized)) + else: + msg = f"Unknown read mode: {option!r}" + raise click.BadParameter(msg) + + return tuple(modes) if modes else ("batch",) + + +def _build_read_indices( + expected_num_systems: int, + read_order: ReadOrder, + seed: int, + read_order_block_size: int, +) -> list[int]: + """Build the logical sample order used for readback benchmarking. + + Parameters + ---------- + expected_num_systems : int + Number of readable samples. + read_order : {"sequential", "shuffle", "block-shuffle"} + Logical index order to benchmark. + seed : int + Seed for randomized read orders. + read_order_block_size : int + Number of contiguous samples per shuffled block in block-shuffle mode. + + Returns + ------- + list[int] + Logical sample indices in readback order. + + Raises + ------ + ValueError + If *read_order_block_size* is less than 1 or *read_order* is unknown. + """ + if read_order_block_size < 1: + raise ValueError( + f"read_order_block_size must be >= 1, got {read_order_block_size}" + ) + + indices = list(range(expected_num_systems)) + rng = random.Random(seed) + + if read_order == "sequential": + return indices + if read_order == "shuffle": + rng.shuffle(indices) + return indices + if read_order == "block-shuffle": + blocks = [ + indices[start : start + read_order_block_size] + for start in range(0, expected_num_systems, read_order_block_size) + ] + rng.shuffle(blocks) + return [index for block in blocks for index in block] + + msg = f"Unknown read order: {read_order!r}" + raise ValueError(msg) + + +class _FixedOrderSampler(Sampler[int]): + """Sampler that yields a precomputed logical read order.""" + + def __init__(self, indices: Sequence[int]) -> None: + self._indices = list(indices) + + def __iter__(self) -> Iterator[int]: + """Yield indices in the configured order.""" + return iter(self._indices) + + def __len__(self) -> int: + """Return the number of configured sample indices.""" + return len(self._indices) + + +def _read_back_store( + store_path: Path, + expected_num_systems: int, + read_mode: ReadMode = "batch", + batch_size: int = DEFAULT_BATCH_SIZE, + prefetch_factor: int = DEFAULT_PREFETCH_FACTOR, + read_order: ReadOrder = "sequential", + read_seed: int = 0, + read_order_block_size: int = DEFAULT_READ_ORDER_BLOCK_SIZE, + pin_memory: bool = False, +) -> tuple[float, int]: + """Read every sample from a Zarr store and return timing and payload bytes. + + Parameters + ---------- + store_path : Path + Zarr store to read. + expected_num_systems : int + Expected number of readable samples. + read_mode : {"batch", "single"}, default="batch" + Readback path to benchmark. ``"batch"`` uses the public + :class:`~nvalchemi.data.datapipes.DataLoader` path with fused + prefetching; ``"single"`` uses one ``reader.read`` call per sample. + batch_size : int, default=64 + Number of samples per emitted dataloader batch in batch mode. + prefetch_factor : int, default=16 + Number of emitted dataloader batches to fuse into each backend read in + batch mode. The effective read window is + ``batch_size * prefetch_factor``. + read_order : {"sequential", "shuffle", "block-shuffle"}, default="sequential" + Logical sample order used for readback. ``"shuffle"`` models fully + shuffled dataloading. ``"block-shuffle"`` shuffles contiguous index + blocks while preserving locality inside each block. + read_seed : int, default=0 + Seed for randomized read orders. + read_order_block_size : int, default=8192 + Number of contiguous samples per shuffled block in block-shuffle mode. + pin_memory : bool, default=False + Request pinned CPU tensors from readers that support pinned-memory reads. + + Returns + ------- + tuple[float, int] + Read time in seconds and total tensor payload bytes read. + + Raises + ------ + ValueError + If *batch_size* is less than 1, if *prefetch_factor* is negative, if + *read_order_block_size* is less than 1, or if *read_mode* / + *read_order* is unknown. + RuntimeError + If the store does not expose the expected number of samples. + """ + from nvalchemi.data.datapipes import DataLoader, Dataset + from nvalchemi.data.datapipes.backends.zarr import AtomicDataZarrReader + + if batch_size < 1: + raise ValueError(f"batch_size must be >= 1, got {batch_size}") + if prefetch_factor < 0: + raise ValueError(f"prefetch_factor must be >= 0, got {prefetch_factor}") + + read_indices = _build_read_indices( + expected_num_systems, + read_order, + read_seed, + read_order_block_size, + ) + + read_bytes = 0 + t0 = time.perf_counter() + with AtomicDataZarrReader(store_path) as reader: + if len(reader) != expected_num_systems: + msg = ( + f"Expected {expected_num_systems} readable samples, " + f"found {len(reader)}." + ) + raise RuntimeError(msg) + if read_mode == "batch": + dataset = Dataset(reader, device="cpu", skip_validation=True) + loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=_FixedOrderSampler(read_indices), + prefetch_factor=prefetch_factor, + use_streams=False, + pin_memory=pin_memory, + ) + for batch in loader: + read_bytes += sum( + value.nelement() * value.element_size() + for _key, value in batch + if isinstance(value, torch.Tensor) + ) + elif read_mode == "single": + for index in read_indices: + data_dict, _metadata = reader.read(index) + read_bytes += _tensor_bytes(data_dict) + else: + msg = f"Unknown read mode: {read_mode!r}" + raise ValueError(msg) + read_time = time.perf_counter() - t0 + return read_time, read_bytes + + def _run_benchmark( num_systems_list: list[int], min_atoms: int, @@ -331,8 +578,15 @@ def _run_benchmark( seed: int, config: dict | None, store_dir: Path, + read_modes: tuple[ReadMode, ...] = ("batch",), + batch_size: int = DEFAULT_BATCH_SIZE, + prefetch_factor: int = DEFAULT_PREFETCH_FACTOR, + read_order: ReadOrder = "sequential", + read_seed: int = 0, + read_order_block_size: int = DEFAULT_READ_ORDER_BLOCK_SIZE, + pin_memory: bool = False, ) -> list[dict]: - """Run the write benchmark for each system count. + """Run the write/read benchmark for each system count. Parameters ---------- @@ -348,6 +602,21 @@ def _run_benchmark( ZarrWriteConfig dict. store_dir : Path Temporary directory for Zarr stores. + read_modes : tuple[ReadMode, ...], default=("batch",) + Readback modes to benchmark for each written store. + batch_size : int, default=64 + Number of samples per emitted dataloader batch in ``"batch"`` mode. + prefetch_factor : int, default=16 + Number of emitted batches to fuse into each backend read in ``"batch"`` + mode. + read_order : {"sequential", "shuffle", "block-shuffle"}, default="sequential" + Logical sample order used during readback. + read_seed : int, default=0 + Seed for randomized read orders. + read_order_block_size : int, default=8192 + Number of contiguous samples per shuffled block in block-shuffle mode. + pin_memory : bool, default=False + Request pinned CPU tensors from readers that support pinned-memory reads. Returns ------- @@ -362,6 +631,8 @@ def _run_benchmark( write_config = ( ZarrWriteConfig.model_validate(config) if config else ZarrWriteConfig() ) + if not read_modes: + raise ValueError("At least one read mode must be provided.") # Pre-compute plans for all system counts max_systems = max(num_systems_list) @@ -391,7 +662,9 @@ def _run_benchmark( with progress: for num_systems in num_systems_list: - task = progress.add_task(f"[cyan]{num_systems:>10,} systems", total=3) + task = progress.add_task( + f"[cyan]{num_systems:>10,} systems", total=3 + len(read_modes) + ) # Step 1: generate data from pre-computed plan progress.update(task, description=f"[cyan]{num_systems:>10,} gen") @@ -410,18 +683,34 @@ def _run_benchmark( write_time = time.perf_counter() - t0 progress.advance(task) - # Step 3: measure + # Step 3: read back through each requested path. + read_results: list[tuple[ReadMode, float, int]] = [] + for read_mode in read_modes: + progress.update( + task, + description=f"[cyan]{num_systems:>10,} read-{read_mode}", + ) + read_time, read_bytes = _read_back_store( + store_path, + num_systems, + read_mode=read_mode, + batch_size=batch_size, + prefetch_factor=prefetch_factor, + read_order=read_order, + read_seed=read_seed, + read_order_block_size=read_order_block_size, + pin_memory=pin_memory, + ) + read_results.append((read_mode, read_time, read_bytes)) + progress.advance(task) + + # Final step: measure progress.update(task, description=f"[cyan]{num_systems:>10,} measure") disk_bytes = _dir_size(store_path) num_files = _file_count(store_path) # compute uncompressed size from numpy arrays - raw_bytes = 0 - for d in data_list: - for key in d.model_fields_set: - val = getattr(d, key, None) - if isinstance(val, torch.Tensor): - raw_bytes += val.nelement() * val.element_size() + raw_bytes = sum(_tensor_bytes(d) for d in data_list) progress.advance(task) progress.update( @@ -433,21 +722,50 @@ def _run_benchmark( avg_edges_run = total_edges / num_systems ratio = raw_bytes / disk_bytes if disk_bytes > 0 else float("inf") - results.append( - { - "num_systems": num_systems, - "total_atoms": total_atoms, - "total_edges": total_edges, - "avg_atoms": avg_atoms_run, - "avg_edges": avg_edges_run, - "raw_bytes": raw_bytes, - "disk_bytes": disk_bytes, - "num_files": num_files, - "ratio": ratio, - "write_time": write_time, - "throughput": num_systems / write_time if write_time > 0 else 0, - } - ) + for read_mode, read_time, read_bytes in read_results: + profile_time = write_time + read_time + results.append( + { + "num_systems": num_systems, + "read_mode": read_mode, + "read_order": read_order, + "read_order_block_size": ( + read_order_block_size + if read_order == "block-shuffle" + else None + ), + "batch_size": batch_size if read_mode == "batch" else 1, + "prefetch_factor": ( + prefetch_factor if read_mode == "batch" else 0 + ), + "effective_read_window": ( + batch_size * max(prefetch_factor, 1) + if read_mode == "batch" + else 1 + ), + "total_atoms": total_atoms, + "total_edges": total_edges, + "avg_atoms": avg_atoms_run, + "avg_edges": avg_edges_run, + "raw_bytes": raw_bytes, + "disk_bytes": disk_bytes, + "read_bytes": read_bytes, + "num_files": num_files, + "ratio": ratio, + "write_time": write_time, + "read_time": read_time, + "profile_time": profile_time, + "write_throughput": ( + num_systems / write_time if write_time > 0 else 0 + ), + "read_throughput": ( + num_systems / read_time if read_time > 0 else 0 + ), + "profile_throughput": ( + num_systems / profile_time if profile_time > 0 else 0 + ), + } + ) return results @@ -463,37 +781,224 @@ def _print_results(results: list[dict], config_desc: str) -> None: Description of the configuration used. """ table = Table( - title=f"Zarr I/O Benchmark — {config_desc}", + title=f"Zarr I/O Roundtrip Benchmark — {config_desc}", box=box.SIMPLE_HEAD, ) - table.add_column("Systems", justify="right", style="cyan") - table.add_column("Avg atoms", justify="right") - table.add_column("Avg edges", justify="right") - table.add_column("Raw size", justify="right") - table.add_column("Disk size", justify="right", style="green") - table.add_column("Ratio", justify="right", style="yellow") - table.add_column("Files", justify="right") - table.add_column("Write time", justify="right") - table.add_column("Systems/s", justify="right", style="bold") + table.add_column("Systems", justify="right", style="cyan", no_wrap=True) + table.add_column("Read path", justify="left", no_wrap=True) + table.add_column("Read order", justify="left", no_wrap=True) + table.add_column("Batch", justify="right", no_wrap=True) + table.add_column("Prefetch", justify="right", no_wrap=True) + table.add_column("Read window", justify="right", no_wrap=True) + table.add_column("Atoms", justify="right", no_wrap=True) + table.add_column("Edges", justify="right", no_wrap=True) + table.add_column("Raw", justify="right", no_wrap=True) + table.add_column("Disk", justify="right", style="green", no_wrap=True) + table.add_column("Ratio", justify="right", style="yellow", no_wrap=True) + table.add_column("Write", justify="right", no_wrap=True) + table.add_column("Read", justify="right", no_wrap=True) + table.add_column("I/O/s", justify="right", style="bold", no_wrap=True) for r in results: table.add_row( f"{r['num_systems']:,}", + r["read_mode"], + r["read_order"], + f"{r['batch_size']:,}", + f"{r['prefetch_factor']:,}", + f"{r['effective_read_window']:,}", f"{r['avg_atoms']:.0f}", f"{r['avg_edges']:.0f}", _fmt_bytes(r["raw_bytes"]), _fmt_bytes(r["disk_bytes"]), f"{r['ratio']:.2f}x", - f"{r['num_files']:,}", f"{r['write_time']:.2f}s", - f"{r['throughput']:,.0f}", + f"{r['read_time']:.2f}s", + f"{r['profile_throughput']:,.0f}", ) console.print() console.print(table) -@click.command("nvalchemi-io-test") +def _run_read_benchmark( + store_path: Path, + read_modes: tuple[ReadMode, ...] = ("batch",), + batch_size: int = DEFAULT_BATCH_SIZE, + prefetch_factor: int = DEFAULT_PREFETCH_FACTOR, + read_order: ReadOrder = "sequential", + read_seed: int = 0, + read_order_block_size: int = DEFAULT_READ_ORDER_BLOCK_SIZE, + pin_memory: bool = False, +) -> list[dict]: + """Benchmark read performance against an existing Zarr store. + + Parameters + ---------- + store_path : Path + Path to an existing Zarr store written by ``AtomicDataZarrWriter``. + read_modes : tuple[ReadMode, ...], default=("batch",) + Readback modes to benchmark. + batch_size : int, default=64 + Number of samples per emitted dataloader batch in batch mode. + prefetch_factor : int, default=16 + Number of emitted batches to fuse into each backend read in batch mode. + read_order : {"sequential", "shuffle", "block-shuffle"}, default="sequential" + Logical sample order for readback. + read_seed : int, default=0 + Seed for randomized read orders. + read_order_block_size : int, default=8192 + Block size for block-shuffle mode. + pin_memory : bool, default=False + Request pinned CPU tensors from readers that support pinned-memory reads. + + Returns + ------- + list[dict] + One result dict per read mode. + """ + from nvalchemi.data.datapipes.backends.zarr import AtomicDataZarrReader + + if not read_modes: + raise ValueError("At least one read mode must be provided.") + + with AtomicDataZarrReader(store_path) as reader: + num_systems = len(reader) + + results = [] + progress = Progress( + TextColumn("{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + console=console, + ) + + with progress: + for read_mode in read_modes: + task = progress.add_task( + f"[cyan]read-{read_mode} ({read_order})", + total=1, + ) + read_time, read_bytes = _read_back_store( + store_path, + num_systems, + read_mode=read_mode, + batch_size=batch_size, + prefetch_factor=prefetch_factor, + read_order=read_order, + read_seed=read_seed, + read_order_block_size=read_order_block_size, + pin_memory=pin_memory, + ) + progress.advance(task) + progress.update(task, description=f"[green]read-{read_mode} done") + + results.append( + { + "store_path": str(store_path), + "num_systems": num_systems, + "read_mode": read_mode, + "read_order": read_order, + "read_order_block_size": ( + read_order_block_size if read_order == "block-shuffle" else None + ), + "batch_size": batch_size if read_mode == "batch" else 1, + "prefetch_factor": prefetch_factor if read_mode == "batch" else 0, + "effective_read_window": ( + batch_size * max(prefetch_factor, 1) + if read_mode == "batch" + else 1 + ), + "read_time": read_time, + "read_bytes": read_bytes, + "read_throughput": ( + num_systems / read_time if read_time > 0 else 0 + ), + } + ) + + return results + + +def _print_read_results(results: list[dict]) -> None: + """Print read-only benchmark results as a Rich table. + + Parameters + ---------- + results : list[dict] + Read benchmark results from ``_run_read_benchmark``. + """ + if not results: + return + + store_path = results[0].get("store_path", "?") + table = Table( + title=f"Zarr Read Benchmark — {store_path}", + box=box.SIMPLE_HEAD, + ) + table.add_column("Samples", justify="right", style="cyan", no_wrap=True) + table.add_column("Read path", justify="left", no_wrap=True) + table.add_column("Read order", justify="left", no_wrap=True) + table.add_column("Batch", justify="right", no_wrap=True) + table.add_column("Prefetch", justify="right", no_wrap=True) + table.add_column("Read window", justify="right", no_wrap=True) + table.add_column("Read time", justify="right", no_wrap=True) + table.add_column("Samples/s", justify="right", style="bold", no_wrap=True) + table.add_column("Data read", justify="right", style="green", no_wrap=True) + + for r in results: + order_desc = r["read_order"] + if r["read_order_block_size"] is not None: + order_desc += f" (blk={r['read_order_block_size']:,})" + table.add_row( + f"{r['num_systems']:,}", + r["read_mode"], + order_desc, + f"{r['batch_size']:,}", + f"{r['prefetch_factor']:,}", + f"{r['effective_read_window']:,}", + f"{r['read_time']:.2f}s", + f"{r['read_throughput']:,.0f}", + _fmt_bytes(r["read_bytes"]), + ) + + console.print() + console.print(table) + + +class _DefaultRoundtripGroup(click.Group): + """Click group that falls back to ``roundtrip`` for unrecognised args. + + When users invoke ``nvalchemi-io-test --num-systems 1000`` (the pre- + group signature), Click would normally fail because ``--num-systems`` + is not a group-level option. This subclass detects that the first + argument is not a known subcommand and transparently inserts + ``roundtrip`` so the old invocation style keeps working. + """ + + def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]: + """Insert ``roundtrip`` when the first arg is not a subcommand.""" + if args and args[0] not in self.commands and not args[0].startswith("--help"): + args = ["roundtrip", *args] + return super().parse_args(ctx, args) + + +@click.group( + "nvalchemi-io-test", cls=_DefaultRoundtripGroup, invoke_without_command=True +) +@click.pass_context +def main(ctx: click.Context) -> None: + """Zarr I/O benchmarks for nvalchemi atomic data. + + Run without a subcommand to see available benchmarks, or use + ``roundtrip`` / ``read`` directly. + """ + if ctx.invoked_subcommand is None: + click.echo(ctx.get_help()) + + +@main.command("roundtrip") @click.option( "--num-systems", "-n", @@ -567,7 +1072,62 @@ def _print_results(results: list[dict], config_desc: str) -> None: default=None, help="Directory for Zarr stores (default: tempdir, cleaned up).", ) -def main( +@click.option( + "--read-mode", + type=click.Choice(["batch", "single", "both"], case_sensitive=False), + multiple=True, + default=("batch",), + show_default=True, + help=( + "Readback path to benchmark. 'batch' uses DataLoader fused prefetch; " + "'single' uses reader.read per sample; repeat to control order." + ), +) +@click.option( + "--batch-size", + type=click.IntRange(min=1), + default=DEFAULT_BATCH_SIZE, + show_default=True, + help="Number of samples per emitted DataLoader batch for --read-mode=batch.", +) +@click.option( + "--prefetch-factor", + type=click.IntRange(min=0), + default=DEFAULT_PREFETCH_FACTOR, + show_default=True, + help="Number of DataLoader batches to fuse into each backend read.", +) +@click.option( + "--read-order", + type=click.Choice(["sequential", "shuffle", "block-shuffle"], case_sensitive=False), + default="sequential", + show_default=True, + help=( + "Logical sample order used for readback. 'shuffle' models full random " + "dataloader reads; 'block-shuffle' shuffles contiguous index blocks." + ), +) +@click.option( + "--read-seed", + type=int, + default=0, + show_default=True, + help="Random seed for --read-order=shuffle and --read-order=block-shuffle.", +) +@click.option( + "--read-order-block-size", + type=click.IntRange(min=1), + default=DEFAULT_READ_ORDER_BLOCK_SIZE, + show_default=True, + help="Contiguous block size for --read-order=block-shuffle.", +) +@click.option( + "--pin-memory/--no-pin-memory", + default=False, + show_default=True, + help="Request pinned CPU tensors from readers that support it.", +) +def roundtrip( num_systems: tuple[int, ...], min_atoms: int, max_atoms: int, @@ -579,12 +1139,20 @@ def main( edge_shard_size: int | None, seed: int, output_dir: Path | None, + read_mode: tuple[str, ...], + batch_size: int, + prefetch_factor: int, + read_order: str, + read_seed: int, + read_order_block_size: int, + pin_memory: bool, ) -> None: - """Run quick Zarr write benchmarks for nvalchemi data. + """Write+read roundtrip benchmark. Generates random AtomicData structures with uniform atom counts between --min-atoms and --max-atoms, writes them to a Zarr store - with the specified configuration, and reports timing and size. + with the specified configuration, reads them back, and reports timing + and size. """ # Build config description for table title parts = [] @@ -598,11 +1166,17 @@ def main( parts.append(f"edge_chunk={edge_chunk_size:,}") if edge_shard_size is not None: parts.append(f"edge_shard={edge_shard_size:,}") + read_modes = _expand_read_modes(read_mode) + read_desc = ", ".join(read_modes) + read_order = cast(ReadOrder, read_order.lower()) config_desc = ", ".join(parts) if parts else "no compression" console.print( - f"[bold]nvalchemi Zarr I/O benchmark[/bold] " - f"atoms={min_atoms}-{max_atoms} config={config_desc}" + f"[bold]nvalchemi Zarr I/O roundtrip benchmark[/bold] " + f"atoms={min_atoms}-{max_atoms} config={config_desc} " + f"read={read_desc} read_order={read_order} " + f"batch={batch_size:,} prefetch={prefetch_factor:,} " + f"read_window={batch_size * max(prefetch_factor, 1):,}" ) config = _build_config( @@ -623,6 +1197,13 @@ def main( seed=seed, config=config, store_dir=store_dir, + read_modes=read_modes, + batch_size=batch_size, + prefetch_factor=prefetch_factor, + read_order=read_order, + read_seed=read_seed, + read_order_block_size=read_order_block_size, + pin_memory=pin_memory, ) _print_results(results, config_desc) finally: @@ -630,5 +1211,102 @@ def main( shutil.rmtree(store_dir, ignore_errors=True) +@main.command("read") +@click.argument("path", type=click.Path(exists=True, file_okay=False, path_type=Path)) +@click.option( + "--read-mode", + type=click.Choice(["batch", "single", "both"], case_sensitive=False), + multiple=True, + default=("batch",), + show_default=True, + help=( + "Readback path to benchmark. 'batch' uses DataLoader fused prefetch; " + "'single' uses reader.read per sample; repeat to control order." + ), +) +@click.option( + "--batch-size", + type=click.IntRange(min=1), + default=DEFAULT_BATCH_SIZE, + show_default=True, + help="Number of samples per emitted DataLoader batch for --read-mode=batch.", +) +@click.option( + "--prefetch-factor", + type=click.IntRange(min=0), + default=DEFAULT_PREFETCH_FACTOR, + show_default=True, + help="Number of DataLoader batches to fuse into each backend read.", +) +@click.option( + "--read-order", + type=click.Choice(["sequential", "shuffle", "block-shuffle"], case_sensitive=False), + default="sequential", + show_default=True, + help=( + "Logical sample order used for readback. 'shuffle' models full random " + "dataloader reads; 'block-shuffle' shuffles contiguous index blocks." + ), +) +@click.option( + "--read-seed", + type=int, + default=0, + show_default=True, + help="Random seed for --read-order=shuffle and --read-order=block-shuffle.", +) +@click.option( + "--read-order-block-size", + type=click.IntRange(min=1), + default=DEFAULT_READ_ORDER_BLOCK_SIZE, + show_default=True, + help="Contiguous block size for --read-order=block-shuffle.", +) +@click.option( + "--pin-memory/--no-pin-memory", + default=False, + show_default=True, + help="Request pinned CPU tensors from readers that support it.", +) +def read_cmd( + path: Path, + read_mode: tuple[str, ...], + batch_size: int, + prefetch_factor: int, + read_order: str, + read_seed: int, + read_order_block_size: int, + pin_memory: bool, +) -> None: + """Benchmark read throughput against an existing Zarr store. + + Reads all samples from PATH using the specified access pattern and + reports timing and throughput. Useful for profiling read performance + in isolation, or comparing sequential vs. shuffled access. + """ + read_modes = _expand_read_modes(read_mode) + read_order_typed = cast(ReadOrder, read_order.lower()) + + console.print( + f"[bold]nvalchemi Zarr read benchmark[/bold] " + f"store={path} read={', '.join(read_modes)} " + f"order={read_order_typed} batch={batch_size:,} " + f"prefetch={prefetch_factor:,} " + f"read_window={batch_size * max(prefetch_factor, 1):,}" + ) + + results = _run_read_benchmark( + store_path=path, + read_modes=read_modes, + batch_size=batch_size, + prefetch_factor=prefetch_factor, + read_order=read_order_typed, + read_seed=read_seed, + read_order_block_size=read_order_block_size, + pin_memory=pin_memory, + ) + _print_read_results(results) + + if __name__ == "__main__": main() diff --git a/nvalchemi/data/level_storage.py b/nvalchemi/data/level_storage.py index 302c3617..6428572b 100644 --- a/nvalchemi/data/level_storage.py +++ b/nvalchemi/data/level_storage.py @@ -826,13 +826,17 @@ def clone(self) -> BaseLevelStorage: validate=False, ) - def to_device(self, device: DeviceType) -> BaseLevelStorage: + def to_device( + self, device: DeviceType, *, non_blocking: bool = False + ) -> BaseLevelStorage: """Move all tensors to *device* (in-place). Parameters ---------- device : DeviceType Target device. + non_blocking : bool, default False + Whether tensor copies may be asynchronous when supported. Returns ------- @@ -841,7 +845,7 @@ def to_device(self, device: DeviceType) -> BaseLevelStorage: """ device = torch.device(device) self.device = device - self._data = self._data.to(device) + self._data = self._data.to(device, non_blocking=non_blocking) return self @@ -1226,9 +1230,11 @@ def defrag( def is_segmented(self) -> bool: return False - def to_device(self, device: DeviceType) -> UniformLevelStorage: + def to_device( + self, device: DeviceType, *, non_blocking: bool = False + ) -> UniformLevelStorage: """Move all tensors to *device* (in-place).""" - super().to_device(device) + super().to_device(device, non_blocking=non_blocking) return self def clone(self) -> UniformLevelStorage: @@ -1619,7 +1625,9 @@ def update_at(self, key: str, value: Any, idx: IndexType) -> None: # -- Device / copy ------------------------------------------------------ - def to_device(self, device: DeviceType) -> SegmentedLevelStorage: + def to_device( + self, device: DeviceType, *, non_blocking: bool = False + ) -> SegmentedLevelStorage: """Move all tensors (including bookkeeping) to *device*. Returns @@ -1627,14 +1635,18 @@ def to_device(self, device: DeviceType) -> SegmentedLevelStorage: Self For method chaining. """ - super().to_device(device) - self.segment_lengths = self.segment_lengths.to(device) + super().to_device(device, non_blocking=non_blocking) + self.segment_lengths = self.segment_lengths.to( + device, non_blocking=non_blocking + ) if self._batch_idx is not None: - self._batch_idx = self._batch_idx.to(device) + self._batch_idx = self._batch_idx.to(device, non_blocking=non_blocking) if self._batch_ptr is not None: - self._batch_ptr = self._batch_ptr.to(device) + self._batch_ptr = self._batch_ptr.to(device, non_blocking=non_blocking) if self._segment_indices is not None: - self._segment_indices = self._segment_indices.to(device) + self._segment_indices = self._segment_indices.to( + device, non_blocking=non_blocking + ) self._batch_ptr_np = None return self @@ -2190,7 +2202,9 @@ def is_segmented(self) -> bool: """Return ``True`` if any group is segmented.""" return any(g.is_segmented() for g in self.groups.values()) - def to_device(self, device: DeviceType) -> MultiLevelStorage: + def to_device( + self, device: DeviceType, *, non_blocking: bool = False + ) -> MultiLevelStorage: """Move all groups to *device*. Returns @@ -2201,7 +2215,7 @@ def to_device(self, device: DeviceType) -> MultiLevelStorage: device = torch.device(device) self.device = device for group in self.groups.values(): - group.to_device(device) + group.to_device(device, non_blocking=non_blocking) return self def clone(self) -> MultiLevelStorage: diff --git a/nvalchemi/distributed.py b/nvalchemi/distributed.py new file mode 100644 index 00000000..50bafc73 --- /dev/null +++ b/nvalchemi/distributed.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Recommended distributed runtime manager for nvalchemi workflows.""" + +from __future__ import annotations + +import os + +from physicsnemo.distributed import ( + DistributedManager, + PhysicsNeMoUninitializedDistributedManagerWarning, +) +from torch import distributed as dist + +__all__ = [ + "DistributedManager", + "PhysicsNeMoUninitializedDistributedManagerWarning", + "resolve_global_rank", + "resolve_world_size", +] + + +def resolve_world_size() -> int: + """Resolve world size from PhysicsNeMo, torch.distributed, or environment.""" + if DistributedManager.is_initialized(): + return int(DistributedManager().world_size) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return int(os.getenv("WORLD_SIZE", 1)) + + +def resolve_global_rank(global_rank: int | None = None) -> int: + """Resolve global rank from an explicit value, distributed state, or env.""" + if global_rank is not None: + return int(global_rank) + if DistributedManager.is_initialized(): + return int(DistributedManager().rank) + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return int(os.getenv("RANK", 0)) diff --git a/nvalchemi/dynamics/base.py b/nvalchemi/dynamics/base.py index 690ee2cc..7beaeb81 100644 --- a/nvalchemi/dynamics/base.py +++ b/nvalchemi/dynamics/base.py @@ -1475,7 +1475,7 @@ def _close_hooks(self) -> None: For hooks that support the context-manager protocol, calls ``__exit__(None, None, None)``. For hooks that only expose a - ``close()`` method (e.g. ``ProfilerHook``), calls ``close()`` + ``close()`` method (e.g. ``TorchProfilerHook``), calls ``close()`` directly. A ``seen`` set prevents double-closing hooks. Called automatically at the end of :meth:`run`. diff --git a/nvalchemi/dynamics/hooks/__init__.py b/nvalchemi/dynamics/hooks/__init__.py index 1e1df3dc..1a05f9fa 100644 --- a/nvalchemi/dynamics/hooks/__init__.py +++ b/nvalchemi/dynamics/hooks/__init__.py @@ -38,8 +38,8 @@ - Freeze selected atoms by category during dynamics. * - :mod:`cell_align` - Align periodic cells to upper-triangular form for variable-cell optimization. - * - :mod:`profiling` - - Performance profiling and step timing. + * - :mod:`nvalchemi.hooks.physicsnemo_profiling` + - PyTorch profiler trace capture through PhysicsNeMo. All hooks implement the :class:`~nvalchemi.hooks.Hook` protocol and accept a :class:`~nvalchemi.hooks.DynamicsContext` plus a stage enum in their @@ -52,9 +52,10 @@ from nvalchemi.dynamics.hooks.freeze import FreezeAtomsHook from nvalchemi.dynamics.hooks.logging import LoggingHook from nvalchemi.dynamics.hooks.monitors import EnergyDriftMonitorHook -from nvalchemi.dynamics.hooks.profiling import ProfilerHook from nvalchemi.dynamics.hooks.safety import MaxForceClampHook, NaNDetectorHook from nvalchemi.dynamics.hooks.snapshot import ConvergedSnapshotHook, SnapshotHook +from nvalchemi.hooks.physicsnemo_profiling import TorchProfilerHook +from nvalchemi.hooks.stage_timing import StageTimingHook __all__ = [ "AlignCellHook", @@ -64,6 +65,19 @@ "LoggingHook", "MaxForceClampHook", "NaNDetectorHook", - "ProfilerHook", "SnapshotHook", + "StageTimingHook", + "TorchProfilerHook", ] + +_REMOVED_PROFILER_HOOKS = {"ProfilerHook"} + + +def __getattr__(name: str) -> object: + """Raise a targeted import error for removed profiler hook names.""" + if name in _REMOVED_PROFILER_HOOKS: + raise ImportError( + f"nvalchemi.dynamics.hooks.{name} was removed. " + "Use nvalchemi.dynamics.hooks.TorchProfilerHook for PyTorch traces or nvalchemi.dynamics.hooks.StageTimingHook for per-stage timing instead." + ) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/nvalchemi/dynamics/hooks/profiling.py b/nvalchemi/dynamics/hooks/profiling.py index 49a81bab..8c83f0c3 100644 --- a/nvalchemi/dynamics/hooks/profiling.py +++ b/nvalchemi/dynamics/hooks/profiling.py @@ -12,455 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Per-stage wall-clock profiling for dynamics workflows. - -Provides :class:`ProfilerHook`, a single hook that registers at multiple -stages and records the elapsed time between consecutive stages at each -step. Supports NVTX range annotations for Nsight Systems, CSV logging, -and formatted console output via ``loguru``. - -The hook supports dynamics and custom workflows via plum dispatch, -automatically detecting the stage type and annotating NVTX ranges with -the appropriate domain (``dynamics`` or ``custom``). -""" +"""Compatibility shim for removed dynamics timing profiler hooks.""" from __future__ import annotations -import csv -import io -import statistics -import time -from enum import Enum -from pathlib import Path -from typing import Literal - -import torch -from loguru import logger -from plum import dispatch - -from nvalchemi.data import Batch -from nvalchemi.dynamics.base import DynamicsStage -from nvalchemi.hooks._context import DynamicsContext - -try: - import nvtx -except ImportError: - nvtx = None - -__all__ = ["ProfilerHook"] - - -def _sort_stages(stages: set[Enum]) -> list[Enum]: - """Sort stage enum members by their integer value.""" - return sorted(stages, key=lambda s: s.value) - - -class ProfilerHook: - """Per-stage timing hook for dynamics workflows. - - A single ``ProfilerHook`` instance registers itself at every - requested stage. On each call it records a timestamp; when the - last profiled stage in a step fires, it computes the elapsed time - between consecutive stages and (optionally) writes to CSV / console. - - The hook uses ``stages`` (plural) so that - :meth:`~nvalchemi.dynamics.base.BaseDynamics.register_hook` - registers it at all listed stages in one call. - - The hook supports :class:`DynamicsStage` and custom enum types - via plum dispatch, automatically annotating NVTX ranges with the - appropriate domain (``dynamics`` or ``custom``). - - Parameters - ---------- - profiled_stages : set[Enum] | {"all", "step", "detailed"} - Which stages to instrument. - - * ``"all"`` (default): every :class:`DynamicsStage` except ``ON_CONVERGE``. - * ``"step"``: ``BEFORE_STEP`` and ``AFTER_STEP`` only. - * ``"detailed"``: all stages from ``BEFORE_STEP`` through - ``AFTER_STEP`` (excluding ``ON_CONVERGE``). - * A custom ``set[Enum]`` for fine-grained control. - frequency : int, optional - Profile every ``frequency`` steps. Default ``1``. - enable_nvtx : bool, optional - Emit NVTX push/pop ranges for Nsight Systems. Default ``True``. - timer_backend : {"cuda_event", "perf_counter", "auto"}, optional - Timing backend. ``"auto"`` selects ``cuda_event`` on GPU - devices and ``perf_counter`` on CPU. Default ``"auto"``. - log_path : str | Path | None, optional - Path to a CSV file for persistent timing logs. Each row - records the rank, step, stage transition, wall-clock offset, - and delta. Default ``None`` (no file). - show_console : bool, optional - Print a formatted timing table via ``loguru`` at each - profiled step. Default ``False``. - console_frequency : int, optional - When ``show_console`` is ``True``, print every - ``console_frequency`` profiled steps. Default ``1``. - - Attributes - ---------- - _profiled_stages : list[Enum] - Profiled stages in execution order (private). - frequency : int - Execution frequency in steps. - timings : dict[Enum, list[float]] - Accumulated per-transition timing data (seconds). - - Examples - -------- - >>> from nvalchemi.dynamics.hooks import ProfilerHook - >>> profiler = ProfilerHook() - >>> dynamics = DemoDynamics(model=model, n_steps=100, dt=0.5, hooks=[profiler]) - >>> dynamics.run(batch) - >>> print(profiler.summary()) - - With CSV logging and console output: - - >>> profiler = ProfilerHook( - ... "detailed", - ... log_path="profiler.csv", - ... show_console=True, - ... console_frequency=10, - ... ) - >>> dynamics = DemoDynamics(model=model, n_steps=1000, dt=0.5, hooks=[profiler]) - >>> dynamics.run(batch) - """ - - def __init__( - self, - profiled_stages: set[Enum] | Literal["all", "step", "detailed"] = "all", - *, - frequency: int = 1, - enable_nvtx: bool = True, - timer_backend: Literal["cuda_event", "perf_counter", "auto"] = "auto", - log_path: str | Path | None = None, - show_console: bool = False, - console_frequency: int = 1, - stage: Enum = DynamicsStage.BEFORE_STEP, - ) -> None: - # Init file handle early so __del__ is safe on validation errors. - self._csv_file: io.TextIOWrapper | None = None - self._csv_writer: csv.DictWriter | None = None - - if isinstance(profiled_stages, str): - if profiled_stages == "all": - resolved = {s for s in DynamicsStage if s != DynamicsStage.ON_CONVERGE} - elif profiled_stages == "step": - resolved = {DynamicsStage.BEFORE_STEP, DynamicsStage.AFTER_STEP} - elif profiled_stages == "detailed": - resolved = { - DynamicsStage.BEFORE_STEP, - DynamicsStage.BEFORE_PRE_UPDATE, - DynamicsStage.AFTER_PRE_UPDATE, - DynamicsStage.BEFORE_COMPUTE, - DynamicsStage.AFTER_COMPUTE, - DynamicsStage.BEFORE_POST_UPDATE, - DynamicsStage.AFTER_POST_UPDATE, - DynamicsStage.AFTER_STEP, - } - else: - raise ValueError( - f"Unknown stages preset {profiled_stages!r}. " - f"Use 'all', 'step', 'detailed', or a set of Enum." - ) - else: - resolved = set(profiled_stages) - - if len(resolved) < 2: - raise ValueError( - "At least two stages are required to measure timing deltas." - ) - - # Primary stage for protocol compliance - self.stage = stage - # Sorted by execution order — private profiled stages list. - self._profiled_stages: list[Enum] = _sort_stages(resolved) - self.frequency = frequency - self.enable_nvtx = enable_nvtx - self.timer_backend = timer_backend - self.log_path = Path(log_path) if log_path is not None else None - self.show_console = show_console - self.console_frequency = console_frequency - - # Per-step scratch — separate dicts for type safety. - self._current_step: int = -1 - self._step_cuda_events: dict[Enum, torch.cuda.Event] = {} - self._step_cpu_timestamps: dict[Enum, int] = {} - - # Accumulated timing: transition endpoint -> list of delta_s. - self.timings: dict[Enum, list[float]] = {s: [] for s in self._profiled_stages} - - self._t0_ns: int = time.perf_counter_ns() - self._backend_resolved: str | None = None - self._steps_recorded: int = 0 - - # ------------------------------------------------------------------ - # Hook entry point - # ------------------------------------------------------------------ - - def _runs_on_stage(self, stage: Enum) -> bool: - """Check if this hook should run on the given stage. +from nvalchemi.hooks.stage_timing import StageTimingHook - Parameters - ---------- - stage : Enum - The stage to check. +__all__ = ["StageTimingHook"] - Returns - ------- - bool - True if this hook runs on the given stage. - """ - return stage in set(self._profiled_stages) +_REMOVED_HOOKS = {"ProfilerHook"} - @torch.compiler.disable - def _record( - self, - batch: Batch, - current_stage: Enum, - step_count: int, - global_rank: int, - domain: str = "dynamics", - ) -> None: - """Record a timestamp for the current stage. - Parameters - ---------- - batch : Batch - The current batch of atomic data. - current_stage : Enum - The current dynamics stage being executed. - step_count : int - The current step number. - global_rank : int - The distributed rank of this process. - domain : str, optional - The domain for NVTX annotation (e.g., "dynamics", "custom"). - Default ``"dynamics"``. - """ - # New step: flush the previous one, then reset scratch. - if step_count != self._current_step: - if self._current_step >= 0: - self._flush_step(global_rank) - self._current_step = step_count - self._step_cuda_events.clear() - self._step_cpu_timestamps.clear() - - # NVTX annotation. - if self.enable_nvtx and nvtx is not None: - idx = self._profiled_stages.index(current_stage) - if idx > 0: - nvtx.pop_range() - nvtx.push_range(f"{domain}/{current_stage.name}/{step_count}") - - # Timestamp. - dev = batch.device - if isinstance(dev, str): - dev = torch.device(dev) - if self._backend_resolved is None: - self._backend_resolved = self._resolve_backend(dev) - if self._backend_resolved == "cuda_event": - event = torch.cuda.Event(enable_timing=True) - event.record() - self._step_cuda_events[current_stage] = event - else: - self._step_cpu_timestamps[current_stage] = time.perf_counter_ns() - - # If this is the last profiled stage in the step, flush now. - if current_stage == self._profiled_stages[-1]: - self._flush_step(global_rank) - self._current_step = -1 - self._step_cuda_events.clear() - self._step_cpu_timestamps.clear() - - @dispatch - def __call__(self, ctx: DynamicsContext, stage: DynamicsStage) -> None: # noqa: F811 - """Record timing for a dynamics stage.""" - self._record( - ctx.batch, stage, ctx.step_count, ctx.global_rank or 0, domain="dynamics" +def __getattr__(name: str) -> object: + """Raise a targeted import error for removed profiler hook names.""" + if name in _REMOVED_HOOKS: + raise ImportError( + f"nvalchemi.dynamics.hooks.profiling.{name} was removed. " + "Use nvalchemi.dynamics.hooks.TorchProfilerHook for PyTorch traces or nvalchemi.dynamics.hooks.StageTimingHook for per-stage timing instead." ) - - @dispatch - def __call__(self, ctx: DynamicsContext, stage: Enum) -> None: # noqa: F811 - """Record timing for a generic stage.""" - self._record( - ctx.batch, stage, ctx.step_count, ctx.global_rank or 0, domain="custom" - ) - - # ------------------------------------------------------------------ - # Backend resolution - # ------------------------------------------------------------------ - - def _resolve_backend(self, device: torch.device) -> str: - """Resolve the timing backend based on configuration and device.""" - if self.timer_backend != "auto": - return self.timer_backend - if device.type == "cuda": - return "cuda_event" - return "perf_counter" - - # ------------------------------------------------------------------ - # Step flush — compute deltas, log - # ------------------------------------------------------------------ - - def _flush_step(self, rank: int) -> None: - """Compute per-transition deltas for the current step and log.""" - use_cuda = self._backend_resolved == "cuda_event" - - if use_cuda: - ordered = [s for s in self._profiled_stages if s in self._step_cuda_events] - else: - ordered = [ - s for s in self._profiled_stages if s in self._step_cpu_timestamps - ] - - if len(ordered) < 2: - return - - if use_cuda: - torch.cuda.synchronize() - - deltas: dict[Enum, float] = {} - for i in range(1, len(ordered)): - prev_stage, curr_stage = ordered[i - 1], ordered[i] - if use_cuda: - prev_ev = self._step_cuda_events[prev_stage] - curr_ev = self._step_cuda_events[curr_stage] - delta_s = prev_ev.elapsed_time(curr_ev) / 1000.0 - else: - prev_ts = self._step_cpu_timestamps[prev_stage] - curr_ts = self._step_cpu_timestamps[curr_stage] - delta_s = (curr_ts - prev_ts) / 1e9 - deltas[curr_stage] = delta_s - self.timings[curr_stage].append(delta_s) - - t_since_init_s = (time.perf_counter_ns() - self._t0_ns) / 1e9 - self._steps_recorded += 1 - - if self.log_path is not None: - self._write_csv(rank, self._current_step, t_since_init_s, ordered, deltas) - - if self.show_console and (self._steps_recorded % self.console_frequency == 0): - self._print_console( - rank, self._current_step, t_since_init_s, ordered, deltas - ) - - # Close NVTX range for the last stage in this step. - if self.enable_nvtx and nvtx is not None: - nvtx.pop_range() - - # ------------------------------------------------------------------ - # CSV output - # ------------------------------------------------------------------ - - def _write_csv( - self, - rank: int, - step: int, - t_since_init: float, - ordered: list[Enum], - deltas: dict[Enum, float], - ) -> None: - """Append one row per transition to the CSV log.""" - rows = [] - for i, stage in enumerate(ordered[1:], start=1): - rows.append( - { - "rank": rank, - "step": step, - "stage": f"{ordered[i - 1].name}->{stage.name}", - "t_since_init_s": f"{t_since_init:.6f}", - "delta_s": f"{deltas[stage]:.6f}", - } - ) - if self._csv_writer is None: - log_path = self.log_path - if log_path is None: - return - fh = open(log_path, "w", newline="") # noqa: SIM115 - self._csv_file = fh - self._csv_writer = csv.DictWriter( - fh, - fieldnames=["rank", "step", "stage", "t_since_init_s", "delta_s"], - ) - self._csv_writer.writeheader() - self._csv_writer.writerows(rows) - if self._csv_file is not None: - self._csv_file.flush() - - # ------------------------------------------------------------------ - # Console output - # ------------------------------------------------------------------ - - def _print_console( - self, - rank: int, - step: int, - t_since_init: float, - ordered: list[Enum], - deltas: dict[Enum, float], - ) -> None: - """Print a formatted timing table for the current step.""" - lines = [f"[Profiler] rank={rank} step={step} t={t_since_init:.3f}s"] - for i, stage in enumerate(ordered[1:], start=1): - prev_name = ordered[i - 1].name - lines.append( - f" {prev_name} -> {stage.name}: {deltas[stage] * 1000:.3f} ms" - ) - logger.info("\n".join(lines)) - - # ------------------------------------------------------------------ - # Summary / reset / close - # ------------------------------------------------------------------ - - def summary(self) -> dict[str, dict[str, float]]: - """Return per-transition timing statistics. - - Returns - ------- - dict[str, dict[str, float]] - Mapping from ``"PREV_STAGE->STAGE"`` label to a stats dict - with keys ``mean_s``, ``std_s``, ``min_s``, ``max_s``, - ``total_s``, ``n_samples``. - """ - result: dict[str, dict[str, float]] = {} - for idx, stage in enumerate(self._profiled_stages): - samples = self.timings[stage] - if not samples: - continue - prev_name = self._profiled_stages[idx - 1].name - label = f"{prev_name}->{stage.name}" - n = len(samples) - result[label] = { - "mean_s": statistics.mean(samples), - "std_s": statistics.stdev(samples) if n > 1 else 0.0, - "min_s": min(samples), - "max_s": max(samples), - "total_s": sum(samples), - "n_samples": float(n), - } - return result - - def reset(self) -> None: - """Clear all accumulated timing data.""" - for stage in self.timings: - self.timings[stage].clear() - self._step_cuda_events.clear() - self._step_cpu_timestamps.clear() - self._current_step = -1 - self._backend_resolved = None - self._t0_ns = time.perf_counter_ns() - self._steps_recorded = 0 - - def close(self) -> None: - """Flush and close the CSV log file, if open.""" - if self._csv_file is not None: - self._csv_file.close() - self._csv_file = None - self._csv_writer = None - - def __del__(self) -> None: - self.close() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/nvalchemi/hooks/__init__.py b/nvalchemi/hooks/__init__.py index 9e47f1db..0bad9ea8 100644 --- a/nvalchemi/hooks/__init__.py +++ b/nvalchemi/hooks/__init__.py @@ -17,19 +17,64 @@ from __future__ import annotations from nvalchemi.hooks._context import DynamicsContext, HookContext, TrainContext -from nvalchemi.hooks._protocol import Hook +from nvalchemi.hooks._protocol import CheckpointableHook, Hook from nvalchemi.hooks._registry import HookRegistryMixin from nvalchemi.hooks.bias import BiasedPotentialHook from nvalchemi.hooks.neighbor_list import NeighborListHook from nvalchemi.hooks.periodic import WrapPeriodicHook +from nvalchemi.hooks.reporting import ( + BaseRichLayout, + DynamicsRichLayout, + Reporter, + ReporterMessage, + ReportingErrorPolicy, + ReportingOrchestrator, + ReportingState, + RichLayout, + RichReporter, + ScalarCallback, + ScalarSnapshot, + TensorBoardReporter, + TensorBoardWriter, + TrainingRichLayout, + collect_scalars, + extract_dynamics_scalars, + extract_loss_scalars, + extract_optimizer_lr_scalars, + extract_scalars, +) +from nvalchemi.hooks.physicsnemo_profiling import TorchProfilerHook +from nvalchemi.hooks.stage_timing import StageTimingHook __all__ = [ + "BaseRichLayout", "BiasedPotentialHook", + "CheckpointableHook", "DynamicsContext", + "DynamicsRichLayout", "Hook", "HookContext", "HookRegistryMixin", "NeighborListHook", + "Reporter", + "ReporterMessage", + "ReportingErrorPolicy", + "ReportingOrchestrator", + "ReportingState", + "RichLayout", + "RichReporter", + "ScalarCallback", + "ScalarSnapshot", + "TensorBoardReporter", + "TensorBoardWriter", + "StageTimingHook", + "TorchProfilerHook", "TrainContext", + "TrainingRichLayout", "WrapPeriodicHook", + "collect_scalars", + "extract_dynamics_scalars", + "extract_loss_scalars", + "extract_optimizer_lr_scalars", + "extract_scalars", ] diff --git a/nvalchemi/hooks/_context.py b/nvalchemi/hooks/_context.py index 4602473a..96e1a740 100644 --- a/nvalchemi/hooks/_context.py +++ b/nvalchemi/hooks/_context.py @@ -16,11 +16,12 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import torch from torch.nn import ModuleDict +from torch.optim.lr_scheduler import LRScheduler if TYPE_CHECKING: from nvalchemi.data.batch import Batch @@ -37,8 +38,9 @@ class HookContext: Attributes ---------- - batch : Batch - Current batch being processed. + batch : Batch | None + Current batch being processed. ``None`` is used for lifecycle stages + that run before the first batch is available. model : BaseModelMixin | None Model being used (if applicable). global_rank : int @@ -48,7 +50,7 @@ class HookContext: the workflow does not inject itself. """ - batch: Batch + batch: Batch | None model: BaseModelMixin | None = None global_rank: int = 0 workflow: Any = None @@ -79,6 +81,11 @@ class TrainContext(HookContext): ---------- step_count : int Current optimizer step number. + batch_count : int + Number of training batches consumed, including batches whose + optimizer step was skipped by update hooks. + epoch_step_count : int + Number of batches consumed within the current training epoch. epoch : int Current training epoch. loss : torch.Tensor | None @@ -92,19 +99,37 @@ class TrainContext(HookContext): key/model mapping should be semantic, e.g. 'student' and 'teacher' in distillation workflows, with 'student' being the intended 'main' model. - optimizers : list[torch.optim.Optimizer] | None - Optimizers participating in the training step. - lr_schedulers : list[object] | None + optimizers : list[torch.optim.Optimizer] + Optimizers participating in the training step. Empty when no + optimizer is attached (e.g. eval-only or manually-driven hook + contexts); ``TrainingUpdateOrchestrator`` and similar consumers + treat an empty list as a no-op. + lr_schedulers : list[torch.optim.lr_scheduler.LRScheduler | None] Learning rate schedulers participating in the training step. + Aligned positionally with ``optimizers`` when populated; entries + may be ``None`` when an optimizer has no scheduler. Empty when no + scheduler is attached. gradients : dict[str, torch.Tensor] | None Parameter gradients for the current step. + grad_scaler : torch.amp.GradScaler | None + AMP gradient scaler for mixed-precision training; ``None`` when + AMP is not in use. + validation : dict[str, Any] | None + Latest validation summary produced by the training strategy's + validation checkpoint (``TrainingStrategy.validate()``). + ``None`` until validation has run, and on non-publishing + distributed ranks. """ step_count: int = 0 + batch_count: int = 0 + epoch_step_count: int = 0 epoch: int = 0 loss: torch.Tensor | None = None losses: dict[str, torch.Tensor] | None = None models: dict[str, BaseModelMixin] | ModuleDict | None = None - optimizers: list[torch.optim.Optimizer] | None = None - lr_schedulers: list[object] | None = None + optimizers: list[torch.optim.Optimizer] = field(default_factory=list) + lr_schedulers: list[LRScheduler | None] = field(default_factory=list) gradients: dict[str, torch.Tensor] | None = None + grad_scaler: torch.amp.GradScaler | None = None + validation: dict[str, Any] | None = None diff --git a/nvalchemi/hooks/_protocol.py b/nvalchemi/hooks/_protocol.py index c2bd7347..2b28dc04 100644 --- a/nvalchemi/hooks/_protocol.py +++ b/nvalchemi/hooks/_protocol.py @@ -16,8 +16,9 @@ from __future__ import annotations +from collections.abc import Mapping from enum import Enum -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: from nvalchemi.hooks._context import HookContext @@ -61,3 +62,23 @@ def __call__(self, ctx: HookContext, stage: Enum) -> None: The stage being dispatched. """ ... + + +@runtime_checkable +class CheckpointableHook(Protocol): + """Protocol for hooks that own restart-critical runtime state. + + Most hooks should remain stateless and omit this protocol. Hooks that + affect resumed training semantics can opt in by exposing ``state_dict`` + and ``load_state_dict``. Pydantic-backed hooks should use + ``model_dump()`` inside their ``state_dict`` implementation for + declarative fields and add only the extra runtime state they own. + """ + + def state_dict(self) -> Mapping[str, Any]: + """Return hook state to store with a training checkpoint.""" + ... + + def load_state_dict(self, state: Mapping[str, Any]) -> None: + """Restore hook state from a training checkpoint.""" + ... diff --git a/nvalchemi/hooks/_registry.py b/nvalchemi/hooks/_registry.py index 90af18f6..127f9c63 100644 --- a/nvalchemi/hooks/_registry.py +++ b/nvalchemi/hooks/_registry.py @@ -121,7 +121,7 @@ def register_hook(self, hook: Hook, stage: Enum | None = None) -> None: ) self.hooks.append(hook) - def _build_context(self, batch: Batch) -> HookContext: + def _build_context(self, batch: Batch | None) -> HookContext: """Build a base HookContext for the current state. Override in subclasses to return a workflow-specific context @@ -130,8 +130,8 @@ def _build_context(self, batch: Batch) -> HookContext: Parameters ---------- - batch : Batch - Current batch being processed. + batch : Batch | None + Current batch being processed, if available. Returns ------- @@ -149,7 +149,7 @@ def _build_context(self, batch: Batch) -> HookContext: workflow=self, ) - def _call_hooks(self, stage: Enum, batch: Batch) -> None: + def _call_hooks(self, stage: Enum, batch: Batch | None) -> None: """Call hooks registered for the given stage, gated by frequency. Hooks fire when ``self.step_count % hook.frequency == 0``. @@ -161,8 +161,8 @@ def _call_hooks(self, stage: Enum, batch: Batch) -> None: ---------- stage : Enum Current workflow stage. - batch : Batch - Current batch being processed. + batch : Batch | None + Current batch being processed, if available. """ ctx = self._build_context(batch) for hook in self.hooks: diff --git a/nvalchemi/hooks/physicsnemo_profiling.py b/nvalchemi/hooks/physicsnemo_profiling.py new file mode 100644 index 00000000..7a23e43b --- /dev/null +++ b/nvalchemi/hooks/physicsnemo_profiling.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PhysicsNeMo-backed PyTorch profiler hook.""" + +from __future__ import annotations + +from collections.abc import Callable +from enum import Enum +from pathlib import Path +from typing import Annotated, Any, ClassVar + +from physicsnemo.utils.profiling import ( + Profiler, + TorchProfilerConfig, + TorchProfileWrapper, +) +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator +from torch.profiler import ProfilerActivity + +from nvalchemi.distributed import ( + DistributedManager, + resolve_global_rank, + resolve_world_size, +) +from nvalchemi.hooks._context import HookContext + +__all__ = ["TorchProfilerHook"] + + +def _parse_activity(activity: ProfilerActivity | str) -> ProfilerActivity: + """Normalize a profiler activity enum or string alias.""" + if isinstance(activity, ProfilerActivity): + return activity + normalized = activity.lower() + match normalized: + case "cpu": + return ProfilerActivity.CPU + case "cuda": + return ProfilerActivity.CUDA + case _: + raise ValueError( + f"Unknown profiler activity {activity!r}; expected 'cpu' or 'cuda'." + ) + + +class TorchProfilerHook(BaseModel): + """Capture PyTorch profiler traces through PhysicsNeMo's profiler wrapper. + + The hook supports both training and dynamics workflows. It starts the + PhysicsNeMo profiler when entering its context, or lazily at the first + supported stage if called without a context manager. It advances the PyTorch + profiler schedule at each batch or dynamics step and finalizes traces at the + end of training or when the hook context closes. + + Parameters + ---------- + output_dir : str | Path + Root directory for profiler outputs. + activities : tuple[ProfilerActivity | str, ...] | None, optional + PyTorch profiler activities. ``None`` lets PhysicsNeMo choose CPU and + CUDA when CUDA is available. Strings may be ``"cpu"`` or ``"cuda"``. + schedule : Callable | None, optional + PyTorch profiler schedule created by :func:`torch.profiler.schedule`. + record_shapes : bool, optional + Whether to record tensor shapes. + profile_memory : bool, optional + Whether to profile memory allocations. + with_flops : bool, optional + Whether to estimate FLOPs for supported operations. + with_stack : bool, optional + Whether to record Python stack traces. + on_trace_ready_path : str | Path | None, optional + Directory passed to PyTorch's tensorboard trace handler. When provided, + it is rank-suffixed because those traces bypass PhysicsNeMo's final + ``trace.json`` export. + frequency : int, optional + Hook dispatch frequency. Keep the default ``1`` unless you explicitly + want the profiler schedule to advance less often. + rank_subdirs : bool, optional + Whether to place nvalchemi-managed outputs under ``rank_``. + Enabled by default for a consistent single- and multi-process layout. + + Attributes + ---------- + stage : Enum | None + ``None`` because this hook dispatches across training and dynamics + stages through :meth:`_runs_on_stage`. + frequency : int + Hook dispatch cadence. + """ + + output_dir: Annotated[ + Path, + Field(description="Root directory for PhysicsNeMo profiler outputs."), + ] + activities: Annotated[ + tuple[ProfilerActivity, ...] | None, + Field( + default=None, + description=( + "PyTorch profiler activities, or None to let PhysicsNeMo " + "choose CPU and CUDA when available." + ), + ), + ] = None + schedule: Annotated[ + Callable[..., Any] | None, + Field(default=None, description="Optional torch.profiler schedule."), + ] = None + record_shapes: Annotated[ + bool, Field(description="Record input tensor shapes in the trace.") + ] = True + profile_memory: Annotated[ + bool, Field(description="Profile memory allocations.") + ] = True + with_flops: Annotated[ + bool, Field(description="Estimate FLOPs for supported operations.") + ] = True + with_stack: Annotated[bool, Field(description="Record Python stack traces.")] = ( + False + ) + on_trace_ready_path: Annotated[ + Path | None, + Field( + default=None, + description="Optional path for PyTorch tensorboard trace handler output.", + ), + ] = None + frequency: Annotated[ + int, + Field( + default=1, + ge=1, + description="Run every N workflow steps.", + ), + ] = 1 + name: Annotated[ + str, + Field(default="torch", description="PhysicsNeMo profiler output name."), + ] = "torch" + rank_subdirs: Annotated[ + bool, + Field( + default=True, + description="Write nvalchemi-managed outputs under rank_.", + ), + ] = True + + stage: ClassVar[Enum | None] = None + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=False, + extra="forbid", + ) + + _profiler: Any | None = PrivateAttr(default=None) + _torch_profiler: Any | None = PrivateAttr(default=None) + _started: bool = PrivateAttr(default=False) + _closed: bool = PrivateAttr(default=False) + _entered_context: bool = PrivateAttr(default=False) + + @field_validator("activities", mode="before") + @classmethod + def _normalize_activities(cls, value: Any) -> tuple[ProfilerActivity, ...] | None: + """Normalize activity aliases before pydantic validation.""" + if value is None: + return None + if isinstance(value, (str, ProfilerActivity)): + raw_values = (value,) + else: + raw_values = tuple(value) + return tuple(_parse_activity(activity) for activity in raw_values) + + def __enter__(self) -> TorchProfilerHook: + """Enter the hook context and start profiling.""" + self._start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any, + ) -> None: + """Finalize profiler output when a workflow context exits.""" + self.close() + + def _runs_on_stage(self, stage: Enum) -> bool: + """Return whether this hook handles ``stage``. + + Parameters + ---------- + stage : Enum + Workflow stage enum value. + + Returns + ------- + bool + ``True`` for supported training and dynamics stages. + """ + from nvalchemi.dynamics.base import DynamicsStage + from nvalchemi.training._stages import TrainingStage + + match stage: + case ( + TrainingStage.BEFORE_TRAINING + | TrainingStage.BEFORE_BATCH + | TrainingStage.AFTER_BATCH + | TrainingStage.AFTER_TRAINING + | DynamicsStage.BEFORE_STEP + | DynamicsStage.AFTER_STEP + ): + return True + case _: + return False + + def __call__(self, ctx: HookContext, stage: Enum) -> None: + """Handle a supported training or dynamics stage. + + Parameters + ---------- + ctx : HookContext + Workflow context containing rank and workflow metadata. + stage : Enum + Current workflow stage. + """ + from nvalchemi.dynamics.base import DynamicsStage + from nvalchemi.training._stages import TrainingStage + + match stage: + case TrainingStage.BEFORE_TRAINING | DynamicsStage.BEFORE_STEP: + self._start(ctx) + case TrainingStage.BEFORE_BATCH if not self._started: + self._start(ctx) + case TrainingStage.AFTER_BATCH | DynamicsStage.AFTER_STEP: + if not self._started: + self._start(ctx) + if self._profiler is not None: + self._profiler.step() + case TrainingStage.AFTER_TRAINING: + self.close() + case _: + return + + def _start(self, ctx: HookContext | None = None) -> None: + """Start the PhysicsNeMo profiler.""" + if self._started: + return + if self._closed: + raise RuntimeError( + "TorchProfilerHook cannot be restarted after it has finalized." + ) + + profiler = Profiler() + if getattr(profiler, "initialized", False) or getattr( + profiler, "enabled", False + ): + raise RuntimeError( + "PhysicsNeMo Profiler is already initialized or enabled. " + "Create and register TorchProfilerHook before other " + "PhysicsNeMo profiler configuration, or finalize the existing " + "profiler before starting this hook." + ) + + rank = resolve_global_rank(None if ctx is None else ctx.global_rank) + output_path = self._resolve_output_path(rank) + trace_path = self._resolve_trace_path(rank) + output_path.mkdir(parents=True, exist_ok=True) + if trace_path is not None: + trace_path.mkdir(parents=True, exist_ok=True) + + config = TorchProfilerConfig( + name=self.name, + torch_prof_activities=self.activities, + record_shapes=self.record_shapes, + with_stack=self.with_stack, + profile_memory=self.profile_memory, + with_flops=self.with_flops, + schedule=self.schedule, + on_trace_ready_path=trace_path, + ) + torch_profiler = TorchProfileWrapper(config) + enabled_torch_profiler = profiler.enable("torch") + profiler.output_path = output_path + profiler.__enter__() + + self._profiler = profiler + self._torch_profiler = enabled_torch_profiler or torch_profiler + self._started = True + self._entered_context = True + + def _resolve_output_path(self, rank: int) -> Path: + """Return the PhysicsNeMo output path for this process.""" + output_dir = self.output_dir + if DistributedManager.is_initialized() and not DistributedManager().distributed: + return output_dir + if self.rank_subdirs or resolve_world_size() > 1: + return output_dir / f"rank_{rank}" + return output_dir + + def _resolve_trace_path(self, rank: int) -> Path | None: + """Return the rank-specific tensorboard trace path, if configured.""" + if self.on_trace_ready_path is None: + return None + return self.on_trace_ready_path / f"rank_{rank}" + + def close(self) -> None: + """Finalize profiler outputs once.""" + if not self._started: + return + if self._profiler is None: + return + if self._entered_context: + self._profiler.__exit__(None, None, None) + self._entered_context = False + self._profiler.finalize() + self._started = False + self._closed = True diff --git a/nvalchemi/hooks/reporting/__init__.py b/nvalchemi/hooks/reporting/__init__.py new file mode 100644 index 00000000..f0b6b7f1 --- /dev/null +++ b/nvalchemi/hooks/reporting/__init__.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hook-native workflow reporting.""" + +from __future__ import annotations + +from nvalchemi.hooks.reporting._orchestrator import ( + DEFAULT_REPORT_STAGES, + ReportingErrorPolicy, + ReportingOrchestrator, +) +from nvalchemi.hooks.reporting._protocol import Reporter +from nvalchemi.hooks.reporting._rich import RichReporter +from nvalchemi.hooks.reporting._scalars import ( + ScalarCallback, + ScalarSnapshot, + collect_scalars, + extract_dynamics_scalars, + extract_loss_scalars, + extract_optimizer_lr_scalars, + extract_scalars, +) +from nvalchemi.hooks.reporting._state import ReporterMessage, ReportingState +from nvalchemi.hooks.reporting._tensorboard import ( + TensorBoardReporter, + TensorBoardWriter, +) +from nvalchemi.hooks.reporting.layouts import ( + BaseRichLayout, + DynamicsRichLayout, + RichLayout, + TrainingRichLayout, +) + +__all__ = [ + "DEFAULT_REPORT_STAGES", + "BaseRichLayout", + "Reporter", + "ReporterMessage", + "ReportingErrorPolicy", + "ReportingOrchestrator", + "ReportingState", + "DynamicsRichLayout", + "RichLayout", + "RichReporter", + "ScalarCallback", + "ScalarSnapshot", + "TensorBoardReporter", + "TensorBoardWriter", + "TrainingRichLayout", + "collect_scalars", + "extract_dynamics_scalars", + "extract_loss_scalars", + "extract_optimizer_lr_scalars", + "extract_scalars", +] diff --git a/nvalchemi/hooks/reporting/_distributed.py b/nvalchemi/hooks/reporting/_distributed.py new file mode 100644 index 00000000..41c637e8 --- /dev/null +++ b/nvalchemi/hooks/reporting/_distributed.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Distributed reporting helpers.""" + +from __future__ import annotations + +from dataclasses import replace + +import torch +from torch import distributed as dist + +from nvalchemi.hooks.reporting._scalars import ScalarSnapshot + +_STRING_REDUCTIONS = { + "sum": (dist.ReduceOp.SUM, False), + "min": (dist.ReduceOp.MIN, False), + "max": (dist.ReduceOp.MAX, False), + "mean": (dist.ReduceOp.SUM, True), + "avg": (dist.ReduceOp.SUM, True), + "average": (dist.ReduceOp.SUM, True), +} + + +def reduce_scalar_snapshot( + snapshot: ScalarSnapshot, + reduction: dist.ReduceOp | str | None, + *, + reporter_name: str, +) -> ScalarSnapshot: + """Reduce snapshot scalar values across distributed ranks. + + Parameters + ---------- + snapshot : ScalarSnapshot + Local scalar snapshot. + reduction : torch.distributed.ReduceOp | str | None + Reduction operation to apply. ``None`` and ``"none"`` disable + reduction. ``"mean"``, ``"avg"``, and ``"average"`` use + :data:`torch.distributed.ReduceOp.SUM` followed by explicit world-size + division. + reporter_name : str + Reporter name used in validation error messages. + + Returns + ------- + ScalarSnapshot + Snapshot with reduced scalar values. The original snapshot is returned + unchanged outside initialized distributed runs or when ``reduction`` is + ``None``. + + Raises + ------ + RuntimeError + If PhysicsNeMo's distributed manager is unavailable, uninitialized, or + selects an unavailable CUDA device. + ValueError + If ranks report different scalar keys. + """ + op, average = normalize_rank_reduction(reduction) + if op is None: + return snapshot + if not dist.is_available() or not dist.is_initialized(): + return snapshot + keys = tuple(sorted(snapshot.scalars)) + gathered_keys: list[tuple[str, ...]] = [() for _ in range(dist.get_world_size())] + dist.all_gather_object(gathered_keys, keys) + if any(rank_keys != keys for rank_keys in gathered_keys): + raise ValueError( + f"{reporter_name} rank reduction requires every rank to report " + "the same scalar keys." + ) + if not keys: + return replace(snapshot, scalars={}) + values = torch.tensor( + [snapshot.scalars[key] for key in keys], + device=_collective_device(), + dtype=torch.float64, + ) + dist.all_reduce(values, op=op) + if average: + values /= dist.get_world_size() + reduced_values = values.cpu().tolist() + reduced_scalars = { + key: float(value) for key, value in zip(keys, reduced_values, strict=True) + } + return replace(snapshot, scalars=reduced_scalars) + + +def normalize_rank_reduction( + reduction: dist.ReduceOp | str | None, +) -> tuple[dist.ReduceOp | None, bool]: + """Normalize user-facing rank reduction input to a PyTorch reduction op. + + Parameters + ---------- + reduction : torch.distributed.ReduceOp | str | None + Reduction configuration supplied by a reporter. + + Returns + ------- + tuple[torch.distributed.ReduceOp | None, bool] + Normalized PyTorch reduction op plus whether to divide by world size + after the collective. + + Raises + ------ + ValueError + If a string reduction is not recognized. + TypeError + If ``reduction`` is not ``None``, a string, or a PyTorch + :class:`torch.distributed.ReduceOp`. + """ + if reduction is None: + return None, False + if isinstance(reduction, str): + key = reduction.lower() + if key == "none": + return None, False + try: + return _STRING_REDUCTIONS[key] + except KeyError as exc: + raise ValueError( + "rank_reduction must be None, a torch.distributed.ReduceOp, " + "or one of 'none', 'mean', 'avg', 'average', 'sum', 'min', " + "or 'max'." + ) from exc + if not isinstance(reduction, dist.ReduceOp): + raise TypeError( + "rank_reduction must be None, a string, or torch.distributed.ReduceOp." + ) + return reduction, False + + +def _collective_device() -> torch.device: + try: + from physicsnemo.distributed import DistributedManager + except ImportError as exc: + raise RuntimeError( + "Rank reduction requires physicsnemo.distributed.DistributedManager. " + "Install nvalchemi-toolkit with the PhysicsNeMo dependency set." + ) from exc + if not DistributedManager.is_initialized(): + raise RuntimeError( + "Rank reduction requires PhysicsNeMo DistributedManager to be " + "initialized before reporting. Call DistributedManager.initialize() " + "during distributed workflow setup." + ) + device = torch.device(DistributedManager().device) + if device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError( + "PhysicsNeMo DistributedManager selected a CUDA device, but CUDA " + "is not available for reporting rank reduction." + ) + return device diff --git a/nvalchemi/hooks/reporting/_orchestrator.py b/nvalchemi/hooks/reporting/_orchestrator.py new file mode 100644 index 00000000..fcaec4fc --- /dev/null +++ b/nvalchemi/hooks/reporting/_orchestrator.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hook-native reporting orchestrator.""" + +from __future__ import annotations + +import warnings +from collections.abc import Sequence +from enum import Enum +from types import TracebackType + +from torch import distributed as dist + +from nvalchemi.hooks._context import HookContext +from nvalchemi.hooks.reporting._protocol import Reporter +from nvalchemi.hooks.reporting._state import ReportingState + +ReportingStage = Enum | str + +DEFAULT_REPORT_STAGES: frozenset[str] = frozenset( + {"AFTER_OPTIMIZER_STEP", "AFTER_STEP"} +) + + +class ReportingErrorPolicy(str, Enum): + """Policy used when an individual reporter raises. + + Attributes + ---------- + RAISE : ReportingErrorPolicy + Re-raise reporter exceptions. + WARN : ReportingErrorPolicy + Emit :class:`UserWarning` and continue to later reporters. + IGNORE : ReportingErrorPolicy + Record the error in :class:`ReportingState` and continue silently. + """ + + RAISE = "raise" + WARN = "warn" + IGNORE = "ignore" + + +class ReportingOrchestrator: + """Fan out hook contexts to reporting sinks. + + ``ReportingOrchestrator`` is itself a normal hook. It uses + ``_runs_on_stage`` so it can be registered with both training and dynamics + hook registries while still choosing the workflow stages it observes. + + Parameters + ---------- + reporters : Sequence[Reporter] + Reporters to call in order for each reporting event. + frequency : int, optional + Run every ``frequency`` workflow steps, using the existing hook + registry gating. Default ``1``. + stages : set[Enum | str] | None, optional + Stages to report. Enum values are matched by identity; strings are + matched against enum member names. Defaults to + ``{"AFTER_OPTIMIZER_STEP", "AFTER_STEP"}``, which gives once-per-step + training and dynamics reporting without importing either workflow. + rank_zero_only : bool, optional + If ``True``, suppress child reporters on nonzero ranks unless they + expose ``requires_all_ranks=True`` for distributed collectives. + Individual reporters may also expose ``rank_zero_only=True`` to + request their own gating. Default ``False``. + error_policy : ReportingErrorPolicy | str, optional + Reporter failure handling policy. Default ``"raise"``. + state : ReportingState | None, optional + Shared reporting state. If omitted, a new state object is created. + """ + + def __init__( + self, + reporters: Sequence[Reporter], + *, + frequency: int = 1, + stages: set[ReportingStage] | None = None, + rank_zero_only: bool = False, + error_policy: ReportingErrorPolicy | str = ReportingErrorPolicy.RAISE, + state: ReportingState | None = None, + ) -> None: + self.reporters = list(reporters) + self.frequency = frequency + self.stage: Enum | None = None + self.rank_zero_only = rank_zero_only + self.error_policy = ReportingErrorPolicy(error_policy) + self.state = state if state is not None else ReportingState() + self._stages = frozenset( + stages if stages is not None else DEFAULT_REPORT_STAGES + ) + self._context_depth = 0 + self._entered_reporters: list[Reporter] = [] + self._disabled_reporter_ids: set[int] = set() + self._closed = False + + @property + def global_rank(self) -> int: + """Return the current distributed rank, or zero outside distributed runs.""" + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + + @property + def is_rank_zero(self) -> bool: + """Return whether this process is rank zero.""" + return self.global_rank == 0 + + def _runs_on_stage(self, stage: Enum) -> bool: + """Return whether reporting should run for ``stage``. + + Parameters + ---------- + stage : Enum + Hook stage under consideration. + + Returns + ------- + bool + ``True`` when the orchestrator should receive this stage. + """ + return stage in self._stages or stage.name in self._stages + + def __call__(self, ctx: HookContext, stage: Enum) -> None: + """Dispatch one hook event to child reporters. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. + stage : Enum + Hook stage being dispatched. + """ + active_reporters = [ + reporter + for reporter in self.reporters + if id(reporter) not in self._disabled_reporter_ids + and not self._skip_reporter_for_rank(reporter) + ] + if not active_reporters: + return + self.state.mark_event(ctx, stage) + for reporter in active_reporters: + try: + reporter.report(ctx, stage, self.state) + except Exception as exc: + self._handle_reporter_error( + reporter, exc, ctx, stage, operation="report" + ) + + def __enter__(self) -> ReportingOrchestrator: + """Enter reporters that implement the context manager protocol.""" + if self._context_depth > 0: + self._context_depth += 1 + return self + self._closed = False + self._entered_reporters = [] + self._disabled_reporter_ids = set() + for reporter in self.reporters: + if self._skip_reporter_for_rank(reporter): + self._disabled_reporter_ids.add(id(reporter)) + continue + enter = getattr(reporter, "__enter__", None) + if enter is not None: + try: + enter() + except Exception as exc: + self._disabled_reporter_ids.add(id(reporter)) + self._handle_enter_error(reporter, exc) + else: + self._entered_reporters.append(reporter) + self._context_depth = 1 + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Exit reporters without replacing an active workflow exception.""" + if self._context_depth == 0: + return + self._context_depth -= 1 + if self._context_depth > 0: + return + self._finish_close(exc_type, exc, tb) + + def close(self) -> None: + """Close reporters in reverse order.""" + self._finish_close(None, None, None) + + def _reporter_rank_zero_only(self, reporter: Reporter) -> bool: + """Return whether ``reporter`` requests rank-zero-only dispatch.""" + return bool(getattr(reporter, "rank_zero_only", False)) + + def _reporter_requires_all_ranks(self, reporter: Reporter) -> bool: + """Return whether ``reporter`` must be dispatched on every rank.""" + return bool(getattr(reporter, "requires_all_ranks", False)) + + def _skip_reporter_for_rank(self, reporter: Reporter) -> bool: + """Return whether ``reporter`` should be skipped on this rank.""" + if self.is_rank_zero: + return False + if self._reporter_requires_all_ranks(reporter): + return False + return self.rank_zero_only or self._reporter_rank_zero_only(reporter) + + def _handle_enter_error(self, reporter: Reporter, exc: Exception) -> None: + """Handle a reporter ``__enter__`` failure.""" + if self.error_policy == ReportingErrorPolicy.RAISE: + try: + self._close_reporters( + list(self._entered_reporters), + type(exc), + exc, + exc.__traceback__, + preserve_workflow_exception=True, + ) + finally: + self._entered_reporters = [] + self._closed = True + self._handle_reporter_error(reporter, exc, operation="enter") + + def _finish_close( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Close reporters once and reset lifecycle state.""" + if self._closed: + self._context_depth = 0 + return + try: + self._close_reporters( + self.reporters, + exc_type, + exc, + tb, + preserve_workflow_exception=exc_type is not None, + ) + finally: + self._entered_reporters = [] + self._context_depth = 0 + self._closed = True + + def _close_reporters( + self, + reporters: Sequence[Reporter], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + *, + preserve_workflow_exception: bool = False, + ) -> None: + """Close reporters, preferring ``__exit__`` for entered reporters.""" + errors: list[tuple[str, Exception]] = [] + entered_ids = {id(reporter) for reporter in self._entered_reporters} + for reporter in reversed(reporters): + reporter_id = id(reporter) + if ( + reporter_id in self._disabled_reporter_ids + and reporter_id not in entered_ids + ): + continue + exit_fn = getattr(reporter, "__exit__", None) + close_fn = getattr(reporter, "close", None) + was_entered = reporter_id in entered_ids + if (not was_entered or exit_fn is None) and close_fn is None: + continue + try: + if was_entered and exit_fn is not None: + exit_fn(exc_type, exc, tb) + else: + close_fn() + except Exception as close_exc: + message = self._record_reporter_error( + reporter, + close_exc, + operation="close", + ) + errors.append((message, close_exc)) + self._apply_close_error_policy(errors, preserve_workflow_exception) + + def _apply_close_error_policy( + self, + errors: Sequence[tuple[str, Exception]], + preserve_workflow_exception: bool, + ) -> None: + """Apply failure policy after all close attempts have completed.""" + if not errors or self.error_policy == ReportingErrorPolicy.IGNORE: + return + if ( + self.error_policy == ReportingErrorPolicy.WARN + or preserve_workflow_exception + ): + for message, _ in errors: + warnings.warn(message, UserWarning, stacklevel=2) + return + raise errors[0][1] + + def _handle_reporter_error( + self, + reporter: Reporter, + exc: Exception, + ctx: HookContext | None = None, + stage: Enum | None = None, + *, + operation: str, + preserve_workflow_exception: bool = False, + ) -> None: + """Apply the configured reporter failure policy.""" + message = self._record_reporter_error( + reporter, + exc, + ctx=ctx, + stage=stage, + operation=operation, + ) + if self.error_policy == ReportingErrorPolicy.IGNORE: + return + if ( + self.error_policy == ReportingErrorPolicy.WARN + or preserve_workflow_exception + ): + warnings.warn(message, UserWarning, stacklevel=2) + return + raise exc + + def _record_reporter_error( + self, + reporter: Reporter, + exc: Exception, + ctx: HookContext | None = None, + stage: Enum | None = None, + *, + operation: str, + ) -> str: + """Record a reporter error message and return its text.""" + message = ( + f"{type(reporter).__name__} failed during {operation}: " + f"{type(exc).__name__}: {exc}" + ) + self.state.add_message( + "error", + message, + reporter=reporter, + ctx=ctx, + stage=stage, + ) + return message diff --git a/nvalchemi/hooks/reporting/_protocol.py b/nvalchemi/hooks/reporting/_protocol.py new file mode 100644 index 00000000..97ec0a68 --- /dev/null +++ b/nvalchemi/hooks/reporting/_protocol.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Reporter protocol definitions.""" + +from __future__ import annotations + +from enum import Enum +from typing import Protocol, runtime_checkable + +from nvalchemi.hooks._context import HookContext +from nvalchemi.hooks.reporting._state import ReportingState + + +@runtime_checkable +class Reporter(Protocol): + """Protocol for reporting sinks owned by ``ReportingOrchestrator``. + + Reporters consume the existing hook context directly. They should not + require the orchestrator to construct separate workflow event objects. + Reporters may optionally expose ``rank_zero_only: bool`` to request + per-reporter rank gating. Reporters that run distributed collectives must + expose ``requires_all_ranks: bool`` so orchestrator-level rank gating does + not skip nonzero ranks before a collective. Reporters may also implement + ``__enter__``, ``__exit__``, or ``close`` for resource lifecycle + management. + """ + + def report(self, ctx: HookContext, stage: Enum, state: ReportingState) -> None: + """Consume one reporting event. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. + stage : Enum + Hook stage being reported. + state : ReportingState + Shared reporting state for the orchestrator run. + """ + ... diff --git a/nvalchemi/hooks/reporting/_rich.py b/nvalchemi/hooks/reporting/_rich.py new file mode 100644 index 00000000..13e3d5c7 --- /dev/null +++ b/nvalchemi/hooks/reporting/_rich.py @@ -0,0 +1,489 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rich live dashboard reporting sink.""" + +from __future__ import annotations + +from collections import deque +from collections.abc import Mapping, Sequence +from enum import Enum +from types import TracebackType + +from rich.console import Console +from rich.layout import Layout +from rich.live import Live +from torch import distributed as dist + +from nvalchemi.hooks._context import DynamicsContext, HookContext, TrainContext +from nvalchemi.hooks.reporting._distributed import ( + normalize_rank_reduction, + reduce_scalar_snapshot, +) +from nvalchemi.hooks.reporting._scalars import ( + ScalarCallback, + ScalarSnapshot, + collect_scalars, +) +from nvalchemi.hooks.reporting._state import ReportingState +from nvalchemi.hooks.reporting.layouts import ( + DynamicsRichLayout, + RichLayout, + TrainingRichLayout, + resolve_rich_layout, +) + +_PREVIEW_DEFAULT = object() + + +class RichReporter: + """Render scalar reporting snapshots as a live Rich dashboard. + + Parameters + ---------- + custom_scalars : Mapping[str, ScalarCallback] | None, optional + Additional scalar callbacks passed to :func:`collect_scalars`. + include_losses : bool, default True + When ``True``, include loss scalars from the hook context. + include_optimizer_lrs : bool, default True + When ``True``, include optimizer learning rates from the hook context. + include_dynamics_scalars : bool | None, optional + When ``True``, include default dynamics observables from the hook + context. ``None`` lets the selected layout choose; the built-in + dynamics layout enables them. + rank_reduction : torch.distributed.ReduceOp | {"none", "mean", "sum", "min", "max"} | None, default None + Optional distributed reduction applied to scalars before rendering. + String values are normalized to :class:`torch.distributed.ReduceOp`. + Reduction requires every rank to call this reporter; only rank zero + renders the reduced dashboard. + title : str, default "nvalchemi report" + Dashboard title. + precision : int, default 6 + Significant digits used when formatting scalar values. + max_scalars : int | None, optional + Maximum number of scalar rows to show. When omitted, all scalars are + shown. + history_size : int, default 200 + Maximum history points retained per scalar. + layout : RichLayout | {"training", "dynamics"} | None, optional + Dashboard layout policy. ``None`` and ``"auto"`` select the first + built-in layout that supports the first reported context. + plot_keys : Sequence[str] | None, optional + Scalar keys to plot. When omitted, the selected layout chooses common + metrics for that workflow before falling back to alphabetical order. + max_plots : int, default 3 + Maximum number of history plots shown in the dashboard. + plot_height : int, default 8 + Height in terminal rows for each plotext plot. + refresh_per_second : float, default 2.0 + Rich ``Live`` refresh rate used while the reporter is entered. + console : Console | None, optional + Rich console used for output. When omitted, a stderr console is created. + screen : bool, default False + Whether Rich ``Live`` should use the terminal alternate screen. + transient : bool, default False + Whether Rich ``Live`` should clear the dashboard on exit. + rank_zero_only : bool, default True + Request rank-zero-only dispatch from :class:`ReportingOrchestrator`. + strict_layout : bool, default False + When ``True``, raise if automatic layout selection cannot match the + incoming context. When ``False``, unmatched contexts are ignored. + """ + + def __init__( + self, + *, + custom_scalars: Mapping[str, ScalarCallback] | None = None, + include_losses: bool = True, + include_optimizer_lrs: bool = True, + include_dynamics_scalars: bool | None = None, + rank_reduction: dist.ReduceOp | str | None = None, + title: str = "nvalchemi report", + precision: int = 6, + max_scalars: int | None = None, + history_size: int = 200, + layout: RichLayout | str | None = None, + plot_keys: Sequence[str] | None = None, + max_plots: int = 3, + plot_height: int = 8, + refresh_per_second: float = 2.0, + console: Console | None = None, + screen: bool = False, + transient: bool = False, + rank_zero_only: bool = True, + strict_layout: bool = False, + ) -> None: + if precision < 0: + raise ValueError("RichReporter precision must be non-negative.") + if max_scalars is not None and max_scalars < 1: + raise ValueError("RichReporter max_scalars must be positive.") + if history_size < 1: + raise ValueError("RichReporter history_size must be positive.") + if max_plots < 0: + raise ValueError("RichReporter max_plots must be non-negative.") + if plot_height < 4: + raise ValueError("RichReporter plot_height must be at least 4.") + if refresh_per_second <= 0: + raise ValueError("RichReporter refresh_per_second must be positive.") + self.custom_scalars = custom_scalars + self.include_losses = include_losses + self.include_optimizer_lrs = include_optimizer_lrs + self.rank_reduction = rank_reduction + self._rank_reduction_op, _ = normalize_rank_reduction(rank_reduction) + self.title = title + self.precision = precision + self.max_scalars = max_scalars + self.history_size = history_size + self._auto_layout = layout is None or layout == "auto" + self._layout_selected = not self._auto_layout + self.layout = ( + TrainingRichLayout() if self._auto_layout else resolve_rich_layout(layout) + ) + self._include_dynamics_scalars_override = include_dynamics_scalars + self.include_dynamics_scalars = ( + bool(getattr(self.layout, "include_dynamics_scalars", False)) + if include_dynamics_scalars is None + else include_dynamics_scalars + ) + self.plot_keys = tuple(plot_keys) if plot_keys is not None else None + self.max_plots = max_plots + self.plot_height = plot_height + self.refresh_per_second = refresh_per_second + self.console = console if console is not None else Console(stderr=True) + self.screen = screen + self.transient = transient + self.strict_layout = strict_layout + self._write_rank_zero_only = ( + rank_zero_only or self._rank_reduction_op is not None + ) + self.rank_zero_only = rank_zero_only and self._rank_reduction_op is None + self.requires_all_ranks = self._rank_reduction_op is not None + self._history: dict[str, deque[tuple[int, float]]] = {} + self._latest_snapshot: ScalarSnapshot | None = None + self._live: Live | None = None + self._entered = False + + @classmethod + def preview( + cls, + *, + history: Mapping[str, Sequence[float]] | None = None, + layout: RichLayout | str | None = None, + steps: Sequence[int] | None = None, + console: Console | None = None, + stage: str | None = None, + step_count: int | None = None, + epoch: int | None | object = _PREVIEW_DEFAULT, + batch_count: int | None | object = _PREVIEW_DEFAULT, + **reporter_kwargs: object, + ) -> None: + """Render a synthetic dashboard preview. + + Parameters + ---------- + history : Mapping[str, Sequence[float]] | None, optional + Metric history used to populate plots and latest values. Defaults to + representative curves from the selected layout. + layout : RichLayout | {"training", "dynamics"} | None, optional + Dashboard layout policy. ``None`` selects the training layout. + steps : Sequence[int] | None, optional + Step values aligned with each history sequence. Defaults to + ``range(len(series))``. + console : Console | None, optional + Rich console used for preview output. + stage : str | None, optional + Stage label shown in the dashboard header. When omitted, the + selected layout supplies a workflow-appropriate default. + step_count : int | None, optional + Step shown in the dashboard header. Defaults to the final step. + epoch : int | None, optional + Epoch shown in dashboard metadata. When omitted, the selected + layout supplies a workflow-appropriate default. + batch_count : int | None, optional + Batch count shown in dashboard metadata. When omitted, the + selected layout supplies a workflow-appropriate default. + **reporter_kwargs : object + Additional keyword arguments forwarded to :class:`RichReporter`. + """ + reporter = cls( + console=console, + layout=layout, + rank_zero_only=False, + **reporter_kwargs, + ) + if reporter._auto_layout: + reporter._set_layout(TrainingRichLayout()) + reporter.seed_history( + reporter.layout.default_preview_history() if history is None else history, + steps=steps, + stage=stage + if stage is not None + else reporter.layout.default_preview_stage(), + step_count=step_count, + epoch=reporter.layout.default_preview_epoch() + if epoch is _PREVIEW_DEFAULT + else epoch, + batch_count=reporter.layout.default_preview_batch_count() + if batch_count is _PREVIEW_DEFAULT + else batch_count, + ) + reporter.console.print(reporter.renderable()) + + @property + def history(self) -> dict[str, tuple[tuple[int, float], ...]]: + """Return retained scalar history. + + Returns + ------- + dict[str, tuple[tuple[int, float], ...]] + Mapping from scalar key to ``(step, value)`` history tuples. + """ + return {key: tuple(values) for key, values in self._history.items()} + + def __enter__(self) -> RichReporter: + """Start the live dashboard.""" + if self._entered: + return self + self._entered = True + if self._rank_reduction_op is None and not ( + self._auto_layout and not self._layout_selected + ): + self._start_live() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Stop the live dashboard.""" + self.close() + + def close(self) -> None: + """Stop the live dashboard if it is active.""" + if self._live is None: + self._entered = False + return + self._live.stop() + self._live = None + self._entered = False + + def report(self, ctx: HookContext, stage: Enum, state: ReportingState) -> None: + """Update the dashboard from one scalar snapshot. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. + stage : Enum + Hook stage being reported. + state : ReportingState + Shared reporting state from the orchestrator. + """ + if not self._ensure_layout(ctx, stage): + return + snapshot = collect_scalars( + ctx, + stage, + state, + custom_scalars=self.custom_scalars, + include_losses=self.include_losses, + include_optimizer_lrs=self.include_optimizer_lrs, + include_dynamics=self.include_dynamics_scalars, + include_progress=True, + ) + if self._rank_reduction_op is not None: + snapshot = reduce_scalar_snapshot( + snapshot, + self.rank_reduction, + reporter_name=type(self).__name__, + ) + if not self._is_rank_zero(ctx): + return + elif self._write_rank_zero_only and not self._is_rank_zero(ctx): + return + self._record_snapshot(snapshot) + renderable = self.renderable() + if self._live is not None: + self._live.update(renderable, refresh=False) + elif self._entered: + self._start_live(renderable) + else: + self.console.print(renderable) + + def seed_history( + self, + history: Mapping[str, Sequence[float]], + *, + steps: Sequence[int] | None = None, + stage: str = "AFTER_OPTIMIZER_STEP", + step_count: int | None = None, + epoch: int | None = None, + batch_count: int | None = None, + global_rank: int = 0, + ) -> ScalarSnapshot: + """Seed dashboard history without running a workflow. + + Parameters + ---------- + history : Mapping[str, Sequence[float]] + Metric history used to populate plots and latest scalar values. + steps : Sequence[int] | None, optional + Step values aligned with each metric series. + stage : str, default "AFTER_OPTIMIZER_STEP" + Stage label for the synthetic snapshot. + step_count : int | None, optional + Step count for the synthetic snapshot. Defaults to the final step. + epoch : int | None, optional + Epoch metadata for the synthetic snapshot. + batch_count : int | None, optional + Batch metadata for the synthetic snapshot. + global_rank : int, default 0 + Rank metadata for the synthetic snapshot. + + Returns + ------- + ScalarSnapshot + Synthetic latest snapshot produced from ``history``. + """ + if not history: + raise ValueError("RichReporter preview history cannot be empty.") + first_values = next(iter(history.values())) + if not first_values: + raise ValueError( + "RichReporter preview history cannot contain empty series." + ) + if steps is None: + resolved_steps = tuple(range(len(first_values))) + else: + resolved_steps = tuple(steps) + if len(resolved_steps) != len(first_values): + raise ValueError("RichReporter preview steps must match series length.") + self._history = {} + latest_scalars: dict[str, float] = {} + for key, values in history.items(): + if len(values) != len(resolved_steps): + raise ValueError("RichReporter preview series lengths must match.") + numeric_values = tuple(float(value) for value in values) + self._history[key] = deque( + zip(resolved_steps, numeric_values, strict=True), + maxlen=self.history_size, + ) + latest_scalars[key] = numeric_values[-1] + resolved_step_count = ( + step_count if step_count is not None else resolved_steps[-1] + ) + snapshot = ScalarSnapshot( + stage=stage, + scalars=latest_scalars, + step_count=resolved_step_count, + batch_count=batch_count, + epoch=epoch, + global_rank=global_rank, + ) + self._latest_snapshot = snapshot + return snapshot + + def renderable(self) -> Layout: + """Build the current dashboard renderable. + + Returns + ------- + Layout + Rich layout containing the header, latest scalar table, and plots. + """ + return self.layout.render( + self._latest_snapshot, + self.history, + title=self.title, + precision=self.precision, + max_scalars=self.max_scalars, + plot_keys=self.plot_keys, + max_plots=self.max_plots, + plot_height=self.plot_height, + ) + + def _ensure_layout(self, ctx: HookContext, stage: Enum) -> bool: + if not self._auto_layout: + return True + if self._layout_selected: + return True + if isinstance(ctx, DynamicsContext) or stage.name == "AFTER_STEP": + self._set_layout(DynamicsRichLayout()) + return True + if isinstance(ctx, TrainContext) or _looks_like_training_context(ctx, stage): + self._set_layout(TrainingRichLayout()) + return True + if self.strict_layout: + raise ValueError( + "RichReporter could not select a layout for " + f"context {type(ctx).__name__} at stage {stage.name!r}." + ) + return False + + def _set_layout(self, layout: RichLayout) -> None: + self.layout = layout + self._layout_selected = True + if self._include_dynamics_scalars_override is None: + self.include_dynamics_scalars = bool( + getattr(self.layout, "include_dynamics_scalars", False) + ) + + def _record_snapshot(self, snapshot: ScalarSnapshot) -> None: + self._latest_snapshot = snapshot + step = self._history_step(snapshot) + for key, value in snapshot.scalars.items(): + if key not in self._history: + self._history[key] = deque(maxlen=self.history_size) + self._history[key].append((step, value)) + + def _history_step(self, snapshot: ScalarSnapshot) -> int: + if snapshot.step_count is not None: + return snapshot.step_count + if snapshot.event_count is not None: + return snapshot.event_count + lengths = [len(values) for values in self._history.values()] + return max(lengths, default=0) + + def _is_rank_zero(self, ctx: HookContext) -> bool: + return ctx.global_rank == 0 + + def _start_live(self, renderable: Layout | None = None) -> None: + if self._live is not None: + return + self._live = Live( + renderable if renderable is not None else self.renderable(), + console=self.console, + refresh_per_second=self.refresh_per_second, + screen=self.screen, + transient=self.transient, + ) + self._live.start() + + +def _looks_like_training_context(ctx: HookContext, stage: Enum) -> bool: + if stage.name == "AFTER_OPTIMIZER_STEP": + return True + return any( + hasattr(ctx, name) + for name in ( + "loss", + "losses", + "optimizers", + "lr_schedulers", + "batch_count", + "epoch_step_count", + ) + ) diff --git a/nvalchemi/hooks/reporting/_scalars.py b/nvalchemi/hooks/reporting/_scalars.py new file mode 100644 index 00000000..378cac62 --- /dev/null +++ b/nvalchemi/hooks/reporting/_scalars.py @@ -0,0 +1,726 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Scalar extraction helpers for reporting sinks.""" + +from __future__ import annotations + +import time +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass, field +from enum import Enum +from numbers import Real +from typing import TypeAlias + +import torch + +from nvalchemi.hooks._context import HookContext +from nvalchemi.hooks.reporting._state import ReporterMessage, ReportingState + +ScalarCallback: TypeAlias = Callable[[HookContext, Enum], object] + +_COMPONENT_SCALAR_SPECS = ( + ("per_component_unweighted", "unweighted"), + ("per_component_weight", "weight"), + ("per_component_raw_weight", "raw_weight"), +) + + +@dataclass(frozen=True, kw_only=True) +class ScalarSnapshot: + """Scalar reporting payload for one hook event. + + Attributes + ---------- + stage : str + Hook stage name associated with the snapshot. + scalars : dict[str, float] + Flat scalar mapping using slash-separated semantic keys. + timestamp_s : float + Wall-clock timestamp from :func:`time.time`. + elapsed_s : float | None + Seconds since the reporting state was created, when available. + event_count : int | None + Reporting orchestrator event count, when available. + step_count : int | None + Workflow step count from the hook context, when available. + batch_count : int | None + Training batch count from the hook context, when available. + epoch_step_count : int | None + Training epoch-local batch count from the hook context, when available. + epoch : int | None + Training epoch from the hook context, when available. + global_rank : int + Distributed rank from the hook context. + messages : tuple[ReporterMessage, ...] + Recent reporting messages captured from the shared reporting state. + """ + + stage: str + scalars: dict[str, float] + timestamp_s: float = field(default_factory=time.time) + elapsed_s: float | None = None + event_count: int | None = None + step_count: int | None = None + batch_count: int | None = None + epoch_step_count: int | None = None + epoch: int | None = None + global_rank: int = 0 + messages: tuple[ReporterMessage, ...] = () + + def as_dict(self) -> dict[str, object]: + """Return a JSON-ready dictionary representation. + + Returns + ------- + dict[str, object] + Snapshot metadata plus the scalar mapping. + """ + return { + "stage": self.stage, + "timestamp_s": self.timestamp_s, + "elapsed_s": self.elapsed_s, + "event_count": self.event_count, + "step_count": self.step_count, + "batch_count": self.batch_count, + "epoch_step_count": self.epoch_step_count, + "epoch": self.epoch, + "global_rank": self.global_rank, + "messages": [ + { + "level": message.level, + "message": message.message, + "reporter": message.reporter, + "stage": message.stage, + "step_count": message.step_count, + "global_rank": message.global_rank, + "timestamp_s": message.timestamp_s, + } + for message in self.messages + ], + "scalars": dict(self.scalars), + } + + +def collect_scalars( + ctx: HookContext, + stage: Enum, + state: ReportingState | None = None, + *, + custom_scalars: Mapping[str, ScalarCallback] | None = None, + include_losses: bool = True, + include_optimizer_lrs: bool = True, + include_dynamics: bool = False, + include_progress: bool = False, +) -> ScalarSnapshot: + """Collect scalar values from a hook context. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. + stage : Enum + Hook stage being reported. + state : ReportingState | None, optional + Shared reporting state used for event count and elapsed time metadata. + custom_scalars : Mapping[str, ScalarCallback] | None, optional + Additional scalar callbacks. Each callback receives ``(ctx, stage)`` and + may return either a scalar value or a nested mapping of scalar values. + include_losses : bool, default True + When ``True``, extract ``ctx.loss`` and ``ctx.losses`` values. + include_optimizer_lrs : bool, default True + When ``True``, extract learning rates from optimizer parameter groups. + include_dynamics : bool, default False + When ``True``, extract default dynamics observables from the current + batch and dynamics context. + include_progress : bool, default False + When ``True``, extract workflow progress, throughput, and ETA scalars + from context counters and workflow metadata when available. + + Returns + ------- + ScalarSnapshot + Snapshot containing flat scalar keys and hook metadata. + + Raises + ------ + TypeError + If a custom scalar or loss value has an unsupported type. + ValueError + If a value expected to be scalar contains multiple elements. + """ + scalars: dict[str, float] = {} + if include_losses: + scalars.update(extract_loss_scalars(ctx)) + if include_optimizer_lrs: + scalars.update(extract_optimizer_lr_scalars(ctx)) + scalars.update(extract_scheduler_lr_scalars(ctx)) + if include_dynamics: + scalars.update(extract_dynamics_scalars(ctx)) + if include_progress: + scalars.update(extract_progress_scalars(ctx, state)) + if custom_scalars is not None: + for name, callback in custom_scalars.items(): + value = callback(ctx, stage) + if value is None: + continue + if isinstance(value, Mapping): + scalars.update(extract_scalars(value, prefix=name)) + else: + scalars[name] = _to_float(value, name) + + return ScalarSnapshot( + stage=stage.name, + scalars=scalars, + elapsed_s=time.monotonic() - state.started_at_s if state is not None else None, + event_count=state.event_count if state is not None else None, + step_count=getattr(ctx, "step_count", None), + batch_count=getattr(ctx, "batch_count", None), + epoch_step_count=getattr(ctx, "epoch_step_count", None), + epoch=getattr(ctx, "epoch", None), + global_rank=ctx.global_rank, + messages=tuple(state.messages) if state is not None else (), + ) + + +def extract_loss_scalars(ctx: HookContext) -> dict[str, float]: + """Extract scalar loss values from a hook context. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. Training contexts may expose ``loss`` and + ``losses`` attributes. + + Returns + ------- + dict[str, float] + Flat loss scalar mapping. Composed-loss outputs use keys such as + ``loss/energy/unweighted`` and ``loss/energy/weight``. + + Raises + ------ + TypeError + If a loss value has an unsupported type. + ValueError + If a value expected to be scalar contains multiple elements. + """ + scalars: dict[str, float] = {} + loss = getattr(ctx, "loss", None) + if loss is not None: + scalars["loss/total"] = _to_float(loss, "loss/total") + + losses = getattr(ctx, "losses", None) + if losses is None: + return scalars + if not isinstance(losses, Mapping): + raise TypeError(f"ctx.losses must be a mapping, got {type(losses).__name__}.") + + if "total_loss" in losses: + scalars["loss/total"] = _to_float(losses["total_loss"], "loss/total") + for source_key, target_suffix in _COMPONENT_SCALAR_SPECS: + _extract_component_scalars( + losses, + source_key=source_key, + target_suffix=target_suffix, + scalars=scalars, + ) + _extract_component_sample_means(losses, scalars) + + composed_loss_keys = _composed_loss_keys() + for name, value in losses.items(): + if name in composed_loss_keys: + continue + if value is None: + continue + if isinstance(value, Mapping): + scalars.update(extract_scalars(value, prefix=f"loss/{name}")) + else: + scalars[f"loss/{name}"] = _to_float(value, f"loss/{name}") + return scalars + + +def extract_dynamics_scalars(ctx: HookContext) -> dict[str, float]: + """Extract scalar dynamics observables from a hook context. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. Dynamics contexts may expose ``converged_mask`` + and batches may expose energy, force, velocity, mass, and status fields. + + Returns + ------- + dict[str, float] + Flat scalar mapping containing available dynamics observables. + + Raises + ------ + ValueError + If a present tensor cannot be reduced because it is empty. + """ + batch = ctx.batch + scalars: dict[str, float] = {} + + energy = _get_tensor_attr(batch, "energy") + if energy is not None: + scalars["energy"] = _tensor_mean_to_float(energy, "energy") + + forces = _get_tensor_attr(batch, "forces") + if forces is not None: + if forces.numel() == 0: + raise ValueError("'fmax' cannot reduce an empty forces tensor.") + scalars["fmax"] = _to_float( + torch.linalg.vector_norm(forces.detach(), dim=-1).max(), + "fmax", + ) + + temperature = _temperature_scalar(batch) + if temperature is not None: + scalars["temperature"] = temperature + + converged_mask = _get_tensor_attr(ctx, "converged_mask") + if converged_mask is not None: + scalars["converged_fraction"] = _tensor_mean_to_float( + converged_mask.float(), + "converged_fraction", + ) + scalars["dynamics/converged_count"] = float( + int(converged_mask.detach().to(device="cpu", dtype=torch.bool).sum().item()) + ) + + active_fraction = _active_fraction_scalar(ctx) + if active_fraction is not None: + scalars["active_fraction"] = active_fraction + + scalars.update(_status_count_scalars(ctx)) + + return scalars + + +def extract_optimizer_lr_scalars(ctx: HookContext) -> dict[str, float]: + """Extract learning-rate scalars from optimizer parameter groups. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. Training contexts may expose an ``optimizers`` + sequence. + + Returns + ------- + dict[str, float] + Flat optimizer learning-rate mapping. + """ + optimizers = getattr(ctx, "optimizers", None) + if not optimizers: + return {} + scalars: dict[str, float] = {} + for optimizer_idx, optimizer in enumerate(optimizers): + param_groups = getattr(optimizer, "param_groups", ()) + for group_idx, group in enumerate(param_groups): + if "lr" not in group: + continue + key = _optimizer_lr_key( + optimizer_count=len(optimizers), + optimizer_idx=optimizer_idx, + group_count=len(param_groups), + group_idx=group_idx, + ) + scalars[key] = _to_float(group["lr"], key) + return scalars + + +def extract_scheduler_lr_scalars(ctx: HookContext) -> dict[str, float]: + """Extract learning-rate scalars from scheduler state. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. Training contexts may expose an + ``lr_schedulers`` sequence. + + Returns + ------- + dict[str, float] + Flat scheduler learning-rate mapping. + """ + schedulers = getattr(ctx, "lr_schedulers", None) + if not schedulers: + return {} + scheduler_slots = list(schedulers) + optimizer_count = len(getattr(ctx, "optimizers", None) or []) + scheduler_count = max(len(scheduler_slots), optimizer_count) + scalars: dict[str, float] = {} + for scheduler_idx, scheduler in enumerate(scheduler_slots): + if scheduler is None: + continue + get_last_lr = getattr(scheduler, "get_last_lr", None) + if not callable(get_last_lr): + continue + lrs = get_last_lr() + if not isinstance(lrs, Sequence): + continue + for group_idx, lr in enumerate(lrs): + key = _scheduler_lr_key( + scheduler_count=scheduler_count, + scheduler_idx=scheduler_idx, + group_count=len(lrs), + group_idx=group_idx, + ) + scalars[key] = _to_float(lr, key) + return scalars + + +def extract_progress_scalars( + ctx: HookContext, + state: ReportingState | None, +) -> dict[str, float]: + """Extract workflow progress, throughput, and ETA scalars. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. + state : ReportingState | None + Shared reporting state used for elapsed time. + + Returns + ------- + dict[str, float] + Flat scalar mapping containing available progress metrics. + """ + scalars: dict[str, float] = {} + workflow = getattr(ctx, "workflow", None) + elapsed_s = ( + time.monotonic() - state.started_at_s + if state is not None and state.started_at_s is not None + else None + ) + step_count = _nonnegative_int(getattr(ctx, "step_count", None)) + batch_count = _nonnegative_int(getattr(ctx, "batch_count", None)) + + if _is_training_context(ctx): + _add_rate_scalar(scalars, "training/steps_per_s", step_count, elapsed_s) + _add_rate_scalar(scalars, "training/batches_per_s", batch_count, elapsed_s) + target_steps = _positive_int_attr(workflow, "num_steps") + if target_steps is not None: + _add_target_progress( + scalars, + prefix="training", + completed=step_count, + target=target_steps, + elapsed_s=elapsed_s, + ) + num_epochs = _positive_int_attr(workflow, "num_epochs") + epoch = _nonnegative_int(getattr(ctx, "epoch", None)) + if num_epochs is not None and epoch is not None: + scalars["training/target_epochs"] = float(num_epochs) + scalars["training/epoch_fraction"] = min(epoch / num_epochs, 1.0) + return scalars + + if _is_dynamics_context(ctx): + _add_rate_scalar(scalars, "dynamics/steps_per_s", step_count, elapsed_s) + target_steps = _positive_int_attr(workflow, "n_steps") + if target_steps is not None: + _add_target_progress( + scalars, + prefix="dynamics", + completed=step_count, + target=target_steps, + elapsed_s=elapsed_s, + ) + return scalars + + +def extract_scalars( + values: Mapping[str, object], + *, + prefix: str | None = None, +) -> dict[str, float]: + """Extract scalar leaves from a nested mapping. + + Parameters + ---------- + values : Mapping[str, object] + Mapping whose leaves must be scalar Python numbers or scalar tensors. + prefix : str | None, optional + Optional key prefix added before every extracted scalar. + + Returns + ------- + dict[str, float] + Flat slash-separated scalar mapping. + + Raises + ------ + TypeError + If a key is not a string or a leaf has an unsupported type. + ValueError + If a tensor leaf contains multiple elements. + """ + scalars: dict[str, float] = {} + root = prefix.strip("/") if prefix else "" + for name, value in values.items(): + if not isinstance(name, str): + raise TypeError(f"Scalar keys must be strings, got {type(name).__name__}.") + if value is None: + continue + key = _join_key(root, name) + if isinstance(value, Mapping): + scalars.update(extract_scalars(value, prefix=key)) + else: + scalars[key] = _to_float(value, key) + return scalars + + +def _extract_component_scalars( + losses: Mapping[str, object], + *, + source_key: str, + target_suffix: str, + scalars: dict[str, float], +) -> None: + values = losses.get(source_key) + if values is None: + return + if not isinstance(values, Mapping): + raise TypeError(f"losses[{source_key!r}] must be a mapping.") + for component, value in values.items(): + if not isinstance(component, str): + raise TypeError( + f"Loss component names must be strings, got {type(component).__name__}." + ) + key = f"loss/{component}/{target_suffix}" + scalars[key] = _to_float(value, key) + + +def _extract_component_sample_means( + losses: Mapping[str, object], + scalars: dict[str, float], +) -> None: + values = losses.get("per_component_sample") + if values is None: + return + if not isinstance(values, Mapping): + raise TypeError("losses['per_component_sample'] must be a mapping.") + for component, value in values.items(): + if not isinstance(component, str): + raise TypeError( + f"Loss component names must be strings, got {type(component).__name__}." + ) + key = f"loss/{component}/sample_mean" + scalars[key] = _tensor_mean_to_float(value, key) + + +def _optimizer_lr_key( + *, + optimizer_count: int, + optimizer_idx: int, + group_count: int, + group_idx: int, +) -> str: + if optimizer_count == 1 and group_count == 1: + return "optimizer/lr" + if optimizer_count == 1: + return f"optimizer/group_{group_idx}/lr" + if group_count == 1: + return f"optimizer/{optimizer_idx}/lr" + return f"optimizer/{optimizer_idx}/group_{group_idx}/lr" + + +def _scheduler_lr_key( + *, + scheduler_count: int, + scheduler_idx: int, + group_count: int, + group_idx: int, +) -> str: + if scheduler_count == 1 and group_count == 1: + return "scheduler/lr" + if scheduler_count == 1: + return f"scheduler/group_{group_idx}/lr" + if group_count == 1: + return f"scheduler/{scheduler_idx}/lr" + return f"scheduler/{scheduler_idx}/group_{group_idx}/lr" + + +def _join_key(prefix: str, name: str) -> str: + clean_name = name.strip("/") + return clean_name if not prefix else f"{prefix}/{clean_name}" + + +def _get_tensor_attr(obj: object, name: str) -> torch.Tensor | None: + value = getattr(obj, name, None) + return value if isinstance(value, torch.Tensor) else None + + +def _temperature_scalar(batch: object) -> float | None: + velocities = _get_tensor_attr(batch, "velocities") + atomic_masses = _get_tensor_attr(batch, "atomic_masses") + batch_idx = _get_tensor_attr(batch, "batch_idx") + num_nodes_per_graph = _get_tensor_attr(batch, "num_nodes_per_graph") + num_graphs = getattr(batch, "num_graphs", None) + if ( + velocities is None + or atomic_masses is None + or batch_idx is None + or num_nodes_per_graph is None + or not isinstance(num_graphs, int) + ): + return None + from nvalchemi.dynamics.hooks._utils import temperature_per_graph # noqa: PLC0415 + + temperature = temperature_per_graph( + velocities, + atomic_masses, + batch_idx, + num_graphs, + num_nodes_per_graph, + ) + return _tensor_mean_to_float(temperature, "temperature") + + +def _active_fraction_scalar(ctx: HookContext) -> float | None: + status = _get_tensor_attr(ctx.batch, "status") + exit_status = getattr(getattr(ctx, "workflow", None), "exit_status", None) + num_graphs = getattr(ctx.batch, "num_graphs", None) + if ( + status is None + or not isinstance(exit_status, int) + or not isinstance(num_graphs, int) + ): + return None + status = status.squeeze(-1) if status.dim() == 2 else status + active_mask = status[:num_graphs] < exit_status + return _tensor_mean_to_float(active_mask.float(), "active_fraction") + + +def _status_count_scalars(ctx: HookContext) -> dict[str, float]: + status = _get_tensor_attr(ctx.batch, "status") + num_graphs = getattr(ctx.batch, "num_graphs", None) + if status is None or not isinstance(num_graphs, int): + return {} + status = status.squeeze(-1) if status.dim() == 2 else status + status = status[:num_graphs].detach().to(device="cpu", dtype=torch.long) + scalars: dict[str, float] = {"dynamics/num_graphs": float(num_graphs)} + if status.numel() == 0: + return scalars + values, counts = torch.unique(status, return_counts=True) + for value, count in zip(values.tolist(), counts.tolist(), strict=True): + scalars[f"dynamics/status/{value}/count"] = float(count) + exit_status = getattr(getattr(ctx, "workflow", None), "exit_status", None) + if isinstance(exit_status, int) and num_graphs > 0: + active_count = int((status < exit_status).sum().item()) + graduated_count = int((status >= exit_status).sum().item()) + scalars["dynamics/active_count"] = float(active_count) + scalars["dynamics/graduated_count"] = float(graduated_count) + scalars["dynamics/graduated_fraction"] = graduated_count / num_graphs + return scalars + + +def _is_training_context(ctx: HookContext) -> bool: + return any( + hasattr(ctx, name) + for name in ("batch_count", "epoch_step_count", "epoch", "optimizers") + ) + + +def _is_dynamics_context(ctx: HookContext) -> bool: + return hasattr(ctx, "converged_mask") or hasattr( + getattr(ctx, "workflow", None), "n_steps" + ) + + +def _nonnegative_int(value: object) -> int | None: + if isinstance(value, bool) or not isinstance(value, int) or value < 0: + return None + return value + + +def _positive_int_attr(obj: object, name: str) -> int | None: + value = getattr(obj, name, None) + if isinstance(value, bool) or not isinstance(value, int) or value <= 0: + return None + return value + + +def _add_rate_scalar( + scalars: dict[str, float], + key: str, + count: int | None, + elapsed_s: float | None, +) -> None: + if count is None or count <= 0 or elapsed_s is None or elapsed_s <= 0: + return + scalars[key] = count / elapsed_s + + +def _add_target_progress( + scalars: dict[str, float], + *, + prefix: str, + completed: int | None, + target: int, + elapsed_s: float | None, +) -> None: + if completed is None: + return + remaining = max(target - completed, 0) + scalars[f"{prefix}/target_steps"] = float(target) + scalars[f"{prefix}/remaining_steps"] = float(remaining) + scalars[f"{prefix}/progress_fraction"] = min(completed / target, 1.0) + if elapsed_s is None or elapsed_s <= 0 or completed <= 0: + return + scalars[f"{prefix}/eta_s"] = remaining / (completed / elapsed_s) + + +def _composed_loss_keys() -> frozenset[str]: + reporting_keys = frozenset( + ("total_loss", "per_component_sample") + + tuple(source_key for source_key, _ in _COMPONENT_SCALAR_SPECS) + ) + try: + from nvalchemi.training.losses.composition import ComposedLossOutput + except ImportError: + return reporting_keys + return reporting_keys | frozenset(ComposedLossOutput.__annotations__) + + +def _to_float(value: object, name: str) -> float: + if isinstance(value, bool): + return float(value) + if isinstance(value, Real): + return float(value) + if isinstance(value, torch.Tensor): + if value.numel() != 1: + raise ValueError( + f"{name!r} must be scalar, got tensor with shape {tuple(value.shape)}." + ) + return float(value.detach().reshape(-1)[0].item()) + raise TypeError( + f"{name!r} must be a scalar number or scalar tensor, " + f"got {type(value).__name__}." + ) + + +def _tensor_mean_to_float(value: object, name: str) -> float: + if not isinstance(value, torch.Tensor): + raise TypeError(f"{name!r} must be a tensor, got {type(value).__name__}.") + if value.numel() == 0: + raise ValueError(f"{name!r} cannot reduce an empty tensor.") + tensor = value.detach() + if not torch.is_floating_point(tensor) and not torch.is_complex(tensor): + tensor = tensor.float() + return float(tensor.mean().item()) diff --git a/nvalchemi/hooks/reporting/_state.py b/nvalchemi/hooks/reporting/_state.py new file mode 100644 index 00000000..c4e46e3c --- /dev/null +++ b/nvalchemi/hooks/reporting/_state.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared reporting runtime state.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from nvalchemi.hooks._context import HookContext + +MessageLevel = Literal["info", "warning", "error"] + + +@dataclass(frozen=True, kw_only=True) +class ReporterMessage: + """Message emitted by reporting infrastructure. + + Attributes + ---------- + level : {"info", "warning", "error"} + Severity level for the message. + message : str + Human-readable message. + reporter : str | None + Reporter class name associated with the message, when available. + stage : str | None + Hook stage name associated with the message, when available. + step_count : int | None + Workflow step count associated with the message, when available. + global_rank : int | None + Distributed rank associated with the message, when available. + timestamp_s : float + Wall-clock timestamp from :func:`time.time`. + """ + + level: MessageLevel + message: str + reporter: str | None = None + stage: str | None = None + step_count: int | None = None + global_rank: int | None = None + timestamp_s: float = field(default_factory=time.time) + + +@dataclass(kw_only=True) +class ReportingState: + """Mutable state shared by a reporting orchestrator and its reporters. + + The state object intentionally stores only orchestration metadata: + counters, timestamps, recent messages, and an extensible metadata mapping. + Workflow values such as losses, schedulers, or dynamics counters should be + read from the hook context rather than duplicated here. + + Attributes + ---------- + max_messages : int + Maximum number of recent messages retained. + started_at_s : float + Monotonic time when the state was created. + event_count : int + Number of reporting events dispatched by the orchestrator. + last_event_at_s : float | None + Monotonic time of the latest reporting event. + last_stage : str | None + Name of the latest reported hook stage. + last_step_count : int | None + Step count from the latest reported context, when available. + last_global_rank : int | None + Rank from the latest reported context, when available. + messages : list[ReporterMessage] + Bounded list of recent reporting messages. + metadata : dict[str, Any] + Scratch space for reporters that need shared per-run state. + """ + + max_messages: int = 100 + started_at_s: float = field(default_factory=time.monotonic) + event_count: int = 0 + last_event_at_s: float | None = None + last_stage: str | None = None + last_step_count: int | None = None + last_global_rank: int | None = None + messages: list[ReporterMessage] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + def mark_event(self, ctx: HookContext, stage: Enum) -> None: + """Record that a reporting event was dispatched. + + Parameters + ---------- + ctx : HookContext + Workflow context passed to the reporting orchestrator. + stage : Enum + Hook stage being reported. + """ + self.event_count += 1 + self.last_event_at_s = time.monotonic() + self.last_stage = stage.name + self.last_step_count = getattr(ctx, "step_count", None) + self.last_global_rank = ctx.global_rank + + def add_message( + self, + level: MessageLevel, + message: str, + *, + reporter: object | None = None, + ctx: HookContext | None = None, + stage: Enum | None = None, + ) -> ReporterMessage: + """Append a bounded recent message. + + Parameters + ---------- + level : {"info", "warning", "error"} + Message severity. + message : str + Human-readable message. + reporter : object | None, optional + Reporter associated with the message. + ctx : HookContext | None, optional + Context associated with the message. + stage : Enum | None, optional + Hook stage associated with the message. + + Returns + ------- + ReporterMessage + The message object appended to :attr:`messages`. + """ + entry = ReporterMessage( + level=level, + message=message, + reporter=type(reporter).__name__ if reporter is not None else None, + stage=stage.name if stage is not None else None, + step_count=getattr(ctx, "step_count", None) if ctx is not None else None, + global_rank=ctx.global_rank if ctx is not None else None, + ) + self.messages.append(entry) + if len(self.messages) > self.max_messages: + del self.messages[: len(self.messages) - self.max_messages] + return entry diff --git a/nvalchemi/hooks/reporting/_tensorboard.py b/nvalchemi/hooks/reporting/_tensorboard.py new file mode 100644 index 00000000..04737ac6 --- /dev/null +++ b/nvalchemi/hooks/reporting/_tensorboard.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorBoard reporting sink.""" + +from __future__ import annotations + +from collections.abc import Mapping +from enum import Enum +from pathlib import Path +from types import TracebackType +from typing import Protocol + +from torch import distributed as dist + +from nvalchemi._optional import OptionalDependency +from nvalchemi.hooks._context import HookContext +from nvalchemi.hooks.reporting._distributed import ( + normalize_rank_reduction, + reduce_scalar_snapshot, +) +from nvalchemi.hooks.reporting._scalars import ScalarCallback, collect_scalars +from nvalchemi.hooks.reporting._state import ReportingState + + +class TensorBoardWriter(Protocol): + """Subset of ``SummaryWriter`` used by :class:`TensorBoardReporter`.""" + + def add_scalar( + self, + tag: str, + scalar_value: float, + global_step: int | None = None, + ) -> None: + """Write one scalar event. + + Parameters + ---------- + tag : str + TensorBoard scalar tag. + scalar_value : float + Scalar value to write. + global_step : int | None, optional + Step associated with the scalar. + """ + ... + + def flush(self) -> None: + """Flush pending TensorBoard events.""" + ... + + def close(self) -> None: + """Close the writer.""" + ... + + +@OptionalDependency.TENSORBOARD.require +class TensorBoardReporter: + """Write scalar reporting snapshots to TensorBoard. + + Parameters + ---------- + log_dir : str | Path + TensorBoard log directory. + custom_scalars : Mapping[str, ScalarCallback] | None, optional + Additional scalar callbacks passed to :func:`collect_scalars`. + include_losses : bool, default True + When ``True``, include loss scalars from the hook context. + include_optimizer_lrs : bool, default True + When ``True``, include optimizer learning rates from the hook context. + rank_reduction : torch.distributed.ReduceOp | {"none", "mean", "sum", "min", "max"} | None, default None + Optional distributed reduction applied to scalars before writing. String + values are normalized to :class:`torch.distributed.ReduceOp`. Reduction + requires every rank to call this reporter; only rank zero writes the + reduced snapshot. + tag_prefix : str | None, optional + Optional prefix prepended to every TensorBoard tag. + flush : bool, default True + Flush the writer after every report event. + rank_zero_only : bool, default True + Request rank-zero-only dispatch from :class:`ReportingOrchestrator`. + When ``False`` and ``rank_reduction="none"``, ``log_dir`` must contain + ``"{rank}"`` or ``"{global_rank}"`` so every rank writes its own event + directory. + writer : TensorBoardWriter | None, optional + Preconstructed writer. This is mainly useful for tests or integrations + that own writer construction. + """ + + def __init__( + self, + log_dir: str | Path, + *, + custom_scalars: Mapping[str, ScalarCallback] | None = None, + include_losses: bool = True, + include_optimizer_lrs: bool = True, + rank_reduction: dist.ReduceOp | str | None = None, + tag_prefix: str | None = None, + flush: bool = True, + rank_zero_only: bool = True, + writer: TensorBoardWriter | None = None, + ) -> None: + self.rank_reduction = rank_reduction + self._rank_reduction_op, _ = normalize_rank_reduction(rank_reduction) + self.log_dir = Path(log_dir) + self.custom_scalars = custom_scalars + self.include_losses = include_losses + self.include_optimizer_lrs = include_optimizer_lrs + self.tag_prefix = tag_prefix.strip("/") if tag_prefix is not None else None + self.flush = flush + self._write_rank_zero_only = ( + rank_zero_only or self._rank_reduction_op is not None + ) + self.rank_zero_only = rank_zero_only and self._rank_reduction_op is None + self.requires_all_ranks = self._rank_reduction_op is not None + self._writer = writer + self._external_writer = writer is not None + self._open_log_dir: Path | None = None + if not self._write_rank_zero_only and not self._has_rank_token: + raise ValueError( + "TensorBoardReporter log_dir must contain '{rank}' or " + "'{global_rank}' when rank_zero_only=False and " + "rank_reduction='none'." + ) + + def __enter__(self) -> TensorBoardReporter: + """Return this reporter; writers are opened lazily on first write.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Close the TensorBoard writer.""" + self.close() + + def close(self) -> None: + """Close the writer if it is open.""" + if self._writer is None: + return + self._writer.close() + self._writer = None + self._open_log_dir = None + + def report(self, ctx: HookContext, stage: Enum, state: ReportingState) -> None: + """Write one scalar snapshot to TensorBoard. + + Parameters + ---------- + ctx : HookContext + Workflow hook context. + stage : Enum + Hook stage being reported. + state : ReportingState + Shared reporting state from the orchestrator. + """ + snapshot = collect_scalars( + ctx, + stage, + state, + custom_scalars=self.custom_scalars, + include_losses=self.include_losses, + include_optimizer_lrs=self.include_optimizer_lrs, + ) + if self._rank_reduction_op is not None: + snapshot = reduce_scalar_snapshot( + snapshot, + self.rank_reduction, + reporter_name=type(self).__name__, + ) + if not self._is_rank_zero(ctx): + return + elif self._write_rank_zero_only and not self._is_rank_zero(ctx): + return + + writer = self._open(self._resolve_log_dir(ctx.global_rank)) + step = snapshot.step_count if snapshot.step_count is not None else None + if step is None: + step = snapshot.event_count + for key, value in sorted(snapshot.scalars.items()): + writer.add_scalar(self._tag(key), value, global_step=step) + if self.flush: + writer.flush() + + @property + def _has_rank_token(self) -> bool: + path = str(self.log_dir) + return "{rank}" in path or "{global_rank}" in path + + def _open(self, log_dir: Path) -> TensorBoardWriter: + if self._writer is not None and self._external_writer: + return self._writer + if self._writer is not None and self._open_log_dir == log_dir: + return self._writer + if self._writer is not None: + self.close() + from torch.utils.tensorboard import SummaryWriter + + self._writer = SummaryWriter(log_dir=str(log_dir)) + self._open_log_dir = log_dir + return self._writer + + def _resolve_log_dir(self, global_rank: int) -> Path: + path = str(self.log_dir) + path = path.replace("{global_rank}", str(global_rank)) + path = path.replace("{rank}", str(global_rank)) + return Path(path) + + def _tag(self, key: str) -> str: + return key if self.tag_prefix is None else f"{self.tag_prefix}/{key}" + + def _is_rank_zero(self, ctx: HookContext) -> bool: + return ctx.global_rank == 0 diff --git a/nvalchemi/hooks/reporting/layouts/__init__.py b/nvalchemi/hooks/reporting/layouts/__init__.py new file mode 100644 index 00000000..ecec8d87 --- /dev/null +++ b/nvalchemi/hooks/reporting/layouts/__init__.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rich reporting layouts.""" + +from __future__ import annotations + +from nvalchemi.hooks.reporting.layouts.base import ( + BaseRichLayout, + RichLayout, + RichLayoutName, + RichMetricHistory, + RichPreviewHistory, +) +from nvalchemi.hooks.reporting.layouts.dynamics import DynamicsRichLayout +from nvalchemi.hooks.reporting.layouts.train import TrainingRichLayout + +__all__ = [ + "BaseRichLayout", + "DynamicsRichLayout", + "RichLayout", + "RichLayoutName", + "RichMetricHistory", + "RichPreviewHistory", + "TrainingRichLayout", + "resolve_rich_layout", +] + +_REQUIRED_RICH_LAYOUT_METHODS = ( + "default_preview_history", + "default_preview_stage", + "default_preview_epoch", + "default_preview_batch_count", + "render", +) + + +def resolve_rich_layout(layout: RichLayout | RichLayoutName | str | None) -> RichLayout: + """Resolve a Rich layout name or instance to a layout object. + + Parameters + ---------- + layout : RichLayout | {"training", "dynamics"} | str | None + Layout instance or concrete built-in layout name. ``"auto"`` and + ``None`` are handled by :class:`~nvalchemi.hooks.RichReporter` before + this resolver is called. + + Returns + ------- + RichLayout + Resolved layout policy. + + Raises + ------ + ValueError + If a string layout name is not recognized. + TypeError + If an object does not implement the layout protocol. + """ + if layout is None or layout == "training": + return TrainingRichLayout() + if layout == "dynamics": + return DynamicsRichLayout() + if isinstance(layout, str): + raise ValueError( + "RichReporter layout must be 'auto', 'training', 'dynamics', " + "or a layout object." + ) + missing = [ + method + for method in _REQUIRED_RICH_LAYOUT_METHODS + if not callable(getattr(layout, method, None)) + ] + if missing: + raise TypeError( + "RichReporter layout objects must define " + f"{', '.join(f'{method}()' for method in _REQUIRED_RICH_LAYOUT_METHODS)}. " + f"Missing: {', '.join(f'{method}()' for method in missing)}." + ) + return layout diff --git a/nvalchemi/hooks/reporting/layouts/base.py b/nvalchemi/hooks/reporting/layouts/base.py new file mode 100644 index 00000000..2b5fdcce --- /dev/null +++ b/nvalchemi/hooks/reporting/layouts/base.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base classes and protocols for Rich reporting layouts.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Literal, Protocol, TypeAlias + +import plotext as plt +from rich import box +from rich.ansi import AnsiDecoder +from rich.console import Console, ConsoleOptions, Group, RenderResult +from rich.layout import Layout +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from nvalchemi.hooks.reporting._scalars import ScalarSnapshot + +RichMetricHistory: TypeAlias = Mapping[str, Sequence[tuple[int, float]]] +RichPreviewHistory: TypeAlias = Mapping[str, Sequence[float]] +RichLayoutName: TypeAlias = Literal["auto", "training", "dynamics"] + + +class RichLayout(Protocol): + """Layout policy used by :class:`~nvalchemi.hooks.reporting.RichReporter`.""" + + def default_preview_history(self) -> RichPreviewHistory: + """Return synthetic metric curves for static dashboard previews.""" + ... + + def default_preview_stage(self) -> str: + """Return the hook stage label used by static dashboard previews.""" + ... + + def default_preview_epoch(self) -> int | None: + """Return the epoch metadata used by static dashboard previews.""" + ... + + def default_preview_batch_count(self) -> int | None: + """Return the batch metadata used by static dashboard previews.""" + ... + + def render( + self, + snapshot: ScalarSnapshot | None, + history: RichMetricHistory, + *, + title: str, + precision: int, + max_scalars: int | None, + plot_keys: Sequence[str] | None, + max_plots: int, + plot_height: int, + ) -> Layout: + """Build the Rich layout for one reporter snapshot.""" + ... + + +class BaseRichLayout: + """Reusable Rich dashboard layout for scalar tables and plot panels. + + Attributes + ---------- + name : str + Short layout name displayed in the dashboard header. + include_dynamics_scalars : bool + Whether :class:`~nvalchemi.hooks.reporting.RichReporter` should collect + default dynamics observables when this layout is selected. + """ + + def __init__( + self, + *, + name: str, + preferred_plot_keys: Sequence[str], + latest_title: str, + history_title: str, + include_dynamics_scalars: bool = False, + ) -> None: + self.name = name + self._preferred_plot_keys = tuple(preferred_plot_keys) + self._latest_title = latest_title + self._history_title = history_title + self.include_dynamics_scalars = include_dynamics_scalars + + def render( + self, + snapshot: ScalarSnapshot | None, + history: RichMetricHistory, + *, + title: str, + precision: int, + max_scalars: int | None, + plot_keys: Sequence[str] | None, + max_plots: int, + plot_height: int, + ) -> Layout: + """Build the Rich layout for one reporter snapshot. + + Parameters + ---------- + snapshot : ScalarSnapshot | None + Latest scalar snapshot, or ``None`` before the first report. + history : RichMetricHistory + Retained scalar history keyed by metric name. + title : str + Dashboard title. + precision : int + Significant digits used for scalar values. + max_scalars : int | None + Maximum number of latest scalar rows. + plot_keys : Sequence[str] | None + Explicit plot key ordering override. + max_plots : int + Maximum number of plot panels. + plot_height : int + Plot height in terminal rows. + + Returns + ------- + Layout + Renderable Rich layout. + """ + layout = Layout(name="root") + layout.split_column( + Layout(name="header", size=3), + Layout(name="body"), + ) + layout["body"].split_row( + Layout(name="latest", ratio=2), + Layout(name="plots", ratio=3), + ) + layout["header"].update(self._build_header(snapshot, title)) + layout["latest"].update( + Panel( + self._build_table(snapshot, precision, max_scalars), + title=self._latest_title, + ) + ) + layout["plots"].update( + Panel( + self._build_plots( + history, + precision=precision, + plot_keys=plot_keys, + max_plots=max_plots, + plot_height=plot_height, + ), + title=self._history_title, + ) + ) + return layout + + def default_preview_history(self) -> RichPreviewHistory: + """Return synthetic metric curves for static dashboard previews.""" + raise NotImplementedError + + def default_preview_stage(self) -> str: + """Return the hook stage label used by static dashboard previews.""" + return "AFTER_OPTIMIZER_STEP" + + def default_preview_epoch(self) -> int | None: + """Return the epoch metadata used by static dashboard previews.""" + return 3 + + def default_preview_batch_count(self) -> int | None: + """Return the batch metadata used by static dashboard previews.""" + return 128 + + def _build_header( + self, + snapshot: ScalarSnapshot | None, + title: str, + ) -> Panel: + if snapshot is None: + body = f"{title} | {self.name} | waiting for metrics" + else: + body = f"{title} | {self.name} | {snapshot.stage}" + if snapshot.step_count is not None: + body = f"{body} | step {snapshot.step_count}" + return Panel(Text(body, overflow="fold"), box=box.SIMPLE) + + def _build_table( + self, + snapshot: ScalarSnapshot | None, + precision: int, + max_scalars: int | None, + ) -> Table: + table = Table(box=box.SIMPLE_HEAD, show_lines=False, expand=True) + table.add_column("Metric", overflow="fold") + table.add_column("Latest", justify="right", no_wrap=True) + if snapshot is None or not snapshot.scalars: + table.add_row("(no scalars)", "") + return table + items = self._scalar_table_items(snapshot) + visible_items = items[:max_scalars] if max_scalars is not None else items + for key, value in visible_items: + table.add_row(key, self._format_value(value, precision)) + if len(visible_items) < len(items): + table.add_row("...", f"{len(items) - len(visible_items)} omitted") + table.caption = self._caption(snapshot) + return table + + def _scalar_table_items(self, snapshot: ScalarSnapshot) -> list[tuple[str, float]]: + preferred = [ + (key, snapshot.scalars[key]) + for key in self._preferred_plot_keys + if key in snapshot.scalars + ] + seen = {key for key, _ in preferred} + preferred.extend( + (key, value) + for key, value in sorted(snapshot.scalars.items()) + if key not in seen + ) + return preferred + + def _build_plots( + self, + history: RichMetricHistory, + *, + precision: int, + plot_keys: Sequence[str] | None, + max_plots: int, + plot_height: int, + ) -> Group | Text: + keys = self._selected_plot_keys( + history, + plot_keys=plot_keys, + max_plots=max_plots, + ) + if not keys: + return Text("No scalar history yet.") + panels = [ + Panel( + _PlotextSeries( + key=key, + series=tuple(history[key]), + precision=precision, + height=plot_height, + ), + title=key, + box=box.SIMPLE, + ) + for key in keys + ] + return Group(*panels) + + def _selected_plot_keys( + self, + history: RichMetricHistory, + *, + plot_keys: Sequence[str] | None, + max_plots: int, + ) -> tuple[str, ...]: + if max_plots == 0: + return () + available = [key for key, values in history.items() if values] + if plot_keys is not None: + keys = [key for key in plot_keys if key in available] + else: + keys = [key for key in self._preferred_plot_keys if key in available] + keys.extend(sorted(key for key in available if key not in keys)) + return tuple(keys[:max_plots]) + + def _format_value(self, value: float, precision: int) -> str: + return f"{value:.{precision}g}" + + def _caption(self, snapshot: ScalarSnapshot) -> str: + parts = [f"rank={snapshot.global_rank}"] + if snapshot.event_count is not None: + parts.append(f"event={snapshot.event_count}") + if snapshot.epoch is not None: + parts.append(f"epoch={snapshot.epoch}") + if snapshot.batch_count is not None: + parts.append(f"batch={snapshot.batch_count}") + return " | ".join(parts) + + def _build_messages(self, snapshot: ScalarSnapshot | None) -> Table: + table = Table.grid(expand=True) + table.add_column("Level", no_wrap=True) + table.add_column("Message", overflow="fold") + if snapshot is None or not snapshot.messages: + table.add_row("info", "No reporter messages.") + return table + for message in snapshot.messages[-3:]: + prefix = message.level + if message.reporter is not None: + prefix = f"{prefix}/{message.reporter}" + table.add_row(prefix, message.message) + return table + + def _format_duration(self, seconds: float) -> str: + if seconds < 60: + return f"{seconds:.1f}s" + minutes, remaining_seconds = divmod(int(seconds), 60) + if minutes < 60: + return f"{minutes}m {remaining_seconds}s" + hours, remaining_minutes = divmod(minutes, 60) + return f"{hours}h {remaining_minutes}m" + + def _add_scalar_row( + self, + table: Table, + snapshot: ScalarSnapshot, + key: str, + label: str, + precision: int, + *, + suffix: str = "", + scale: float = 1.0, + ) -> None: + if key not in snapshot.scalars: + return + value = snapshot.scalars[key] * scale + table.add_row(label, f"{self._format_value(value, precision)}{suffix}") + + +class _PlotextSeries: + def __init__( + self, + *, + key: str, + series: Sequence[tuple[int, float]], + precision: int, + height: int, + ) -> None: + self.key = key + self.series = series + self.precision = precision + self.height = height + self.decoder = AnsiDecoder() + + def __rich_console__( + self, + console: Console, + options: ConsoleOptions, + ) -> RenderResult: + width = max(20, options.max_width or console.width) + canvas = self._build_canvas(width) + yield Group(*self.decoder.decode(canvas)) + + def _build_canvas(self, width: int) -> str: + plt.clf() + steps = [step for step, _ in self.series] + values = [value for _, value in self.series] + plt.plotsize(width, self.height) + plt.theme("dark") + plt.title(self.key) + plt.xlabel("step") + if len(values) == 1: + plt.scatter(steps, values) + else: + plt.plot(steps, values) + return plt.build() diff --git a/nvalchemi/hooks/reporting/layouts/dynamics.py b/nvalchemi/hooks/reporting/layouts/dynamics.py new file mode 100644 index 00000000..221e597e --- /dev/null +++ b/nvalchemi/hooks/reporting/layouts/dynamics.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dynamics Rich reporting layout.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from rich import box +from rich.layout import Layout +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from nvalchemi.hooks.reporting._scalars import ScalarSnapshot +from nvalchemi.hooks.reporting.layouts.base import ( + BaseRichLayout, + RichMetricHistory, + RichPreviewHistory, +) + + +class DynamicsRichLayout(BaseRichLayout): + """Rich dashboard layout for dynamics workflows.""" + + _observable_keys = ("energy", "fmax", "temperature", "energy_drift") + _status_keys = ("active_fraction", "converged_fraction") + + def __init__(self) -> None: + super().__init__( + name="dynamics", + preferred_plot_keys=( + "energy", + "fmax", + "temperature", + "energy_drift", + "converged_fraction", + "active_fraction", + ), + latest_title="State", + history_title="Traces", + include_dynamics_scalars=True, + ) + + def render( + self, + snapshot: ScalarSnapshot | None, + history: RichMetricHistory, + *, + title: str, + precision: int, + max_scalars: int | None, + plot_keys: Sequence[str] | None, + max_plots: int, + plot_height: int, + ) -> Layout: + """Build a dynamics-specific Rich dashboard. + + Parameters + ---------- + snapshot : ScalarSnapshot | None + Latest scalar snapshot, or ``None`` before the first report. + history : RichMetricHistory + Retained scalar history keyed by metric name. + title : str + Dashboard title. + precision : int + Significant digits used for scalar values. + max_scalars : int | None + Maximum number of observable rows. + plot_keys : Sequence[str] | None + Explicit plot key ordering override. + max_plots : int + Maximum number of plot panels. + plot_height : int + Plot height in terminal rows. + + Returns + ------- + Layout + Renderable Rich layout with dynamics observables, status, and traces. + """ + layout = Layout(name="root") + layout.split_column( + Layout(name="header", size=3), + Layout(name="body"), + ) + layout["body"].split_row( + Layout(name="state", ratio=2), + Layout(name="traces", ratio=3), + ) + layout["state"].split_column( + Layout(name="observables", ratio=2), + Layout(name="pipeline", ratio=2), + Layout(name="messages", size=4), + ) + layout["header"].update(self._build_header(snapshot, title)) + layout["observables"].update( + Panel( + self._build_observables(snapshot, precision, max_scalars), + title="Observables", + ) + ) + layout["pipeline"].update( + Panel( + self._build_pipeline(snapshot, precision), + title="Convergence / Pipeline", + ) + ) + layout["messages"].update( + Panel(self._build_messages(snapshot), title="Messages") + ) + layout["traces"].update( + Panel( + self._build_plots( + history, + precision=precision, + plot_keys=plot_keys, + max_plots=max_plots, + plot_height=plot_height, + ), + title="Dynamics Traces", + ) + ) + return layout + + def default_preview_history(self) -> RichPreviewHistory: + """Return representative dynamics metrics for preview rendering.""" + return { + "energy": (-15.2, -15.18, -15.21, -15.19, -15.2, -15.18), + "fmax": (0.42, 0.31, 0.22, 0.18, 0.12, 0.08), + "temperature": (297.0, 301.0, 299.0, 300.0, 302.0, 300.0), + "energy_drift": (0.0, 0.02, -0.01, 0.01, 0.0, 0.02), + "converged_fraction": (0.05, 0.12, 0.25, 0.41, 0.68, 0.92), + "active_fraction": (1.0, 1.0, 0.95, 0.9, 0.72, 0.5), + } + + def default_preview_stage(self) -> str: + """Return the dynamics hook stage label used by static previews.""" + return "AFTER_STEP" + + def default_preview_epoch(self) -> None: + """Return no epoch metadata for dynamics previews.""" + return None + + def default_preview_batch_count(self) -> None: + """Return no batch metadata for dynamics previews.""" + return None + + def _build_observables( + self, + snapshot: ScalarSnapshot | None, + precision: int, + max_scalars: int | None, + ) -> Table: + table = Table(box=box.SIMPLE_HEAD, show_lines=False, expand=True) + table.add_column("Observable", overflow="fold") + table.add_column("Latest", justify="right", no_wrap=True) + if snapshot is None or not snapshot.scalars: + table.add_row("(waiting)", "") + return table + keys = [key for key in self._observable_keys if key in snapshot.scalars] + keys.extend( + sorted( + key + for key in snapshot.scalars + if key not in keys and key not in self._status_keys + ) + ) + visible_keys = keys[:max_scalars] if max_scalars is not None else keys + for key in visible_keys: + table.add_row(key, self._format_value(snapshot.scalars[key], precision)) + if len(visible_keys) < len(keys): + table.add_row("...", f"{len(keys) - len(visible_keys)} omitted") + return table + + def _build_status(self, snapshot: ScalarSnapshot | None, precision: int) -> Table: + table = Table.grid(expand=True) + table.add_column("Field", overflow="fold") + table.add_column("Value", justify="right", no_wrap=True) + if snapshot is None: + table.add_row("state", Text("waiting")) + return table + for key in self._status_keys: + if key in snapshot.scalars: + table.add_row(key, self._format_value(snapshot.scalars[key], precision)) + table.add_row("rank", str(snapshot.global_rank)) + if snapshot.event_count is not None: + table.add_row("event", str(snapshot.event_count)) + if snapshot.step_count is not None: + table.add_row("step", str(snapshot.step_count)) + return table + + def _build_pipeline( + self, + snapshot: ScalarSnapshot | None, + precision: int, + ) -> Table: + table = Table.grid(expand=True) + table.add_column("Field", overflow="fold") + table.add_column("Value", justify="right", no_wrap=True) + if snapshot is None: + table.add_row("state", "waiting") + return table + for key in self._status_keys: + if key in snapshot.scalars: + table.add_row(key, self._format_value(snapshot.scalars[key], precision)) + if "dynamics/num_graphs" in snapshot.scalars: + table.add_row( + "systems", + self._format_value(snapshot.scalars["dynamics/num_graphs"], precision), + ) + if "dynamics/active_count" in snapshot.scalars: + table.add_row( + "active", + self._format_value( + snapshot.scalars["dynamics/active_count"], precision + ), + ) + if "dynamics/graduated_count" in snapshot.scalars: + table.add_row( + "graduated", + self._format_value( + snapshot.scalars["dynamics/graduated_count"], + precision, + ), + ) + if "dynamics/converged_count" in snapshot.scalars: + table.add_row( + "converged", + self._format_value( + snapshot.scalars["dynamics/converged_count"], + precision, + ), + ) + for key, value in sorted(snapshot.scalars.items()): + prefix = "dynamics/status/" + suffix = "/count" + if key.startswith(prefix) and key.endswith(suffix): + status = key[len(prefix) : -len(suffix)] + table.add_row(f"status {status}", self._format_value(value, precision)) + self._add_scalar_row( + table, + snapshot, + "dynamics/progress_fraction", + "progress", + precision, + suffix="%", + scale=100.0, + ) + self._add_scalar_row( + table, + snapshot, + "dynamics/steps_per_s", + "steps/s", + precision, + ) + if "dynamics/eta_s" in snapshot.scalars: + table.add_row( + "eta", self._format_duration(snapshot.scalars["dynamics/eta_s"]) + ) + table.add_row("rank", str(snapshot.global_rank)) + if snapshot.event_count is not None: + table.add_row("event", str(snapshot.event_count)) + if snapshot.step_count is not None: + table.add_row("step", str(snapshot.step_count)) + return table diff --git a/nvalchemi/hooks/reporting/layouts/train.py b/nvalchemi/hooks/reporting/layouts/train.py new file mode 100644 index 00000000..6c612c44 --- /dev/null +++ b/nvalchemi/hooks/reporting/layouts/train.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training Rich reporting layout.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from rich.layout import Layout +from rich.panel import Panel +from rich.table import Table + +from nvalchemi.hooks.reporting._scalars import ScalarSnapshot +from nvalchemi.hooks.reporting.layouts.base import ( + BaseRichLayout, + RichMetricHistory, + RichPreviewHistory, +) + + +class TrainingRichLayout(BaseRichLayout): + """Rich dashboard layout for training workflows.""" + + def __init__(self) -> None: + super().__init__( + name="training", + preferred_plot_keys=( + "loss/total", + "optimizer/lr", + "scheduler/lr", + "loss/energy/unweighted", + "loss/forces/unweighted", + ), + latest_title="Latest", + history_title="History", + ) + + def render( + self, + snapshot: ScalarSnapshot | None, + history: RichMetricHistory, + *, + title: str, + precision: int, + max_scalars: int | None, + plot_keys: Sequence[str] | None, + max_plots: int, + plot_height: int, + ) -> Layout: + """Build a training-specific Rich dashboard. + + Parameters + ---------- + snapshot : ScalarSnapshot | None + Latest scalar snapshot, or ``None`` before the first report. + history : RichMetricHistory + Retained scalar history keyed by metric name. + title : str + Dashboard title. + precision : int + Significant digits used for scalar values. + max_scalars : int | None + Maximum number of latest scalar rows. + plot_keys : Sequence[str] | None + Explicit plot key ordering override. + max_plots : int + Maximum number of plot panels. + plot_height : int + Plot height in terminal rows. + + Returns + ------- + Layout + Renderable Rich layout with training metrics and progress. + """ + layout = Layout(name="root") + layout.split_column( + Layout(name="header", size=3), + Layout(name="body"), + Layout(name="messages", size=5), + ) + layout["body"].split_row( + Layout(name="left", ratio=2), + Layout(name="plots", ratio=3), + ) + layout["left"].split_column( + Layout(name="latest", ratio=3), + Layout(name="progress", size=9), + ) + layout["header"].update(self._build_header(snapshot, title)) + layout["latest"].update( + Panel( + self._build_table(snapshot, precision, max_scalars), + title="Latest Metrics", + ) + ) + layout["progress"].update( + Panel( + self._build_progress(snapshot, precision), + title="Progress", + ) + ) + layout["plots"].update( + Panel( + self._build_plots( + history, + precision=precision, + plot_keys=plot_keys, + max_plots=max_plots, + plot_height=plot_height, + ), + title="Training Curves", + ) + ) + layout["messages"].update( + Panel(self._build_messages(snapshot), title="Messages") + ) + return layout + + def default_preview_history(self) -> RichPreviewHistory: + """Return representative training metrics for preview rendering.""" + return { + "loss/total": (1.2, 0.86, 0.61, 0.43, 0.31, 0.24), + "loss/energy/unweighted": (0.54, 0.39, 0.27, 0.19, 0.14, 0.11), + "loss/forces/unweighted": (0.66, 0.47, 0.34, 0.24, 0.17, 0.13), + "optimizer/lr": (1e-3, 1e-3, 8e-4, 5e-4, 2e-4, 1e-4), + "scheduler/lr": (1e-3, 1e-3, 8e-4, 5e-4, 2e-4, 1e-4), + } + + def _build_progress( + self, + snapshot: ScalarSnapshot | None, + precision: int, + ) -> Table: + table = Table.grid(expand=True) + table.add_column("Field", overflow="fold") + table.add_column("Value", justify="right", no_wrap=True) + if snapshot is None: + table.add_row("state", "waiting") + return table + table.add_row("rank", str(snapshot.global_rank)) + if snapshot.event_count is not None: + table.add_row("event", str(snapshot.event_count)) + if snapshot.step_count is not None: + table.add_row("step", str(snapshot.step_count)) + if snapshot.batch_count is not None: + table.add_row("batch", str(snapshot.batch_count)) + if snapshot.epoch is not None: + table.add_row("epoch", str(snapshot.epoch)) + if snapshot.epoch_step_count is not None: + table.add_row("epoch batch", str(snapshot.epoch_step_count)) + self._add_scalar_row( + table, + snapshot, + "training/progress_fraction", + "progress", + precision, + suffix="%", + scale=100.0, + ) + self._add_scalar_row( + table, + snapshot, + "training/steps_per_s", + "steps/s", + precision, + ) + self._add_scalar_row( + table, + snapshot, + "training/batches_per_s", + "batches/s", + precision, + ) + if "training/eta_s" in snapshot.scalars: + table.add_row( + "eta", self._format_duration(snapshot.scalars["training/eta_s"]) + ) + if "scheduler/lr" in snapshot.scalars: + table.add_row( + "scheduler lr", + self._format_value(snapshot.scalars["scheduler/lr"], precision), + ) + elif "optimizer/lr" in snapshot.scalars: + table.add_row( + "optimizer lr", + self._format_value(snapshot.scalars["optimizer/lr"], precision), + ) + return table diff --git a/nvalchemi/hooks/stage_timing.py b/nvalchemi/hooks/stage_timing.py new file mode 100644 index 00000000..ddac0c2a --- /dev/null +++ b/nvalchemi/hooks/stage_timing.py @@ -0,0 +1,476 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Per-stage wall-clock timing for hook-enabled workflows.""" + +from __future__ import annotations + +import csv +import io +import statistics +import time +from enum import Enum +from pathlib import Path +from typing import Literal + +import torch +from loguru import logger + +from nvalchemi.data import Batch +from nvalchemi.hooks._context import HookContext + +try: + import nvtx +except ImportError: + nvtx = None + +__all__ = ["StageTimingHook"] + + +def _get_dynamics_stage_type() -> type[Enum]: + """Import and return the dynamics stage enum on demand.""" + from nvalchemi.dynamics.base import DynamicsStage + + return DynamicsStage + + +def _stage_domain(stage: Enum) -> str: + """Return a stable domain label for a stage enum.""" + stage_type = type(stage) + if ( + stage_type.__name__ == "DynamicsStage" + and stage_type.__module__ == "nvalchemi.dynamics.base" + ): + return "dynamics" + return "custom" + + +def _sort_stages(stages: set[Enum]) -> list[Enum]: + """Sort stage enum members by their integer value.""" + return sorted(stages, key=lambda s: s.value) + + +class StageTimingHook: + """Per-stage timing hook for hook-enabled workflows. + + A single ``StageTimingHook`` instance registers itself at every + requested stage. On each call it records a timestamp; when the + last profiled stage in a step fires, it computes the elapsed time + between consecutive stages and (optionally) writes to CSV / console. + + The hook implements ``_runs_on_stage`` so hook registries can dispatch it + at every selected stage. + + The hook supports dynamics presets and custom enum stage sets. + Contexts should provide ``batch``, ``global_rank``, and ``step_count`` attributes. + + Parameters + ---------- + profiled_stages : set[Enum] | {"all", "step", "detailed"} + Which stages to instrument. + + * ``"all"`` (default): every dynamics stage except ``ON_CONVERGE``. + * ``"step"``: ``BEFORE_STEP`` and ``AFTER_STEP`` only. + * ``"detailed"``: all stages from ``BEFORE_STEP`` through + ``AFTER_STEP`` (excluding ``ON_CONVERGE``). + * A custom ``set[Enum]`` for fine-grained control. + frequency : int, optional + Profile every ``frequency`` steps. Default ``1``. + enable_nvtx : bool, optional + Emit NVTX push/pop ranges for Nsight Systems. Default ``True``. + timer_backend : {"cuda_event", "perf_counter", "auto"}, optional + Timing backend. ``"auto"`` selects ``cuda_event`` on GPU + devices and ``perf_counter`` on CPU. Default ``"auto"``. + log_path : str | Path | None, optional + Path to a CSV file for persistent timing logs. Each row + records the rank, step, stage transition, wall-clock offset, + and delta. Default ``None`` (no file). + show_console : bool, optional + Print a formatted timing table via ``loguru`` at each + profiled step. Default ``False``. + console_frequency : int, optional + When ``show_console`` is ``True``, print every + ``console_frequency`` profiled steps. Default ``1``. + + Attributes + ---------- + _profiled_stages : list[Enum] + Profiled stages in execution order (private). + frequency : int + Execution frequency in steps. + timings : dict[Enum, list[float]] + Accumulated per-transition timing data (seconds). + + Examples + -------- + >>> from nvalchemi.hooks import StageTimingHook + >>> profiler = StageTimingHook() + >>> dynamics = DemoDynamics(model=model, n_steps=100, dt=0.5, hooks=[profiler]) + >>> dynamics.run(batch) + >>> print(profiler.summary()) + + With CSV logging and console output: + + >>> profiler = StageTimingHook( + ... "detailed", + ... log_path="profiler.csv", + ... show_console=True, + ... console_frequency=10, + ... ) + >>> dynamics = DemoDynamics(model=model, n_steps=1000, dt=0.5, hooks=[profiler]) + >>> dynamics.run(batch) + """ + + def __init__( + self, + profiled_stages: set[Enum] | Literal["all", "step", "detailed"] = "all", + *, + frequency: int = 1, + enable_nvtx: bool = True, + timer_backend: Literal["cuda_event", "perf_counter", "auto"] = "auto", + log_path: str | Path | None = None, + show_console: bool = False, + console_frequency: int = 1, + stage: Enum | None = None, + ) -> None: + # Init file handle early so __del__ is safe on validation errors. + self._csv_file: io.TextIOWrapper | None = None + self._csv_writer: csv.DictWriter | None = None + + if isinstance(profiled_stages, str): + dynamics_stage = _get_dynamics_stage_type() + if profiled_stages == "all": + resolved = { + s for s in dynamics_stage if s != dynamics_stage.ON_CONVERGE + } + elif profiled_stages == "step": + resolved = {dynamics_stage.BEFORE_STEP, dynamics_stage.AFTER_STEP} + elif profiled_stages == "detailed": + resolved = { + dynamics_stage.BEFORE_STEP, + dynamics_stage.BEFORE_PRE_UPDATE, + dynamics_stage.AFTER_PRE_UPDATE, + dynamics_stage.BEFORE_COMPUTE, + dynamics_stage.AFTER_COMPUTE, + dynamics_stage.BEFORE_POST_UPDATE, + dynamics_stage.AFTER_POST_UPDATE, + dynamics_stage.AFTER_STEP, + } + else: + raise ValueError( + f"Unknown stages preset {profiled_stages!r}. " + f"Use 'all', 'step', 'detailed', or a set of Enum." + ) + else: + resolved = set(profiled_stages) + + if len(resolved) < 2: + raise ValueError( + "At least two stages are required to measure timing deltas." + ) + + # Sorted by execution order — private profiled stages list. + self._profiled_stages: list[Enum] = _sort_stages(resolved) + # Primary stage for protocol compliance. + self.stage = stage or self._profiled_stages[0] + self.frequency = frequency + self.enable_nvtx = enable_nvtx + self.timer_backend = timer_backend + self.log_path = Path(log_path) if log_path is not None else None + self.show_console = show_console + self.console_frequency = console_frequency + + # Per-step scratch — separate dicts for type safety. + self._current_step: int = -1 + self._step_cuda_events: dict[Enum, torch.cuda.Event] = {} + self._step_cpu_timestamps: dict[Enum, int] = {} + + # Accumulated timing: transition endpoint -> list of delta_s. + self.timings: dict[Enum, list[float]] = {s: [] for s in self._profiled_stages} + + self._t0_ns: int = time.perf_counter_ns() + self._backend_resolved: str | None = None + self._steps_recorded: int = 0 + + # ------------------------------------------------------------------ + # Hook entry point + # ------------------------------------------------------------------ + + def _runs_on_stage(self, stage: Enum) -> bool: + """Check if this hook should run on the given stage. + + Parameters + ---------- + stage : Enum + The stage to check. + + Returns + ------- + bool + True if this hook runs on the given stage. + """ + return stage in set(self._profiled_stages) + + @torch.compiler.disable + def _record( + self, + batch: Batch, + current_stage: Enum, + step_count: int, + global_rank: int, + domain: str = "dynamics", + ) -> None: + """Record a timestamp for the current stage. + + Parameters + ---------- + batch : Batch + The current batch of atomic data. + current_stage : Enum + The current dynamics stage being executed. + step_count : int + The current step number. + global_rank : int + The distributed rank of this process. + domain : str, optional + The domain for NVTX annotation (e.g., "dynamics", "custom"). + Default ``"dynamics"``. + """ + # New step: flush the previous one, then reset scratch. + if step_count != self._current_step: + if self._current_step >= 0: + self._flush_step(global_rank) + self._current_step = step_count + self._step_cuda_events.clear() + self._step_cpu_timestamps.clear() + + # NVTX annotation. + if self.enable_nvtx and nvtx is not None: + idx = self._profiled_stages.index(current_stage) + if idx > 0: + nvtx.pop_range() + nvtx.push_range(f"{domain}/{current_stage.name}/{step_count}") + + # Timestamp. + dev = batch.device + if isinstance(dev, str): + dev = torch.device(dev) + if self._backend_resolved is None: + self._backend_resolved = self._resolve_backend(dev) + if self._backend_resolved == "cuda_event": + event = torch.cuda.Event(enable_timing=True) + event.record() + self._step_cuda_events[current_stage] = event + else: + self._step_cpu_timestamps[current_stage] = time.perf_counter_ns() + + # If this is the last profiled stage in the step, flush now. + if current_stage == self._profiled_stages[-1]: + self._flush_step(global_rank) + self._current_step = -1 + self._step_cuda_events.clear() + self._step_cpu_timestamps.clear() + + def __call__(self, ctx: HookContext, stage: Enum) -> None: + """Record timing for a profiled stage. + + Parameters + ---------- + ctx : HookContext + Workflow context. Dynamics and training contexts provide + ``step_count``; base contexts default to step ``0``. + stage : Enum + Current workflow stage. + """ + batch = ctx.batch + if batch is None: + raise ValueError("StageTimingHook requires ctx.batch to record timing.") + step_count = int(getattr(ctx, "step_count", 0)) + domain = _stage_domain(stage) + self._record(batch, stage, step_count, ctx.global_rank or 0, domain=domain) + + # ------------------------------------------------------------------ + # Backend resolution + # ------------------------------------------------------------------ + + def _resolve_backend(self, device: torch.device) -> str: + """Resolve the timing backend based on configuration and device.""" + if self.timer_backend != "auto": + return self.timer_backend + if device.type == "cuda": + return "cuda_event" + return "perf_counter" + + # ------------------------------------------------------------------ + # Step flush — compute deltas, log + # ------------------------------------------------------------------ + + def _flush_step(self, rank: int) -> None: + """Compute per-transition deltas for the current step and log.""" + use_cuda = self._backend_resolved == "cuda_event" + + if use_cuda: + ordered = [s for s in self._profiled_stages if s in self._step_cuda_events] + else: + ordered = [ + s for s in self._profiled_stages if s in self._step_cpu_timestamps + ] + + if len(ordered) < 2: + return + + if use_cuda: + torch.cuda.synchronize() + + deltas: dict[Enum, float] = {} + for i in range(1, len(ordered)): + prev_stage, curr_stage = ordered[i - 1], ordered[i] + if use_cuda: + prev_ev = self._step_cuda_events[prev_stage] + curr_ev = self._step_cuda_events[curr_stage] + delta_s = prev_ev.elapsed_time(curr_ev) / 1000.0 + else: + prev_ts = self._step_cpu_timestamps[prev_stage] + curr_ts = self._step_cpu_timestamps[curr_stage] + delta_s = (curr_ts - prev_ts) / 1e9 + deltas[curr_stage] = delta_s + self.timings[curr_stage].append(delta_s) + + t_since_init_s = (time.perf_counter_ns() - self._t0_ns) / 1e9 + self._steps_recorded += 1 + + if self.log_path is not None: + self._write_csv(rank, self._current_step, t_since_init_s, ordered, deltas) + + if self.show_console and (self._steps_recorded % self.console_frequency == 0): + self._print_console( + rank, self._current_step, t_since_init_s, ordered, deltas + ) + + # Close NVTX range for the last stage in this step. + if self.enable_nvtx and nvtx is not None: + nvtx.pop_range() + + # ------------------------------------------------------------------ + # CSV output + # ------------------------------------------------------------------ + + def _write_csv( + self, + rank: int, + step: int, + t_since_init: float, + ordered: list[Enum], + deltas: dict[Enum, float], + ) -> None: + """Append one row per transition to the CSV log.""" + rows = [] + for i, stage in enumerate(ordered[1:], start=1): + rows.append( + { + "rank": rank, + "step": step, + "stage": f"{ordered[i - 1].name}->{stage.name}", + "t_since_init_s": f"{t_since_init:.6f}", + "delta_s": f"{deltas[stage]:.6f}", + } + ) + if self._csv_writer is None: + log_path = self.log_path + if log_path is None: + return + fh = open(log_path, "w", newline="") # noqa: SIM115 + self._csv_file = fh + self._csv_writer = csv.DictWriter( + fh, + fieldnames=["rank", "step", "stage", "t_since_init_s", "delta_s"], + ) + self._csv_writer.writeheader() + self._csv_writer.writerows(rows) + if self._csv_file is not None: + self._csv_file.flush() + + # ------------------------------------------------------------------ + # Console output + # ------------------------------------------------------------------ + + def _print_console( + self, + rank: int, + step: int, + t_since_init: float, + ordered: list[Enum], + deltas: dict[Enum, float], + ) -> None: + """Print a formatted timing table for the current step.""" + lines = [f"[Profiler] rank={rank} step={step} t={t_since_init:.3f}s"] + for i, stage in enumerate(ordered[1:], start=1): + prev_name = ordered[i - 1].name + lines.append( + f" {prev_name} -> {stage.name}: {deltas[stage] * 1000:.3f} ms" + ) + logger.info("\n".join(lines)) + + # ------------------------------------------------------------------ + # Summary / reset / close + # ------------------------------------------------------------------ + + def summary(self) -> dict[str, dict[str, float]]: + """Return per-transition timing statistics. + + Returns + ------- + dict[str, dict[str, float]] + Mapping from ``"PREV_STAGE->STAGE"`` label to a stats dict + with keys ``mean_s``, ``std_s``, ``min_s``, ``max_s``, + ``total_s``, ``n_samples``. + """ + result: dict[str, dict[str, float]] = {} + for idx, stage in enumerate(self._profiled_stages): + samples = self.timings[stage] + if not samples: + continue + prev_name = self._profiled_stages[idx - 1].name + label = f"{prev_name}->{stage.name}" + n = len(samples) + result[label] = { + "mean_s": statistics.mean(samples), + "std_s": statistics.stdev(samples) if n > 1 else 0.0, + "min_s": min(samples), + "max_s": max(samples), + "total_s": sum(samples), + "n_samples": float(n), + } + return result + + def reset(self) -> None: + """Clear all accumulated timing data.""" + for stage in self.timings: + self.timings[stage].clear() + self._step_cuda_events.clear() + self._step_cpu_timestamps.clear() + self._current_step = -1 + self._backend_resolved = None + self._t0_ns = time.perf_counter_ns() + self._steps_recorded = 0 + + def close(self) -> None: + """Flush and close the CSV log file, if open.""" + if self._csv_file is not None: + self._csv_file.close() + self._csv_file = None + self._csv_writer = None + + def __del__(self) -> None: + self.close() diff --git a/nvalchemi/models/mace.py b/nvalchemi/models/mace.py index c83bf1d2..42903e35 100644 --- a/nvalchemi/models/mace.py +++ b/nvalchemi/models/mace.py @@ -61,7 +61,7 @@ import warnings from importlib.metadata import version from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import torch from torch import nn @@ -76,6 +76,9 @@ NeighborListFormat, ) +if TYPE_CHECKING: + from nvalchemi.training._spec import BaseSpec + _torch_version = version("torch") __all__ = ["MACEWrapper"] @@ -131,9 +134,15 @@ class MACEWrapper(nn.Module, BaseModelMixin): model: nn.Module - def __init__(self, model: nn.Module) -> None: + def __init__( + self, + model: nn.Module, + *, + reconstruction_spec: "BaseSpec | None" = None, + ) -> None: super().__init__() self.model = model + self._checkpoint_spec = reconstruction_spec # Cache the model dtype — determined at construction, stable thereafter. self._cached_model_dtype: torch.dtype = next(model.parameters()).dtype @@ -171,6 +180,17 @@ def __init__(self, model: nn.Module) -> None: ), ) + def checkpoint_spec(self) -> "BaseSpec | None": + """Return the factory spec used to reconstruct this wrapper, if known. + + Wrappers created by :meth:`from_checkpoint` store a callable spec for + that factory so strategy checkpoints can rebuild optimized MACE models + without introspecting the transformed inner MACE module constructor. + Wrappers around arbitrary live modules return ``None`` and use the + generic constructor-introspection fallback. + """ + return self._checkpoint_spec + # ------------------------------------------------------------------ # BaseModelMixin required properties # ------------------------------------------------------------------ @@ -532,7 +552,17 @@ def from_checkpoint( param.requires_grad = False model = torch.compile(model, **compile_kwargs) - return cls(model) + from nvalchemi.training._spec import create_model_spec + + checkpoint_spec = create_model_spec( + cls.from_checkpoint, + checkpoint_path=str(checkpoint_path), + enable_cueq=enable_cueq, + dtype=dtype, + compile_model=compile_model, + **compile_kwargs, + ) + return cls(model, reconstruction_spec=checkpoint_spec) # ------------------------------------------------------------------ # Export diff --git a/nvalchemi/models/pipeline.py b/nvalchemi/models/pipeline.py index dba07b70..69b4db18 100644 --- a/nvalchemi/models/pipeline.py +++ b/nvalchemi/models/pipeline.py @@ -500,8 +500,15 @@ def _call_step( format), the batch's neighbor tensors are swapped to the model-specific version for the duration of the call. """ - override = self._step_active_overrides.get(id(step)) - needs_neighbor_adapt = self._step_needs_neighbor_adapt.get(id(step), False) + step_id = id(step) + if step_id not in self._step_needs_neighbor_adapt: + # After copy.deepcopy (e.g. EMA AveragedModel), which clones + # the dicts but creates new PipelineStep objects with new ids, + # the lookup tables are stale so we rebuild them via _configure_sub_models. + self._configure_sub_models() + step_id = id(step) + override = self._step_active_overrides.get(step_id) + needs_neighbor_adapt = self._step_needs_neighbor_adapt.get(step_id, False) saved_neighbors: dict[str, Any] | None = None saved_active: set[str] | None = None diff --git a/nvalchemi/training/__init__.py b/nvalchemi/training/__init__.py new file mode 100644 index 00000000..d8e19ad0 --- /dev/null +++ b/nvalchemi/training/__init__.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training framework for ALCHEMI — stages, specs, losses, and checkpoint I/O.""" + +from __future__ import annotations + +from nvalchemi.training._checkpoint import ( + CheckpointManifest, + CheckpointValidator, + load_checkpoint, + save_checkpoint, +) +from nvalchemi.training._spec import ( + BaseSpec, + create_model_spec, + create_model_spec_from_json, + register_type_serializer, +) +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training._validation import ( + BatchValidationCallback, + ValidationConfig, + ValidationLoop, +) +from nvalchemi.training.hooks import ( + CheckpointHook, + DDPHook, + EMAHook, +) +from nvalchemi.training.losses import ( + BaseLossFunction, + ComposedLossFunction, + ComposedLossOutput, + ConstantWeight, + CosineWeight, + EnergyHuberLoss, + EnergyMAELoss, + EnergyMSELoss, + ForceHuberLoss, + ForceL2NormLoss, + ForceMSELoss, + LinearWeight, + LossWeightSchedule, + PiecewiseWeight, + ReductionContext, + StressHuberLoss, + StressMSELoss, + loss_component_to_spec, +) +from nvalchemi.training.optimizers import ( + OptimizerConfig, + setup_optimizers, + step_lr_schedulers, + step_optimizers, + zero_gradients, +) +from nvalchemi.training.runtime import ( + configure_dataloader, + configure_parallelism, + freeze_unconfigured_models, + move_to_devices, +) +from nvalchemi.training.strategy import TrainingStrategy, default_training_fn + +__all__ = [ + "BaseLossFunction", + "BaseSpec", + "CheckpointManifest", + "CheckpointHook", + "CheckpointValidator", + "ComposedLossFunction", + "ComposedLossOutput", + "ConstantWeight", + "CosineWeight", + "EnergyHuberLoss", + "EnergyMAELoss", + "EnergyMSELoss", + "ForceHuberLoss", + "ForceL2NormLoss", + "ForceMSELoss", + "DDPHook", + "EMAHook", + "LinearWeight", + "LossWeightSchedule", + "OptimizerConfig", + "PiecewiseWeight", + "ReductionContext", + "StressHuberLoss", + "StressMSELoss", + "TrainingStage", + "TrainingStrategy", + "BatchValidationCallback", + "ValidationConfig", + "ValidationLoop", + "configure_dataloader", + "configure_parallelism", + "create_model_spec", + "create_model_spec_from_json", + "default_training_fn", + "freeze_unconfigured_models", + "loss_component_to_spec", + "load_checkpoint", + "move_to_devices", + "register_type_serializer", + "save_checkpoint", + "setup_optimizers", + "step_lr_schedulers", + "step_optimizers", + "zero_gradients", +] diff --git a/nvalchemi/training/_checkpoint.py b/nvalchemi/training/_checkpoint.py new file mode 100644 index 00000000..d23bcaba --- /dev/null +++ b/nvalchemi/training/_checkpoint.py @@ -0,0 +1,1760 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-component, manifest-based checkpoint layer. + +This module saves and loads checkpoints for multiple named models, +optimizers, and schedulers without relying on :mod:`pickle`. A top-level +``manifest.json`` coordinates all components and their associations. + +Layout +------ +A single call to :func:`save_checkpoint` writes:: + + {root_folder}/ + manifest.json + models/{name}/ + spec.json + checkpoints/{N}.pt + optimizers/{name}/ # optional + spec.json + checkpoints/{N}.pt + schedulers/{name}/ # optional + spec.json + checkpoints/{N}.pt + +The ``manifest.json`` records which components are present, the latest +checkpoint index, and optional associations that wire optimizers to models +and schedulers to optimizers:: + + { + "checkpoint_index": 0, + "models": ["student", "teacher"], + "optimizers": ["student_opt"], + "schedulers": ["student_sched"], + "associations": { + "student": { + "optimizers": ["student_opt"], + "schedulers": ["student_sched"] + } + } + } + +The ``associations`` key specifies connectivity between models and +their respective optimizer(s) and LR scheduler(s). This can be explicitly +provided by the user, or automatically inferred by matching parameters +with optimizers/LR schedulers. + +Examples +-------- +Single model:: + + save_checkpoint("runs/exp1", models={"main": (model, spec)}) + result = load_checkpoint("runs/exp1") + model, spec = result.models["main"] + +Knowledge distillation (two models + optimizer + scheduler):: + + save_checkpoint( + "runs/kd", + models={"student": (student, s_spec), "teacher": (teacher, t_spec)}, + optimizers={"s_opt": (optimizer, opt_spec)}, + schedulers={"s_sched": (scheduler, sched_spec)}, + # associations can be inferred automatically from param_groups + ) + result = load_checkpoint("runs/kd") + student, _ = result.models["student"] +""" + +from __future__ import annotations + +import itertools +import json +import warnings +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from pathlib import Path +from typing import Annotated, Any + +import torch +import torch.nn as nn +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, PlainSerializer + +from nvalchemi.hooks._protocol import CheckpointableHook +from nvalchemi.training._spec import ( + BaseSpec, + create_model_spec, + create_model_spec_from_json, +) + +CheckpointValidator = Callable[[str, Mapping[str, Any], Mapping[str, Any]], None] +"""Callable used to validate a loaded model entry. + +Validators receive ``(model_name, model_entry, loaded_checkpoint)`` and should +raise an exception with an actionable message when compatibility checks fail. +""" + +# --------------------------------------------------------------------------- +# Dual-mode field helpers +# --------------------------------------------------------------------------- + + +def _component_before(v: Any) -> dict[str, Any]: + """Accept ``list[str]`` (from JSON) or ``dict`` (from code) for component fields.""" + if isinstance(v, list): + # From disk: list of names → placeholder dict (values populated later) + return {name: None for name in v} + return v + + +def _component_serialize(d: dict[str, Any]) -> list[str]: + """Serialize a component dict to a sorted list of its keys.""" + return sorted(d.keys()) + + +def _is_fsdp_wrapped(module: nn.Module) -> bool: + """Return whether ``module`` is wrapped by FSDP or FSDP2.""" + fsdp_types: list[type[nn.Module]] = [] + try: + from torch.distributed.fsdp import FullyShardedDataParallel + + fsdp_types.append(FullyShardedDataParallel) + except (ImportError, AttributeError): + pass + try: + from torch.distributed._composable.fsdp import FSDPModule + + fsdp_types.append(FSDPModule) + except (ImportError, AttributeError): + pass + return bool(fsdp_types) and isinstance(module, tuple(fsdp_types)) + + +def _checkpoint_model(module: nn.Module) -> nn.Module: + """Return a model suitable for native checkpoint state and spec extraction.""" + if isinstance(module, torch.nn.parallel.DistributedDataParallel): + return module.module + if _is_fsdp_wrapped(module): + recipe_url = ( + "https://docs.pytorch.org/tutorials/recipes/" + "distributed_checkpoint_recipe.html" + ) + raise NotImplementedError( + "Native nvalchemi checkpoints do not yet support FSDP/FSDP2-wrapped " + "models. Use torch.distributed.checkpoint with PyTorch's distributed " + f"checkpoint recipe instead: {recipe_url}" + ) + return module + + +def _checkpoint_model_components( + models: Mapping[str, tuple[nn.Module, BaseSpec]], +) -> dict[str, tuple[nn.Module, BaseSpec]]: + """Unwrap supported distributed model wrappers before checkpointing.""" + return { + name: (_checkpoint_model(module), spec) + for name, (module, spec) in models.items() + } + + +# --------------------------------------------------------------------------- +# Manifest schema + runtime container (unified) +# --------------------------------------------------------------------------- + +_SCHEMA_VERSION = 1 +"""Current manifest schema version. Bump when manifest structure changes.""" + +_STRATEGY_FILENAME = "strategy.json" +"""File containing strategy recipe and runtime counters for native checkpoints.""" + +_STRATEGY_CHECKPOINT_DIR = Path("strategy") / "checkpoints" +"""Directory containing per-index strategy checkpoint metadata.""" + +_HOOK_CHECKPOINT_DIR = Path("hooks") / "checkpoints" +"""Directory containing per-index runtime hook state.""" + +_SCHEDULER_OPTIMIZERS_KEY = "scheduler_optimizers" +"""Association key mapping scheduler component names to optimizer names.""" + +# Type aliases for the runtime dict shapes +_ModelDict = dict[str, tuple[nn.Module, BaseSpec] | None] +_OptimizerDict = dict[str, tuple[torch.optim.Optimizer, BaseSpec] | None] +_SchedulerDict = dict[str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec] | None] +_Associations = dict[str, dict[str, Any]] + + +class CheckpointManifest(BaseModel): + """Unified checkpoint manifest and runtime container. + + This Pydantic model serves a dual role: + + 1. **On-disk schema** — ``manifest.json`` stores component names as + sorted string lists together with metadata and associations. + 2. **Runtime container** — after :func:`load_checkpoint` hydrates the + components, the same instance carries live ``(object, spec)`` tuples. + + The ``models``, ``optimizers``, and ``schedulers`` fields accept + either a ``list[str]`` (from JSON) or a ``dict[str, tuple]`` (from + code). Serialization always produces sorted name lists via + :class:`~pydantic.PlainSerializer`. + + Attributes + ---------- + schema_version + Schema version. Defaults to the current ``_SCHEMA_VERSION``. + When manifest structure changes, bump ``_SCHEMA_VERSION`` and + add a migration step in :meth:`read`. + checkpoint_index + The latest checkpoint index written. + models + Component dict keyed by name. At runtime each value is a + ``(nn.Module, BaseSpec)`` tuple; on disk, serialized as a + sorted ``list[str]`` of names. + optimizers + Same dual-mode dict for optimizers (empty by default). + schedulers + Same dual-mode dict for schedulers (empty by default). + associations + Model-centric linkage: maps a model name to + ``{"optimizers": [...], "schedulers": [...]}``. + + Examples + -------- + >>> manifest = CheckpointManifest( + ... checkpoint_index=0, models={"main": None}, + ... ) + >>> manifest.model_dump()["models"] + ['main'] + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + schema_version: Annotated[ + int, Field(default=_SCHEMA_VERSION, description="Manifest schema version.") + ] + checkpoint_index: Annotated[ + int, Field(description="Latest checkpoint index written.") + ] + models: Annotated[ + _ModelDict, + BeforeValidator(_component_before), + PlainSerializer(_component_serialize, return_type=list[str]), + Field(description="Model components keyed by name."), + ] + optimizers: Annotated[ + _OptimizerDict, + BeforeValidator(_component_before), + PlainSerializer(_component_serialize, return_type=list[str]), + Field(default_factory=dict, description="Optimizer components keyed by name."), + ] + schedulers: Annotated[ + _SchedulerDict, + BeforeValidator(_component_before), + PlainSerializer(_component_serialize, return_type=list[str]), + Field(default_factory=dict, description="Scheduler components keyed by name."), + ] + associations: Annotated[ + _Associations, + Field( + default_factory=dict, + description="Model-centric linkage to optimizers/schedulers.", + ), + ] + + @staticmethod + def _migrate(raw: dict[str, Any]) -> dict[str, Any]: + """Migrate an older manifest dict to the current schema version. + + Parameters + ---------- + raw + Parsed ``manifest.json`` content. + + Returns + ------- + dict[str, Any] + Dict conforming to the current ``_SCHEMA_VERSION``, ready + for :meth:`pydantic.BaseModel.model_validate`. + + Raises + ------ + ValueError + If the manifest's schema version is newer than supported. + """ + version = raw.get("schema_version", 0) + if version > _SCHEMA_VERSION: + raise ValueError( + f"Checkpoint schema version {version} is newer than supported " + f"({_SCHEMA_VERSION}). Upgrade nvalchemi to load this checkpoint." + ) + # Future migrations chain here: + # if version < 1: + # raw = _migrate_v0_to_v1(raw) + raw["schema_version"] = _SCHEMA_VERSION + return raw + + @classmethod + def read(cls, root: Path) -> CheckpointManifest: + """Read, migrate, and validate ``manifest.json`` from *root*. + + Parameters + ---------- + root + Checkpoint root directory containing ``manifest.json``. + + Returns + ------- + CheckpointManifest + Validated manifest instance. Component dicts contain + placeholder ``None`` values until hydrated by + :func:`load_checkpoint`. + + Raises + ------ + FileNotFoundError + If ``manifest.json`` does not exist. + ValueError + If the manifest's schema version is newer than supported. + pydantic.ValidationError + If the manifest JSON does not conform to the schema. + """ + manifest_path = root / "manifest.json" + if not manifest_path.exists(): + raise FileNotFoundError( + f"No manifest.json found in {root}. Use save_checkpoint to " + f"create a checkpoint first." + ) + raw = json.loads(manifest_path.read_text()) + migrated = cls._migrate(raw) + return cls.model_validate(migrated) + + def write(self, root: Path) -> None: + """Write this manifest to ``{root}/manifest.json``. + + Parameters + ---------- + root + Checkpoint root directory. + """ + (root / "manifest.json").write_text(self.model_dump_json(indent=2)) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _ckpt_indices(ckpt_dir: Path) -> list[int]: + """Return sorted integer stems from ``*.pt`` files in *ckpt_dir*.""" + return sorted(int(p.stem) for p in ckpt_dir.glob("*.pt") if p.stem.isdigit()) + + +def _without_spec_timestamps(value: Any) -> Any: + """Return JSON-like *value* with BaseSpec timestamps removed recursively.""" + if isinstance(value, dict): + return { + key: _without_spec_timestamps(item) + for key, item in value.items() + if not (key == "timestamp" and "cls_path" in value) + } + if isinstance(value, list): + return [_without_spec_timestamps(item) for item in value] + return value + + +def _check_spec_consistency(spec_path: Path, spec: BaseSpec) -> None: + """Write *spec* to *spec_path* on first call; raise on mismatch thereafter. + + Parameters + ---------- + spec_path + Path to the ``spec.json`` file. + spec + The spec to write or compare against the existing file. + + Raises + ------ + ValueError + If the existing ``spec.json`` disagrees with *spec* on any field + other than ``timestamp``. + """ + spec_json = spec.model_dump_json(indent=2) + if spec_path.exists(): + existing = _without_spec_timestamps(json.loads(spec_path.read_text())) + new_spec = _without_spec_timestamps(json.loads(spec_json)) + if existing != new_spec: + diffs = sorted( + k + for k in set(existing) | set(new_spec) + if existing.get(k) != new_spec.get(k) + ) + preview = ", ".join( + f"{k}: {existing.get(k)!r} -> {new_spec.get(k)!r}" for k in diffs[:3] + ) + suffix = f" (+{len(diffs) - 3} more)" if len(diffs) > 3 else "" + raise ValueError( + f"spec.json at {spec_path} disagrees with the spec being " + f"saved. Differing fields: {preview}{suffix}." + ) + else: + spec_path.write_text(spec_json) + + +def _save_component( + root: Path, + category: str, + name: str, + state_dict: dict[str, Any], + spec: BaseSpec, + checkpoint_index: int, +) -> None: + """Write *spec* and *state_dict* under ``root/category/name/``.""" + comp_dir = root / category / name + ckpt_dir = comp_dir / "checkpoints" + ckpt_dir.mkdir(parents=True, exist_ok=True) + _check_spec_consistency(comp_dir / "spec.json", spec) + torch.save(state_dict, ckpt_dir / f"{checkpoint_index}.pt") + + +def _snapshot_state_value(value: Any) -> Any: + """Return a CPU copy of tensors nested inside a state-dict value.""" + if isinstance(value, torch.Tensor): + return value.detach().to(device="cpu", copy=True) + if isinstance(value, Mapping): + return {key: _snapshot_state_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_snapshot_state_value(item) for item in value] + if isinstance(value, tuple): + return tuple(_snapshot_state_value(item) for item in value) + return value + + +def _snapshot_state_dict(state_dict: Mapping[str, Any]) -> dict[str, Any]: + """Return a CPU-only state dict detached from live training objects.""" + return {key: _snapshot_state_value(value) for key, value in state_dict.items()} + + +def _snapshot_components( + components: Mapping[str, tuple[Any, BaseSpec]], +) -> dict[str, tuple[dict[str, Any], BaseSpec]]: + """Capture component state dicts and specs for asynchronous writing.""" + return { + name: (_snapshot_state_dict(component.state_dict()), spec) + for name, (component, spec) in components.items() + } + + +def _hook_state_key(hook: object, occurrence: int) -> str: + """Return the stable class-occurrence key used for hook state matching.""" + return f"{type(hook).__module__}.{type(hook).__qualname__}:{occurrence}" + + +def _iter_checkpointable_hooks(hooks: Iterable[object]) -> Iterator[CheckpointableHook]: + """Yield hooks that explicitly opt into checkpointed runtime state.""" + for hook in hooks: + children = getattr(hook, "_hooks", None) + if isinstance(children, Sequence) and not isinstance(children, (str, bytes)): + yield from _iter_checkpointable_hooks(children) + if isinstance(hook, CheckpointableHook): + yield hook + + +def _snapshot_hook_states(strategy: Any) -> dict[str, dict[str, Any]]: + """Capture checkpointable runtime hook state detached from live tensors.""" + states: dict[str, dict[str, Any]] = {} + occurrences: dict[str, int] = {} + for hook in _iter_checkpointable_hooks(strategy.hooks): + class_name = f"{type(hook).__module__}.{type(hook).__qualname__}" + occurrence = occurrences.get(class_name, 0) + occurrences[class_name] = occurrence + 1 + states[_hook_state_key(hook, occurrence)] = _snapshot_state_dict( + hook.state_dict() + ) + return states + + +def _hook_state_path(root: Path, checkpoint_index: int) -> Path: + """Return the hook-state checkpoint path for ``checkpoint_index``.""" + return root / _HOOK_CHECKPOINT_DIR / f"{checkpoint_index}.pt" + + +def _save_hook_states( + root: Path, + hook_states: Mapping[str, Mapping[str, Any]], + checkpoint_index: int, +) -> None: + """Write hook state for a checkpoint when checkpointable hooks are present.""" + if not hook_states: + return + path = _hook_state_path(root, checkpoint_index) + path.parent.mkdir(parents=True, exist_ok=True) + state_dict = dict(hook_states) + torch.save(state_dict, path) + + +def _load_hook_states( + root: Path, + strategy: Any, + checkpoint_index: int, + *, + map_location: str | torch.device | None, +) -> None: + """Restore matching checkpointable hook state into a loaded strategy.""" + path = _hook_state_path(root, checkpoint_index) + if not path.exists(): + return + saved_states = torch.load( + path, + weights_only=True, + map_location=map_location, + ) + occurrences: dict[str, int] = {} + for hook in _iter_checkpointable_hooks(strategy.hooks): + class_name = f"{type(hook).__module__}.{type(hook).__qualname__}" + occurrence = occurrences.get(class_name, 0) + occurrences[class_name] = occurrence + 1 + state = saved_states.get(_hook_state_key(hook, occurrence)) + if state is not None: + hook.load_state_dict(state) + + +def _resolve_checkpoint_index(root: Path, checkpoint_index: int) -> int: + """Return an explicit checkpoint index, resolving ``-1`` by auto-increment.""" + if checkpoint_index != -1: + return checkpoint_index + manifest_path = root / "manifest.json" + if manifest_path.exists(): + prev = CheckpointManifest.read(root) + return prev.checkpoint_index + 1 + return 0 + + +def _create_checkpoint_snapshot( + root_folder: Path | str, + *, + checkpoint_index: int = -1, + strategy: Any, +) -> dict[str, Any]: + """Capture a strategy checkpoint payload detached from live tensors. + + The snapshot is intended for background filesystem writes. It still runs + on the caller thread and copies tensors to CPU so later training updates + cannot mutate data while :func:`torch.save` serializes it. + """ + from nvalchemi.training.strategy import TrainingStrategy + + if not isinstance(strategy, TrainingStrategy): + raise TypeError( + "strategy must be a TrainingStrategy instance; got " + f"{type(strategy).__name__}." + ) + root = Path(root_folder) + models, optimizers, schedulers, associations, strategy_metadata = ( + _strategy_components(strategy) + ) + return { + "checkpoint_index": _resolve_checkpoint_index(root, checkpoint_index), + "models": _snapshot_components(models), + "optimizers": _snapshot_components(optimizers), + "schedulers": _snapshot_components(schedulers), + "associations": _copy_associations(associations), + "strategy_metadata": dict(strategy_metadata), + "hook_states": _snapshot_hook_states(strategy), + } + + +def _write_checkpoint_snapshot( + root_folder: Path | str, snapshot: Mapping[str, Any] +) -> int: + """Write a detached checkpoint snapshot to disk.""" + root = Path(root_folder) + checkpoint_index = int(snapshot["checkpoint_index"]) + models = snapshot["models"] + optimizers = snapshot["optimizers"] + schedulers = snapshot["schedulers"] + associations = snapshot["associations"] + strategy_metadata = snapshot.get("strategy_metadata") + hook_states = snapshot.get("hook_states", {}) + + for name, (state_dict, spec) in models.items(): + _save_component( + root, + "models", + name, + state_dict, + spec, + checkpoint_index, + ) + for name, (state_dict, spec) in optimizers.items(): + _save_component( + root, + "optimizers", + name, + state_dict, + spec, + checkpoint_index, + ) + for name, (state_dict, spec) in schedulers.items(): + _save_component( + root, + "schedulers", + name, + state_dict, + spec, + checkpoint_index, + ) + + manifest = CheckpointManifest( + checkpoint_index=checkpoint_index, + models={name: None for name in models}, + optimizers={name: None for name in optimizers}, + schedulers={name: None for name in schedulers}, + associations=associations, + ) + manifest.write(root) + _save_hook_states(root, hook_states, checkpoint_index) + if strategy_metadata is not None: + _write_strategy_metadata( + root, strategy_metadata, checkpoint_index=checkpoint_index + ) + return checkpoint_index + + +def _assoc_names(assoc: Mapping[str, Any], key: str) -> list[str]: + """Return an association list field, tolerating older or malformed entries.""" + raw = assoc.get(key, []) + return list(raw) if isinstance(raw, list) else [] + + +def _assoc_scheduler_optimizers(assoc: Mapping[str, Any]) -> dict[str, str]: + """Return scheduler-to-optimizer association edges from *assoc*.""" + raw = assoc.get(_SCHEDULER_OPTIMIZERS_KEY, {}) + if not isinstance(raw, Mapping): + return {} + return {str(scheduler): str(optimizer) for scheduler, optimizer in raw.items()} + + +def _copy_associations(associations: Mapping[str, Mapping[str, Any]]) -> _Associations: + """Return a shallow JSON-like copy of association entries.""" + copied: _Associations = {} + for model_name, assoc in associations.items(): + entry: dict[str, Any] = { + "optimizers": _assoc_names(assoc, "optimizers"), + "schedulers": _assoc_names(assoc, "schedulers"), + } + scheduler_optimizers = _assoc_scheduler_optimizers(assoc) + if scheduler_optimizers: + entry[_SCHEDULER_OPTIMIZERS_KEY] = scheduler_optimizers + copied[model_name] = entry + return copied + + +def _scheduler_optimizer_edges( + optimizers: Mapping[str, tuple[torch.optim.Optimizer, BaseSpec]], + schedulers: Mapping[str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec]], +) -> dict[str, str]: + """Return scheduler component names keyed to their optimizer component names.""" + edges: dict[str, str] = {} + for scheduler_name, (scheduler, _) in schedulers.items(): + for optimizer_name, (optimizer, _) in optimizers.items(): + if scheduler.optimizer is optimizer: # type: ignore[attr-defined] + edges[scheduler_name] = optimizer_name + break + return edges + + +def _with_scheduler_optimizer_edges( + associations: Mapping[str, Mapping[str, Any]], + optimizers: Mapping[str, tuple[torch.optim.Optimizer, BaseSpec]], + schedulers: Mapping[str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec]], +) -> _Associations: + """Attach explicit scheduler-to-optimizer edges to model associations.""" + enriched = _copy_associations(associations) + edges = _scheduler_optimizer_edges(optimizers, schedulers) + if not edges: + return enriched + + for assoc in enriched.values(): + optimizer_names = set(_assoc_names(assoc, "optimizers")) + scheduler_names = set(_assoc_names(assoc, "schedulers")) + model_edges = { + scheduler_name: optimizer_name + for scheduler_name, optimizer_name in edges.items() + if scheduler_name in scheduler_names and optimizer_name in optimizer_names + } + if model_edges: + assoc[_SCHEDULER_OPTIMIZERS_KEY] = { + **_assoc_scheduler_optimizers(assoc), + **model_edges, + } + return enriched + + +def _infer_associations( + models: dict[str, tuple[nn.Module, BaseSpec]], + optimizers: dict[str, tuple[torch.optim.Optimizer, BaseSpec]], + schedulers: dict[str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec]], +) -> _Associations: + """Infer model-centric associations from optimizer ``param_groups``. + + For each optimizer, collect the ``data_ptr()`` values of every parameter + in its ``param_groups`` and match against each model's ``parameters()``. + The optimizer is associated with every model that owns at least one of + those parameters. + + Schedulers are linked to their optimizer via + ``scheduler.optimizer is optimizer`` identity checks. + + Parameters + ---------- + models + ``{name: (module, spec)}`` mapping. + optimizers + ``{name: (optimizer, spec)}`` mapping. + schedulers + ``{name: (scheduler, spec)}`` mapping. + + Returns + ------- + dict[str, dict[str, list[str]]] + Model-centric associations, e.g. + ``{"student": {"optimizers": ["s_opt"], "schedulers": ["s_sched"]}}``. + """ + # Build data_ptr → model_name index + ptr_to_model: dict[int, str] = {} + for model_name, (module, _) in models.items(): + for p in module.parameters(): + ptr_to_model[p.data_ptr()] = model_name + + # Map each optimizer to every model that owns at least one parameter + opt_to_models: dict[str, list[str]] = {} + for opt_name, (optimizer, _) in optimizers.items(): + matched: dict[str, bool] = {} + for group in optimizer.param_groups: + for p in group["params"]: + owner = ptr_to_model.get(p.data_ptr()) + if owner is not None: + matched[owner] = True + if matched: + opt_to_models[opt_name] = list(matched) + + # Map each scheduler to its optimizer (identity check) + sched_to_opt = _scheduler_optimizer_edges(optimizers, schedulers) + + # Build model-centric structure + assoc: _Associations = {} + for opt_name, model_names in opt_to_models.items(): + for model_name in model_names: + assoc.setdefault(model_name, {"optimizers": [], "schedulers": []}) + assoc[model_name]["optimizers"].append(opt_name) + for sched_name, opt_name in sched_to_opt.items(): + model_names = opt_to_models.get(opt_name, []) + for model_name in model_names: + assoc.setdefault(model_name, {"optimizers": [], "schedulers": []}) + assoc[model_name]["schedulers"].append(sched_name) + scheduler_optimizers = assoc[model_name].setdefault( + _SCHEDULER_OPTIMIZERS_KEY, {} + ) + scheduler_optimizers[sched_name] = opt_name + + return assoc + + +def _find_associated_model_params( + optimizer_name: str, + associations: _Associations, + models: dict[str, tuple[nn.Module, BaseSpec]], +) -> Iterator[torch.nn.Parameter]: + """Return chained parameters from all models associated with *optimizer_name*.""" + matched: list[str] = [] + for model_name, assoc in associations.items(): + if optimizer_name in _assoc_names(assoc, "optimizers"): + matched.append(model_name) + if matched: + return itertools.chain.from_iterable( + models[name][0].parameters() for name in matched + ) + # Fallback: if exactly one model exists, use it + if len(models) == 1: + return next(iter(models.values()))[0].parameters() + raise ValueError( + f"Cannot determine which model's parameters to use for optimizer " + f"{optimizer_name!r}. Provide associations or use a single model." + ) + + +def _find_associated_optimizer( + scheduler_name: str, + associations: _Associations, + optimizers: dict[str, tuple[torch.optim.Optimizer, BaseSpec]], +) -> torch.optim.Optimizer: + """Return the optimizer whose associations include *scheduler_name*.""" + for assoc in associations.values(): + edge = _assoc_scheduler_optimizers(assoc).get(scheduler_name) + if edge is not None: + if edge in optimizers: + return optimizers[edge][0] + raise ValueError( + f"Scheduler {scheduler_name!r} is associated with optimizer " + f"{edge!r}, but that optimizer was not loaded." + ) + + scheduler_names = _assoc_names(assoc, "schedulers") + optimizer_names = _assoc_names(assoc, "optimizers") + if scheduler_name in scheduler_names: + scheduler_index = scheduler_names.index(scheduler_name) + if scheduler_index < len(optimizer_names): + optimizer_name = optimizer_names[scheduler_index] + if optimizer_name in optimizers: + return optimizers[optimizer_name][0] + # Fallback: if exactly one optimizer exists, use it + if len(optimizers) == 1: + return next(iter(optimizers.values()))[0] + raise ValueError( + f"Cannot determine which optimizer to use for scheduler " + f"{scheduler_name!r}. Provide associations or use a single optimizer." + ) + + +def _strategy_metadata_path(root: Path) -> Path: + """Return the checkpoint strategy metadata path under ``root``.""" + return root / _STRATEGY_FILENAME + + +def _indexed_strategy_metadata_path(root: Path, checkpoint_index: int) -> Path: + """Return the per-index strategy metadata path under ``root``.""" + return root / _STRATEGY_CHECKPOINT_DIR / f"{checkpoint_index}.json" + + +def _read_strategy_metadata( + root: Path, + *, + checkpoint_index: int, + latest_checkpoint_index: int, +) -> dict[str, Any] | None: + """Read strategy checkpoint metadata if the checkpoint contains it.""" + indexed_path = _indexed_strategy_metadata_path(root, checkpoint_index) + if indexed_path.exists(): + return json.loads(indexed_path.read_text()) + + path = _strategy_metadata_path(root) + if not path.exists(): + return None + if checkpoint_index != latest_checkpoint_index: + raise FileNotFoundError( + "This checkpoint has root-level strategy metadata only, so " + f"checkpoint_index={checkpoint_index} cannot be loaded coherently. " + f"Load the latest index ({latest_checkpoint_index}) or recreate the " + "checkpoint with per-index strategy metadata." + ) + return json.loads(path.read_text()) + + +def _write_strategy_metadata( + root: Path, + metadata: Mapping[str, Any], + *, + checkpoint_index: int, +) -> None: + """Write latest and per-index JSON strategy metadata.""" + root.mkdir(parents=True, exist_ok=True) + payload = json.dumps(metadata, indent=2) + _strategy_metadata_path(root).write_text(payload) + indexed_path = _indexed_strategy_metadata_path(root, checkpoint_index) + indexed_path.parent.mkdir(parents=True, exist_ok=True) + indexed_path.write_text(payload) + + +def _component_name(model_name: str, kind: str, index: int, count: int) -> str: + """Return a stable optimizer/scheduler component name for a model config.""" + suffix = kind if count == 1 else f"{kind}_{index}" + return f"{model_name}_{suffix}" + + +def _models_from_strategy_metadata( + strategy: Any, + metadata: Mapping[str, Any], +) -> dict[str, tuple[nn.Module, BaseSpec]]: + """Collect model components and specs from a strategy checkpoint payload.""" + raw_specs = metadata.get("model_specs", {}) + if not isinstance(raw_specs, Mapping): + raise ValueError("strategy checkpoint metadata has invalid 'model_specs'.") + + models: dict[str, tuple[nn.Module, BaseSpec]] = {} + missing: list[str] = [] + for name, module in strategy.models.items(): + checkpoint_module = _checkpoint_model(module) + raw = raw_specs.get(name) + if raw is None: + missing.append(name) + continue + models[name] = (checkpoint_module, create_model_spec_from_json(dict(raw))) + if missing: + raise ValueError( + "Cannot save strategy checkpoint because model spec generation " + f"failed for model(s) {missing!r}. Ensure these models can be " + "reconstructed from BaseSpec before checkpointing." + ) + return models + + +def _strategy_components( + strategy: Any, +) -> tuple[ + dict[str, tuple[nn.Module, BaseSpec]], + dict[str, tuple[torch.optim.Optimizer, BaseSpec]], + dict[str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec]], + _Associations, + dict[str, Any], +]: + """Extract manifest components from a :class:`TrainingStrategy` instance.""" + metadata = strategy.to_checkpoint_dict() + models = _models_from_strategy_metadata(strategy, metadata) + flat_opts, flat_scheds = strategy._setup_runtime_optimizers(rebuild=False) + + optimizers: dict[str, tuple[torch.optim.Optimizer, BaseSpec]] = {} + schedulers: dict[str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec]] = {} + associations: _Associations = {} + + cursor = 0 + for model_name, configs in strategy.optimizer_configs.items(): + assoc = associations.setdefault( + model_name, {"optimizers": [], "schedulers": []} + ) + for index, config in enumerate(configs): + try: + optimizer = flat_opts[cursor] + scheduler = flat_scheds[cursor] + except IndexError as exc: + raise RuntimeError( + "Strategy optimizer state is inconsistent with optimizer_configs." + ) from exc + + optimizer_name = _component_name( + model_name, "optimizer", index, len(configs) + ) + optimizers[optimizer_name] = ( + optimizer, + create_model_spec(config.optimizer_cls, **config.optimizer_kwargs), + ) + assoc["optimizers"].append(optimizer_name) + + if scheduler is not None: + if config.scheduler_cls is None: + raise RuntimeError( + f"Strategy has scheduler state for {optimizer_name!r}, " + "but its OptimizerConfig has scheduler_cls=None." + ) + scheduler_name = _component_name( + model_name, "scheduler", index, len(configs) + ) + schedulers[scheduler_name] = ( + scheduler, + create_model_spec(config.scheduler_cls, **config.scheduler_kwargs), + ) + assoc["schedulers"].append(scheduler_name) + scheduler_optimizers = assoc.setdefault(_SCHEDULER_OPTIMIZERS_KEY, {}) + scheduler_optimizers[scheduler_name] = optimizer_name + cursor += 1 + + return models, optimizers, schedulers, associations, metadata + + +def _loaded_model_objects( + manifest: CheckpointManifest, +) -> dict[str, nn.Module]: + """Return loaded models from a hydrated manifest.""" + return {name: pair[0] for name, pair in manifest.models.items() if pair is not None} + + +def _install_strategy_optimizer_state( + strategy: Any, manifest: CheckpointManifest +) -> None: + """Attach loaded optimizer/scheduler objects to a strategy for restart.""" + flat_opts: list[torch.optim.Optimizer] = [] + flat_scheds: list[torch.optim.lr_scheduler.LRScheduler | None] = [] + for model_name, configs in strategy.optimizer_configs.items(): + for index, config in enumerate(configs): + optimizer_name = _component_name( + model_name, "optimizer", index, len(configs) + ) + optimizer_pair = manifest.optimizers.get(optimizer_name) + if optimizer_pair is None: + raise ValueError( + f"Checkpoint strategy expects optimizer {optimizer_name!r}, " + "but it was not loaded from the manifest." + ) + flat_opts.append(optimizer_pair[0]) + + scheduler_name = _component_name( + model_name, "scheduler", index, len(configs) + ) + scheduler_pair = manifest.schedulers.get(scheduler_name) + if config.scheduler_cls is not None and scheduler_pair is None: + raise ValueError( + f"Checkpoint strategy expects scheduler {scheduler_name!r}, " + "but it was not loaded from the manifest." + ) + flat_scheds.append( + scheduler_pair[0] if scheduler_pair is not None else None + ) + + strategy._optimizers = flat_opts + strategy._lr_schedulers = flat_scheds + strategy._resume_optimizer_state = bool(flat_opts) + + +def _restore_strategy_runtime_state( + strategy: Any, + metadata: Mapping[str, Any] | None, +) -> None: + """Restore saved runtime counters into a live strategy.""" + if metadata is None: + return + runtime_state = metadata.get("runtime_state", {}) + if runtime_state is None: + return + if not isinstance(runtime_state, Mapping): + raise ValueError( + "strategy checkpoint metadata has invalid 'runtime_state'; " + f"got {type(runtime_state).__name__}." + ) + for key in ("step_count", "batch_count", "epoch_count", "epoch_step_count"): + if key in runtime_state: + value = int(runtime_state[key]) + if value < 0: + raise ValueError( + f"strategy checkpoint runtime counter {key!r} must be " + f"non-negative; got {value}." + ) + setattr(strategy, key, value) + + +def _optimizer_scheduler_maps_from_strategy( + strategy: Any, +) -> tuple[ + dict[str, torch.optim.Optimizer], + dict[str, torch.optim.lr_scheduler.LRScheduler], +]: + """Return checkpoint component-name maps for a live strategy runtime.""" + flat_opts, flat_scheds = strategy._setup_runtime_optimizers(rebuild=False) + optimizers: dict[str, torch.optim.Optimizer] = {} + schedulers: dict[str, torch.optim.lr_scheduler.LRScheduler] = {} + + cursor = 0 + for model_name, configs in strategy.optimizer_configs.items(): + for index, config in enumerate(configs): + try: + optimizer = flat_opts[cursor] + scheduler = flat_scheds[cursor] + except IndexError as exc: + raise RuntimeError( + "Strategy optimizer state is inconsistent with optimizer_configs." + ) from exc + + optimizers[ + _component_name(model_name, "optimizer", index, len(configs)) + ] = optimizer + if scheduler is not None: + schedulers[ + _component_name(model_name, "scheduler", index, len(configs)) + ] = scheduler + cursor += 1 + + return optimizers, schedulers + + +def _restore_checkpoint_into_strategy( + root: Path, + manifest: CheckpointManifest, + *, + checkpoint_index: int, + strategy: Any, + strategy_metadata: Mapping[str, Any] | None, + map_location: str | torch.device | None, +) -> dict[str, Any]: + """Load checkpoint state into an already-constructed strategy.""" + from nvalchemi.training.strategy import TrainingStrategy + + if not isinstance(strategy, TrainingStrategy): + raise TypeError( + "strategy must be a TrainingStrategy instance; got " + f"{type(strategy).__name__}." + ) + + missing_models = sorted(set(manifest.models) - set(strategy.models)) + if missing_models: + raise KeyError( + "Checkpoint contains model(s) not present in the live strategy: " + f"{missing_models!r}." + ) + + loaded_models: dict[str, tuple[nn.Module, BaseSpec | None]] = {} + for name in manifest.models: + model = _checkpoint_model(strategy.models[name]) + weights = torch.load( + root / "models" / name / "checkpoints" / f"{checkpoint_index}.pt", + weights_only=True, + map_location=map_location, + ) + model.load_state_dict(weights) + spec_path = root / "models" / name / "spec.json" + spec = _load_spec(spec_path) if spec_path.exists() else None + loaded_models[name] = (model, spec) + + live_optimizers, live_schedulers = _optimizer_scheduler_maps_from_strategy(strategy) + + loaded_optimizers: dict[str, tuple[torch.optim.Optimizer, BaseSpec | None]] = {} + missing_optimizers = sorted(set(manifest.optimizers) - set(live_optimizers)) + if missing_optimizers: + raise KeyError( + "Checkpoint contains optimizer(s) not present in the live strategy: " + f"{missing_optimizers!r}." + ) + for name in manifest.optimizers: + optimizer = live_optimizers[name] + state = torch.load( + root / "optimizers" / name / "checkpoints" / f"{checkpoint_index}.pt", + weights_only=True, + map_location=map_location, + ) + optimizer.load_state_dict(state) + spec_path = root / "optimizers" / name / "spec.json" + spec = _load_spec(spec_path) if spec_path.exists() else None + loaded_optimizers[name] = (optimizer, spec) + + loaded_schedulers: dict[ + str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec | None] + ] = {} + missing_schedulers = sorted(set(manifest.schedulers) - set(live_schedulers)) + if missing_schedulers: + raise KeyError( + "Checkpoint contains scheduler(s) not present in the live strategy: " + f"{missing_schedulers!r}." + ) + for name in manifest.schedulers: + scheduler = live_schedulers[name] + state = torch.load( + root / "schedulers" / name / "checkpoints" / f"{checkpoint_index}.pt", + weights_only=True, + map_location=map_location, + ) + scheduler.load_state_dict(state) + spec_path = root / "schedulers" / name / "spec.json" + spec = _load_spec(spec_path) if spec_path.exists() else None + loaded_schedulers[name] = (scheduler, spec) + + strategy._resume_optimizer_state = bool(loaded_optimizers) + _restore_strategy_runtime_state(strategy, strategy_metadata) + _load_hook_states( + root, + strategy, + checkpoint_index, + map_location=map_location, + ) + + manifest.models = loaded_models + manifest.optimizers = loaded_optimizers + manifest.schedulers = loaded_schedulers + manifest.checkpoint_index = checkpoint_index + return _manifest_to_loaded_checkpoint(manifest, root=root, strategy=strategy) + + +def _manifest_to_loaded_checkpoint( + manifest: CheckpointManifest, + *, + root: Path, + strategy: Any = None, + source_format: str = "native", +) -> dict[str, Any]: + """Convert a hydrated manifest into the high-level builtin dict shape.""" + models: dict[str, dict[str, Any]] = {} + for model_name, pair in manifest.models.items(): + if pair is None: + continue + model, spec = pair + assoc = manifest.associations.get( + model_name, {"optimizers": [], "schedulers": []} + ) + model_optimizers = { + name: {"optimizer": opt_pair[0], "spec": opt_pair[1]} + for name in _assoc_names(assoc, "optimizers") + if (opt_pair := manifest.optimizers.get(name)) is not None + } + model_schedulers = { + name: {"scheduler": sched_pair[0], "spec": sched_pair[1]} + for name in _assoc_names(assoc, "schedulers") + if (sched_pair := manifest.schedulers.get(name)) is not None + } + models[model_name] = { + "model": model, + "spec": spec, + "optimizers": model_optimizers, + "schedulers": model_schedulers, + "metadata": {"associations": assoc}, + } + + return { + "strategy": strategy, + "models": models, + "manifest": manifest, + "checkpoint_index": manifest.checkpoint_index, + "source": {"format": source_format, "path": str(root)}, + } + + +def _run_validators( + loaded: Mapping[str, Any], + validators: Sequence[CheckpointValidator] | None, +) -> None: + """Run caller-supplied validators against each loaded model entry.""" + if not validators: + return + source = loaded.get("source", {}) + source_path = ( + source.get("path", "") if isinstance(source, Mapping) else source + ) + for model_name, entry in loaded.get("models", {}).items(): + for validator in validators: + validator_name = getattr(validator, "__name__", type(validator).__name__) + try: + validator(model_name, entry, loaded) + except Exception as exc: + raise ValueError( + f"Checkpoint validator {validator_name!r} failed for model " + f"{model_name!r} loaded from {source_path}: {exc}" + ) from exc + + +def _load_mace_checkpoint( + checkpoint_path: Path, + *, + map_location: str | torch.device | None, + adapter_kwargs: Mapping[str, Any] | None, +) -> dict[str, Any]: + """Load a local MACE checkpoint through :class:`MACEWrapper`.""" + kwargs = dict(adapter_kwargs or {}) + allowed = {"model_name", "dtype", "enable_cueq", "compile_model", "compile_kwargs"} + unknown = sorted(set(kwargs) - allowed) + if unknown: + raise ValueError(f"Unknown MACE adapter option(s): {unknown}.") + + if not checkpoint_path.is_file(): + raise FileNotFoundError( + "The MACE checkpoint adapter only accepts local checkpoint files; " + f"{checkpoint_path} does not exist." + ) + + model_name = kwargs.pop("model_name", "main") + dtype = kwargs.pop("dtype", None) + enable_cueq = kwargs.pop("enable_cueq", False) + compile_model = kwargs.pop("compile_model", False) + compile_kwargs = kwargs.pop("compile_kwargs", {}) + if not isinstance(compile_kwargs, Mapping): + raise TypeError("MACE adapter option 'compile_kwargs' must be a mapping.") + + device = torch.device("cpu") if map_location is None else torch.device(map_location) + warnings.warn( + "Loading MACE .pt checkpoints requires the MACE full-model pickle " + "loader under the hood. Only load local MACE checkpoints from trusted " + "sources.", + UserWarning, + stacklevel=2, + ) + from nvalchemi.models.mace import MACEWrapper + + model = MACEWrapper.from_checkpoint( + checkpoint_path, + device=device, + dtype=dtype, + enable_cueq=enable_cueq, + compile_model=compile_model, + **dict(compile_kwargs), + ) + return { + "strategy": None, + "models": { + model_name: { + "model": model, + "spec": None, + "optimizers": {}, + "schedulers": {}, + "metadata": {"adapter": "mace"}, + } + }, + "manifest": None, + "checkpoint_index": None, + "source": {"format": "mace", "path": str(checkpoint_path)}, + } + + +def _strategy_target_device( + strategy_metadata: Mapping[str, Any] | None, + map_location: str | torch.device | None, +) -> torch.device | None: + """Return the model/optimizer load device for a strategy checkpoint.""" + if map_location is not None: + return torch.device(map_location) + if strategy_metadata is None: + return None + + raw_devices = strategy_metadata.get("devices") + if not isinstance(raw_devices, Sequence) or isinstance(raw_devices, str): + return None + if not raw_devices: + return None + return torch.device(raw_devices[0]) + + +def _with_strategy_device_override( + strategy_metadata: Mapping[str, Any], + map_location: str | torch.device | None, +) -> dict[str, Any]: + """Return strategy metadata with runtime devices overridden when requested.""" + metadata = dict(strategy_metadata) + if map_location is not None: + metadata["devices"] = [str(torch.device(map_location))] + return metadata + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def save_checkpoint( + root_folder: Path | str, + models: dict[str, tuple[nn.Module, BaseSpec]] | Any | None = None, + optimizers: dict[str, tuple[torch.optim.Optimizer, BaseSpec]] | None = None, + schedulers: ( + dict[str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec]] | None + ) = None, + associations: _Associations | None = None, + checkpoint_index: int = -1, + strategy: Any | None = None, +) -> int: + """Save a checkpoint with a manifest. + + The low-level component form accepts explicit ``models``, ``optimizers``, + and ``schedulers`` mappings. The strategy-aware form accepts + ``strategy=TrainingStrategy(...)`` (or the strategy as the second + positional argument) and writes additional ``strategy.json`` metadata with + the serializable recipe and restart counters. + + Parameters + ---------- + root_folder + Root directory for the checkpoint tree. + models + Mapping of model name to ``(module, spec)`` pairs, or a + :class:`~nvalchemi.training.strategy.TrainingStrategy` instance. + optimizers + Optional mapping of optimizer name to ``(optimizer, spec)`` pairs. + schedulers + Optional mapping of scheduler name to ``(scheduler, spec)`` pairs. + associations + Optional model-centric linkage mapping a model name to + ``{"optimizers": [...], "schedulers": [...]}``. When ``None`` + (default), associations are inferred automatically by matching + optimizer ``param_groups`` to model parameters via ``data_ptr()`` + identity, and schedulers to optimizers via object identity. + checkpoint_index + Index for the checkpoint files. ``-1`` (default) auto-increments + from the manifest's last index, or starts at ``0``. + strategy + Optional training strategy to save as a restartable checkpoint. + + Returns + ------- + int + The checkpoint index that was written. + + Raises + ------ + ValueError + If an existing ``spec.json`` disagrees with the spec being saved + (ignoring ``timestamp``). + + Examples + -------- + >>> import tempfile, torch.nn as nn + >>> from nvalchemi.training._spec import create_model_spec + >>> with tempfile.TemporaryDirectory() as tmp: + ... spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + ... save_checkpoint(tmp, models={"main": (nn.Linear(4, 2), spec)}) + 0 + """ + from nvalchemi.training.strategy import TrainingStrategy + + root = Path(root_folder) + strategy_metadata: dict[str, Any] | None = None + if strategy is None and isinstance(models, TrainingStrategy): + strategy = models + models = None + if strategy is not None: + if not isinstance(strategy, TrainingStrategy): + raise TypeError( + "strategy must be a TrainingStrategy instance; got " + f"{type(strategy).__name__}." + ) + ( + models, + optimizers, + schedulers, + associations, + strategy_metadata, + ) = _strategy_components(strategy) + if models is None: + raise ValueError("save_checkpoint requires models=... or strategy=....") + + models = _checkpoint_model_components(models) + optimizers = optimizers or {} + schedulers = schedulers or {} + if associations is None: + associations = _infer_associations(models, optimizers, schedulers) + else: + associations = _with_scheduler_optimizer_edges( + associations, optimizers, schedulers + ) + + checkpoint_index = _resolve_checkpoint_index(root, checkpoint_index) + + # Save each component category + for name, (module, spec) in models.items(): + _save_component( + root, "models", name, module.state_dict(), spec, checkpoint_index + ) + + for name, (opt, spec) in optimizers.items(): + _save_component( + root, "optimizers", name, opt.state_dict(), spec, checkpoint_index + ) + + for name, (sched, spec) in schedulers.items(): + _save_component( + root, "schedulers", name, sched.state_dict(), spec, checkpoint_index + ) + + # Write manifest — pass live dicts directly; PlainSerializer extracts keys + manifest = CheckpointManifest( + checkpoint_index=checkpoint_index, + models=models, + optimizers=optimizers, + schedulers=schedulers, + associations=associations, + ) + manifest.write(root) + if strategy_metadata is not None: + _write_strategy_metadata( + root, strategy_metadata, checkpoint_index=checkpoint_index + ) + if strategy is not None: + _save_hook_states(root, _snapshot_hook_states(strategy), checkpoint_index) + return checkpoint_index + + +def load_checkpoint( + root_folder: Path | str, + checkpoint_index: int = -1, + map_location: str | torch.device | None = None, + model_names: Iterable[str] | None = None, + *, + adapter: str | None = None, + adapter_kwargs: Mapping[str, Any] | None = None, + validators: Sequence[CheckpointValidator] | None = None, + hooks: Sequence[Any] | None = None, + training_fn: Any = None, + strategy: Any | None = None, +) -> CheckpointManifest | dict[str, Any]: + """Load a multi-component checkpoint written by :func:`save_checkpoint`. + + Components are rebuilt in dependency order: models first, then + optimizers (which need model parameters), then schedulers (which need + an optimizer instance). Associations from the manifest wire each + optimizer to the correct model and each scheduler to the correct + optimizer. + + Parameters + ---------- + root_folder + Root directory containing ``manifest.json``. + checkpoint_index + Index of the checkpoint to load. ``-1`` (default) loads the + latest index recorded in the manifest. + map_location + Forwarded to every :func:`torch.load` call. When not ``None``, + each loaded model is additionally moved via + ``model.to(map_location)``. Optimizers and schedulers have their + state placed by ``torch.load`` alone (they lack a standard + ``.to()`` API). + model_names + If given, load only the models with these names together with the + optimizers and schedulers wired to them through + ``manifest.associations``. Accepts any iterable of strings + (typically a set). ``None`` (default) loads every component on + disk. The returned manifest's ``associations`` still reflects the + full on-disk mapping, so callers can inspect what was not loaded. + adapter + Optional foreign-checkpoint adapter name. V1 supports ``"mace"`` for + trusted local MACE ``.pt`` files. + adapter_kwargs + Adapter-specific options. For ``adapter="mace"``, accepted keys are + ``model_name``, ``dtype``, ``enable_cueq``, ``compile_model``, and + ``compile_kwargs``. + validators + Optional callbacks invoked as ``validator(model_name, entry, loaded)`` + for each high-level loaded model entry. Use these for model-specific + chemistry or topology compatibility checks. + hooks + Runtime hooks supplied when reconstructing a saved strategy. + training_fn + Runtime training function override supplied when reconstructing a + saved strategy. + strategy + Optional already-constructed strategy to hydrate from the checkpoint. + This mode restores model, optimizer, scheduler, runtime-counter, and + checkpointable hook state into the live objects instead of rebuilding + models from saved specs. + + Returns + ------- + CheckpointManifest + For legacy component-only checkpoints, a hydrated manifest is + returned. + dict[str, Any] + For strategy checkpoints or adapter loads, a builtin dict containing + ``strategy``, ``models``, ``manifest``, ``checkpoint_index``, and + ``source`` is returned. + + Raises + ------ + FileNotFoundError + If ``manifest.json`` is missing or a checkpoint ``.pt`` file + does not exist. + KeyError + If any name in ``model_names`` does not appear in + ``manifest.models``. + RuntimeError + If a model spec does not build an :class:`~torch.nn.Module`. + + Examples + -------- + >>> import tempfile, torch.nn as nn + >>> from nvalchemi.training._spec import create_model_spec + >>> with tempfile.TemporaryDirectory() as tmp: + ... spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + ... _ = save_checkpoint(tmp, models={"main": (nn.Linear(4, 2), spec)}) + ... result = load_checkpoint(tmp) + ... isinstance(result.models["main"][0], nn.Linear) + True + + Loading onto CPU regardless of the original device:: + + result = load_checkpoint("runs/exp1", map_location="cpu") + + Selecting a subset of models (e.g., teacher and student but not the + third auxiliary model):: + + result = load_checkpoint("runs/kd", model_names={"teacher", "student"}) + """ + root = Path(root_folder) + if adapter is not None: + if strategy is not None: + raise ValueError("load_checkpoint does not support strategy with adapter.") + if adapter != "mace": + raise ValueError( + f"Unsupported checkpoint adapter {adapter!r}; supported: ['mace']." + ) + loaded = _load_mace_checkpoint( + root, + map_location=map_location, + adapter_kwargs=adapter_kwargs, + ) + _run_validators(loaded, validators) + return loaded + + manifest = CheckpointManifest.read(root) + + if checkpoint_index == -1: + checkpoint_index = manifest.checkpoint_index + + associations = manifest.associations + strategy_metadata = _read_strategy_metadata( + root, + checkpoint_index=checkpoint_index, + latest_checkpoint_index=manifest.checkpoint_index, + ) + load_location = _strategy_target_device(strategy_metadata, map_location) + + if strategy is not None: + if model_names is not None: + raise ValueError( + "load_checkpoint(strategy=...) restores the complete live strategy; " + "model_names is not supported in this mode." + ) + loaded = _restore_checkpoint_into_strategy( + root, + manifest, + checkpoint_index=checkpoint_index, + strategy=strategy, + strategy_metadata=strategy_metadata, + map_location=load_location, + ) + if strategy_metadata is not None: + loaded["strategy_metadata"] = _with_strategy_device_override( + strategy_metadata, map_location + ) + _run_validators(loaded, validators) + return loaded + + # determine what models to load + selected_models = set(manifest.models) if model_names is None else set(model_names) + unknown = selected_models - set(manifest.models) + if unknown: + raise KeyError( + f"Unknown model(s) {sorted(unknown)!r}. " + f"Available: {sorted(manifest.models)!r}" + ) + + # Build the load set as the union of each selected model's associations. + # When ``model_names is None`` this is equivalent to loading every + # component listed in the manifest. + models_to_load = [n for n in manifest.models if n in selected_models] + if model_names is None: + optimizers_to_load = list(manifest.optimizers) + schedulers_to_load = list(manifest.schedulers) + else: + wanted_optimizers: set[str] = set() + wanted_schedulers: set[str] = set() + for n in selected_models: + assoc = associations.get(n, {}) + wanted_optimizers.update(_assoc_names(assoc, "optimizers")) + wanted_schedulers.update(_assoc_names(assoc, "schedulers")) + optimizers_to_load = [n for n in manifest.optimizers if n in wanted_optimizers] + schedulers_to_load = [n for n in manifest.schedulers if n in wanted_schedulers] + + # --- Models --- + loaded_models: dict[str, tuple[nn.Module, BaseSpec]] = {} + for name in models_to_load: + spec = _load_spec(root / "models" / name / "spec.json") + build_kwargs = ( + {"device": load_location} + if load_location is not None and spec.accepts_kwarg("device") + else {} + ) + model = spec.build(**build_kwargs) + if not isinstance(model, nn.Module): + raise RuntimeError( + f"Model spec for {name!r} built {type(model)!r}, expected nn.Module." + ) + # Move models whose factories do not accept device after construction. + # Factory-loaded models such as MACE + cuEq need the device during + # construction so conversion happens on the intended accelerator. + if load_location is not None and not build_kwargs: + model.to(load_location) + weights = torch.load( + root / "models" / name / "checkpoints" / f"{checkpoint_index}.pt", + weights_only=True, + map_location=load_location, + ) + model.load_state_dict(weights) + loaded_models[name] = (model, spec) + + # --- Optimizers --- + loaded_optimizers: dict[str, tuple[torch.optim.Optimizer, BaseSpec]] = {} + for name in optimizers_to_load: + spec = _load_spec(root / "optimizers" / name / "spec.json") + params = _find_associated_model_params(name, associations, loaded_models) + optimizer = spec.build(params) + state = torch.load( + root / "optimizers" / name / "checkpoints" / f"{checkpoint_index}.pt", + weights_only=True, + map_location=load_location, + ) + optimizer.load_state_dict(state) + loaded_optimizers[name] = (optimizer, spec) + + # --- Schedulers --- + loaded_schedulers: dict[ + str, tuple[torch.optim.lr_scheduler.LRScheduler, BaseSpec] + ] = {} + for name in schedulers_to_load: + spec = _load_spec(root / "schedulers" / name / "spec.json") + assoc_optimizer = _find_associated_optimizer( + name, associations, loaded_optimizers + ) + scheduler = spec.build(assoc_optimizer) + state = torch.load( + root / "schedulers" / name / "checkpoints" / f"{checkpoint_index}.pt", + weights_only=True, + map_location=load_location, + ) + scheduler.load_state_dict(state) + loaded_schedulers[name] = (scheduler, spec) + + # Hydrate manifest with live objects + manifest.models = loaded_models + manifest.optimizers = loaded_optimizers + manifest.schedulers = loaded_schedulers + manifest.checkpoint_index = checkpoint_index + if strategy_metadata is None: + if validators is not None: + loaded = _manifest_to_loaded_checkpoint(manifest, root=root) + _run_validators(loaded, validators) + return manifest + + strategy = None + if strategy_metadata is not None and model_names is None: + from nvalchemi.training.strategy import TrainingStrategy + + loaded_strategy_models: Any = _loaded_model_objects(manifest) + if strategy_metadata.get("single_model_input") is True and set( + loaded_strategy_models + ) == {"main"}: + loaded_strategy_models = loaded_strategy_models["main"] + + runtime_strategy_metadata = _with_strategy_device_override( + strategy_metadata, map_location + ) + strategy = TrainingStrategy.from_checkpoint_dict( + runtime_strategy_metadata, + models=loaded_strategy_models, + hooks=hooks, + training_fn=training_fn, + ) + _install_strategy_optimizer_state(strategy, manifest) + _load_hook_states( + root, + strategy, + checkpoint_index, + map_location=load_location, + ) + + loaded = _manifest_to_loaded_checkpoint( + manifest, + root=root, + strategy=strategy, + ) + if strategy_metadata is not None: + loaded["strategy_metadata"] = _with_strategy_device_override( + strategy_metadata, map_location + ) + _run_validators(loaded, validators) + return loaded + + +def _load_spec(spec_path: Path) -> BaseSpec: + """Read and rehydrate a :class:`BaseSpec` from *spec_path*.""" + if not spec_path.exists(): + raise FileNotFoundError(f"Expected spec at {spec_path} but file not found.") + return create_model_spec_from_json(json.loads(spec_path.read_text())) diff --git a/nvalchemi/training/_spec.py b/nvalchemi/training/_spec.py new file mode 100644 index 00000000..be5d571c --- /dev/null +++ b/nvalchemi/training/_spec.py @@ -0,0 +1,574 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Reproducible, no-pickle serialization of MLIP hyperparameters. + +This module provides :class:`BaseSpec`, a Pydantic model that captures the +keyword arguments of any importable target callable --- typically an MLIP +constructor, model factory, optimizer, or learning-rate scheduler --- and +serializes them to plain JSON. Spec reconstruction imports the target callable +by its dotted path and invokes it with the stored kwargs. This approach ensures that ``pickle`` is not needed +to recreate objects at runtime: + +- Hyperparameters are stored as plain JSON (strings, numbers, lists, dicts). +- :class:`torch.Tensor` is serialized as ``{dtype, shape, data}`` — a data + structure, not a bytecode payload. +- :class:`torch.dtype` is serialized as its string name and rehydrated with + an :func:`isinstance` guard so that an attacker-controlled string cannot + smuggle arbitrary ``torch.*`` attributes through :func:`getattr`. +- Model weights (stored separately) must be loaded with + ``torch.load(..., weights_only=True)`` — the only pickle-free code path + that PyTorch offers for weight bundles. + +Custom (de)serializers for additional types are registered via +:func:`register_type_serializer`. The module pre-registers handlers for +:class:`torch.dtype`, :class:`torch.device`, and :class:`torch.Tensor`. +""" + +from __future__ import annotations + +import inspect +from datetime import datetime, timezone +from typing import Annotated, Any, get_args, get_origin + +import torch +from pydantic import ( + AfterValidator, + BaseModel, + ConfigDict, + Field, + SerializeAsAny, + create_model, +) + +from nvalchemi._serialization import ( + _TYPE_SERIALIZERS, + SerializableTaggedClass, + _callable_path_of, + _callable_signature, + _constructor_signature, + _deserialize_tagged_type, + _import_callable, + _is_serializable_class_annotation, + _is_tagged_type, + _wrap_class_type_annotation, + _wrap_custom_type, +) +from nvalchemi._serialization import ( + _dtype_deserialize as _dtype_deserialize, +) +from nvalchemi._serialization import ( + _import_cls as _import_cls, +) +from nvalchemi._serialization import ( + register_type_serializer as register_type_serializer, +) + +_META_FIELDS: frozenset[str] = frozenset({"cls_path", "timestamp"}) +"""Field names reserved by :class:`BaseSpec` itself; never forwarded to ``build``.""" + + +def _ensure_importable(cls_path: str) -> str: + """Pydantic validator: ensure the target path is importable and callable.""" + _import_callable(cls_path) + return cls_path + + +# --------------------------------------------------------------------------- +# Signature introspection +# --------------------------------------------------------------------------- + + +def _signature(target: Any) -> inspect.Signature: + """Return the string-annotation-resolved signature for ``target``.""" + if isinstance(target, type): + return _constructor_signature(target) + return _callable_signature(target) + + +def _check_no_positional_only(target: Any) -> None: + """Raise :class:`TypeError` if ``target`` has positional-only params.""" + for name, p in _signature(target).parameters.items(): + if p.kind is inspect.Parameter.POSITIONAL_ONLY: + raise TypeError( + f"{_callable_path_of(target)} has positional-only param {name!r}; " + "create_model_spec only supports kwargs." + ) + + +# --------------------------------------------------------------------------- +# BaseSpec +# --------------------------------------------------------------------------- + + +class BaseSpec(BaseModel): + """Base class for JSON-serializable, no-pickle hyperparameter specs. + + Concrete spec classes are built dynamically by :func:`create_model_spec` + via :func:`pydantic.create_model`; each carries one field per + ``__init__`` kwarg of its target class plus the two metadata fields + defined here. + + Attributes + ---------- + cls_path + Dotted path (``"module.submodule.QualName"``) identifying the target + callable. Validated at assignment time by :func:`_import_callable`. + timestamp + ISO-8601 UTC timestamp recording when the spec was created. + + Notes + ----- + ``revalidate_instances="never"`` is deliberate: specs are immutable + records of past state; revalidating on access would reject any + already-typed field values (e.g. rehydrated :class:`torch.Tensor` + objects) that were stored through a :class:`~pydantic.BeforeValidator`. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + revalidate_instances="never", + ) + + cls_path: Annotated[ + str, + AfterValidator(_ensure_importable), + Field(description="Dotted import path of the target callable."), + ] + timestamp: Annotated[ + str, + Field(description="ISO-8601 UTC timestamp of spec creation."), + ] + + def accepts_kwarg(self, name: str) -> bool: + """Return whether the target callable accepts ``name`` as a keyword.""" + target = _import_callable(self.cls_path) + sig = _signature(target) + return name in sig.parameters or any( + p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + + def build(self, *args: Any, strict: bool = False, **extra_kwargs: Any) -> object: + """Invoke the target callable with the stored hyperparameters. + + Positional ``*args`` and ``**extra_kwargs`` inject runtime-only + values that cannot be serialized into the spec --- for example, + ``model.parameters()`` for an optimizer or an ``optimizer`` instance + for a learning-rate scheduler. + + Nested :class:`BaseSpec` field values are built recursively before + forwarding to the target constructor. Non-empty ``list``/``tuple`` + fields that contain :class:`BaseSpec` items are built item-wise, + preserving non-spec items and the container type. Nested + collections (e.g. ``list[list[BaseSpec]]``) are not traversed; + wrap them in a serializable spec object or flatten the + collection. A JSON round-trip preserves tuple-valued spec sequences + when the target constructor annotates the parameter as a tuple; + otherwise JSON arrays rehydrate as lists. + + Parameters + ---------- + *args + Positional arguments forwarded to the target class constructor + (runtime-only, not stored in the spec). + strict + Reserved for future use; currently a no-op retained to preserve + the public API. Accepts any value without effect. + **extra_kwargs + Extra keyword arguments forwarded to the target callable, overriding + any spec-stored kwargs of the same name. + + Returns + ------- + object + A freshly constructed object from the callable at :attr:`cls_path`. + + Raises + ------ + TypeError + If the target callable cannot be invoked with the resolved + kwargs. + """ + del strict # reserved for future use + target = _import_callable(self.cls_path) + sig = _signature(target) + resolved: dict[str, Any] = {} + for name in type(self).model_fields: + if name in _META_FIELDS: + continue + v = getattr(self, name) + # Nested spec: build unless target expects the spec itself. + if isinstance(v, BaseSpec): + param = sig.parameters.get(name) + ann = param.annotation if param is not None else None + wants_spec = isinstance(ann, type) and issubclass(ann, BaseSpec) + resolved[name] = v if wants_spec else v.build() + elif _is_basespec_sequence(v): + resolved[name] = _build_sequence_of_specs(v) + else: + resolved[name] = v + resolved.update(extra_kwargs) + try: + return target(*args, **resolved) + except TypeError as e: + raise TypeError( + f"Failed to build {self.cls_path} from spec " + f"(saved at {self.timestamp}): {e}. The callable signature " + "may have changed since the spec was created." + ) from e + + +# --------------------------------------------------------------------------- +# Type annotation resolution +# --------------------------------------------------------------------------- + + +def _try_deserialize(name: str, value: Any, sig: inspect.Signature) -> Any: + """Probe registered deserializers to rehydrate a raw JSON value. + + Returns the first successfully deserialized typed instance, or the + original ``value`` unchanged if no safe deserializer accepts it. This + covers the case where ``__init__`` has no annotation for well-known + parameters whose stored value is a serialized custom type (e.g. + ``torch.dtype`` as a str for a ``dtype`` parameter). + + Only tagged class dictionaries, unannotated ``dtype`` / ``device`` strings, + and tensor-shaped dicts are probed. Broad string deserializers such as raw + class dotted-path resolution are deliberately skipped here so ordinary + string fields remain strings. + """ + if not isinstance(value, (str, dict)): + return value + + param = sig.parameters.get(name) + sig_ann = param.annotation if param is not None else inspect.Parameter.empty + if sig_ann is not inspect.Parameter.empty and sig_ann is not Any: + return value + + deserializer: Any | None = None + if isinstance(value, str): + if name == "dtype": + deserializer = _TYPE_SERIALIZERS[torch.dtype][1] + elif name == "device": + deserializer = _TYPE_SERIALIZERS[torch.device][1] + elif _is_tagged_type(value): + deserializer = _deserialize_tagged_type + elif set(value) == {"data", "dtype", "shape"}: + deserializer = _TYPE_SERIALIZERS[torch.Tensor][1] + + if deserializer is None: + return value + + try: + return deserializer(value) + except (TypeError, ValueError, KeyError, AttributeError, RuntimeError): + return value + + +def _maybe_class_annotation(annotation: Any) -> Any | None: + """Return a dotted-path serializer annotation for class types if applicable.""" + if not _is_serializable_class_annotation(annotation): + return None + return _wrap_class_type_annotation(annotation) + + +def _maybe_registered_type_annotation(annotation: Any) -> Any | None: + """Return a serializer annotation for registered types and optional variants.""" + if annotation in _TYPE_SERIALIZERS: + return _wrap_custom_type(annotation) + args = get_args(annotation) + if len(args) != 2 or type(None) not in args: + return None + registered = [arg for arg in args if arg in _TYPE_SERIALIZERS] + if len(registered) != 1: + return None + return _wrap_custom_type(registered[0]) | None + + +def _expects_tuple_sequence(name: str, sig: inspect.Signature) -> bool: + """Return whether ``name`` is annotated as a tuple-valued parameter.""" + param = sig.parameters.get(name) + if param is None: + return False + annotation = param.annotation + return annotation is tuple or get_origin(annotation) is tuple + + +def _is_basespec_sequence(value: Any) -> bool: + """Return whether value is a non-empty list/tuple containing BaseSpec items.""" + return ( + isinstance(value, (list, tuple)) + and len(value) > 0 + and any(isinstance(v, BaseSpec) for v in value) + ) + + +def _is_spec_dict(value: Any) -> bool: + """Return whether value is a JSON-dict representation of a BaseSpec.""" + return isinstance(value, dict) and "cls_path" in value + + +def _is_spec_dict_sequence(value: Any) -> bool: + """Return whether value is a non-empty list containing spec-dicts.""" + return ( + isinstance(value, list) + and len(value) > 0 + and any(_is_spec_dict(v) for v in value) + ) + + +def _build_sequence_of_specs(value: Any) -> Any: + """Rebuild :class:`BaseSpec` items in a list/tuple, preserving other items.""" + return type(value)( + item.build() if isinstance(item, BaseSpec) else item for item in value + ) + + +def _rehydrate_spec_sequence( + name: str, + value: list[Any], + sig: inspect.Signature, +) -> list[Any] | tuple[Any, ...]: + """Rehydrate spec-dict items in a JSON list, preserving other items.""" + spec_items = [ + create_model_spec_from_json(item) + if _is_spec_dict(item) + else _try_deserialize(name, item, sig) + for item in value + ] + return tuple(spec_items) if _expects_tuple_sequence(name, sig) else spec_items + + +def _resolve_annotation(name: str, value: Any, sig: inspect.Signature) -> Any: + """Pick the Pydantic field annotation for ``(name, value)`` in ``sig``. + + Order of precedence: + + 1. ``value`` is a :class:`BaseSpec` → ``SerializeAsAny[BaseSpec]`` + (preserves the concrete dynamic schema under + :meth:`~pydantic.BaseModel.model_dump_json`). + 2. ``value`` is a non-empty ``list``/``tuple`` containing + :class:`BaseSpec` items → ``SerializeAsAny[list[Any]]`` or + ``SerializeAsAny[tuple[Any, ...]]``. This lets collection fields + (e.g. ``ComposedLossFunction.components`` and mixed scalar/spec + weight lists) round-trip by preserving each item's dynamic schema. + 3. The ``__init__`` signature annotates this parameter as a class type + (``type``, ``type[T]``, or optional variants) → wrap with dotted-path + class serialization hooks. + 4. The ``__init__`` signature annotates this parameter with a registered + custom type → wrap via :func:`_wrap_custom_type`. + 5. The ``__init__`` signature has any non-``Any`` annotation → use it. + 6. Otherwise infer from ``type(value)``; if the inferred type is in the + registry, wrap it; ``None`` values fall back to :class:`typing.Any`. + """ + if isinstance(value, BaseSpec): + return SerializeAsAny[BaseSpec] + + if _is_basespec_sequence(value): + return ( + SerializeAsAny[list[Any]] + if isinstance(value, list) + else SerializeAsAny[tuple[Any, ...]] + ) + + param = sig.parameters.get(name) + sig_ann = param.annotation if param is not None else inspect.Parameter.empty + has_sig_ann = sig_ann is not inspect.Parameter.empty and sig_ann is not Any + + if has_sig_ann: + class_annotation = _maybe_class_annotation(sig_ann) + if class_annotation is not None: + return class_annotation + registered_annotation = _maybe_registered_type_annotation(sig_ann) + if registered_annotation is not None: + return registered_annotation + if has_sig_ann: + return sig_ann + + if isinstance(value, type): + return SerializableTaggedClass + + vt = type(value) + if vt in _TYPE_SERIALIZERS: + return _wrap_custom_type(vt) + return vt if value is not None else Any + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def create_model_spec(target: Any, **kwargs: Any) -> BaseSpec: + """Build a :class:`BaseSpec` instance for ``target`` with the given kwargs. + + A new Pydantic model class is dynamically created via + :func:`pydantic.create_model`, one field per kwarg, each annotated by + :func:`_resolve_annotation`. The resulting spec is JSON-serializable + with :meth:`~pydantic.BaseModel.model_dump_json` and reconstructible + with :func:`create_model_spec_from_json`. + + Non-empty ``list``/``tuple`` kwargs containing :class:`BaseSpec` + items are annotated so each dynamic spec schema survives JSON dump + and rehydration, and :meth:`BaseSpec.build` then rebuilds each spec + item while preserving non-spec items. Empty collections are stored + as-is. Nested collections (e.g. ``list[list[BaseSpec]]``) are not + traversed; wrap them in a serializable spec object or flatten the + collection. A JSON round-trip preserves tuple-valued spec sequences + when the target constructor annotates the parameter as a tuple; + otherwise JSON arrays rehydrate as lists. + + Parameters + ---------- + target + The target importable callable. Must accept all ``**kwargs`` as keyword + arguments and must not declare any positional-only parameters. + **kwargs + Hyperparameters for ``target``. Registered types + (:class:`torch.Tensor`, :class:`torch.dtype`, :class:`torch.device`, + and any user-registered types) are handled via the type-serializer + registry. Other values must themselves be JSON-serializable by + Pydantic. + + Returns + ------- + BaseSpec + A dynamically subclassed :class:`BaseSpec` instance named + ``"{target.__name__}Spec"`` with one field per kwarg plus the two + metadata fields. + + Raises + ------ + TypeError + If ``target`` has positional-only parameters, or if ``**kwargs`` + contains names absent from the signature while the signature has no + ``**kwargs`` parameter. + + Examples + -------- + >>> import torch.nn as nn + >>> spec = create_model_spec(nn.Linear, in_features=8, out_features=4) + >>> module = spec.build() + >>> (module.in_features, module.out_features) + (8, 4) + """ + _check_no_positional_only(target) + sig = _signature(target) + + unknown = set(kwargs) - set(sig.parameters) + if unknown: + var_kw = any( + p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + if not var_kw: + raise TypeError( + f"Unknown kwargs for {_callable_path_of(target)}: {sorted(unknown)}" + ) + + fields: dict[str, tuple[Any, Any]] = {} + for name, value in kwargs.items(): + annotation = _resolve_annotation(name, value, sig) + fields[name] = (annotation, value) + + model_cls = create_model( + f"{getattr(target, '__name__', type(target).__name__)}Spec", + __base__=BaseSpec, + **fields, + ) + return model_cls( + cls_path=_callable_path_of(target), + timestamp=datetime.now(timezone.utc).isoformat(), + **kwargs, + ) + + +def create_model_spec_from_json(spec: dict[str, Any]) -> BaseSpec: + """Rebuild a :class:`BaseSpec` from its JSON-dict form. + + Recursively rehydrates nested specs (detected as values that are + :class:`dict` and contain a ``"cls_path"`` key). Lists of such dicts + are rehydrated item-wise, preserving the collection order. Pydantic's + :class:`~pydantic.BeforeValidator` hooks on registered types handle the + str → :class:`torch.dtype` / :class:`torch.device` / dict → + :class:`torch.Tensor` conversions transparently. + + The original ``timestamp`` is preserved via :func:`object.__setattr__` + rather than stamped fresh, so that a round-tripped spec remains + byte-identical (up to JSON-whitespace) with its source. + + Parameters + ---------- + spec + A :class:`dict` as produced by + :meth:`~pydantic.BaseModel.model_dump` or by + :func:`json.loads` on the output of + :meth:`~pydantic.BaseModel.model_dump_json`. + + Returns + ------- + BaseSpec + A spec instance equivalent to the source, with the original + ``timestamp`` preserved. + + Raises + ------ + ValueError + If ``spec`` is missing ``cls_path`` or ``timestamp``, or if + ``cls_path`` cannot be imported / resolves to a non-callable. The + underlying exception is preserved as ``__cause__``. + + Examples + -------- + >>> import json, torch.nn as nn + >>> s = create_model_spec(nn.Linear, in_features=4, out_features=2) + >>> dumped = json.loads(s.model_dump_json()) + >>> s2 = create_model_spec_from_json(dumped) + >>> s2.timestamp == s.timestamp + True + """ + schema = dict(spec) + try: + cls_path = schema.pop("cls_path") + stored_timestamp = schema.pop("timestamp") + except KeyError as e: + raise ValueError( + f"Spec JSON missing required field {e.args[0]!r}; " + f"present keys: {sorted(spec)}" + ) from e + + try: + target = _import_callable(cls_path) + except Exception as e: + raise ValueError( + f"Could not resolve cls_path={cls_path!r} while rehydrating spec JSON: {e}" + ) from e + + sig = _signature(target) + kwargs: dict[str, Any] = {} + for name, value in schema.items(): + if _is_spec_dict(value): + kwargs[name] = create_model_spec_from_json(value) + elif _is_spec_dict_sequence(value): + kwargs[name] = _rehydrate_spec_sequence(name, value, sig) + else: + # Eagerly deserialize safe unannotated custom forms (tagged class + # dicts, dtype/device strings, tensor dicts). This keeps raw + # importable strings as strings while preserving known structured + # serializer payloads. + kwargs[name] = _try_deserialize(name, value, sig) + + rebuilt = create_model_spec(target, **kwargs) + # Preserve original provenance rather than stamping a fresh timestamp. + object.__setattr__(rebuilt, "timestamp", stored_timestamp) + return rebuilt diff --git a/nvalchemi/training/_spec_utils.py b/nvalchemi/training/_spec_utils.py new file mode 100644 index 00000000..b443b3be --- /dev/null +++ b/nvalchemi/training/_spec_utils.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Serialization utilities for :class:`nvalchemi.training.strategy.TrainingStrategy`.""" + +from __future__ import annotations + +import importlib +import warnings +from collections.abc import Callable, Mapping +from typing import Any + +import torch + +from nvalchemi._serialization import _extract_init_kwargs_from_attrs +from nvalchemi.models.base import BaseModelMixin +from nvalchemi.training._spec import ( + BaseSpec, + create_model_spec, + create_model_spec_from_json, +) +from nvalchemi.training._strategy_validation import ModelInput, _normalize_models +from nvalchemi.training.losses.composition import ComposedLossFunction +from nvalchemi.training.optimizers import OptimizerConfig + + +def _resolve_dotted_callable(path: str) -> Callable[..., Any]: + """Resolve a dotted path ``"module.attribute"`` to a callable.""" + module_path, _, attr = path.rpartition(".") + if not module_path: + raise ValueError( + f"Cannot resolve training_fn from dotted path {path!r}: " + "expected 'module.attribute'." + ) + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as exc: + missing = exc.name or "" + if missing == module_path or module_path.startswith(f"{missing}."): + raise ValueError( + f"Cannot resolve training_fn from dotted path {path!r}: " + f"module {module_path!r} not found. Expected 'module.attribute'." + ) from exc + raise ValueError( + f"Imported module {module_path!r} failed while resolving " + f"training_fn {path!r}: missing transitive dependency " + f"{missing!r}. Install it or fix the import inside " + f"{module_path!r}." + ) from exc + except ImportError as exc: + raise ValueError( + f"Imported module {module_path!r} failed while resolving " + f"training_fn {path!r}: {exc}. Check imports and dependencies " + "inside that module." + ) from exc + try: + obj = getattr(module, attr) + except AttributeError as exc: + raise ValueError( + f"Cannot resolve training_fn from dotted path {path!r}: " + f"module {module_path!r} has no attribute {attr!r}." + ) from exc + if not callable(obj): + raise ValueError( + f"{path!r} resolves to {type(obj).__name__}, which is not callable." + ) + return obj + + +def _callable_dotted_path(fn: Callable[..., Any]) -> str: + """Return ``"module.name"`` for a module-level callable or raise ``ValueError``.""" + module = getattr(fn, "__module__", None) + qualname = getattr(fn, "__qualname__", None) + name = getattr(fn, "__name__", None) + if not module or not qualname: + raise ValueError( + f"training_fn is not serializable — {type(fn).__name__} " + "lacks __module__ / __qualname__. Only importable " + "module-level callables can be written to spec." + ) + if "" in qualname or "" in qualname: + raise ValueError( + f"training_fn is not serializable — {qualname!r} is a lambda " + "or local function. Only importable module-level callables " + "can be written to spec." + ) + if name is None or qualname != name: + raise ValueError( + f"training_fn is not serializable — {qualname!r} is not a " + "module-level callable (nested class/function or bound method). " + "Only importable module-level callables can be written to spec." + ) + return f"{module}.{qualname}" + + +def _model_specs_from_models( + models: dict[str, BaseModelMixin], +) -> dict[str, dict[str, Any]]: + """Best-effort ``BaseSpec`` dumps for importable model constructors.""" + specs: dict[str, dict[str, Any]] = {} + for key, model in models.items(): + try: + spec = _model_provided_checkpoint_spec(model) + validate_rebuild = spec is None + if spec is None: + spec = _module_spec_from_attrs(model) + if validate_rebuild: + rebuilt = create_model_spec_from_json(spec.model_dump()).build() + if not isinstance(rebuilt, BaseModelMixin): + raise TypeError( + f"rebuilt {type(rebuilt).__name__}, expected BaseModelMixin" + ) + rebuilt.to(torch.device("cpu")) + specs[key] = spec.model_dump() + except (TypeError, ValueError, AttributeError) as exc: + warnings.warn( + f"Omitting model spec for {key!r}: {exc}", + UserWarning, + stacklevel=2, + ) + return specs + + +def _model_provided_checkpoint_spec(module: torch.nn.Module) -> BaseSpec | None: + """Return an explicit model-provided checkpoint spec, if available.""" + if isinstance(module, torch.nn.parallel.DistributedDataParallel): + module = module.module + checkpoint_spec = getattr(module, "checkpoint_spec", None) + if not callable(checkpoint_spec): + return None + spec = checkpoint_spec() + if spec is None: + return None + if not isinstance(spec, BaseSpec): + raise TypeError( + "checkpoint_spec() must return a BaseSpec or None; got " + f"{type(spec).__name__}." + ) + return spec + + +def _module_spec_from_attrs(module: torch.nn.Module) -> BaseSpec: + """Build a recursive spec from constructor-matching module attributes.""" + if isinstance(module, torch.nn.parallel.DistributedDataParallel): + module = module.module + kwargs = _extract_init_kwargs_from_attrs(module) + for name, value in list(kwargs.items()): + if isinstance(value, torch.nn.Module): + kwargs[name] = _module_spec_from_attrs(value) + return create_model_spec(type(module), **kwargs) + + +def _models_from_spec_dict( + spec_models: Mapping[str, Any], +) -> dict[str, BaseModelMixin]: + """Build serialized model specs, omitting entries that fail to rebuild.""" + models: dict[str, BaseModelMixin] = {} + for key, raw in spec_models.items(): + if not isinstance(raw, Mapping): + warnings.warn( + f"Omitting model spec for {key!r}: expected BaseSpec dict, " + f"got {type(raw).__name__}.", + UserWarning, + stacklevel=2, + ) + continue + try: + model = create_model_spec_from_json(dict(raw)).build() + except (TypeError, ValueError, AttributeError) as exc: + warnings.warn( + f"Omitting model spec for {key!r}: {exc}", + UserWarning, + stacklevel=2, + ) + continue + if not isinstance(model, BaseModelMixin): + warnings.warn( + f"Omitting model spec for {key!r}: built " + f"{type(model).__name__}, expected BaseModelMixin.", + UserWarning, + stacklevel=2, + ) + continue + models[key] = model + return models + + +def _optimizer_configs_from_spec(raw: Any) -> dict[str, list[OptimizerConfig]]: + """Rebuild named optimizer configs from a serialized spec field.""" + if not isinstance(raw, Mapping): + raise ValueError( + "from_spec_dict: 'optimizer_configs' must be a mapping of " + f"str -> list[dict]; got {type(raw).__name__}." + ) + optimizer_configs: dict[str, list[OptimizerConfig]] = {} + for raw_key, entries in raw.items(): + if not isinstance(raw_key, str): + raise ValueError( + "from_spec_dict: 'optimizer_configs' keys must be strings; " + f"got key of type {type(raw_key).__name__}." + ) + if not isinstance(entries, list) or not all( + isinstance(entry, Mapping) for entry in entries + ): + raise ValueError( + f"from_spec_dict: 'optimizer_configs[{raw_key!r}]' must " + "be a list of OptimizerConfig spec dicts." + ) + key = "main" if raw_key == "0" else raw_key + optimizer_configs[key] = [ + OptimizerConfig.from_spec(create_model_spec_from_json(entry)) + for entry in entries + ] + return optimizer_configs + + +def _devices_from_spec(raw: Any) -> list[torch.device]: + """Rebuild device strings from a serialized spec field.""" + if not isinstance(raw, list) or not all(isinstance(device, str) for device in raw): + raise ValueError( + "from_spec_dict: 'devices' must be a list of device strings; " + f"got {type(raw).__name__}." + ) + return [torch.device(device) for device in raw] + + +def _loss_fn_from_spec(raw: Any) -> ComposedLossFunction: + """Rebuild the composed loss from a serialized spec field.""" + if not isinstance(raw, Mapping): + raise ValueError( + "from_spec_dict: 'loss_fn_spec' must be a BaseSpec dump dict; " + f"got {type(raw).__name__}." + ) + loss_fn = create_model_spec_from_json(raw).build() + if not isinstance(loss_fn, ComposedLossFunction): + raise ValueError( + f"loss_fn_spec built {type(loss_fn).__name__}, expected " + "ComposedLossFunction." + ) + return loss_fn + + +def _training_fn_from_spec( + spec: Mapping[str, Any], + override: Callable[..., Mapping[str, torch.Tensor]] | str | None, +) -> Callable[..., Mapping[str, torch.Tensor]] | str: + """Resolve runtime or serialized training function input.""" + if override is not None: + return override + raw = spec.get("training_fn") + if raw is None: + raise ValueError( + "from_spec_dict: no training_fn was supplied and the spec does " + "not contain one. Pass training_fn=... explicitly." + ) + if not isinstance(raw, str): + raise ValueError( + "from_spec_dict: 'training_fn' must be a dotted-path string " + f"('module.attribute'); got {type(raw).__name__}." + ) + return _resolve_dotted_callable(raw) + + +def _models_from_spec_and_overrides( + spec_models_raw: Any, + runtime_models: ModelInput | None, + *, + single_model_input: bool | None = None, +) -> ModelInput: + """Build spec models, apply runtime overrides, and preserve call mode.""" + if not isinstance(spec_models_raw, Mapping): + raise ValueError( + "from_spec_dict: 'model_specs' must be a mapping when present; " + f"got {type(spec_models_raw).__name__}." + ) + merged = _models_from_spec_dict(spec_models_raw) + if runtime_models is not None: + merged.update(_normalize_models(runtime_models)) + # Return shape intentionally preserves the public call-mode distinction: + # ``models=model`` means ``training_fn(model, batch)``, while + # ``models={"main": model}`` means ``training_fn(models, batch)``. + if isinstance(runtime_models, BaseModelMixin) and set(merged) == {"main"}: + return merged["main"] + if runtime_models is not None: + return merged + if single_model_input is True and set(merged) == {"main"}: + return merged["main"] + if single_model_input is False: + return merged + if set(merged) == {"main"}: + return merged["main"] + return merged + + +def _single_model_input_from_spec(raw: Any) -> bool | None: + """Return serialized call mode or ``None`` for legacy specs.""" + if raw is None: + return None + if not isinstance(raw, bool): + raise ValueError( + "from_spec_dict: 'single_model_input' must be a bool when present; " + f"got {type(raw).__name__}." + ) + return raw diff --git a/nvalchemi/training/_stages.py b/nvalchemi/training/_stages.py new file mode 100644 index 00000000..7841a3a2 --- /dev/null +++ b/nvalchemi/training/_stages.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training-lifecycle stage enum.""" + +from __future__ import annotations + +from enum import Enum, auto + +__all__ = ["TrainingStage"] + + +class TrainingStage(Enum): + """Stages of the training lifecycle at which hooks can fire. + + Parallel to :class:`nvalchemi.dynamics.base.DynamicsStage`, this enum + marks the points before and after each operation in a training run. + Members are paired ``BEFORE_*`` / ``AFTER_*`` around each lifecycle + event, from the once-per-run ``BEFORE_TRAINING`` / ``AFTER_TRAINING`` + outer pair down to the per-batch forward, loss, backward, and + optimizer-step phases. + + Attributes + ---------- + SETUP : TrainingStage + Fires once before optimizer construction, after runtime device + placement has been resolved. Setup hooks may mutate workflow state + such as model wrappers and dataloaders before training begins. + BEFORE_TRAINING : TrainingStage + Fires once before the epoch loop, after the model is on device + and optimizers are constructed. + BEFORE_EPOCH : TrainingStage + Fires at the start of each epoch, before the first batch. + BEFORE_BATCH : TrainingStage + Fires at the start of each batch, before the default gradient + zeroing path. A training-update orchestrator may claim this stage + to decide whether zeroing should run for the batch. + BEFORE_FORWARD : TrainingStage + Fires before the model forward pass. + AFTER_FORWARD : TrainingStage + Fires after the model forward pass; predictions are available. + BEFORE_LOSS : TrainingStage + Fires before the loss computation. + AFTER_LOSS : TrainingStage + Fires after the loss computation; the loss tensor is populated. + BEFORE_BACKWARD : TrainingStage + Fires before the backward pass. + DO_BACKWARD : TrainingStage + Replacement slot for the backward pass. At most one hook may claim + this stage; when claimed, ``TrainingStrategy`` skips its default + ``loss.backward()`` and the claiming hook is responsible for + performing (and scaling, if needed) the backward. Observers should + use ``BEFORE_BACKWARD``/``AFTER_BACKWARD``. + AFTER_BACKWARD : TrainingStage + Fires after the backward pass has made gradients available; typical + slot for gradient clipping or gradient-norm logging. + BEFORE_OPTIMIZER_STEP : TrainingStage + Fires immediately before the optimizer step and remains distinct from + ``AFTER_BACKWARD`` as the public last pre-step point; typical slot for + observers that need to see unscaled gradients (see ``DO_BACKWARD``). + DO_OPTIMIZER_STEP : TrainingStage + Replacement slot for the optimizer and LR-scheduler step. At most + one hook may claim this stage; when claimed, ``TrainingStrategy`` + skips its default optimizer and scheduler stepping and the claiming + hook must step each optimizer in ``ctx.optimizers`` (and its + corresponding scheduler if present). Observers should use + ``BEFORE_OPTIMIZER_STEP``/``AFTER_OPTIMIZER_STEP``. + AFTER_OPTIMIZER_STEP : TrainingStage + Fires after the optimizer and scheduler step path completes; + typical slot for EMA updates, skip-aware training updates, and + post-step logging. + AFTER_BATCH : TrainingStage + Fires at the end of each batch for generic batch cleanup, distinct + from optimizer-step-aware ``AFTER_OPTIMIZER_STEP`` hooks. + AFTER_EPOCH : TrainingStage + Fires at the end of each epoch, after the last batch. + AFTER_TRAINING : TrainingStage + Fires once after the final epoch. + AFTER_VALIDATION : TrainingStage + Fires from inside ``TrainingStrategy.validate()`` immediately after a + validation pass produces its summary and before any metric-driven LR + schedulers consume it. Because validation runs at multiple cadences + (step, epoch, and once at end of training), this is an event-defined + stage rather than a fixed loop position; it is the reliable slot for + loggers and observers that need the latest validation summary + (available via ``ctx.workflow.last_validation``). + """ + + SETUP = auto() + BEFORE_TRAINING = auto() + BEFORE_EPOCH = auto() + BEFORE_BATCH = auto() + BEFORE_FORWARD = auto() + AFTER_FORWARD = auto() + BEFORE_LOSS = auto() + AFTER_LOSS = auto() + BEFORE_BACKWARD = auto() + DO_BACKWARD = auto() + AFTER_BACKWARD = auto() + BEFORE_OPTIMIZER_STEP = auto() + DO_OPTIMIZER_STEP = auto() + AFTER_OPTIMIZER_STEP = auto() + AFTER_BATCH = auto() + AFTER_EPOCH = auto() + AFTER_TRAINING = auto() + AFTER_VALIDATION = auto() diff --git a/nvalchemi/training/_strategy_validation.py b/nvalchemi/training/_strategy_validation.py new file mode 100644 index 00000000..083cee47 --- /dev/null +++ b/nvalchemi/training/_strategy_validation.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Validation helpers for :mod:`nvalchemi.training.strategy`.""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable, Mapping +from typing import Any, TypeAlias, get_origin, get_type_hints + +from torch.nn import ModuleDict + +from nvalchemi.models.base import BaseModelMixin + +ModelInput: TypeAlias = BaseModelMixin | dict[str, BaseModelMixin] | ModuleDict +_TRAINING_FN_REQUIRED_MESSAGE = ( + "training_fn must be provided explicitly. To opt into the stock " + "single-model behavior, use `from nvalchemi.training import " + "default_training_fn` or `from nvalchemi.training.strategy import " + "default_training_fn`." +) + + +def _normalize_models(value: Any) -> Any: + """Normalize model inputs to a plain named-model dict.""" + if isinstance(value, BaseModelMixin): + return {"main": value} + if isinstance(value, ModuleDict): + value = dict(value.items()) + if isinstance(value, dict): + invalid = { + key: type(model).__name__ + for key, model in value.items() + if not isinstance(model, BaseModelMixin) + } + if invalid: + raise ValueError( + "models must map names to BaseModelMixin instances; " + f"invalid entries: {invalid}." + ) + return dict(value) + return value + + +def _callable_accepts_two_args(fn: Callable[..., Any]) -> bool: + """Return whether ``fn`` can be called with exactly two positional args.""" + sig = inspect.signature(fn) + try: + sig.bind(object(), object()) + except TypeError: + return False + return True + + +def _first_parameter_annotation(fn: Callable[..., Any]) -> Any: + """Return the first parameter annotation, resolving type hints when possible.""" + sig = inspect.signature(fn) + try: + first = next(iter(sig.parameters.values())) + except StopIteration: + return inspect.Parameter.empty + try: + hints = get_type_hints(fn) + except (NameError, TypeError, AttributeError): + hints = getattr(fn, "__annotations__", {}) + return hints.get(first.name, first.annotation) + + +def _is_mapping_model_annotation(annotation: Any) -> bool: + """Return whether annotation clearly means named model mapping.""" + if annotation in (Any, inspect.Parameter.empty): + return False + origin = get_origin(annotation) + if origin is dict or origin is Mapping: + args = getattr(annotation, "__args__", ()) + return len(args) == 2 and args[0] is str and _is_model_annotation(args[1]) + try: + return isinstance(annotation, type) and issubclass(annotation, ModuleDict) + except TypeError: + return False + + +def _is_model_annotation(annotation: Any) -> bool: + """Return whether annotation clearly means ``BaseModelMixin`` or subclass.""" + if annotation in (Any, inspect.Parameter.empty): + return False + try: + return isinstance(annotation, type) and issubclass(annotation, BaseModelMixin) + except TypeError: + return False + + +def _validate_training_fn_call_shape( + fn: Callable[..., Any], + *, + single_model_input: bool, +) -> None: + """Validate ``training_fn`` arity and obvious first-argument mismatches.""" + if not _callable_accepts_two_args(fn): + raise ValueError( + "training_fn must accept exactly the two arguments " + "(model_or_models, batch) without requiring additional args." + ) + annotation = _first_parameter_annotation(fn) + if single_model_input and _is_mapping_model_annotation(annotation): + raise ValueError( + "single-model strategies call training_fn(model, batch), but the " + "first parameter is annotated as a model mapping." + ) + if not single_model_input and _is_model_annotation(annotation): + raise ValueError( + "named-model strategies call training_fn(models, batch), but the " + "first parameter is annotated as a single BaseModelMixin. Pass " + "models=model for single-model behavior, or define " + "training_fn(models: dict[str, BaseModelMixin], batch)." + ) diff --git a/nvalchemi/training/_validation.py b/nvalchemi/training/_validation.py new file mode 100644 index 00000000..9ce1fe23 --- /dev/null +++ b/nvalchemi/training/_validation.py @@ -0,0 +1,991 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Validation configuration, shared helpers, and the :class:`ValidationLoop` orchestrator. + +This module contains :class:`ValidationConfig`, :class:`ValidationLoop`, +and the low-level utilities used by +:meth:`~nvalchemi.training.TrainingStrategy.validate` validation passes. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +from collections.abc import Callable, Iterable, Mapping +from contextlib import AbstractContextManager +from types import TracebackType +from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, runtime_checkable + +import torch +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PlainValidator, + field_validator, + model_validator, +) +from torch import nn + +from nvalchemi.data import Batch +from nvalchemi.training.distributed import ( + all_reduce as distributed_all_reduce, +) +from nvalchemi.training.distributed import ( + get_rank as get_distributed_rank, +) +from nvalchemi.training.distributed import ( + is_distributed_initialized, +) +from nvalchemi.training.losses.composition import ( + ComposedLossFunction, + ComposedLossOutput, + as_composed_loss, + compute_supervised_loss, +) + +if TYPE_CHECKING: + from nvalchemi.training.strategy import TrainingStrategy + +__all__ = ["BatchValidationCallback", "ValidationConfig", "ValidationLoop"] + + +@runtime_checkable +class BatchValidationCallback(Protocol): + """Protocol for an optional per-batch validation callback. + + A user-supplied object implementing this protocol is invoked once per + validation batch inside :meth:`ValidationLoop.execute`, immediately + after predictions and the per-batch loss are computed. It is the + extension point for streaming per-batch outputs (e.g. predictions or + diagnostics) to a custom logging or storage system. + + Summary-level logging does not require this callback: register a hook + on :attr:`~nvalchemi.training.TrainingStage.AFTER_VALIDATION` and read + the validation summary from ``ctx.validation``. + + Notes + ----- + No concrete implementation is provided. Users supply their own. + """ + + def __call__( + self, + *, + batch: Batch, + predictions: Mapping[str, torch.Tensor], + loss: ComposedLossOutput, + batch_count: int, + step_count: int, + epoch: int, + ) -> None: + """Consume one validation batch's predictions and loss. + + Parameters + ---------- + batch : Batch + The validation batch that was evaluated. + predictions : Mapping[str, torch.Tensor] + The output of the validation function for this batch. + loss : ComposedLossOutput + The per-batch composed loss output. + batch_count : int + Zero-based index of this batch within the validation pass. + step_count : int + Training step count at which this validation pass runs. + epoch : int + Training epoch at which this validation pass runs. + """ + ... + + +def _ensure_reiterable_validation_data(value: Any) -> Any: + """Reject one-shot iterators so validation can restart each pass. + + Parameters + ---------- + value : Any + Candidate ``validation_data``. Must be a re-iterable container + (e.g. ``list``, ``DataLoader``, ``Dataset``) whose ``__iter__`` + returns a fresh iterator each call. + + Returns + ------- + Any + The value unchanged when it is re-iterable. + + Raises + ------ + ValueError + When ``value`` is not iterable at all, or when it is a one-shot + iterator (e.g. a generator) that cannot be re-iterated across + repeated validation passes. + """ + try: + iterator = iter(value) + except TypeError as exc: + raise ValueError( + "validation_data must be iterable (e.g. a list, DataLoader, or " + f"Dataset of Batch); got {type(value).__name__}." + ) from exc + if iterator is value: + raise ValueError( + "validation_data must be a re-iterable container, not a one-shot " + "iterator/generator. Validation runs multiple times and must " + "restart from the beginning each pass; pass a list (or a " + "re-iterable DataLoader/Dataset) instead of a generator." + ) + return value + + +class ValidationConfig(BaseModel): + """Configuration for strategy-owned validation passes. + + ``ValidationConfig`` is a plain data object consumed by + ``TrainingStrategy.validate()`` via :class:`ValidationLoop`. + It does NOT drive hook dispatch — the strategy reads it directly. + + Attributes + ---------- + validation_data : Iterable[Batch] + Re-iterable container (e.g. ``list``, ``DataLoader``, ``Dataset``) + yielding :class:`~nvalchemi.data.Batch` instances. The strategy + re-iterates this on every validation pass; one-shot generators + and bare iterators are rejected at construction time. + validation_fn : Callable | None + Validation forward callable. ``None`` means use the strategy's + ``training_fn`` with the same single-model or named-model call + convention. + loss_fn : ComposedLossFunction | None + Validation loss function. ``None`` means use the strategy's + ``loss_fn``. Leaf losses are auto-normalized to a + :class:`ComposedLossFunction` via :func:`as_composed_loss`. + every_n_epochs : int | None + Run validation after every *n*-th completed epoch. Mutually + exclusive with ``every_n_steps``. + every_n_steps : int | None + Run validation after every *n*-th completed optimizer step. + Mutually exclusive with ``every_n_epochs``. + grad_mode : {"auto", "enabled", "disabled"} + Autograd policy during validation. ``"auto"`` enables gradients + when any loss component has ``requires_eval_grad=True`` and + disables them when all components report ``False``. + set_eval : bool + If ``True``, set validation modules to eval mode and restore + their original training modes afterward. + use_ema : {"auto", "always", "never"} + Whether the strategy's ``inference_model`` slot (populated by + EMA) should replace live training weights for validation. + use_mixed_precision : {"auto", "always", "never"} + Whether to reuse a registered :class:`MixedPrecisionHook` + autocast context for validation inference. + batch_callback : BatchValidationCallback | None + Optional user-supplied callable invoked once per validation + batch with the batch, predictions, and per-batch loss output. + Use it to stream per-sample diagnostics to a custom logging or + storage backend. ``None`` disables per-batch callbacks. For + epoch-level (summary) logging, register a hook on the + ``AFTER_VALIDATION`` stage and read ``ctx.validation`` instead. + name : str + Name stored in the validation summary dictionary. + """ + + validation_data: Annotated[ + Iterable[Batch], PlainValidator(_ensure_reiterable_validation_data) + ] + validation_fn: Callable[..., Any] | None = None + loss_fn: ComposedLossFunction | None = None + every_n_epochs: int | None = Field(default=None, ge=1) + every_n_steps: int | None = Field(default=None, ge=1) + grad_mode: Literal["auto", "enabled", "disabled"] = "auto" + set_eval: bool = True + use_ema: Literal["auto", "always", "never"] = "auto" + use_mixed_precision: Literal["auto", "always", "never"] = "auto" + batch_callback: BatchValidationCallback | None = None + name: str = Field(default="validation", min_length=1) + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + @field_validator("loss_fn", mode="before") + @classmethod + def _normalize_loss_fn(cls, value: Any) -> ComposedLossFunction | None: + """Normalize a leaf loss into a one-component composed loss.""" + return None if value is None else as_composed_loss(value) + + @model_validator(mode="after") + def _validate_schedule(self) -> ValidationConfig: + """Enforce mutual exclusion of ``every_n_epochs`` and ``every_n_steps``.""" + if self.every_n_epochs is not None and self.every_n_steps is not None: + raise ValueError("Only one of every_n_epochs or every_n_steps may be set.") + return self + + +# ------------------------------------------------------------------ +# Shared validation utilities +# ------------------------------------------------------------------ + + +def _unique_modules(modules: Iterable[nn.Module]) -> tuple[nn.Module, ...]: + """Return unique modules while preserving first-seen order.""" + seen: set[int] = set() + unique: list[nn.Module] = [] + for module in modules: + if id(module) in seen: + continue + seen.add(id(module)) + unique.append(module) + return tuple(unique) + + +def _module_training_modes( + modules: Iterable[nn.Module], +) -> dict[int, tuple[nn.Module, bool]]: + """Snapshot unique module training modes for later restoration.""" + modes: dict[int, tuple[nn.Module, bool]] = {} + for module in modules: + if id(module) not in modes: + modes[id(module)] = (module, module.training) + return modes + + +def _snapshot_parameter_grads( + modules: Iterable[nn.Module], +) -> dict[int, tuple[nn.Parameter, torch.Tensor | None]]: + """Clone current parameter gradients so validation can restore them.""" + snapshot: dict[int, tuple[nn.Parameter, torch.Tensor | None]] = {} + for module in modules: + for parameter in module.parameters(): + if id(parameter) in snapshot: + continue + grad = parameter.grad + snapshot[id(parameter)] = ( + parameter, + None if grad is None else grad.detach().clone(), + ) + return snapshot + + +def _clear_parameter_grads(modules: Iterable[nn.Module]) -> None: + """Clear parameter gradients on validation modules.""" + for module in modules: + for parameter in module.parameters(): + parameter.grad = None + + +def _restore_parameter_grads( + snapshot: Mapping[int, tuple[nn.Parameter, torch.Tensor | None]], +) -> None: + """Restore parameter gradients captured by :func:`_snapshot_parameter_grads`.""" + for parameter, grad in snapshot.values(): + parameter.grad = grad + + +def _tensor_to_cpu(value: torch.Tensor) -> torch.Tensor: + """Detach a scalar summary tensor and move it to CPU.""" + return value.detach().cpu() + + +def _as_float64_scalar(value: torch.Tensor, device: torch.device) -> torch.Tensor: + """Detach ``value`` and return a scalar float64 tensor on ``device``.""" + return value.detach().to(device=device, dtype=torch.float64).reshape(-1).sum() + + +class _LossAccumulator: + """Accumulate composed-loss diagnostics over validation batches.""" + + def __init__(self, device: torch.device) -> None: + self.device = device + self.batch_count = 0 + self.total_sum: torch.Tensor | None = None + self.per_component_unweighted_sum: dict[str, torch.Tensor] = {} + self.per_component_sample_sum: dict[str, torch.Tensor] = {} + self.per_component_sample_count: dict[str, int] = {} + self.per_component_weight: dict[str, float] = {} + self.per_component_raw_weight: dict[str, float] = {} + + def update(self, loss_out: ComposedLossOutput) -> None: + """Add one batch's loss output to the running totals.""" + self.batch_count += 1 + total = loss_out["total_loss"].detach() + self.total_sum = total if self.total_sum is None else self.total_sum + total + for name, value in loss_out["per_component_unweighted"].items(): + detached = value.detach() + previous = self.per_component_unweighted_sum.get(name) + self.per_component_unweighted_sum[name] = ( + detached if previous is None else previous + detached + ) + for name, sample in loss_out["per_component_sample"].items(): + detached_sum = sample.detach().sum() + previous = self.per_component_sample_sum.get(name) + self.per_component_sample_sum[name] = ( + detached_sum if previous is None else previous + detached_sum + ) + self.per_component_sample_count[name] = ( + self.per_component_sample_count.get(name, 0) + sample.numel() + ) + self.per_component_weight = dict(loss_out["per_component_weight"]) + self.per_component_raw_weight = dict(loss_out["per_component_raw_weight"]) + + def summary( + self, + *, + name: str, + model_source: str, + ema_model_keys: tuple[str, ...], + precision: str, + publish: bool, + distributed_manager: Any | None = None, + ) -> dict[str, Any] | None: + """Return the local or distributed-reduced validation summary.""" + if self.batch_count == 0 or self.total_sum is None: + raise ValueError("validation_data produced no batches.") + + component_keys = tuple(sorted(self.per_component_unweighted_sum)) + sample_keys = tuple(sorted(self.per_component_sample_sum)) + values = [ + _as_float64_scalar(self.total_sum, self.device), + torch.tensor( + float(self.batch_count), device=self.device, dtype=torch.float64 + ), + ] + values.extend( + _as_float64_scalar(self.per_component_unweighted_sum[key], self.device) + for key in component_keys + ) + for key in sample_keys: + values.append( + _as_float64_scalar(self.per_component_sample_sum[key], self.device) + ) + values.append( + torch.tensor( + float(self.per_component_sample_count[key]), + device=self.device, + dtype=torch.float64, + ) + ) + packed = torch.stack(values) + distributed_reduced = _distributed_sum_in_place(packed, distributed_manager) + if not publish: + return None + + index = 0 + total_sum = packed[index] + index += 1 + batch_count = packed[index] + index += 1 + reduced_batch_count = int(batch_count.item()) + + per_component_unweighted: dict[str, torch.Tensor] = {} + for key in component_keys: + per_component_unweighted[key] = _tensor_to_cpu(packed[index] / batch_count) + index += 1 + + per_component_sample: dict[str, torch.Tensor] = {} + sample_counts: dict[str, int] = {} + for key in sample_keys: + sample_sum = packed[index] + index += 1 + sample_count = packed[index] + index += 1 + sample_counts[key] = int(sample_count.item()) + per_component_sample[key] = _tensor_to_cpu(sample_sum / sample_count) + + return { + "name": name, + "total_loss": _tensor_to_cpu(total_sum / batch_count), + "per_component_unweighted": per_component_unweighted, + "per_component_weight": dict(self.per_component_weight), + "per_component_raw_weight": dict(self.per_component_raw_weight), + "per_component_sample": per_component_sample, + "num_batches": reduced_batch_count, + "per_component_sample_count": sample_counts, + "model_source": model_source, + "ema_model_keys": list(ema_model_keys), + "precision": precision, + "distributed_reduced": distributed_reduced, + } + + +def _distributed_sum_in_place( + value: torch.Tensor, distributed_manager: Any | None +) -> bool: + """All-reduce ``value`` when distributed communication is active.""" + if not is_distributed_initialized(distributed_manager): + return False + distributed_all_reduce(value, distributed_manager) + return True + + +# ------------------------------------------------------------------ +# Internal context accessor for ValidationLoop +# ------------------------------------------------------------------ + + +@dataclasses.dataclass +class _LoopContext: + """Snapshot of counters and handles consumed by :class:`ValidationLoop`. + + Attributes + ---------- + step_count : int + Current optimizer step count. + epoch : int + Current epoch count. + distributed_manager : Any | None + Distributed manager handle. + num_models : int + Total number of models in the workflow. + """ + + step_count: int + epoch: int + distributed_manager: Any | None + num_models: int + + +def _resolve_grad_from_config( + config: ValidationConfig, + loss_fn: ComposedLossFunction, +) -> bool: + """Resolve the autograd policy from a :class:`ValidationConfig`. + + Parameters + ---------- + config : ValidationConfig + Validation configuration containing the ``grad_mode`` policy. + loss_fn : ComposedLossFunction + The resolved validation loss function used to infer gradient + requirements when ``grad_mode='auto'``. + + Returns + ------- + bool + ``True`` when validation should run with gradients enabled. + """ + if config.grad_mode == "enabled": + return True + if config.grad_mode == "disabled": + return False + return loss_fn.requires_eval_grad() + + +def _resolve_model_arg( + strategy: TrainingStrategy, + config: ValidationConfig, +) -> tuple[Any, tuple[nn.Module, ...], tuple[str, ...]]: + """Resolve the model argument for a strategy-integrated validation pass. + + Reads the strategy-owned ``inference_model`` slot and falls back + to live training models for keys not covered by the slot. + + Parameters + ---------- + strategy : TrainingStrategy + The training strategy owning the validation pass. + config : ValidationConfig + The resolved validation configuration. + + Returns + ------- + tuple[Any, tuple[nn.Module, ...], tuple[str, ...]] + A three-element tuple: + + * **model_arg** -- The value passed to the validation forward + callable. A single :class:`nn.Module` for single-model + strategies, or a ``dict[str, ...]`` for named-model + strategies. + * **modules** -- All unique :class:`nn.Module` instances + participating in the forward pass (for training-mode + management). + * **ema_keys** -- Sorted tuple of model keys that were + sourced from the ``inference_model`` slot rather than + live training weights. + + Raises + ------ + RuntimeError + When ``use_ema='always'`` and the ``inference_model`` slot + cannot satisfy the requirement (empty slot or missing keys). + """ + use_ema = config.use_ema + slot = strategy.inference_model + + if use_ema == "never": + slot = None + + if use_ema == "always" and slot is None: + raise RuntimeError( + "ValidationConfig use_ema='always' requires a populated " + "inference_model slot (e.g. via EMAHook)." + ) + + if strategy.single_model_input: + live = strategy.models["main"] + if isinstance(slot, nn.Module) and not isinstance(slot, nn.ModuleDict): + model = slot + ema_keys: tuple[str, ...] = ("main",) + else: + model = live + ema_keys = () + return model, (model,), ema_keys + + # Named-model path + resolved: dict[str, Any] = dict(strategy.models) + used_ema_keys: list[str] = [] + + if isinstance(slot, nn.ModuleDict): + for key in list(slot.keys()): + if key in resolved: + resolved[key] = slot[key] + used_ema_keys.append(key) + elif isinstance(slot, nn.Module): + if "main" in resolved: + resolved["main"] = slot + used_ema_keys.append("main") + + if use_ema == "always": + missing = sorted(set(resolved) - set(used_ema_keys)) + if missing: + raise RuntimeError( + "ValidationConfig use_ema='always' requires the " + "inference_model slot to cover every model key; " + f"missing: {missing}." + ) + + modules = tuple( + value for value in resolved.values() if isinstance(value, nn.Module) + ) + return resolved, _unique_modules(modules), tuple(sorted(used_ema_keys)) + + +# ------------------------------------------------------------------ +# ValidationLoop — public context-manager orchestrator +# ------------------------------------------------------------------ + + +class ValidationLoop: + """Context-manager orchestrator for a single validation pass. + + ``ValidationLoop`` encapsulates the full validation lifecycle — + setup, per-batch forward + loss accumulation, distributed summary + reduction, sink writes, and teardown — in a single reusable object. + + Two construction paths are supported: + + * **Standalone** via :meth:`__init__`: caller provides all + dependencies explicitly. No strategy or hook scanning. + * **Strategy-integrated** via :meth:`from_training_strategy`: + reads capabilities through strategy introspection and holds + a live reference for counter/model access during ``execute()``. + + Usage:: + + with ValidationLoop.from_training_strategy(strategy) as loop: + summary = loop.execute() + + Parameters + ---------- + validation_data : Iterable[Batch] + Re-iterable object yielding validation batches. + config : ValidationConfig + Validation configuration. + device : torch.device + Primary device for the validation pass. + model : nn.Module | None + Single model for single-model validation. Mutually exclusive + with ``models``. + models : dict[str, nn.Module] | None + Named models for named-model validation. Mutually exclusive + with ``model``. + loss_fn : ComposedLossFunction | None + Validation loss function. Falls back to ``config.loss_fn`` + when ``None``. + validation_fn : Callable[..., Any] | None + Validation forward callable. Required in standalone mode. + inference_model : nn.Module | nn.ModuleDict | None + Optional EMA/inference model to swap in during validation. + autocast : Callable[[], AbstractContextManager[None]] | None + Precision context factory. ``None`` uses + :func:`contextlib.nullcontext` and precision label ``"float32"``. + grad_enabled : bool | None + Autograd policy. ``None`` infers from ``config.grad_mode`` + and ``loss_fn.requires_eval_grad()``. + distributed_manager : Any | None + Optional distributed manager for all-reduce and barrier ops. + step_count : int + Optimizer step counter for sink metadata. + epoch : int + Epoch counter for sink metadata. + + Raises + ------ + ValueError + When both or neither of ``model``/``models`` are supplied, + or when required arguments (``loss_fn``, ``validation_fn``) + are missing. + """ + + def __init__( + self, + *, + validation_data: Iterable[Batch], + config: ValidationConfig, + device: torch.device, + model: nn.Module | None = None, + models: dict[str, nn.Module] | None = None, + loss_fn: ComposedLossFunction | None = None, + validation_fn: Callable[..., Any] | None = None, + inference_model: nn.Module | nn.ModuleDict | None = None, + autocast: Callable[[], AbstractContextManager[None]] | None = None, + grad_enabled: bool | None = None, + distributed_manager: Any | None = None, + step_count: int = 0, + epoch: int = 0, + ) -> None: + have_model = model is not None + have_models = models is not None + if have_model == have_models: + raise ValueError("Exactly one of 'model' or 'models' must be provided.") + + resolved_loss_fn = loss_fn if loss_fn is not None else config.loss_fn + if resolved_loss_fn is None: + raise ValueError( + "loss_fn must be provided either directly or via " + "config.loss_fn in standalone mode." + ) + resolved_loss_fn = as_composed_loss(resolved_loss_fn) + + if validation_fn is None: + raise ValueError("validation_fn is required in standalone mode.") + + if autocast is not None: + self._precision_context = autocast + self._precision = "mixed" + else: + self._precision_context: Callable[[], AbstractContextManager[None]] = ( + contextlib.nullcontext + ) + self._precision = "float32" + + if grad_enabled is None: + grad_enabled = _resolve_grad_from_config(config, resolved_loss_fn) + + self._validation_data = validation_data + self._config = config + self._device = device + self._loss_fn = resolved_loss_fn + self._validation_fn = validation_fn + self._grad_enabled = grad_enabled + + # Resolve model_arg, modules, ema_model_keys for standalone path + if have_model: + assert model is not None # noqa: S101 # narrowing + self._single_model_input = True + ema_keys: tuple[str, ...] = () + if ( + inference_model is not None + and isinstance(inference_model, nn.Module) + and not isinstance(inference_model, nn.ModuleDict) + ): + effective_model = inference_model + ema_keys = ("main",) + else: + effective_model = model + self._model_arg: Any = effective_model + self._modules = _unique_modules((effective_model,)) + self._ema_model_keys = ema_keys + self._num_models = 1 + else: + assert models is not None # noqa: S101 # narrowing + self._single_model_input = False + resolved: dict[str, Any] = dict(models) + used_ema_keys: list[str] = [] + if isinstance(inference_model, nn.ModuleDict): + for key in list(inference_model.keys()): + if key in resolved: + resolved[key] = inference_model[key] + used_ema_keys.append(key) + elif isinstance(inference_model, nn.Module): + if "main" in resolved: + resolved["main"] = inference_model + used_ema_keys.append("main") + mods = tuple(v for v in resolved.values() if isinstance(v, nn.Module)) + self._model_arg = resolved + self._modules = _unique_modules(mods) + self._ema_model_keys = tuple(sorted(used_ema_keys)) + self._num_models = len(models) + + # Standalone context: fixed values + self._strategy: TrainingStrategy | None = None + self._standalone_context = _LoopContext( + step_count=step_count, + epoch=epoch, + distributed_manager=distributed_manager, + num_models=self._num_models, + ) + self._successful = False + self._entered = False + self._modes: dict[int, tuple[nn.Module, bool]] = {} + self._grad_snapshot: dict[int, tuple[nn.Parameter, torch.Tensor | None]] = {} + + @classmethod + def from_training_strategy( + cls, + strategy: TrainingStrategy, + config: ValidationConfig | None = None, + ) -> ValidationLoop: + """Build a :class:`ValidationLoop` from a :class:`TrainingStrategy`. + + Reads capabilities through the strategy's introspection methods + and holds a live reference for counter/model access during + :meth:`execute`. + + Parameters + ---------- + strategy : TrainingStrategy + The training strategy owning the validation pass. + config : ValidationConfig | None + Override validation config. ``None`` uses + ``strategy.validation_config``. + + Returns + ------- + ValidationLoop + A loop instance ready to be used as a context manager. + + Raises + ------ + RuntimeError + When ``strategy.validation_config`` is ``None`` and no + ``config`` override is provided. + """ + resolved_config = config if config is not None else strategy.validation_config + if resolved_config is None: + raise RuntimeError( + "ValidationLoop.from_training_strategy() requires a " + "validation_config on the strategy or as an argument." + ) + + device = strategy.devices[0] + + # -- loss resolution (was _resolve_validation_loss_fn) -- + if resolved_config.loss_fn is not None: + loss_fn = resolved_config.loss_fn + else: + loss_fn = as_composed_loss(strategy.loss_fn) + + validation_fn = resolved_config.validation_fn or strategy.training_fn + + # -- grad resolution (was _resolve_validation_grad) -- + grad_enabled = _resolve_grad_from_config(resolved_config, loss_fn) + + # -- model resolution (was _validation_model_arg) -- + model_arg, modules, ema_model_keys = _resolve_model_arg( + strategy, resolved_config + ) + + precision_context, precision = strategy._inference_autocast(device) + + loop = cls.__new__(cls) + loop._validation_data = resolved_config.validation_data + loop._config = resolved_config + loop._device = device + loop._loss_fn = loss_fn + loop._validation_fn = validation_fn + loop._grad_enabled = grad_enabled + loop._precision_context = precision_context + loop._precision = precision + loop._model_arg = model_arg + loop._modules = _unique_modules(modules) + loop._ema_model_keys = ema_model_keys + loop._single_model_input = strategy.single_model_input + loop._num_models = len(strategy.models) + loop._strategy = strategy + loop._standalone_context = None + loop._successful = False + loop._entered = False + loop._modes = {} + loop._grad_snapshot = {} + return loop + + def _context(self) -> _LoopContext: + """Return live counters and handles for the current execution. + + Returns + ------- + _LoopContext + Context snapshot. Strategy-integrated loops read live + values from the held strategy reference; standalone loops + return stored values. + """ + if self._strategy is not None: + return _LoopContext( + step_count=self._strategy.step_count, + epoch=self._strategy.epoch_count, + distributed_manager=self._strategy.distributed_manager, + num_models=len(self._strategy.models), + ) + assert self._standalone_context is not None # noqa: S101 # narrowing + return self._standalone_context + + def __enter__(self) -> ValidationLoop: + """Set up the validation pass. + + Snapshots training modes, sets eval mode (if configured), and + snapshots and clears parameter gradients (if grad-enabled). + + Returns + ------- + ValidationLoop + The loop handle. + """ + # Snapshot + set eval + self._modes = _module_training_modes(self._modules) + if self._config.set_eval: + for module, _training in self._modes.values(): + module.eval() + + # Snapshot + clear grads + if self._grad_enabled: + self._grad_snapshot = _snapshot_parameter_grads(self._modules) + _clear_parameter_grads(self._modules) + + self._entered = True + self._successful = False + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: + """Tear down the validation pass. + + Restores parameter gradients (if grad-enabled) and restores + module training modes (if ``set_eval``). + + Returns ``False`` so exceptions are not suppressed. + + Parameters + ---------- + exc_type : type[BaseException] | None + Exception type, if any. + exc_val : BaseException | None + Exception instance, if any. + exc_tb : TracebackType | None + Exception traceback, if any. + + Returns + ------- + bool + Always ``False``. + """ + try: + # Grad restore + if self._grad_enabled: + _clear_parameter_grads(self._modules) + _restore_parameter_grads(self._grad_snapshot) + + # Training mode restore + if self._config.set_eval: + for module, training in self._modes.values(): + module.train(training) + finally: + self._entered = False + return False + + def execute(self) -> dict[str, Any] | None: + """Run the validation loop over all batches and return the summary. + + Iterates ``validation_data``, runs the forward pass and loss + computation per batch, invokes the optional per-batch callback, + accumulates results, computes the distributed-reduced summary, + and returns the summary dictionary. + + Returns + ------- + dict[str, Any] | None + The validation summary on rank 0, ``None`` on + non-publishing distributed ranks. + + Raises + ------ + RuntimeError + When called outside the context manager. + ValueError + When ``validation_data`` produces no batches. + """ + if not self._entered: + raise RuntimeError( + "ValidationLoop.execute() must be called inside a 'with' block." + ) + + ctx = self._context() + device = self._device + accumulator = _LossAccumulator(device) + + # Per-batch loop + for batch_count, batch in enumerate(self._validation_data): + validation_batch = batch.to(device, non_blocking=True) + if self._grad_enabled: + _clear_parameter_grads(self._modules) + grad_ctx = torch.enable_grad() if self._grad_enabled else torch.no_grad() + with grad_ctx, self._precision_context(): + predictions = self._validation_fn(self._model_arg, validation_batch) + loss_out = compute_supervised_loss( + self._loss_fn, + predictions, + validation_batch, + step=ctx.step_count, + epoch=ctx.epoch, + batch_label="Validation batch", + ) + accumulator.update(loss_out) + # call the per-batch callback; this allows for user-defined operations + # on the scope, e.g. log as much as you'd like + if self._config.batch_callback is not None: + self._config.batch_callback( + batch=validation_batch, + predictions=predictions, + loss=loss_out, + batch_count=batch_count, + step_count=ctx.step_count, + epoch=ctx.epoch, + ) + + # Build summary + num_models = ctx.num_models + model_source = ( + "ema" + if (self._ema_model_keys and len(self._ema_model_keys) == num_models) + else "mixed" + if self._ema_model_keys + else "live" + ) + summary = accumulator.summary( + name=self._config.name, + model_source=model_source, + ema_model_keys=self._ema_model_keys, + precision=self._precision, + publish=get_distributed_rank(ctx.distributed_manager) == 0, + distributed_manager=ctx.distributed_manager, + ) + + self._successful = True + return summary diff --git a/nvalchemi/training/distributed.py b/nvalchemi/training/distributed.py new file mode 100644 index 00000000..ee45870c --- /dev/null +++ b/nvalchemi/training/distributed.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Structural helpers for distributed training managers. + +This module intentionally does not define a concrete manager class. Phase-2 +training can accept a manager supplied by another package while retaining a +``torch.distributed`` fallback for local tests and ``torchrun`` launches. +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any + +import torch +from torch import distributed as dist + +if TYPE_CHECKING: + from nvalchemi.distributed import DistributedManager + +__all__ = [ + "all_reduce", + "barrier", + "destroy_distributed", + "distributed_device", + "get_local_rank", + "get_rank", + "get_world_size", + "init_distributed", + "is_distributed_initialized", +] + + +def _read_attr_or_call(manager: Any, *names: str) -> Any: + """Return the first manager attribute or zero-arg method result found.""" + for name in names: + if not hasattr(manager, name): + continue + value = getattr(manager, name) + if callable(value): + try: + return value() + except TypeError: + continue + return value + return None + + +def _call_manager(manager: Any, *names: str, **kwargs: Any) -> bool: + """Call the first matching manager method and report whether one ran.""" + for name in names: + method = getattr(manager, name, None) + if not callable(method): + continue + try: + method(**kwargs) + except TypeError: + method() + return True + return False + + +def _env_int(name: str, default: int) -> int: + """Read an integer torchrun environment variable.""" + value = os.environ.get(name) + if value is None: + return default + try: + return int(value) + except ValueError: + return default + + +def is_distributed_initialized(manager: DistributedManager | None = None) -> bool: + """Return whether distributed communication is initialized.""" + if manager is not None: + value = _read_attr_or_call( + manager, + "is_initialized", + "initialized", + "is_distributed_initialized", + ) + if value is not None: + return bool(value) + return dist.is_available() and dist.is_initialized() + + +def get_rank(manager: DistributedManager | None = None) -> int: + """Return the global process rank.""" + if manager is not None: + value = _read_attr_or_call(manager, "global_rank", "rank", "get_rank") + if value is not None: + return int(value) + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return _env_int("RANK", 0) + + +def get_world_size(manager: DistributedManager | None = None) -> int: + """Return the distributed world size.""" + if manager is not None: + value = _read_attr_or_call(manager, "world_size", "get_world_size") + if value is not None: + return int(value) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return _env_int("WORLD_SIZE", 1) + + +def get_local_rank(manager: DistributedManager | None = None) -> int: + """Return the process-local rank.""" + if manager is not None: + value = _read_attr_or_call(manager, "local_rank", "get_local_rank") + if value is not None: + return int(value) + if dist.is_available() and dist.is_initialized(): + try: + return int(dist.get_node_local_rank()) + except (AttributeError, RuntimeError): + pass + return _env_int("LOCAL_RANK", 0) + + +def distributed_device( + manager: DistributedManager | None, + fallback: torch.device | str, + *, + prefer_cuda: bool = True, +) -> torch.device: + """Resolve the device for the current rank.""" + if manager is not None: + value = _read_attr_or_call(manager, "device", "get_device") + if value is not None: + return torch.device(value) + fallback_device = torch.device(fallback) + if prefer_cuda and torch.cuda.is_available(): + return torch.device("cuda", get_local_rank(manager)) + return fallback_device + + +def init_distributed( + manager: DistributedManager | None = None, + *, + backend: str | None = None, + **kwargs: Any, +) -> bool: + """Initialize distributed communication and return whether this call did so.""" + if is_distributed_initialized(manager): + return False + if manager is not None: + return _call_manager( + manager, + "init_process_group", + "initialize", + "init", + "setup", + backend=backend, + **kwargs, + ) + if get_world_size(None) <= 1: + return False + resolved_backend = backend or ("nccl" if torch.cuda.is_available() else "gloo") + dist.init_process_group(backend=resolved_backend, **kwargs) + return True + + +def destroy_distributed(manager: DistributedManager | None = None) -> bool: + """Destroy distributed communication if possible.""" + if manager is not None: + return _call_manager( + manager, + "destroy_process_group", + "destroy", + "cleanup", + "teardown", + ) + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() + return True + return False + + +def barrier(manager: DistributedManager | None = None) -> None: + """Synchronize all ranks when distributed communication is initialized.""" + if manager is not None and _call_manager(manager, "barrier"): + return + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +def all_reduce( + tensor: torch.Tensor, + manager: DistributedManager | None = None, + *, + op: dist.ReduceOp = dist.ReduceOp.SUM, +) -> torch.Tensor: + """All-reduce ``tensor`` in place and return it.""" + if manager is not None: + method = getattr(manager, "all_reduce", None) + if callable(method): + result = method(tensor, op=op) + return tensor if result is None else result + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(tensor, op=op) + return tensor diff --git a/nvalchemi/training/hooks/__init__.py b/nvalchemi/training/hooks/__init__.py new file mode 100644 index 00000000..7efc8690 --- /dev/null +++ b/nvalchemi/training/hooks/__init__.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training hooks bundled with :mod:`nvalchemi.training`.""" + +from __future__ import annotations + +from nvalchemi.hooks import TorchProfilerHook +from nvalchemi.training.hooks.checkpoint import CheckpointHook +from nvalchemi.training.hooks.ddp import DDPHook +from nvalchemi.training.hooks.ema import EMAHook +from nvalchemi.training.hooks.mixed_precision import MixedPrecisionHook +from nvalchemi.training.hooks.update import ( + TrainingUpdateHook, + TrainingUpdateOrchestrator, +) + +__all__ = [ + "CheckpointHook", + "DDPHook", + "EMAHook", + "MixedPrecisionHook", + "TorchProfilerHook", + "TrainingUpdateHook", + "TrainingUpdateOrchestrator", +] diff --git a/nvalchemi/training/hooks/checkpoint.py b/nvalchemi/training/hooks/checkpoint.py new file mode 100644 index 00000000..4129f0b7 --- /dev/null +++ b/nvalchemi/training/hooks/checkpoint.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Periodic checkpoint-saving training hook.""" + +from __future__ import annotations + +from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path +from types import TracebackType +from typing import Annotated, ClassVar + +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator + +from nvalchemi.hooks._context import TrainContext +from nvalchemi.training._checkpoint import ( + _create_checkpoint_snapshot, + _write_checkpoint_snapshot, +) +from nvalchemi.training._stages import TrainingStage + +__all__ = ["CheckpointHook"] + + +class CheckpointHook(BaseModel): + """Periodically save restartable training strategy checkpoints. + + The hook observes completed training counters and saves + :class:`~nvalchemi.training.strategy.TrainingStrategy` checkpoints through + the same manifest layout as :func:`nvalchemi.training.save_checkpoint`. + It fires either every ``step_interval`` completed optimizer steps or every + ``epoch_interval`` completed epochs. The two cadences are mutually + exclusive so each hook owns one clear checkpoint policy. + + With ``async_save=True`` (default), the hook first captures an immutable + CPU snapshot of model, optimizer, scheduler, and strategy metadata on the + training thread, then writes that snapshot on a single background thread. + This avoids racing against live training tensors while still moving the + filesystem work off the critical path. If a later checkpoint is due while + the previous background write is still running, the hook waits for the + previous write before capturing the next snapshot so manifest indices stay + ordered. + + Parameters + ---------- + checkpoint_dir : Path | str + Directory where checkpoint manifests and component state files are + written. + step_interval : int | None, optional + Save every N completed optimizer steps. Skipped optimizer steps do not + advance this cadence. Exactly one of ``step_interval`` or + ``epoch_interval`` must be provided. + epoch_interval : int | None, optional + Save every N completed epochs. Exactly one of ``step_interval`` or + ``epoch_interval`` must be provided. + async_save : bool, optional + If ``True``, write captured snapshots on a background thread. If + ``False``, write synchronously during hook dispatch. Default ``True``. + rank_zero_only : bool, optional + If ``True``, only distributed rank 0 writes checkpoints. Default + ``True``. + + Attributes + ---------- + last_checkpoint_index : int | None + Most recent checkpoint index known to have been written. In async mode, + this updates when the background future completes. + + Raises + ------ + ValueError + If neither interval is provided, or an interval is not positive. + RuntimeError + If the hook is called without a strategy workflow in ``TrainContext``. + + Examples + -------- + >>> from nvalchemi.training import CheckpointHook, TrainingStrategy + >>> hook = CheckpointHook("runs/example/checkpoints", step_interval=1000) + >>> strategy = TrainingStrategy(..., hooks=[hook]) # doctest: +SKIP + >>> strategy.run(train_loader) # doctest: +SKIP + """ + + checkpoint_dir: Annotated[ + Path, + Field(description="Root directory for restartable training checkpoints."), + ] + step_interval: Annotated[ + int | None, + Field(default=None, gt=0, description="Completed-step save interval."), + ] = None + epoch_interval: Annotated[ + int | None, + Field(default=None, gt=0, description="Completed-epoch save interval."), + ] = None + async_save: Annotated[ + bool, + Field(description="Write checkpoint snapshots on a background thread."), + ] = True + rank_zero_only: Annotated[ + bool, + Field(description="Restrict checkpoint writes to distributed rank 0."), + ] = True + last_checkpoint_index: Annotated[ + int | None, + Field(default=None, ge=0, exclude=True), + ] = None + + frequency: ClassVar[int] = 1 + stage: ClassVar[TrainingStage | None] = None + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + _executor: ThreadPoolExecutor | None = PrivateAttr(default=None) + _future: Future[int] | None = PrivateAttr(default=None) + + def __init__( + self, checkpoint_dir: Path | str | None = None, **data: object + ) -> None: + """Initialize the hook, accepting ``checkpoint_dir`` positionally.""" + if checkpoint_dir is not None: + if "checkpoint_dir" in data: + raise TypeError( + "CheckpointHook got checkpoint_dir both positionally and " + "as a keyword argument." + ) + data["checkpoint_dir"] = checkpoint_dir + super().__init__(**data) + + @model_validator(mode="after") + def _validate_cadence(self) -> CheckpointHook: + """Require exactly one save cadence.""" + if self.epoch_interval and self.step_interval: + raise ValueError( + "CheckpointHook requires exactly one of step_interval or " + "epoch_interval." + ) + return self + + def _runs_on_stage(self, stage: TrainingStage) -> bool: + """Return whether this hook observes a training stage.""" + return ( + self.step_interval is not None and stage is TrainingStage.AFTER_BATCH + ) or (self.epoch_interval is not None and stage is TrainingStage.AFTER_EPOCH) + + def __enter__(self) -> CheckpointHook: + """Create the background writer when async checkpointing is enabled.""" + if self.async_save and self._executor is None: + self._executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="nvalchemi-checkpoint", + ) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Flush any pending checkpoint write before leaving training.""" + del exc, tb + try: + self.close() + except Exception: + if exc_type is None: + raise + + def close(self) -> None: + """Wait for pending async writes and close the background writer.""" + try: + self._finish_pending(block=True) + finally: + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None + + def _finish_pending(self, *, block: bool) -> None: + """Collect a pending async result, optionally waiting for it.""" + if self._future is None: + return + if not block and not self._future.done(): + return + self.last_checkpoint_index = self._future.result() + self._future = None + + def _should_save(self, ctx: TrainContext, stage: TrainingStage) -> bool: + """Return whether ``ctx`` reaches the configured save cadence.""" + if self.rank_zero_only and ctx.global_rank != 0: + return False + if ( + stage is TrainingStage.AFTER_BATCH + and self.step_interval is not None + and ctx.step_count > 0 + ): + return ctx.step_count % self.step_interval == 0 + if ( + stage is TrainingStage.AFTER_EPOCH + and self.epoch_interval is not None + and ctx.epoch > 0 + ): + return ctx.epoch % self.epoch_interval == 0 + return False + + def _save_checkpoint(self, ctx: TrainContext) -> None: + """Capture and write one strategy checkpoint.""" + if ctx.workflow is None: + raise RuntimeError( + "CheckpointHook requires TrainContext.workflow to reference " + "the active TrainingStrategy." + ) + self._finish_pending(block=False) + if self._future is not None: + self._finish_pending(block=True) + + snapshot = _create_checkpoint_snapshot( + self.checkpoint_dir, + strategy=ctx.workflow, + ) + if not self.async_save: + self.last_checkpoint_index = _write_checkpoint_snapshot( + self.checkpoint_dir, + snapshot, + ) + return + + if self._executor is None: + raise RuntimeError( + "CheckpointHook async writer is not initialized. Run it through " + "TrainingStrategy so hook contexts are entered, or call " + "__enter__() before invoking the hook directly." + ) + self._future = self._executor.submit( + _write_checkpoint_snapshot, + self.checkpoint_dir, + snapshot, + ) + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + """Save a checkpoint when the configured cadence is reached.""" + if self._should_save(ctx, stage): + self._save_checkpoint(ctx) diff --git a/nvalchemi/training/hooks/ddp.py b/nvalchemi/training/hooks/ddp.py new file mode 100644 index 00000000..f6b18b75 --- /dev/null +++ b/nvalchemi/training/hooks/ddp.py @@ -0,0 +1,432 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DistributedDataParallel setup hook for training strategies.""" + +from __future__ import annotations + +from collections.abc import Callable +from inspect import Parameter, signature +from typing import TYPE_CHECKING, Any, ClassVar + +import torch +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from torch.utils.data import DistributedSampler, RandomSampler + +from nvalchemi.data.datapipes.samplers import DistributedSamplerProtocol +from nvalchemi.hooks._context import TrainContext +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training.distributed import ( + destroy_distributed, + distributed_device, + get_rank, + get_world_size, + init_distributed, + is_distributed_initialized, +) + +if TYPE_CHECKING: + from collections.abc import Iterable + + from nvalchemi.data.batch import Batch + from nvalchemi.distributed import DistributedManager + from nvalchemi.training.strategy import TrainingStrategy + +__all__ = ["DDPHook"] + + +def _manager_process_group(manager: DistributedManager | None) -> Any: + """Return a process group exposed by a structural manager, if any.""" + if manager is None: + return None + for name in ("process_group", "group", "get_process_group"): + if not hasattr(manager, name): + continue + value = getattr(manager, name) + if callable(value): + try: + return value() + except TypeError: + continue + return value + return None + + +def _sampler_is_distributed( + sampler: Any, sampler_cls: Callable[..., Any] = DistributedSampler +) -> bool: + """Return whether ``sampler`` is already a configured distributed sampler.""" + if isinstance(sampler, DistributedSamplerProtocol): + return True + return isinstance(sampler_cls, type) and isinstance(sampler, sampler_cls) + + +def _accepts_distributed_sampler_defaults(sampler_cls: Callable[..., Any]) -> bool: + """Return whether a sampler factory accepts PyTorch distributed kwargs.""" + if sampler_cls is DistributedSampler or ( + isinstance(sampler_cls, type) and issubclass(sampler_cls, DistributedSampler) + ): + return True + try: + parameters = signature(sampler_cls).parameters + except (TypeError, ValueError): + return False + if any( + parameter.kind is Parameter.VAR_KEYWORD for parameter in parameters.values() + ): + return True + return {"num_replicas", "rank"}.issubset(parameters) + + +def _infer_shuffle(dataloader: Any, configured: bool | None) -> bool: + """Infer sampler shuffling from the original dataloader when unspecified.""" + if configured is not None: + return configured + return isinstance(getattr(dataloader, "sampler", None), RandomSampler) + + +class DDPHook(BaseModel): + """Wrap training models with ``DistributedDataParallel`` at setup time. + + ``DDPHook`` is a standard training hook that runs at + :attr:`~nvalchemi.training.TrainingStage.SETUP`. It initializes + ``torch.distributed`` from torchrun environment variables when needed, + optionally uses ``TrainingStrategy.distributed_manager`` for rank/device + metadata, wraps selected models in + :class:`torch.nn.parallel.DistributedDataParallel`, and injects the + configured distributed sampler into dataloaders with ``dataset`` and + ``sampler`` attributes. + + Parameters + ---------- + model_keys : tuple[str, ...] | None, optional + Named models to wrap. ``None`` wraps all models that have optimizer + configs. + find_unused_parameters : bool | None, optional + Forwarded to ``DistributedDataParallel``. ``None`` uses the external + manager's setting when present, otherwise ``False``. + broadcast_buffers : bool | None, optional + Forwarded to ``DistributedDataParallel``. ``None`` uses the external + manager's setting when present, otherwise ``False``. + static_graph : bool, optional + Forwarded to ``DistributedDataParallel``. + process_group : Any, optional + Explicit process group. Defaults to a process group exposed by the + external distributed manager or PyTorch's default group. + backend : str | None, optional + Backend used when this hook initializes ``torch.distributed``. + auto_init : bool, optional + If ``True``, initialize ``torch.distributed`` when ``WORLD_SIZE > 1`` + and no manager/process group has already initialized communication. + sampler_cls : Callable[..., Any], optional + Sampler class or factory used for supported dataloaders. The callable is + invoked as ``sampler_cls(dataset, **sampler_kwargs)``. The default is + :class:`torch.utils.data.DistributedSampler`. + sampler_kwargs : dict[str, Any], optional + Keyword arguments forwarded to ``sampler_cls``. For the default + ``DistributedSampler`` and sampler callables that accept PyTorch's + distributed sampler keywords, missing ``num_replicas``, ``rank``, + ``shuffle``, ``seed``, and ``drop_last`` values are inferred from the + manager and dataloader before user-provided kwargs are applied. + """ + + model_keys: tuple[str, ...] | None = None + find_unused_parameters: bool | None = None + broadcast_buffers: bool | None = None + static_graph: bool = False + process_group: Any | None = None + backend: str | None = None + auto_init: bool = True + sampler_cls: Callable[..., Any] = DistributedSampler + sampler_kwargs: dict[str, Any] = Field(default_factory=dict) + + frequency: ClassVar[int] = 1 + stage: ClassVar[TrainingStage] = TrainingStage.SETUP + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=False, + extra="forbid", + ) + + _original_models: dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict) + _initialized_process_group: bool = PrivateAttr(default=False) + _manager: DistributedManager | None = PrivateAttr(default=None) + _strategy: Any | None = PrivateAttr(default=None) + _is_wrapped: bool = PrivateAttr(default=False) + + def prepare_strategy(self, strategy: TrainingStrategy) -> None: + """Prepare rank/device state before the strategy moves models.""" + manager = strategy.distributed_manager + self._manager = manager + if self.auto_init: + self._initialized_process_group = init_distributed( + manager, + backend=self.backend, + ) + world_size = get_world_size(manager) + if world_size <= 1: + return + device = distributed_device( + manager, + strategy.devices[0], + prefer_cuda=self.backend != "gloo", + ) + if device.type == "cuda": + torch.cuda.set_device(device) + strategy.devices = [device] + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + """Run DDP setup when the strategy dispatches ``TrainingStage.SETUP``.""" + if stage is not TrainingStage.SETUP: + return + strategy = ctx.workflow + if strategy is None: + raise RuntimeError("DDPHook requires a TrainContext.workflow.") + self._wrap_models(strategy) + strategy.active_dataloader = self.prepare_dataloader(strategy.active_dataloader) + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any, + ) -> None: + """Restore original models and clean up process groups owned by this hook.""" + self.close() + + def close(self) -> None: + """Restore wrapped models and destroy process group if this hook created it.""" + if self._original_models: + strategy = self._strategy + for key, model in self._original_models.items(): + if strategy is not None: + strategy.models[key] = model + self._original_models.clear() + self._strategy = None + self._is_wrapped = False + if self._initialized_process_group: + destroy_distributed(self._manager) + self._initialized_process_group = False + + def _target_model_keys(self, strategy: TrainingStrategy) -> tuple[str, ...]: + """Return model keys this hook should wrap.""" + if self.model_keys is not None: + keys = self.model_keys + else: + keys = tuple(strategy.optimizer_configs) + missing = [key for key in keys if key not in strategy.models] + if missing: + raise KeyError( + f"DDPHook model_keys include unknown model(s) {missing}; " + f"available model keys: {sorted(strategy.models)}." + ) + return keys + + def _wrap_models(self, strategy: TrainingStrategy) -> None: + """Wrap selected strategy models in DistributedDataParallel.""" + if self._is_wrapped: + return + manager = strategy.distributed_manager + world_size = get_world_size(manager) + initialized = is_distributed_initialized(manager) + if world_size <= 1: + return + if not initialized: + raise RuntimeError( + "DDPHook requires initialized distributed communication when " + "world_size > 1. Launch with torchrun, initialize " + "torch.distributed before strategy.run(), or provide an " + "initialized distributed_manager." + ) + + process_group = self.process_group or _manager_process_group(manager) + self._strategy = strategy + for key in self._target_model_keys(strategy): + model = strategy.models[key] + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + continue + self._original_models[key] = model + strategy.models[key] = self._build_ddp(model, process_group) + self._is_wrapped = True + + def _build_ddp( + self, + model: torch.nn.Module, + process_group: Any | None, + ) -> torch.nn.parallel.DistributedDataParallel: + """Construct a DDP wrapper for ``model``.""" + kwargs: dict[str, Any] = { + "find_unused_parameters": self._resolve_ddp_flag( + "find_unused_parameters", + default=False, + ), + "broadcast_buffers": self._resolve_ddp_flag( + "broadcast_buffers", + default=False, + ), + "static_graph": self.static_graph, + } + if process_group is not None: + kwargs["process_group"] = process_group + device = next(model.parameters()).device + if device.type == "cuda": + device_index = 0 if device.index is None else device.index + kwargs["device_ids"] = [device_index] + kwargs["output_device"] = device_index + return torch.nn.parallel.DistributedDataParallel(model, **kwargs) + + def _resolve_ddp_flag(self, name: str, *, default: bool) -> bool: + """Resolve a DDP boolean option from hook field, manager, or default.""" + value = getattr(self, name) + if value is not None: + return bool(value) + if self._manager is not None and hasattr(self._manager, name): + return bool(getattr(self._manager, name)) + return default + + def prepare_dataloader( + self, + dataloader: Iterable[Batch] | None, + ) -> Iterable[Batch] | None: + """Inject the configured sampler into dataloaders that expose one.""" + if dataloader is None: + return None + manager = self._manager + world_size = get_world_size(manager) + if world_size <= 1: + return dataloader + if not hasattr(dataloader, "sampler"): + return dataloader + if not hasattr(dataloader, "dataset"): + raise ValueError( + "DDPHook cannot inject a distributed sampler into a dataloader " + "with no dataset attribute." + ) + + sampler = getattr(dataloader, "sampler", None) + if _sampler_is_distributed(sampler, self.sampler_cls): + return dataloader + nested_sampler = getattr( + getattr(dataloader, "batch_sampler", None), "sampler", None + ) + if _sampler_is_distributed(nested_sampler, self.sampler_cls): + return dataloader + + drop_last = self._dataloader_drop_last(dataloader) + sampler = self._build_sampler(dataloader, drop_last=drop_last) + if self._assign_dataloader_sampler(dataloader, sampler): + return dataloader + return self._rebuild_dataloader_with_sampler( + dataloader, + sampler, + drop_last=drop_last, + ) + + def _uses_distributed_sampler_defaults(self) -> bool: + """Return whether sampler construction should apply torch defaults.""" + return _accepts_distributed_sampler_defaults(self.sampler_cls) + + def _build_sampler_kwargs( + self, dataloader: Any, *, drop_last: bool + ) -> dict[str, Any]: + """Return kwargs for the configured sampler class or factory.""" + kwargs: dict[str, Any] = {} + if self._uses_distributed_sampler_defaults(): + manager = self._manager + configured_shuffle = self.sampler_kwargs.get("shuffle") + kwargs.update( + { + "num_replicas": get_world_size(manager), + "rank": get_rank(manager), + "shuffle": _infer_shuffle(dataloader, configured_shuffle), + "seed": 0, + "drop_last": drop_last, + } + ) + kwargs.update(self.sampler_kwargs) + return kwargs + + def _build_sampler(self, dataloader: Any, *, drop_last: bool) -> Any: + """Create the configured distributed sampler for ``dataloader``.""" + return self.sampler_cls( + dataloader.dataset, + **self._build_sampler_kwargs(dataloader, drop_last=drop_last), + ) + + def _dataloader_drop_last(self, dataloader: Any) -> bool: + """Infer whether the dataloader drops incomplete batches.""" + batch_sampler = getattr(dataloader, "batch_sampler", None) + if hasattr(batch_sampler, "drop_last"): + return bool(batch_sampler.drop_last) + return bool(getattr(dataloader, "drop_last", False)) + + def _assign_dataloader_sampler(self, dataloader: Any, sampler: Any) -> bool: + """Try to assign ``sampler`` directly to ``dataloader.sampler``.""" + try: + dataloader.sampler = sampler + except (AttributeError, ValueError): + return False + return getattr(dataloader, "sampler", None) is sampler + + def _rebuild_dataloader_with_sampler( + self, + dataloader: Any, + sampler: Any, + *, + drop_last: bool, + ) -> Any: + """Return a replacement dataloader when the sampler attribute is immutable.""" + if getattr(dataloader, "batch_size", None) is None: + raise ValueError( + "DDPHook cannot inject DistributedSampler into a DataLoader " + "constructed with batch_sampler. Pass a distributed-aware " + "batch_sampler instead." + ) + kwargs: dict[str, Any] = { + "batch_size": dataloader.batch_size, + "sampler": sampler, + "drop_last": drop_last, + } + for name in ( + "num_workers", + "collate_fn", + "pin_memory", + "timeout", + "worker_init_fn", + "generator", + "persistent_workers", + ): + if hasattr(dataloader, name): + kwargs[name] = getattr(dataloader, name) + if hasattr(dataloader, "multiprocessing_context"): + multiprocessing_context = getattr(dataloader, "multiprocessing_context") + if multiprocessing_context is not None: + kwargs["multiprocessing_context"] = multiprocessing_context + if getattr(dataloader, "num_workers", 0) > 0: + prefetch_factor = getattr(dataloader, "prefetch_factor", None) + if prefetch_factor is not None: + kwargs["prefetch_factor"] = prefetch_factor + pin_memory_device = getattr(dataloader, "pin_memory_device", "") + if pin_memory_device: + kwargs["pin_memory_device"] = pin_memory_device + if hasattr(dataloader, "in_order"): + kwargs["in_order"] = dataloader.in_order + try: + return type(dataloader)(dataloader.dataset, **kwargs) + except TypeError as exc: + raise ValueError( + "DDPHook could not assign dataloader.sampler and could not " + "rebuild the dataloader with the configured sampler." + ) from exc diff --git a/nvalchemi/training/hooks/ema.py b/nvalchemi/training/hooks/ema.py new file mode 100644 index 00000000..d5fa02bb --- /dev/null +++ b/nvalchemi/training/hooks/ema.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Exponential-moving-average (EMA) training hook.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Annotated, Any, ClassVar + +import torch +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, StringConstraints +from torch import nn +from torch.optim.swa_utils import AveragedModel, get_ema_multi_avg_fn + +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training.hooks.update import TrainingUpdateHook + +if TYPE_CHECKING: + import torch + + from nvalchemi.hooks._context import TrainContext + + +__all__ = ["EMAHook"] + + +def _unwrap_model(m: nn.Module) -> nn.Module: + """Returns a nested module if it exists, otherwise no-op""" + return m.module if hasattr(m, "module") else m + + +def _module_tensors(module: nn.Module) -> dict[str, torch.Tensor]: + """Return registered parameters and buffers by name.""" + tensors = { + name: param + for name, param in module.named_parameters(recurse=True, remove_duplicate=False) + } + tensors.update( + { + name: buffer + for name, buffer in module.named_buffers( + recurse=True, remove_duplicate=False + ) + } + ) + return tensors + + +def _align_tensor_to_source(tensor: torch.Tensor, source: torch.Tensor) -> None: + """Align a registered tensor to the source tensor's device and dtype.""" + dtype = source.dtype if tensor.is_floating_point() else tensor.dtype + if tensor.device == source.device and tensor.dtype == dtype: + return + with torch.no_grad(): + tensor.data = tensor.data.to(device=source.device, dtype=dtype) + if tensor.grad is not None: + tensor.grad.data = tensor.grad.data.to(device=source.device, dtype=dtype) + + +def _align_to_source_tensors( + target: nn.Module, source_tensors: Mapping[str, torch.Tensor] +) -> None: + """Align target parameters and buffers to their corresponding source tensors.""" + for name, param in target.named_parameters(recurse=True, remove_duplicate=False): + if name in source_tensors: + _align_tensor_to_source(param, source_tensors[name]) + for name, buffer in target.named_buffers(recurse=True, remove_duplicate=False): + if name in source_tensors: + _align_tensor_to_source(buffer, source_tensors[name]) + + +class EMAHook(BaseModel, TrainingUpdateHook): + """Hook maintaining an exponential moving average of a training model. + + Runs through :class:`~nvalchemi.training.hooks.TrainingUpdateOrchestrator` + and updates at :attr:`TrainingStage.AFTER_OPTIMIZER_STEP`. It lazily builds a + :class:`~torch.optim.swa_utils.AveragedModel` wrapped around + ``ctx.models[model_key]`` on the first eligible step, and updates it + via :func:`~torch.optim.swa_utils.get_ema_multi_avg_fn` — no manual + parameter arithmetic. The hook is a pure observer: it never calls + ``backward()``, touches gradients, drives any optimizer / scheduler / + ``GradScaler``, or mutates ``ctx.models``. If an earlier update hook + vetoes :attr:`TrainingStage.DO_OPTIMIZER_STEP`, the orchestrator passes + ``will_skip=True`` and EMA does not update on that batch. + + Access the averaged wrapper via :meth:`get_averaged_model`, which raises + a :class:`RuntimeError` if no eligible step has yet triggered lazy + initialization. A ``device``/``dtype`` field is omitted by design; after + :class:`~torch.optim.swa_utils.AveragedModel` deep-copies the source, + EMAHook aligns each averaged parameter and buffer to the corresponding + source tensor's device and floating-point dtype. This keeps generated or + monkey-patched modules whose deepcopy/load path materializes registered + tensors on CPU or in a default dtype usable without model-specific hooks. + + Parameters + ---------- + model_key : str, optional + Key identifying the source model inside ``ctx.models``. Default ``"main"``. + decay : float, optional + EMA decay factor in ``[0.0, 1.0)``. Default ``0.999``. + update_every : int, optional + Positive step stride for averaging updates. Default ``1``. + start_step : int, optional + Non-negative minimum completed step before updates begin. Default ``0``. + use_buffers : bool, optional + Forwarded to :class:`AveragedModel`; when ``True`` also averages + module buffers. Default ``True``. + + Raises + ------ + pydantic.ValidationError + If any field violates its declared bounds or an unknown kwarg is passed. + KeyError + On first eligible call, if ``model_key`` is missing from ``ctx.models``. + RuntimeError + From :meth:`get_averaged_model` when called before lazy init. + + See Also + -------- + torch.optim.swa_utils.AveragedModel : Underlying averaging wrapper. + torch.optim.swa_utils.get_ema_multi_avg_fn : Factory for the EMA averaging function. + + Examples + -------- + Checkpoint recipe for **inference / eval reload** of the EMA-averaged + weights. Save ``hook.get_averaged_model().module`` alongside the base + model and rebuild the :class:`~torch.optim.swa_utils.AveragedModel` + wrapper after loading, because + :func:`~nvalchemi.training.create_model_spec` only reconstructs plain + :class:`~torch.nn.Module` objects: + + >>> from torch import nn # doctest: +SKIP + >>> from torch.optim.swa_utils import AveragedModel # doctest: +SKIP + >>> from nvalchemi.training import ( # doctest: +SKIP + ... EMAHook, create_model_spec, load_checkpoint, save_checkpoint, + ... ) + >>> base = nn.Linear(4, 2) # doctest: +SKIP + >>> hook = EMAHook(model_key="main", decay=0.99) # doctest: +SKIP + >>> # ... training loop drives `hook` via TrainingStrategy ... + >>> spec = create_model_spec(nn.Linear, in_features=4, out_features=2) # doctest: +SKIP + >>> save_checkpoint( # doctest: +SKIP + ... "ckpt/", + ... models={ + ... "main": (base, spec), + ... "main_ema": (hook.get_averaged_model().module, spec), + ... }, + ... ) + >>> loaded = load_checkpoint("ckpt/") # doctest: +SKIP + >>> reconstructed_ema = AveragedModel(loaded.models["main_ema"][0]) # doctest: +SKIP + + To **resume training with EMA continuing** from a checkpoint, use + :meth:`state_dict` / :meth:`load_state_dict`, which round-trip + ``num_updates`` and the averaged weights into a freshly constructed + hook. + + Notes + ----- + The default deepcopy-based construction does not support + ``fully_shard`` (FSDP2) / DTensor models; override + :meth:`_build_averaged_model` to supply a pre-built sharded copy. + """ + + model_key: Annotated[ + str, + StringConstraints(strip_whitespace=True, min_length=1), + Field(description="Key identifying the source model in ctx.models."), + ] = "main" + decay: Annotated[ + float, Field(ge=0.0, lt=1.0, description="EMA decay factor in [0.0, 1.0).") + ] = 0.999 + update_every: Annotated[ + int, + Field( + gt=0, + description="Completed-step interval between EMA updates (global-modulo).", + ), + ] = 1 + start_step: Annotated[ + int, Field(ge=0, description="First completed step eligible for EMA updates.") + ] = 0 + use_buffers: Annotated[ + bool, + Field( + description="If True, also average module buffers (e.g. BN running stats)." + ), + ] = True + num_updates: Annotated[ + int, + Field( + ge=0, + description="Number of EMA updates performed; restored from checkpoints.", + ), + ] = 0 + + # Runs after lower-priority update hooks have made step/veto decisions. + priority: ClassVar[int] = 50 + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + _averaged_model: AveragedModel | None = PrivateAttr(default=None) + _pending_averaged_state: dict[str, Any] | None = PrivateAttr(default=None) + + def _build_averaged_model(self, source: nn.Module) -> AveragedModel: + """Build the :class:`AveragedModel` wrapping ``source``. + + Override point: a caller that owns model sharding can return a + pre-built copy instead (the default deepcopy fails on a + ``fully_shard``-ed source). + """ + averaged = AveragedModel( + source, + multi_avg_fn=get_ema_multi_avg_fn(self.decay), + use_buffers=self.use_buffers, + ) + _align_to_source_tensors(averaged.module, _module_tensors(source)) + return averaged + + def _ensure_initialized(self, ctx: TrainContext) -> None: + if self._averaged_model is not None: + return + try: + source = ctx.models[self.model_key] + except KeyError as exc: + available = sorted(ctx.models.keys()) + raise KeyError( + f"EMAHook could not resolve model_key={self.model_key!r}; " + f"available keys in TrainContext.models: {available}" + ) from exc + + self._averaged_model = self._build_averaged_model(_unwrap_model(source)) + if self._pending_averaged_state is not None: + source_tensors = _module_tensors(_unwrap_model(source)) + self._averaged_model.load_state_dict(self._pending_averaged_state) + _align_to_source_tensors(self._averaged_model.module, source_tensors) + self._pending_averaged_state = None + + def _publish_averaged_model(self, ctx: TrainContext) -> None: + """Publish averaged weights into the strategy inference-model slot.""" + setter = getattr(ctx.workflow, "set_inference_model", None) + if setter is not None: + setter(self.get_averaged_model().module, model_key=self.model_key) + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool = False, + ) -> tuple[bool, torch.Tensor | None]: + """Initialize or update the averaged model at the relevant stages.""" + match stage: + case TrainingStage.SETUP: + # Build the EMA copy early so validation can use restored weights. + self._ensure_initialized(ctx) + self._publish_averaged_model(ctx) + case TrainingStage.AFTER_OPTIMIZER_STEP: + if will_skip: + return True, getattr(ctx, "loss", None) + completed_step = ctx.step_count + 1 + if ( + completed_step < self.start_step + or completed_step % self.update_every + ): + return True, getattr(ctx, "loss", None) + # Apply the actual EMA update only after an eligible optimizer step. + self._ensure_initialized(ctx) + source = ctx.models[self.model_key] + self.get_averaged_model().update_parameters(_unwrap_model(source)) + self.num_updates += 1 + self._publish_averaged_model(ctx) + case _: + # Other training stages do not affect EMA state. + pass + return True, getattr(ctx, "loss", None) + + def get_averaged_model(self) -> AveragedModel: + """Return the :class:`AveragedModel` wrapper or raise if uninitialized. + + Raises + ------ + RuntimeError + If neither setup nor an eligible training step has initialized EMA. + """ + if self._averaged_model is None: + raise RuntimeError( + "EMAHook has not initialized an averaged model yet. " + "The hook initializes during TrainingStage.SETUP or the first " + f"eligible AFTER_OPTIMIZER_STEP (start_step={self.start_step}, " + f"update_every={self.update_every})." + ) + return self._averaged_model + + def state_dict(self) -> dict[str, Any]: + """Return a serializable snapshot of hook state. + + Returns + ------- + dict[str, Any] + Contains the config fields, ``num_updates``, and — if + available — ``averaged_model_state`` sourced from the live + :class:`AveragedModel` or, before lazy init, from any + stashed pending state. No ``device`` key is emitted. + """ + out: dict[str, Any] = self.model_dump() + if self._averaged_model is not None: + out["averaged_model_state"] = self._averaged_model.state_dict() + elif self._pending_averaged_state is not None: + out["averaged_model_state"] = self._pending_averaged_state + return out + + def load_state_dict(self, state: Mapping[str, Any]) -> None: + """Restore hook counters and averaged weights from a prior snapshot. + + Parameters + ---------- + state : Mapping[str, Any] + Mapping produced by :meth:`state_dict`. Missing config keys + and ``num_updates`` are ignored. Missing + ``averaged_model_state`` clears any prior live or pending + averaged state. + Any present config key must equal the corresponding + constructor field. + + Raises + ------ + ValueError + If a config field in ``state`` differs from this hook's + current field. + + Notes + ----- + Before lazy init, ``averaged_model_state`` is stashed and + applied during :meth:`_ensure_initialized`. Clearing on absence + prevents stale averaged state from surviving a config-only + reload. Checkpoint loaders may still choose a ``map_location``, + but EMAHook reapplies per-tensor device and floating-point dtype + placement after loading averaged state so registered tensors remain + usable for validation. + """ + for key in type(self).model_fields: + if key == "num_updates": + continue + if key in state and state[key] != (current := getattr(self, key)): + raise ValueError( + f"EMAHook checkpoint conflict: {key}={state[key]!r} vs " + f"constructor {key}={current!r}; construct the hook " + "with matching config or load into a fresh instance" + ) + if "num_updates" in state: + self.num_updates = int(state["num_updates"]) + if "averaged_model_state" in state: + if self._averaged_model is None: + self._pending_averaged_state = state["averaged_model_state"] + else: + tensors = _module_tensors(self._averaged_model.module) + self._averaged_model.load_state_dict(state["averaged_model_state"]) + _align_to_source_tensors(self._averaged_model.module, tensors) + self._pending_averaged_state = None + else: + self._averaged_model = None + self._pending_averaged_state = None diff --git a/nvalchemi/training/hooks/mixed_precision.py b/nvalchemi/training/hooks/mixed_precision.py new file mode 100644 index 00000000..88d95d09 --- /dev/null +++ b/nvalchemi/training/hooks/mixed_precision.py @@ -0,0 +1,327 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mixed-precision update hook driving ``torch.amp.autocast`` and ``GradScaler``. + +See :class:`MixedPrecisionHook` for the user-facing API. The hook composes +through :class:`~nvalchemi.training.hooks.TrainingUpdateOrchestrator` so that +:class:`~nvalchemi.training.strategy.TrainingStrategy` remains free of any +AMP-specific code. +""" + +from __future__ import annotations + +from contextlib import AbstractContextManager, nullcontext +from types import TracebackType +from typing import Annotated, Any, ClassVar + +import torch +from pydantic import AfterValidator, BaseModel, BeforeValidator, ConfigDict, PrivateAttr + +from nvalchemi._serialization import _dtype_deserialize, _wrap_custom_type +from nvalchemi.hooks._context import TrainContext +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training.hooks.update import TrainingUpdateHook + +__all__ = ["MixedPrecisionHook"] + + +_SUPPORTED_PRECISIONS: tuple[torch.dtype, ...] = ( + torch.float32, + torch.bfloat16, + torch.float16, +) +"""Autocast dtypes this hook understands.""" + +_PRECISION_ALIASES: dict[str, str] = { + "fp32": "float32", + "bf16": "bfloat16", + "fp16": "float16", +} +"""Common shorthand precision names accepted by :class:`MixedPrecisionHook`.""" + + +def _supported_precision_names() -> str: + """Return the supported precision names for validation messages.""" + return ", ".join( + str(dtype).removeprefix("torch.") for dtype in _SUPPORTED_PRECISIONS + ) + + +def _deserialize_precision(value: Any) -> Any: + """Deserialize canonical dtype strings plus supported shorthand aliases.""" + if not isinstance(value, str): + return value + normalized = value.removeprefix("torch.").lower() + normalized = _PRECISION_ALIASES.get(normalized, normalized) + try: + return _dtype_deserialize(normalized) + except (TypeError, ValueError) as exc: + supported = _supported_precision_names() + raise ValueError( + f"MixedPrecisionHook.precision must be one of ({supported}); got {value!r}." + ) from exc + + +def _restrict_precision(value: torch.dtype) -> torch.dtype: + """Reject dtypes outside :data:`_SUPPORTED_PRECISIONS`.""" + if value not in _SUPPORTED_PRECISIONS: + supported = _supported_precision_names() + raise ValueError( + f"MixedPrecisionHook.precision must be one of ({supported}); got {value!r}." + ) + return value + + +Precision = Annotated[ + _wrap_custom_type(torch.dtype), + BeforeValidator(_deserialize_precision), + AfterValidator(_restrict_precision), +] +"""``torch.dtype`` field accepting canonical names, aliases, or dtype objects.""" + + +class MixedPrecisionHook(BaseModel, TrainingUpdateHook): + """Automatic-mixed-precision hook driving autocast and ``GradScaler``. + + ``MixedPrecisionHook`` is a + :class:`~nvalchemi.training.hooks.TrainingUpdateHook`. When it is + registered directly on :class:`~nvalchemi.training.strategy.TrainingStrategy`, + the strategy auto-wraps it in a + :class:`~nvalchemi.training.hooks.TrainingUpdateOrchestrator`. The + orchestrator owns ``backward()`` and optimizer/scheduler stepping; + this hook supplies a scaled loss, exposes ``ctx.grad_scaler`` for + scaler-aware stepping, and unscales gradients immediately before an + optimizer step proceeds so gradient accumulation can keep accumulating + scaled gradients. + + The first :attr:`TrainingStage.BEFORE_BATCH` lazily constructs the + autocast region on the workflow's primary device + (``ctx.workflow.devices[0]``), so the hook need not know the device at + construction time. For fp16, the same path also lazily constructs the + :class:`torch.amp.GradScaler`. The autocast region is released inside + :attr:`TrainingStage.DO_BACKWARD` before the orchestrator calls + ``backward()``, while the scaler persists across batches. Force and + stress predictions produced during the model forward, plus the configured + training losses, are therefore inside the autocast region; backward is + not. + + Precision modes: + + * :data:`torch.float32` — no autocast context or scaler is created; the + hook is a functional no-op aside from participating in the orchestrated + update path. + * :data:`torch.bfloat16` — autocast casts eligible ops to ``bfloat16``. + No gradient scaling because bf16's exponent range matches fp32. + * :data:`torch.float16` — autocast casts eligible ops to ``float16`` + during forward and loss computation. The scaler scales the loss before + the orchestrator calls ``backward()``, unscales gradients just before + optimizer stepping, + and skips optimizer steps that would otherwise consume ``inf``/``nan`` + gradients. + + Parameters + ---------- + precision : torch.dtype + Autocast dtype and scaler policy. Accepts either a + :class:`torch.dtype` (e.g. ``torch.float16``) or the canonical + string name (``"float32"``, ``"bfloat16"``, ``"float16"``), or a + shorthand alias (``"fp32"``, ``"bf16"``, ``"fp16"``). + + Attributes + ---------- + precision : torch.dtype + Active autocast dtype. + priority : int + Training-update priority. Fixed at ``20`` so loss-scaling runs + after gradient accumulation transforms and before gradient + clipping / spike-skip hooks. + + Raises + ------ + pydantic.ValidationError + If ``precision`` is not one of the supported dtypes. + + Examples + -------- + >>> import torch + >>> from nvalchemi.training.hooks import MixedPrecisionHook + >>> MixedPrecisionHook(precision=torch.bfloat16).precision + torch.bfloat16 + >>> MixedPrecisionHook(precision="float16").precision + torch.float16 + >>> MixedPrecisionHook(precision="bf16").precision + torch.bfloat16 + + Notes + ----- + * When multiple optimizers are configured, every optimizer in + ``ctx.optimizers`` is unscaled in list order immediately before + stepping. The orchestrator advances each scheduler in + ``ctx.lr_schedulers`` only when its paired optimizer step was not + skipped by the scaler. + * For gradient accumulation, accumulated gradients remain scaled until + the effective batch is ready to step. Earlier-priority update hooks + can veto :attr:`TrainingStage.DO_OPTIMIZER_STEP` to suppress unscale, + scaler step, and scaler update for intermediate accumulation batches. + * A strategy may register only one ``MixedPrecisionHook``. Multiple + instances are rejected to prevent duplicated autocast/scaler operations. + * Under ``precision=torch.float16`` on CPU, no warning is emitted and + no exception is raised; the hook still drives ``backward()`` and + ``step()`` through the same scaler path. + """ + + precision: Precision + + priority: ClassVar[int] = 20 + _exclusive_update_key: ClassVar[str | None] = "MixedPrecisionHook" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=False, + extra="forbid", + ) + + _autocast_ctx: torch.amp.autocast | None = PrivateAttr(default=None) + _scaler: torch.amp.GradScaler | None = PrivateAttr(default=None) + _active: bool = PrivateAttr(default=False) + + def __enter__(self) -> MixedPrecisionHook: + """Enter the hook's context; lazy-init is deferred to workflow stages. + + Returns + ------- + MixedPrecisionHook + This hook instance, for ``with`` expressions. + """ + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Exit the autocast region and reset internal state for reuse. + + Parameters + ---------- + exc_type : type[BaseException] | None + Exception class raised inside the managed block, if any. + exc : BaseException | None + Exception instance raised inside the managed block, if any. + tb : TracebackType | None + Traceback associated with ``exc``, if any. + """ + self._exit_autocast(exc_type, exc, tb) + self._scaler = None + + def inference_autocast(self, device: torch.device) -> AbstractContextManager[None]: + """Return the inference autocast context matching this precision. + + Parameters + ---------- + device : torch.device + Primary workflow device for the validation or inference pass. + + Returns + ------- + contextlib.AbstractContextManager[None] + No-op context for ``float32`` precision, otherwise a + :class:`torch.amp.autocast` context using this hook's configured + dtype. This helper intentionally does not create or touch a + :class:`torch.amp.GradScaler`, which is training-update state. + """ + if self.precision == torch.float32: + return nullcontext() + return torch.amp.autocast( + device_type=device.type, + dtype=self.precision, + enabled=True, + ) + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor | None]: + """Handle training-update stages inside ``TrainingUpdateOrchestrator``.""" + match stage: + case TrainingStage.BEFORE_BATCH: + self._enter_autocast(ctx) + case TrainingStage.DO_BACKWARD: + self._exit_autocast(None, None, None) + if self.precision == torch.float16: + scaler = self._ensure_scaler(ctx) + ctx.grad_scaler = scaler + return True, scaler.scale(ctx.loss) + case TrainingStage.DO_OPTIMIZER_STEP: + if self.precision == torch.float16: + scaler = self._ensure_scaler(ctx) + ctx.grad_scaler = scaler + if not will_skip: + self._unscale_gradients(ctx) + case TrainingStage.AFTER_OPTIMIZER_STEP: + self._exit_autocast(None, None, None) + case _: + pass + return True, ctx.loss + + def _ensure_scaler(self, ctx: TrainContext) -> torch.amp.GradScaler: + """Lazily construct the fp16 scaler for this workflow device.""" + if self._scaler is None: + device_type = ctx.workflow.devices[0].type + self._scaler = torch.amp.GradScaler( + device=device_type, + enabled=True, + ) + return self._scaler + + def _enter_autocast(self, ctx: TrainContext) -> None: + """Enter the forward/loss autocast region for this batch.""" + if self.precision == torch.float32: + return + if self.precision == torch.float16: + ctx.grad_scaler = self._ensure_scaler(ctx) + device_type = ctx.workflow.devices[0].type + if self._autocast_ctx is None: + self._autocast_ctx = torch.amp.autocast( + device_type=device_type, + dtype=self.precision, + enabled=True, + ) + self._autocast_ctx.__enter__() + self._active = True + + def _unscale_gradients(self, ctx: TrainContext) -> None: + """Unscale gradients immediately before an optimizer step proceeds.""" + if self.precision != torch.float16: + return + if self._scaler is None: + raise RuntimeError("MixedPrecisionHook: scaler not initialized.") + for opt in ctx.optimizers: + self._scaler.unscale_(opt) + + def _exit_autocast( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Exit the active autocast region while preserving scaler state.""" + if self._active and self._autocast_ctx is not None: + self._autocast_ctx.__exit__(exc_type, exc, tb) + self._autocast_ctx = None + self._active = False diff --git a/nvalchemi/training/hooks/update.py b/nvalchemi/training/hooks/update.py new file mode 100644 index 00000000..9f5cc73d --- /dev/null +++ b/nvalchemi/training/hooks/update.py @@ -0,0 +1,512 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training-update hook base class and orchestrator.""" + +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Any, ClassVar + +from nvalchemi.hooks._context import TrainContext +from nvalchemi.hooks._protocol import Hook +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training.optimizers import ( + _is_metric_driven, + step_lr_schedulers, + step_optimizers, + zero_gradients, +) + +if TYPE_CHECKING: + import torch + + +_TRAINING_UPDATE_STAGES: tuple[TrainingStage, ...] = ( + TrainingStage.BEFORE_BATCH, + TrainingStage.DO_BACKWARD, + TrainingStage.DO_OPTIMIZER_STEP, + TrainingStage.AFTER_OPTIMIZER_STEP, +) + + +_MULTIPLE_ORCHESTRATOR_MSG = ( + "Only one TrainingUpdateOrchestrator is allowed; compose update hooks " + "with `+` before registration." +) + + +def _hook_claims_stage(hook: Any, stage: TrainingStage) -> bool: + """Return True if hook fires on stage (mirrors _registry._call_hooks dispatch).""" + runs_on_stage = getattr(hook, "_runs_on_stage", None) + if runs_on_stage is not None: + return runs_on_stage(stage) + return getattr(hook, "stage", None) == stage + + +def _fold_training_update_hooks( + hooks: Sequence[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator], +) -> list[Hook | TrainingUpdateOrchestrator]: + """Fold TrainingUpdateHook/Orchestrator instances into a single orchestrator.""" + others: list[Hook] = [] + update_hooks: list[TrainingUpdateHook | TrainingUpdateOrchestrator] = [] + update_insertion_index: int | None = None + n_orch = 0 + for h in hooks: + if isinstance(h, TrainingUpdateOrchestrator): + if update_insertion_index is None: + update_insertion_index = len(others) + update_hooks.append(h) + n_orch += 1 + elif isinstance(h, TrainingUpdateHook): + if update_insertion_index is None: + update_insertion_index = len(others) + update_hooks.append(h) + else: + others.append(h) + if not update_hooks: + return list(hooks) + if n_orch > 1: + raise ValueError(_MULTIPLE_ORCHESTRATOR_MSG) + if len(update_hooks) == 1 and isinstance( + update_hooks[0], TrainingUpdateOrchestrator + ): + folded = update_hooks[0] + else: + folded = TrainingUpdateOrchestrator(*update_hooks) + insert_at = ( + update_insertion_index if update_insertion_index is not None else len(others) + ) + result: list[Hook | TrainingUpdateOrchestrator] = list(others) + result.insert(insert_at, folded) + return result + + +def _check_veto(decision: object, hook: object, stage: TrainingStage) -> None: + """Validate that ``__call__`` returned a strict ``bool`` for ``proceed``.""" + if not isinstance(decision, bool): + raise TypeError( + f"{type(hook).__name__}.__call__(stage={stage.name}) must return " + f"(bool, Tensor | None); proceed got {type(decision).__name__}. " + "Return True to proceed or False to skip." + ) + + +def _require_loss( + loss: torch.Tensor | None, hook: object, stage: TrainingStage +) -> torch.Tensor: + """Return ``loss`` or raise a stage-specific error for missing losses.""" + if loss is None: + raise TypeError( + f"{type(hook).__name__} did not provide a Tensor loss for " + f"{stage.name}; got None." + ) + return loss + + +def _get_scaler_scale(scaler: object) -> float | None: + """Return the scaler scale as ``float`` when the scaler exposes one.""" + get_scale = getattr(scaler, "get_scale", None) + if get_scale is None: + return None + try: + return float(get_scale()) + except (TypeError, ValueError): + return None + + +def _grad_scaler_step_skipped( + grad_scaler: Any, opt: torch.optim.Optimizer +) -> bool | None: + """Return whether ``grad_scaler.step(opt)`` skipped the optimizer step.""" + try: + found_inf = grad_scaler._found_inf_per_device(opt) + except Exception: + return None + try: + return any(bool(v.item()) for v in found_inf.values()) + except Exception: + return None + + +def _step_optimizers_with_context(ctx: TrainContext) -> bool: + """Step optimizers/schedulers and return whether optimizer stepping ran.""" + if ctx.grad_scaler is None: + step_optimizers(ctx.optimizers) + step_lr_schedulers(ctx.lr_schedulers) + return True + + if not ctx.lr_schedulers or all(sched is None for sched in ctx.lr_schedulers): + pre_scale = _get_scaler_scale(ctx.grad_scaler) + for opt in ctx.optimizers: + ctx.grad_scaler.step(opt) + ctx.grad_scaler.update() + post_scale = _get_scaler_scale(ctx.grad_scaler) + return ( + True if pre_scale is None or post_scale is None else post_scale >= pre_scale + ) + + skipped_flags: list[bool | None] = [] + for opt in ctx.optimizers: + ctx.grad_scaler.step(opt) + skipped_flags.append(_grad_scaler_step_skipped(ctx.grad_scaler, opt)) + + need_fallback = any(flag is None for flag in skipped_flags) + pre_scale = _get_scaler_scale(ctx.grad_scaler) if need_fallback else None + ctx.grad_scaler.update() + post_scale = _get_scaler_scale(ctx.grad_scaler) if need_fallback else None + fallback_skipped = ( + need_fallback + and pre_scale is not None + and post_scale is not None + and post_scale < pre_scale + ) + schedulers = list(ctx.lr_schedulers) + if len(schedulers) < len(skipped_flags): + schedulers.extend([None] * (len(skipped_flags) - len(schedulers))) + step_skipped_flags = [ + skipped is True or (fallback_skipped and skipped is None) + for skipped in skipped_flags + ] + if not any(step_skipped_flags): + step_lr_schedulers(ctx.lr_schedulers) + return True + for sched, step_skipped in zip(schedulers, step_skipped_flags, strict=True): + if sched is None: + continue + if _is_metric_driven(sched): + continue + if step_skipped: + continue + sched.step() + return False + + +class TrainingUpdateHook: + """Base class for hooks that customize training-update phases. + + Subclasses override :meth:`__call__` and dispatch on ``stage`` to + handle one or more claimed stages: ``BEFORE_BATCH``, ``DO_BACKWARD``, + ``DO_OPTIMIZER_STEP``, ``AFTER_OPTIMIZER_STEP``. + Compose via ``+`` to build a :class:`TrainingUpdateOrchestrator`. + See :ref:`training-update-hooks` for the stage contract and restrictions + each update hook must follow. + + Attributes + ---------- + priority : int + Dispatch order within an orchestrator; lower runs first. Canonical + buckets: 10 = gradient accumulation, 20 = mixed precision, + 30 = gradient clipping, 40 = spike skipping. Default 50. + _exclusive_update_key : str | None + Optional key for hook families that must appear at most once inside + an orchestrator. + + Notes + ----- + ``TrainingUpdateHook`` is NOT directly compatible with the standard + :class:`Hook` Protocol -- its ``__call__`` signature includes a + ``will_skip`` argument and returns ``(bool, torch.Tensor | None)`` rather + than the Protocol's ``__call__(ctx, stage) -> None``. This is + intentional: ``Hook`` is a structural Protocol so domain-specific + hook families can use signatures suited to their semantics. Bare + instances must be composed via ``+`` or wrapped by a + :class:`TrainingUpdateOrchestrator` (the strategy auto-wraps lone + hooks); the orchestrator owns Protocol compliance. + + ``will_skip`` is a stage-local cumulative veto signal. It is ``True`` when + an earlier, higher-priority hook has already requested that the current + stage's gated operation be skipped. The orchestrator still calls later + hooks after a veto so they can observe the decision, update bookkeeping, or + emit diagnostics, but those hooks should avoid side effects that assume the + gated operation will run. A hook may also return ``False`` to veto the + operation for lower-priority hooks. + + This signal is intended for composable pipeline behavior. For example, a + gradient-accumulation hook can veto ``DO_OPTIMIZER_STEP`` on non-step + microbatches; later hooks then receive ``will_skip=True`` and can skip + work such as gradient clipping, scaler updates, or expensive parameter + scans. ``will_skip`` is reset for each stage dispatch and should not be + interpreted as a global training-step status unless the orchestrator also + records that state on ``ctx``. + + Each ``__call__`` returns ``(proceed, loss)``: + + - ``proceed`` is a strict ``bool`` (``int``/``None`` raise + ``TypeError``). On ``BEFORE_BATCH`` and ``DO_OPTIMIZER_STEP`` the + orchestrator applies any-veto-wins composition: if any hook returns + ``False`` the gated operation (``zero_gradients`` or + ``optimizer/scheduler.step``) is skipped. On ``DO_BACKWARD`` and + ``AFTER_OPTIMIZER_STEP`` the value is unused; return ``True``. + - ``loss`` is the loss tensor the hook would use, transformed or not. + Default is ``ctx.loss`` unchanged. The orchestrator threads it + through hooks in priority order during ``DO_BACKWARD`` so each hook + sees its predecessor's transform; ``backward()`` runs once on the + final loss. Hooks that run on stages other than ``DO_BACKWARD`` may + return ``None`` for ``loss`` because the orchestrator ignores it + there. + + Examples + -------- + >>> import torch + >>> from nvalchemi.training._stages import TrainingStage + >>> class ClipGrads(TrainingUpdateHook): + ... priority = 30 + ... def __init__(self, max_norm): + ... self.max_norm = max_norm + ... def __call__(self, ctx, stage, will_skip): + ... match stage: + ... case TrainingStage.DO_OPTIMIZER_STEP: + ... if not will_skip: + ... for opt in ctx.optimizers: + ... params = (p for g in opt.param_groups for p in g["params"]) + ... torch.nn.utils.clip_grad_norm_(params, self.max_norm) + ... return True, ctx.loss + ... case _: + ... return True, ctx.loss + """ + + priority: int = 50 + _exclusive_update_key: ClassVar[str | None] = None + + def _runs_on_stage(self, stage: TrainingStage) -> bool: + """Return ``True`` for stages a training-update hook claims.""" + return stage in _TRAINING_UPDATE_STAGES + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor | None]: + """Run the hook for an update stage. + + Parameters + ---------- + ctx : TrainContext + Mutable training context shared by all hooks during the current + stage dispatch. + stage : TrainingStage + Update stage currently being dispatched. + will_skip : bool + ``True`` when an earlier, higher-priority hook has already vetoed + the gated operation for ``stage``. Hooks should use this to skip + side effects that only make sense when the operation will run, + while still performing any bookkeeping that must happen on every + dispatch. + + Returns + ------- + tuple[bool, torch.Tensor | None] + ``(proceed, loss)``. ``proceed`` controls the skip signal passed + to subsequent hooks: ``True`` keeps the pipeline proceeding, + while ``False`` causes later hooks to receive ``will_skip=True`` + and skips the gated operation for ``stage``. ``loss`` is the loss + tensor to pass to subsequent hooks; return ``ctx.loss`` unchanged + when the hook does not transform the loss. + """ + return True, ctx.loss + + def __add__( + self, other: TrainingUpdateHook | TrainingUpdateOrchestrator + ) -> TrainingUpdateOrchestrator: + """Compose this hook with another update hook or orchestrator. + + Parameters + ---------- + other : TrainingUpdateHook | TrainingUpdateOrchestrator + Hook or orchestrator to compose with this hook. + + Returns + ------- + TrainingUpdateOrchestrator + Orchestrator containing this hook and ``other``. Hook execution + order is determined by ``priority`` after composition. + """ + if not isinstance(other, (TrainingUpdateHook, TrainingUpdateOrchestrator)): + return NotImplemented + return TrainingUpdateOrchestrator(self, other) + + +class TrainingUpdateOrchestrator: + """Composes :class:`TrainingUpdateHook` instances and drives updates. + + Claims the training-update stages ``BEFORE_BATCH``, ``DO_BACKWARD``, + ``DO_OPTIMIZER_STEP``, ``AFTER_OPTIMIZER_STEP``. The strategy also calls + the orchestrator during ``SETUP`` so child hooks can initialize runtime + state before the first batch. Per-stage behavior is + selected by direct :class:`TrainingStage` comparisons to avoid per-batch + multiple-dispatch overhead. + See :ref:`training-update-hooks` for the stage contract enforced by the + orchestrator. + + Parameters + ---------- + *hooks : TrainingUpdateHook or TrainingUpdateOrchestrator + Hooks to compose. Any orchestrator argument is flattened into its + children. Members are sorted by ``priority`` ascending; ties + preserve insertion order (Python's stable sort). + + Attributes + ---------- + frequency : int + Required by the :class:`Hook` Protocol; always ``1``. + stage : None + Set to ``None`` so the registry consults ``_runs_on_stage``. + + Raises + ------ + TypeError + If any positional argument is not a ``TrainingUpdateHook`` or + ``TrainingUpdateOrchestrator``. + + Notes + ----- + ``TrainingUpdateOrchestrator`` IS compatible with the standard + :class:`Hook` Protocol -- it is the registry-facing wrapper around + one or more :class:`TrainingUpdateHook` instances. Concrete training + update hooks (``EMAHook``, ``GradientClipHook``, etc.) are + NOT directly Protocol-compliant on their own; they must be composed + into an orchestrator before registration. The training strategy + auto-wraps a bare :class:`TrainingUpdateHook` for convenience. + + On ``DO_BACKWARD`` each hook returns ``(_, loss)``; the orchestrator + assigns ``ctx.loss = loss`` between hooks so the next hook sees the + transformed value. ``backward()`` is called once on the final + ``ctx.loss``. Example: a ``*0.5`` hook followed by a ``*2.0`` hook + leaves ``ctx.loss`` equal to the original loss before backward. + """ + + frequency: int = 1 + stage = None + + def __init__(self, *hooks: TrainingUpdateHook | TrainingUpdateOrchestrator) -> None: + flattened: list[TrainingUpdateHook] = [] + for i, h in enumerate(hooks): + if isinstance(h, TrainingUpdateOrchestrator): + flattened.extend(h._hooks) + elif isinstance(h, TrainingUpdateHook): + flattened.append(h) + else: + raise TypeError( + f"argument {i} must be TrainingUpdateHook or " + f"TrainingUpdateOrchestrator; got {type(h).__name__}. " + "If you have an iterable, call " + "TrainingUpdateOrchestrator(*hooks)." + ) + flattened.sort(key=lambda h: h.priority) + exclusive_hooks: dict[str, TrainingUpdateHook] = {} + for hook in flattened: + key = hook._exclusive_update_key + if key is None: + continue + if key in exclusive_hooks: + first = type(exclusive_hooks[key]).__name__ + second = type(hook).__name__ + raise ValueError( + f"Only one update hook with exclusive key {key!r} may be " + f"registered; got {first} and {second}." + ) + exclusive_hooks[key] = hook + self._hooks: list[TrainingUpdateHook] = flattened + self._optimizer_step_skipped = False + + def _runs_on_stage(self, stage: TrainingStage) -> bool: + """Return ``True`` for the stages this orchestrator claims.""" + return stage in _TRAINING_UPDATE_STAGES + + def __enter__(self) -> TrainingUpdateOrchestrator: + """Enter lifecycle contexts owned by child update hooks.""" + for hook in self._hooks: + enter = getattr(hook, "__enter__", None) + if enter is not None: + enter() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Exit or close lifecycle contexts owned by child update hooks.""" + for hook in reversed(self._hooks): + exit_ = getattr(hook, "__exit__", None) + if exit_ is not None: + exit_(exc_type, exc, tb) + else: + close = getattr(hook, "close", None) + if close is not None: + close() + + def close(self) -> None: + """Close child update hooks that expose ``close``.""" + for hook in reversed(self._hooks): + close = getattr(hook, "close", None) + if close is not None: + close() + + @property + def optimizer_step_skipped(self) -> bool: + """Whether the most recent optimizer-step stage was vetoed.""" + return self._optimizer_step_skipped + + def iter_hooks(self) -> Iterator[TrainingUpdateHook]: + """Yield child update hooks in orchestrator dispatch order.""" + return iter(self._hooks) + + def _should_run_gated_stage(self, ctx: TrainContext, stage: TrainingStage) -> bool: + """Run all hooks for a gated stage and return the any-veto-wins decision.""" + should_run = True + for hook in self._hooks: + proceed, _ = hook(ctx, stage, not should_run) + _check_veto(proceed, hook, stage) + should_run = proceed and should_run + return should_run + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + """Run orchestrator logic for ``stage`` when it is an update stage.""" + match stage: + case TrainingStage.SETUP: + for hook in self._hooks: + hook(ctx, stage, False) + case TrainingStage.BEFORE_BATCH: + # situation where this may skip is gradient accumulation; otherwise + # the typical workflow would be to actually zero gradients + if self._should_run_gated_stage(ctx, stage): + zero_gradients(ctx.optimizers) + case TrainingStage.DO_BACKWARD: + for hook in self._hooks: + _, loss = hook(ctx, stage, False) + ctx.loss = _require_loss(loss, hook, stage) + _require_loss(ctx.loss, self, stage).backward() + case TrainingStage.DO_OPTIMIZER_STEP: + # situation where this might be skipped is during gradient + # accumulation, or perhaps spike skipping + should_run = self._should_run_gated_stage(ctx, stage) + if should_run: + should_run = _step_optimizers_with_context(ctx) + self._optimizer_step_skipped = not should_run + case TrainingStage.AFTER_OPTIMIZER_STEP: + for hook in self._hooks: + hook(ctx, stage, self._optimizer_step_skipped) + + def __add__( + self, other: TrainingUpdateHook | TrainingUpdateOrchestrator + ) -> TrainingUpdateOrchestrator: + """Implements the syntactic sugar to compose multiple update hooks together""" + if not isinstance(other, (TrainingUpdateHook, TrainingUpdateOrchestrator)): + return NotImplemented + return TrainingUpdateOrchestrator(self, other) diff --git a/nvalchemi/training/losses/__init__.py b/nvalchemi/training/losses/__init__.py new file mode 100644 index 00000000..8eaac255 --- /dev/null +++ b/nvalchemi/training/losses/__init__.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss-function abstractions, schedules, terms, and reductions. + +Loss terms are :class:`BaseLossFunction` instances that consume +prediction and target tensors directly and return raw (unweighted) loss +tensors. :class:`ComposedLossFunction` owns the per-component weighting +— either plain floats or :class:`LossWeightSchedule` instances — and, +by default, renormalizes the effective weights to sum to ``1.0``. +Operator sugar (``3.0 * EnergyMSELoss() + 2.0 * ForceMSELoss()``) builds a +composition in one expression. Schedule instances attached to a +composition's weights are reconstructed by ``TrainingStrategy`` from +their ``(instance, spec)`` pair, mirroring the pattern used for models +and optimizers. +""" + +from __future__ import annotations + +from nvalchemi.training.losses.base import LossWeightSchedule +from nvalchemi.training.losses.composition import ( + BaseLossFunction, + ComposedLossFunction, + ComposedLossOutput, + ReductionContext, + assert_same_shape, + loss_component_to_spec, +) +from nvalchemi.training.losses.reductions import ( + frobenius_mse, + per_graph_mean, + per_graph_sum, +) +from nvalchemi.training.losses.schedules import ( + ConstantWeight, + CosineWeight, + LinearWeight, + PiecewiseWeight, +) +from nvalchemi.training.losses.terms import ( + EnergyHuberLoss, + EnergyMAELoss, + EnergyMSELoss, + ForceHuberLoss, + ForceL2NormLoss, + ForceMSELoss, + StressHuberLoss, + StressMSELoss, +) + +__all__ = [ + "BaseLossFunction", + "ComposedLossFunction", + "ComposedLossOutput", + "ConstantWeight", + "CosineWeight", + "EnergyHuberLoss", + "EnergyMAELoss", + "EnergyMSELoss", + "ForceHuberLoss", + "ForceL2NormLoss", + "ForceMSELoss", + "LinearWeight", + "LossWeightSchedule", + "PiecewiseWeight", + "ReductionContext", + "StressHuberLoss", + "StressMSELoss", + "assert_same_shape", + "frobenius_mse", + "loss_component_to_spec", + "per_graph_mean", + "per_graph_sum", +] diff --git a/nvalchemi/training/losses/base.py b/nvalchemi/training/losses/base.py new file mode 100644 index 00000000..4433d1ba --- /dev/null +++ b/nvalchemi/training/losses/base.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base protocols and models for loss-function schedules. + +This module also re-exports :class:`BaseLossFunction` and +:class:`ComposedLossFunction` from :mod:`.composition` for discoverability: +subclass authors can do ``from nvalchemi.training.losses.base import +BaseLossFunction`` without tracking the internal module layout. The +canonical home of the leaf base class, keyed composition aggregator, and +composition output type remains :mod:`.composition`. +""" + +from __future__ import annotations + +from typing import Annotated, Protocol, runtime_checkable + +from pydantic import BaseModel, Field + + +@runtime_checkable +class LossWeightSchedule(Protocol): + """Runtime-checkable protocol for loss-weight schedules. + + Any object callable with signature ``(step: int, epoch: int) -> float`` + and exposing a ``per_epoch`` attribute satisfies this protocol, and is + therefore accepted inside + :class:`~nvalchemi.training.losses.ComposedLossFunction`'s ``weights`` + sequence or as the right-hand side of ``schedule * leaf``. Concrete + Pydantic schedules live in + :mod:`~nvalchemi.training.losses.schedules`. + + Attributes + ---------- + per_epoch + If ``True``, the schedule should advance by ``epoch`` instead of + by ``step``. This aligns loss-weight updates with training loops + that update learning-rate schedules once per epoch. + + Parameters + ---------- + step + Current global training step (0-indexed). + epoch + Current epoch number (0-indexed). + + Returns + ------- + float + Scalar weight to apply to the associated loss term. + """ + + per_epoch: Annotated[ + bool, + "Whether the schedule steps per epoch; if False, schedule will update per step/batch.", + ] + + def __call__(self, step: int, epoch: int) -> float: + """Evaluate the schedule at ``(step, epoch)``.""" + ... + + +class _BaseWeightSchedule(BaseModel): + """Base Pydantic model for serializable loss-weight schedules. + + Attributes + ---------- + per_epoch + If ``False``, schedule windows advance by global step. If + ``True``, they advance by epoch. + """ + + model_config = {"frozen": True} + + per_epoch: Annotated[ + bool, + Field( + default=False, + description=( + "Whether to advance this schedule by epoch instead of by global step." + ), + ), + ] = False + + def _map_schedule_index(self, step: int, epoch: int) -> int: + """Return the counter used to advance this schedule. + + This method is only intended to be used if your schedule is mutually + exclusive; if your schedule uses both step *and* epoch values, then + you do not need to use this function as it's only for routing. + """ + return epoch if self.per_epoch else step + + +# Re-exports for discoverability. Import at the bottom to avoid a circular +# import: ``composition`` imports ``_BaseWeightSchedule`` indirectly through +# ``schedules``, which imports this module. +from nvalchemi.training.losses.composition import ( # noqa: E402 + BaseLossFunction, + ComposedLossFunction, + ComposedLossOutput, + ReductionContext, +) + +__all__ = [ + "BaseLossFunction", + "ComposedLossFunction", + "ComposedLossOutput", + "LossWeightSchedule", + "ReductionContext", +] diff --git a/nvalchemi/training/losses/composition.py b/nvalchemi/training/losses/composition.py new file mode 100644 index 00000000..9a88ab3d --- /dev/null +++ b/nvalchemi/training/losses/composition.py @@ -0,0 +1,1138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Composable :class:`torch.nn.Module`-based loss-function abstractions. + +Leaf loss terms are tensor-to-tensor :class:`BaseLossFunction` instances +whose :meth:`~BaseLossFunction.forward` returns the raw, unweighted loss +tensor. :class:`ComposedLossFunction` owns the per-component weighting +(either floats or :class:`LossWeightSchedule` instances) and, by default, +normalizes the resolved weights so they sum to ``1.0`` at every call. +This keeps weight scheduling a *relative* knob and leaves the learning +rate as the sole *absolute* magnitude control. +""" + +from __future__ import annotations + +import abc +import math +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Any, TypedDict, cast + +import torch +from torch import nn + +from nvalchemi._serialization import _extract_init_kwargs_from_attrs +from nvalchemi.training._spec import BaseSpec, create_model_spec +from nvalchemi.training.losses.base import LossWeightSchedule + + +class ComposedLossOutput(TypedDict): + """Output returned by :class:`ComposedLossFunction`. + + This is solely used as a type hint, and not as a concrete data + structure; it's used to signal to users that the emitted dict + from composed losses will always at least contain the keys within + this ``TypedDict``. + + The mapping always contains ``total_loss`` and four per-component + sub-mappings keyed by component name. ``per_component_unweighted`` + holds each raw component loss before multiplication by its effective + weight. ``per_component_weight`` holds the effective (possibly + normalized) weight actually applied to each component at this call; + ``per_component_raw_weight`` holds the pre-normalization resolved + weight — identical to ``per_component_weight`` when + ``normalize_weights=False`` and useful for logging the underlying + schedule value regardless of normalization. ``per_component_sample`` + carries per-component **weighted** per-sample loss tensors of shape + ``(B,)``, detached; see :attr:`BaseLossFunction.per_sample_loss` for + the per-leaf populate-or-skip contract. + """ + + total_loss: torch.Tensor + per_component_unweighted: dict[str, torch.Tensor] + per_component_weight: dict[str, float] + per_component_raw_weight: dict[str, float] + per_component_sample: dict[str, torch.Tensor] + + +def loss_component_to_spec(component: BaseLossFunction) -> BaseSpec: + """Serialize a leaf loss component to a :class:`BaseSpec`. + + Parameters + ---------- + component : BaseLossFunction + Loss component to serialize. Constructor attributes are recovered by + signature introspection, and nested weight schedules are serialized as + nested specs when present. + + Returns + ------- + BaseSpec + JSON-ready spec that rebuilds ``component``. + + Raises + ------ + TypeError + If ``component`` is a composed loss or is not a leaf + :class:`BaseLossFunction`. + """ + if isinstance(component, ComposedLossFunction): + raise TypeError( + "loss_component_to_spec accepts only leaf BaseLossFunction objects; " + "use ComposedLossFunction spec serialization for composed losses." + ) + if not isinstance(component, BaseLossFunction): + raise TypeError( + "loss_component_to_spec accepts only leaf BaseLossFunction objects; " + f"got {type(component).__name__}." + ) + kwargs = _extract_init_kwargs_from_attrs(component) + weight = kwargs.get("weight") + if weight is not None and hasattr(weight, "model_dump"): + kwargs["weight"] = create_model_spec(type(weight), **weight.model_dump()) + return create_model_spec(type(component), **kwargs) + + +def assert_same_shape( + pred: torch.Tensor, + target: torch.Tensor, + *, + name: str, + prediction_key: str | None = None, + target_key: str | None = None, + strict: bool = False, +) -> None: + """Raise :class:`ValueError` when ``pred`` and ``target`` are not compatible. + + Checks dtype equality first (a dtype mismatch is usually a bug + upstream of shape), then the shape compatibility policy selected by + ``strict``. + + Shape policy + ------------ + ``strict=False`` (default) accepts any pair of shapes that is + broadcast-compatible via :func:`torch.broadcast_shapes`. This is + convenient for custom losses that legitimately broadcast (e.g. a + per-graph scale against a per-component target) but is a trap for + elementwise losses: ``(B, 1)`` vs ``(B, 3)`` passes, and the + subsequent ``pred - target`` silently broadcasts into a ``(B, 3)`` + residual — usually not what you intend. + + ``strict=True`` requires ``pred.shape == target.shape`` exactly. All + built-in leaf losses (:class:`EnergyMSELoss`, :class:`ForceMSELoss`, + :class:`StressMSELoss`) pass ``strict=True`` because their elementwise + arithmetic would otherwise corrupt the scalar loss under a + broadcast-compatible-but-unequal pair. Custom + :class:`BaseLossFunction` subclasses that do elementwise arithmetic + should also pass ``strict=True``. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor. + target : torch.Tensor + Target tensor whose dtype must equal ``pred``'s and whose shape + must be compatible with ``pred``'s under the selected policy. + name : str + Calling loss-term's class name, used as a prefix in the error + message (typically ``type(self).__name__``). + prediction_key : str, optional + Key the prediction tensor was pulled from in the composed + mapping. When provided, included in the error message. + target_key : str, optional + Key the target tensor was pulled from in the composed mapping. + When provided, included in the error message. + strict : bool, default False + When ``True``, require ``pred.shape == target.shape``. When + ``False``, only require broadcast compatibility. + + Raises + ------ + ValueError + If ``pred.dtype != target.dtype``, or if the shape policy is + violated (broadcast-incompatible for ``strict=False``, unequal + for ``strict=True``). + """ + pred_fragment = ( + f"prediction_key={prediction_key!r}" + if prediction_key is not None + else "prediction" + ) + target_fragment = ( + f"target_key={target_key!r}" if target_key is not None else "target" + ) + if pred.dtype != target.dtype: + raise ValueError( + f"{name}: prediction and target dtype mismatch; " + f"{pred_fragment} has dtype {pred.dtype}, " + f"{target_fragment} has dtype {target.dtype}." + ) + if strict: + if pred.shape != target.shape: + raise ValueError( + f"{name}: prediction and target shape must match exactly " + f"for elementwise loss; {pred_fragment} has shape " + f"{tuple(pred.shape)}, {target_fragment} has shape " + f"{tuple(target.shape)}." + ) + return + try: + torch.broadcast_shapes(pred.shape, target.shape) + except RuntimeError as exc: + raise ValueError( + f"{name}: prediction and target shape mismatch; " + f"{pred_fragment} has shape {tuple(pred.shape)}, " + f"{target_fragment} has shape {tuple(target.shape)}." + ) from exc + + +class ReductionContext(dict): + """Lightweight metadata bag flowing through the loss template pipeline. + + A plain ``dict`` subclass used to pass metadata between + :meth:`BaseLossFunction.normalize`, :meth:`~BaseLossFunction.mask`, + and :meth:`~BaseLossFunction.reduce`. Using a bare ``dict`` instead + of ``TypedDict(total=False)`` keeps the type ``torch.compile``-safe + (Dynamo rejects ``TypedDict`` with optional keys). + + Conventional keys + ----------------- + ``"weights"`` : torch.Tensor + Per-sample weights for the final reduction. For energy losses + with ``per_atom=True`` this carries atom counts ``(B, 1)``; for + force losses it may carry per-atom or per-component weights. + """ + + +class BaseLossFunction(nn.Module, abc.ABC): + """Abstract :class:`torch.nn.Module` base for ALCHEMI loss functions. + + ``BaseLossFunction`` implements a **template-method** + :meth:`forward` pipeline that orchestrates five overridable hooks: + + 1. :meth:`validate` — shape / dtype checks. + 2. :meth:`normalize` — pre-process ``pred`` and ``target`` + (e.g. per-atom energy division) and return a + :class:`ReductionContext` for downstream hooks. + 3. :meth:`mask` — produce a boolean validity tensor + (e.g. ``torch.isfinite``, padding masks). + 4. :meth:`compute_residual` — **abstract**; the only method every + leaf *must* implement. Receives ``pred``, ``target``, and the + validity ``mask`` produced by step 3. + 5. :meth:`reduce` — collapse the residual tensor and validity mask + into a scalar loss and populate :attr:`per_sample_loss`. + + Loss authors subclass ``BaseLossFunction`` and override + :meth:`compute_residual` at a minimum. Normalization, masking, and + reduction come free via the defaults, or can be overridden + individually for domain-specific behaviour (e.g. per-atom energy + division in :meth:`normalize`, padding-aware force masking in + :meth:`mask`, graph-balanced force reduction in :meth:`reduce`). + + Leaves are weightless — weighting and scheduling live on + :class:`ComposedLossFunction`. Operator sugar + (``scalar * leaf``, ``leaf + leaf``, ``sum([...])``) produces a + composition; see :class:`ComposedLossFunction` for semantics. + + Attributes + ---------- + requires_eval_grad : bool | None + Whether this loss term requires autograd during evaluation. Losses + based on derived outputs such as forces and stress should set this to + ``True``; direct scalar-output losses should set it to ``False``. + ``None`` means callers cannot infer the policy automatically. + per_sample_loss : torch.Tensor | None + Detached per-graph loss tensor of shape ``(B,)`` left as a side + effect of the most recent :meth:`forward` call, or ``None`` when + the loss does not naturally compute a per-graph view (or when + ``forward`` has never been called). Intended for logging and + diagnostics only — gradients flow through the scalar returned by + :meth:`forward`, not through this attribute. + """ + + requires_eval_grad: bool | None = None + + def __init__(self) -> None: + """Initialize the base loss as a stateless :class:`nn.Module`.""" + super().__init__() + self.per_sample_loss: torch.Tensor | None = None + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + """Template-method pipeline: validate → normalize → mask → residual → reduce. + + Subclasses should **not** override this method. Override the + individual hooks instead. Extra keyword arguments (``batch``, + ``batch_idx``, ``num_nodes_per_graph``, etc.) are forwarded to + every hook via ``**kwargs``. + """ + self.per_sample_loss = None + self.validate(pred, target) + pred, target, ctx = self.normalize(pred, target, **kwargs) + valid = self.mask(pred, target, ctx, **kwargs) + residual = self.compute_residual(pred, target, valid) + return self.reduce(residual, valid, ctx, **kwargs) + + def validate( + self, + pred: torch.Tensor, + target: torch.Tensor, + ) -> None: + """Check shape and dtype compatibility of ``pred`` and ``target``. + + Default implementation calls :func:`assert_same_shape` with + ``strict=True`` when ``prediction_key`` / ``target_key`` + attributes are present on the instance. + """ + assert_same_shape( + pred, + target, + name=type(self).__name__, + prediction_key=getattr(self, "prediction_key", None), + target_key=getattr(self, "target_key", None), + strict=True, + ) + + def normalize( + self, + pred: torch.Tensor, + target: torch.Tensor, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, ReductionContext]: + """Pre-process prediction and target before residual computation. + + Returns a ``(pred, target, ctx)`` triple. The default + implementation is the identity — ``ctx`` is empty. + + Override to inject per-atom energy division, or any other + pre-processing that should be available to all loss authors as a + composable step. + """ + return pred, target, ReductionContext() + + def mask( + self, + pred: torch.Tensor, + target: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Return a boolean validity mask for ``target``. + + The default implementation returns an all-``True`` mask matching + ``target``'s shape. Override to exclude non-finite entries, + padding, or any other invalid positions. + """ + return torch.ones_like(target, dtype=torch.bool) + + @abc.abstractmethod + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return the per-element residual tensor. + + This is the only hook that **must** be overridden. The ``valid`` + mask (from :meth:`mask`) is provided so the leaf can zero out + invalid positions before computing the residual (important for + operations like ``vector_norm`` where masking after the + reduction would be incorrect). + """ + + def reduce( + self, + residual: torch.Tensor, + valid: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Collapse a residual tensor to a scalar loss. + + The default implementation computes a validity-weighted mean: + ``(residual * valid_float).sum() / valid_float.sum()``, where + ``valid_float`` incorporates optional ``ctx["weights"]``. + + Override for domain-specific reductions (graph-balanced force + reduction, RMSD, etc.). Implementations should also populate + :attr:`per_sample_loss` with a detached ``(B,)`` tensor when a + per-graph decomposition is available. + """ + valid_weights = valid.to(dtype=residual.dtype) + weights = ctx.get("weights") + if weights is not None: + valid_weights = valid_weights * weights.expand_as(residual) + scalar = residual.mul(valid_weights).sum() / valid_weights.sum().clamp_min(1.0) + self._populate_per_sample_loss(residual) + return scalar + + def _populate_per_sample_loss(self, residual: torch.Tensor) -> None: + """Set :attr:`per_sample_loss` when the residual has a per-graph shape.""" + if residual.ndim == 1: + self.per_sample_loss = residual.detach() + elif residual.ndim == 2 and residual.shape[-1] == 1: + self.per_sample_loss = residual.squeeze(-1).detach() + + # Arithmetic dunders — return ComposedLossFunction. + def __mul__(self, other: Any) -> ComposedLossFunction: + """Return ``ComposedLossFunction([self], weights=[other])``. + + ``other`` may be a :class:`float`/:class:`int` or a + :class:`LossWeightSchedule`. + """ + match other: + case bool(): + return NotImplemented + case int() | float() | LossWeightSchedule(): + return ComposedLossFunction([self], weights=[other]) + case _: + return NotImplemented + + def __rmul__(self, other: Any) -> ComposedLossFunction: + """Mirror of :meth:`__mul__` for ``scalar * loss``.""" + return self.__mul__(other) + + def __add__(self, other: Any) -> ComposedLossFunction: + """Return ``self + other`` flattening any existing composition. + + Both operands get weight ``1.0`` unless they are themselves + compositions, in which case their existing weights are preserved. + """ + if isinstance(other, ComposedLossFunction): + return ComposedLossFunction( + [self, *other.components], + weights=[1.0, *other._weights], + normalize_weights=other.normalize_weights, + ) + if isinstance(other, BaseLossFunction): + return ComposedLossFunction([self, other], weights=[1.0, 1.0]) + return NotImplemented + + def __radd__(self, other: Any) -> BaseLossFunction | ComposedLossFunction: + """Return ``self`` when seeded with integer ``0`` (for :func:`sum`).""" + if other == 0: + return self + if isinstance(other, (BaseLossFunction, ComposedLossFunction)): + return self.__add__(other) + return NotImplemented + + +def _resolve_weight( + weight: LossWeightSchedule | float, + step: int, + epoch: int | None, + *, + context: str, +) -> float: + """Resolve a single weight (float or schedule) to a finite float. + + Parameters + ---------- + weight + Either a plain scalar or a :class:`LossWeightSchedule`. + step, epoch + Training counters forwarded to the schedule. + context + Caller-supplied name (typically the component's class name) used + in error messages. + + Raises + ------ + ValueError + If a ``per_epoch=True`` schedule is evaluated with + ``epoch is None`` or the schedule returns a non-finite value. + TypeError + If the schedule returns a non-numeric value. + """ + if not isinstance(weight, LossWeightSchedule): + coerced = float(weight) + if not math.isfinite(coerced): + raise ValueError( + f"{context}: weight {weight!r} is not finite; " + "weights must be finite floats." + ) + return coerced + if weight.per_epoch and epoch is None: + raise ValueError( + f"epoch must be provided when the {context} loss weight " + "schedule has per_epoch=True. Pass epoch= to " + "the loss, or set per_epoch=False on the schedule." + ) + try: + value = weight(step, epoch or 0) + except TypeError as exc: + raise TypeError( + f"{type(weight).__name__} does not satisfy the " + "LossWeightSchedule contract: __call__ must accept " + "(step: int, epoch: int) and return a float." + ) from exc + if not isinstance(value, (int, float)): + raise TypeError( + f"{type(weight).__name__} returned {type(value).__name__}; " + "LossWeightSchedule.__call__ must return float." + ) + coerced = float(value) + if not math.isfinite(coerced): + raise ValueError( + f"{type(weight).__name__} for {context} returned non-finite " + f"weight {coerced!r}; schedules must return finite floats." + ) + return coerced + + +def _component_names(components: Sequence[BaseLossFunction]) -> tuple[str, ...]: + """Return class names with suffixes applied to duplicate component types.""" + raw_names = tuple(type(comp).__name__ for comp in components) + counts: dict[str, int] = {} + for name in raw_names: + counts[name] = counts.get(name, 0) + 1 + next_index: dict[str, int] = {} + names: list[str] = [] + for name in raw_names: + if counts[name] > 1: + idx = next_index.get(name, 0) + next_index[name] = idx + 1 + names.append(f"{name}_{idx}") + else: + names.append(name) + return tuple(names) + + +class ComposedLossFunction(nn.Module): + """Weighted sum of :class:`BaseLossFunction` components. + + This class owns the per-component weighting — leaves are weightless. + Weights may be plain floats or :class:`LossWeightSchedule` instances; + they are resolved to floats at call time. By default the resolved + weights are normalized to sum to ``1.0`` so scheduling controls + *relative* contributions while the learning rate controls the + absolute loss magnitude. Opt out with ``normalize_weights=False``. + + Components live in an :class:`torch.nn.ModuleList` for + ``.modules()`` / ``.state_dict()`` / nested-``__repr__`` support. + When a component is itself a :class:`ComposedLossFunction`, its + components and weights are flattened into the parent element-wise so + ``(A + B) + C`` is equivalent to ``A + B + C``. + + Parameters + ---------- + components + Loss terms to combine; must contain at least one element. + weights + Optional per-component weights. When provided, ``weights`` must + have the same length as ``components`` at construction time + (i.e. top-level components — child weights inside nested + compositions are multiplied element-wise by the parent weight + during flattening). A ``None`` entry is shorthand for ``1.0``, + so ``weights=[None, 2.0, None]`` means "component 1 gets 2×, + others default". Passing ``weights=None`` defaults every + component to ``1.0``. + normalize_weights + When ``True`` (default), resolved weights are divided by their + sum at each call so the effective weights sum to ``1.0``. A + zero-sum raises :class:`ValueError`. When ``False``, raw + weighted sums are returned. + + Attributes + ---------- + components + :class:`torch.nn.ModuleList` of the flattened leaf components. + normalize_weights + Whether effective weights are renormalized to sum to ``1.0``. + """ + + def __init__( + self, + components: Sequence[BaseLossFunction | ComposedLossFunction], + *, + weights: Sequence[LossWeightSchedule | float | None] | None = None, + normalize_weights: bool = True, + ) -> None: + """Store flattened components, their weights, and the normalization flag.""" + super().__init__() + components = tuple(components) + if len(components) == 0: + raise ValueError("components must contain at least one loss term") + for i, comp in enumerate(components): + if not isinstance(comp, (BaseLossFunction, ComposedLossFunction)): + raise TypeError( + f"components[{i}] must be a BaseLossFunction or " + f"ComposedLossFunction, got " + f"{type(comp).__name__}" + ) + + if weights is None: + raw_weights: list[LossWeightSchedule | float] = [1.0] * len(components) + else: + raw_weights = [1.0 if w is None else w for w in weights] + if len(raw_weights) != len(components): + raise ValueError( + f"weights has length {len(raw_weights)} but components has " + f"length {len(components)}; lengths must match." + ) + for i, w in enumerate(raw_weights): + match w: + case bool(): + valid = False + case int() | float() | LossWeightSchedule(): + valid = True + case _: + valid = False + if not valid: + raise TypeError( + f"weights[{i}] must be a float or LossWeightSchedule, " + f"got {type(w).__name__}." + ) + + flat_components: list[BaseLossFunction] = [] + flat_weights: list[LossWeightSchedule | float] = [] + for comp, parent_w in zip(components, raw_weights, strict=True): + if isinstance(comp, ComposedLossFunction): + for child_comp, child_w in zip( + comp.components, comp._weights, strict=True + ): + flat_components.append(child_comp) + flat_weights.append(_compose_weights(parent_w, child_w)) + else: + flat_components.append(comp) + flat_weights.append(parent_w) + + self.components: nn.ModuleList = nn.ModuleList(flat_components) + self._weights: list[LossWeightSchedule | float] = flat_weights + self.normalize_weights: bool = normalize_weights + + def _resolve_raw_and_effective( + self, step: int, epoch: int | None + ) -> tuple[tuple[str, ...], list[float], list[float]]: + """Resolve raw and effective weights in a single pass. + + Returns a triple ``(names, raw, effective)`` where ``raw`` holds + the per-component resolved floats (pre-normalization) and + ``effective`` holds the weights that will actually be applied — + identical to ``raw`` when :attr:`normalize_weights` is ``False`` + and ``raw / sum(raw)`` otherwise. When normalization is enabled + the raw weights must sum to a strictly positive float; a sum + that is non-positive (negative, zero, or non-finite from + cancellation) is rejected with :class:`ValueError` because the + resulting normalization either flips every contribution's sign + or blows up. Individual raw weights may themselves be negative + as long as their sum is positive. + """ + names = _component_names(tuple(self.components)) + raw = [ + _resolve_weight(w, step, epoch, context=name) + for w, name in zip(self._weights, names, strict=True) + ] + if not self.normalize_weights: + return names, raw, list(raw) + total = sum(raw) + if not math.isfinite(total) or total <= 0.0: + resolved = dict(zip(names, raw, strict=True)) + raise ValueError( + "ComposedLossFunction: cannot normalize weights whose sum " + f"is not strictly positive (sum={total!r}). Resolved " + f"weights at step={step}, epoch={epoch}: {resolved}. " + "Choose weights whose sum is a finite positive float or " + "set normalize_weights=False." + ) + effective = [w / total for w in raw] + return names, raw, effective + + def current_weight(self, step: int = 0, epoch: int | None = None) -> list[float]: + """Resolve each component's weight to a float for ``(step, epoch)``. + + When :attr:`normalize_weights` is ``True`` the returned list sums + to ``1.0``; otherwise it is the raw resolved weights. With + normalization enabled the raw sum must be a strictly positive + float or :class:`ValueError` is raised. + + Parameters + ---------- + step + Current global training step. + epoch + Current training epoch, or ``None`` when unused. + + Returns + ------- + list[float] + One effective weight per component, in order. + + Raises + ------ + ValueError + If normalization is enabled and the raw weights do not sum + to a strictly positive, finite float. + """ + _, _, effective = self._resolve_raw_and_effective(step, epoch) + return effective + + def weight_factors( + self, step: int = 0, epoch: int | None = None + ) -> dict[str, float]: + """Return a flat ``{component_name: effective_weight}`` dict. + + Duplicate class names get numeric suffixes (``_0``, ``_1``, ...) + applied to *all* colliding entries, not only the duplicates. + """ + names = _component_names(tuple(self.components)) + effective = self.current_weight(step=step, epoch=epoch) + return dict(zip(names, effective, strict=True)) + + def requires_eval_grad(self) -> bool: + """Whether evaluating this loss needs autograd enabled. + + Inspects each leaf component's ``requires_eval_grad`` flag. A + component reporting ``True`` (e.g. a force/stress loss that + differentiates the energy) forces gradient-enabled evaluation; + components reporting ``False`` do not. A component reporting + ``None`` is undeclared and cannot be inferred automatically. + + Returns + ------- + bool + ``True`` when at least one component requires gradients, + ``False`` when every component explicitly declares it does + not. + + Raises + ------ + ValueError + When one or more components report ``requires_eval_grad=None`` + and none require gradients, so the requirement is ambiguous. + """ + unknown: list[str] = [] + for component in self.components: + requires_eval_grad = getattr(component, "requires_eval_grad", None) + if requires_eval_grad is True: + return True + if requires_eval_grad is None: + unknown.append(type(component).__name__) + if unknown: + names = ", ".join(unknown) + raise ValueError( + "Cannot infer whether evaluating this loss requires " + f"gradients for component(s): {names}. Set " + "requires_eval_grad on the component(s), or resolve the " + "policy explicitly (e.g. ValidationConfig grad_mode=" + "'enabled' or 'disabled')." + ) + return False + + def forward( + self, + predictions: Mapping[str, torch.Tensor], + targets: Mapping[str, torch.Tensor], + *, + step: int = 0, + epoch: int | None = None, + **kwargs: Any, + ) -> ComposedLossOutput: + """Return the weighted total loss and per-component diagnostics. + + Each component is called with the routed ``pred`` / ``target`` + tensors, then its raw loss is scaled by the effective weight for + this step. The output's ``per_component_unweighted`` contains + each raw component loss before effective weighting; + ``per_component_weight`` + holds the scalar weights that were applied (after normalization, + if enabled); ``per_component_raw_weight`` holds the + pre-normalization resolved weights so schedule ramps remain + observable on single-component normalized compositions; see + :attr:`BaseLossFunction.per_sample_loss` for the + ``per_component_sample`` contract. + """ + names, raw_weights, effective = self._resolve_raw_and_effective(step, epoch) + + per_component_unweighted: dict[str, torch.Tensor] = {} + per_component_sample: dict[str, torch.Tensor] = {} + per_component_weight: dict[str, float] = dict( + zip(names, effective, strict=True) + ) + per_component_raw_weight: dict[str, float] = dict( + zip(names, raw_weights, strict=True) + ) + total: torch.Tensor | None = None + + for name, comp, weight in zip(names, self.components, effective, strict=True): + prediction_key = getattr(comp, "prediction_key", None) + target_key = getattr(comp, "target_key", None) + if prediction_key is None: + raise AttributeError( + f"{type(comp).__name__} cannot be used in " + "ComposedLossFunction without a prediction_key attribute." + ) + if target_key is None: + raise AttributeError( + f"{type(comp).__name__} cannot be used in " + "ComposedLossFunction without a target_key attribute." + ) + try: + pred = predictions[prediction_key] + except KeyError as exc: + raise KeyError( + f"{type(comp).__name__}: prediction mapping is missing " + f"key {prediction_key!r}" + ) from exc + try: + target = targets[target_key] + except KeyError as exc: + raise KeyError( + f"{type(comp).__name__}: target mapping is missing " + f"key {target_key!r}" + ) from exc + if not isinstance(pred, torch.Tensor): + raise TypeError( + f"{type(comp).__name__}: prediction mapping key " + f"{prediction_key!r} must resolve to torch.Tensor, " + f"got {type(pred).__name__}." + ) + if not isinstance(target, torch.Tensor): + raise TypeError( + f"{type(comp).__name__}: target mapping key " + f"{target_key!r} must resolve to torch.Tensor, " + f"got {type(target).__name__}." + ) + # Guard against stale diagnostics from custom leaves that forget to clear. + comp.per_sample_loss = None + raw = comp(pred, target, **kwargs) + if not isinstance(raw, torch.Tensor): + raise TypeError( + f"{type(comp).__name__} returned " + f"{type(raw).__name__} from forward(); " + "BaseLossFunction subclasses must return a torch.Tensor." + ) + contribution = weight * raw + per_component_unweighted[name] = raw + sample = comp.per_sample_loss + if sample is not None: + if not isinstance(sample, torch.Tensor): + raise TypeError( + f"{type(comp).__name__} (component {name!r}) set " + f"per_sample_loss to {type(sample).__name__}; " + "must be a torch.Tensor or None." + ) + if sample.ndim != 1: + raise ValueError( + f"{type(comp).__name__} (component {name!r}) set " + f"per_sample_loss with shape {tuple(sample.shape)}; " + "must be a 1-D tensor of shape (B,)." + ) + per_component_sample[name] = (weight * sample).detach() + total = contribution if total is None else total + contribution + + if total is None: + raise RuntimeError("ComposedLossFunction has no components.") + + return cast( + ComposedLossOutput, + { + "total_loss": total, + "per_component_unweighted": per_component_unweighted, + "per_component_weight": per_component_weight, + "per_component_raw_weight": per_component_raw_weight, + "per_component_sample": per_component_sample, + }, + ) + + def __mul__(self, other: Any) -> ComposedLossFunction: + """Scale every component weight by a float ``other``. + + Only float/int scalars are accepted. Schedules are rejected with + :class:`TypeError`: compose schedules onto the individual + components before combining, or multiply the composition by a + plain float. + """ + if isinstance(other, bool) or not isinstance(other, (int, float)): + if isinstance(other, LossWeightSchedule): + raise TypeError( + "Multiplying a ComposedLossFunction by a " + "LossWeightSchedule is not supported. Scale each " + "component individually (e.g. schedule * EnergyMSELoss()) " + "and compose the results, or multiply by a float." + ) + return NotImplemented + scale = float(other) + scaled_weights = [_compose_weights(scale, w) for w in self._weights] + return ComposedLossFunction( + list(self.components), + weights=scaled_weights, + normalize_weights=self.normalize_weights, + ) + + def __rmul__(self, other: Any) -> ComposedLossFunction: + """Mirror of :meth:`__mul__` for ``scalar * composition``.""" + return self.__mul__(other) + + def __add__(self, other: Any) -> ComposedLossFunction: + """Return ``self + other`` flattening any existing composition. + + The result inherits :attr:`normalize_weights` from ``self``. + Adding two compositions with mismatched ``normalize_weights`` + raises :class:`ValueError` — combine them explicitly via + :class:`ComposedLossFunction` to pick the intended flag. + """ + if isinstance(other, ComposedLossFunction): + if self.normalize_weights != other.normalize_weights: + raise ValueError( + "Cannot add ComposedLossFunctions with mismatched " + f"normalize_weights (self={self.normalize_weights}, " + f"other={other.normalize_weights}). Construct the " + "combined composition explicitly via " + "ComposedLossFunction(..., normalize_weights=...)." + ) + return ComposedLossFunction( + [*self.components, *other.components], + weights=[*self._weights, *other._weights], + normalize_weights=self.normalize_weights, + ) + if isinstance(other, BaseLossFunction): + return ComposedLossFunction( + [*self.components, other], + weights=[*self._weights, 1.0], + normalize_weights=self.normalize_weights, + ) + return NotImplemented + + def __radd__(self, other: Any) -> ComposedLossFunction: + """Return ``self`` when seeded with integer ``0`` (for :func:`sum`).""" + if other == 0: + return self + if isinstance(other, BaseLossFunction): + return ComposedLossFunction( + [other, *self.components], + weights=[1.0, *self._weights], + normalize_weights=self.normalize_weights, + ) + return NotImplemented + + def extra_repr(self) -> str: + """Expose component count and normalization alongside the default repr.""" + return ( + f"num_components={len(self.components)}, " + f"normalize_weights={self.normalize_weights}" + ) + + +def as_composed_loss( + loss_fn: BaseLossFunction | ComposedLossFunction, +) -> ComposedLossFunction: + """Return ``loss_fn`` as a :class:`ComposedLossFunction`. + + Parameters + ---------- + loss_fn : BaseLossFunction | ComposedLossFunction + Leaf or composed loss to normalize. + + Returns + ------- + ComposedLossFunction + The original composed loss or a one-component composition. + + Raises + ------ + TypeError + If ``loss_fn`` is not an ALCHEMI loss function. + """ + if isinstance(loss_fn, ComposedLossFunction): + return loss_fn + if isinstance(loss_fn, BaseLossFunction): + return ComposedLossFunction([loss_fn]) + raise TypeError( + "loss_fn must be a BaseLossFunction or ComposedLossFunction; " + f"got {type(loss_fn).__name__}." + ) + + +def loss_target_keys(loss_fn: ComposedLossFunction) -> tuple[str, ...]: + """Return unique target keys required by ``loss_fn`` in component order. + + Parameters + ---------- + loss_fn : ComposedLossFunction + Loss whose components declare ``target_key`` attributes. + + Returns + ------- + tuple[str, ...] + Unique target keys to read from a batch. + """ + seen_keys: set[str] = set() + target_keys: list[str] = [] + for component in loss_fn.components: + key = getattr(component, "target_key", None) + if key is None or key in seen_keys: + continue + seen_keys.add(key) + target_keys.append(key) + return tuple(target_keys) + + +def assemble_loss_targets( + loss_fn: ComposedLossFunction, + batch: Any, + *, + target_keys: Sequence[str] | None = None, + batch_label: str = "Batch", +) -> dict[str, torch.Tensor]: + """Collect target tensors required by ``loss_fn`` from ``batch``. + + Parameters + ---------- + loss_fn : ComposedLossFunction + Loss whose component ``target_key`` attributes define required targets. + batch : Any + Batch-like object exposing target tensors as attributes. + target_keys : Sequence[str] | None, optional + Precomputed target keys. Defaults to :func:`loss_target_keys`. + batch_label : str, default "Batch" + Human-readable batch label used in missing-target errors. + + Returns + ------- + dict[str, torch.Tensor] + Mapping from target key to target tensor. + + Raises + ------ + AttributeError + If a required target is absent from ``batch``. + """ + component_by_key = { + key: type(component).__name__ + for component in loss_fn.components + if (key := getattr(component, "target_key", None)) is not None + } + targets: dict[str, torch.Tensor] = {} + for key in target_keys if target_keys is not None else loss_target_keys(loss_fn): + try: + targets[key] = getattr(batch, key) + except AttributeError as exc: + component_name = component_by_key.get(key, type(loss_fn).__name__) + raise AttributeError( + f"{batch_label} is missing target attribute {key!r} " + f"required by {component_name}." + ) from exc + return targets + + +def compute_supervised_loss( + loss_fn: ComposedLossFunction, + predictions: Mapping[str, torch.Tensor], + batch: Any, + *, + step: int, + epoch: int, + target_keys: Sequence[str] | None = None, + batch_label: str = "Batch", +) -> ComposedLossOutput: + """Run ``loss_fn`` with targets and graph metadata from ``batch``. + + Parameters + ---------- + loss_fn : ComposedLossFunction + Supervised loss to evaluate. + predictions : Mapping[str, torch.Tensor] + Model predictions keyed by component ``prediction_key`` values. + batch : Any + Batch-like object exposing targets and optional graph metadata. + step : int + Current global optimizer step. + epoch : int + Current training epoch. + target_keys : Sequence[str] | None, optional + Precomputed target keys to avoid repeated component scans. + batch_label : str, default "Batch" + Human-readable batch label used in missing-target errors. + + Returns + ------- + ComposedLossOutput + Total and per-component loss diagnostics. + """ + graph_meta: dict[str, Any] = {} + for attr in ("batch_idx", "num_graphs", "num_nodes_per_graph"): + value = getattr(batch, attr, None) + if value is not None: + graph_meta[attr] = value + return loss_fn( + predictions, + assemble_loss_targets( + loss_fn, + batch, + target_keys=target_keys, + batch_label=batch_label, + ), + step=step, + epoch=epoch, + **graph_meta, + ) + + +def _compose_weights( + outer: LossWeightSchedule | float, + inner: LossWeightSchedule | float, +) -> LossWeightSchedule | float: + """Return ``outer * inner`` as a weight, keeping floats where possible. + + If either operand is a schedule, the result is a + :class:`_ProductWeight` that resolves ``outer(step, epoch) * + inner(step, epoch)`` lazily. Pure float × float collapses to a float. + """ + outer_is_schedule = isinstance(outer, LossWeightSchedule) + inner_is_schedule = isinstance(inner, LossWeightSchedule) + if not outer_is_schedule and not inner_is_schedule: + return float(outer) * float(inner) + return _ProductWeight(outer, inner) + + +@dataclass(frozen=True) +class _ProductWeight: + """Lazy product of two weights — either operand may be a schedule or a float. + + Needed for nested composition flattening: when a parent composition + has a non-unity weight and a child's weight is a + :class:`LossWeightSchedule`, the product cannot be resolved at + construction time because the schedule is a callable of + ``(step, epoch)``. :class:`_ProductWeight` captures both operands + and evaluates the product at call time while structurally + satisfying the :class:`LossWeightSchedule` protocol (``per_epoch`` + attribute + ``__call__``). + """ + + left: LossWeightSchedule | float + right: LossWeightSchedule | float + per_epoch: bool = field(init=False) + + def __post_init__(self) -> None: + """Derive ``per_epoch`` from the two operands.""" + combined = bool( + getattr(self.left, "per_epoch", False) + or getattr(self.right, "per_epoch", False) + ) + # Frozen dataclass → must go through object.__setattr__. + object.__setattr__(self, "per_epoch", combined) + + def __call__(self, step: int, epoch: int) -> float: + """Return ``left(step, epoch) * right(step, epoch)``.""" + left = self.left(step, epoch) if callable(self.left) else float(self.left) + right = self.right(step, epoch) if callable(self.right) else float(self.right) + return float(left) * float(right) diff --git a/nvalchemi/training/losses/reductions.py b/nvalchemi/training/losses/reductions.py new file mode 100644 index 00000000..0ca6f5f8 --- /dev/null +++ b/nvalchemi/training/losses/reductions.py @@ -0,0 +1,260 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Graph-aware reduction primitives for loss functions. + +Scatter reductions (``V ... → B ...``) +-------------------------------------- + +:func:`per_graph_sum` and :func:`per_graph_mean` take a flat per-node +tensor with a ``batch_idx`` mapping each node to its graph and reduce +the leading node dim into a per-graph output, preserving trailing dims +verbatim. + +These helpers only produce per-graph tensors. They do not choose the +final scalar weighting across graphs. For per-graph values :math:`x_i`, +a graph-balanced scalar is :math:`B^{-1} \sum_i x_i`, while an +atom-weighted scalar is :math:`(\sum_i N_i x_i) / (\sum_i N_i)`. + +Matrix reductions (``B ... m n → B ...``) +----------------------------------------- + +:func:`frobenius_mse` is *not* a scatter reduction: it operates on an +already-per-graph tensor and averages the squared residual over the +trailing two matrix dims. It takes neither ``batch_idx`` nor +``num_graphs``. + +Common parameters for scatter reductions +---------------------------------------- + +- ``values``: per-node tensor whose leading dim indexes nodes; trailing + dims are preserved. +- ``batch_idx``: 1-D ``BatchIndices`` mapping each node to its graph. +- ``num_graphs`` (optional): when supplied, trusted without scanning + ``batch_idx`` — the recommended hot-path convention (avoids a GPU→CPU + sync). When omitted, inferred as ``batch_idx.max().item() + 1``. + Empty ``batch_idx`` always requires ``num_graphs``. + +Scatter reductions raise :class:`ValueError` on shape mismatch, on +non-positive ``num_graphs``, or on inability to infer ``num_graphs``. + +Migrating from ``per_graph_mse`` +-------------------------------- + +The former ``per_graph_mse`` helper has been removed. The direct +replacement composes :func:`per_graph_mean` with a pointwise squared +error:: + + per_graph_mean((pred - target).pow(2), batch_idx, num_graphs) + +Hot-path callers that want a scalar-per-graph MSE should reduce +trailing dims before scattering:: + + per_graph_mean( + (pred - target).pow(2).mean(dim=tuple(range(1, pred.ndim))), + batch_idx, + num_graphs, + ) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeAlias + +import torch + +from nvalchemi._typing import BatchIndices + +if TYPE_CHECKING: + from jaxtyping import Float, Num + +_NumGraphs: TypeAlias = int | torch.Tensor + + +def _resolve_batch_indices( + batch_idx: BatchIndices, + num_graphs: int | None, + device: torch.device, +) -> tuple[BatchIndices, _NumGraphs]: + """Return ``batch_idx`` on ``device`` and the graph count it implies. + + When ``num_graphs`` is supplied, trust it and skip any scan of + ``batch_idx`` — this is the hot path. When omitted, infer via + ``batch_idx.max()``; under ``torch.compile`` the max stays a tensor, + otherwise it is materialized on the host (forcing a device sync). + """ + batch_idx = batch_idx.to(device=device, dtype=torch.long) + if batch_idx.ndim != 1: + raise ValueError(f"batch_idx must be 1D, got shape {tuple(batch_idx.shape)}") + if num_graphs is not None: + if num_graphs <= 0: + raise ValueError(f"num_graphs must be positive, got {num_graphs}") + return batch_idx, num_graphs + if batch_idx.numel() == 0: + raise ValueError( + "Cannot infer num_graphs from empty batch_idx; " + "pass num_graphs explicitly when reducing an empty batch." + ) + if torch.compiler.is_compiling(): + return batch_idx, batch_idx.max() + 1 + return batch_idx, int(batch_idx.max().item()) + 1 + + +def _check_leading_dim( + values: Float[torch.Tensor, "V ..."], # noqa: F722 + batch_idx: BatchIndices, + *, + name: str, +) -> None: + """Validate that ``values`` and ``batch_idx`` have matching leading dims.""" + if values.shape[0] != batch_idx.shape[0]: + raise ValueError( + f"{name} leading dim ({values.shape[0]}) must match " + f"batch_idx length ({batch_idx.shape[0]})" + ) + + +def _prep_reduction( + values: Float[torch.Tensor, "V ..."], # noqa: F722 + batch_idx: BatchIndices, + num_graphs: int | None, + *, + name: str, +) -> tuple[BatchIndices, _NumGraphs]: + """Validate leading dim and resolve ``(batch_idx, num_graphs)`` on values' device.""" + _check_leading_dim(values, batch_idx, name=name) + return _resolve_batch_indices(batch_idx, num_graphs, values.device) + + +def _per_graph_sum_resolved( + values: Float[torch.Tensor, "V ..."], # noqa: F722 + batch_idx: BatchIndices, + num_graphs: _NumGraphs, +) -> Float[torch.Tensor, "B ..."]: # noqa: F722 + """Sum per-node values after ``batch_idx`` and ``num_graphs`` are resolved.""" + out_shape = (num_graphs, *values.shape[1:]) + out = torch.zeros(out_shape, dtype=values.dtype, device=values.device) + idx_shape = [1] * (values.ndim - 1) + index = batch_idx.view(-1, *idx_shape).expand_as(values) + # TODO: refactor to use warp kernels when backwards ready + out.scatter_add_(0, index, values) + return out + + +def _num_nodes_per_graph( + batch_idx: BatchIndices, + num_graphs: _NumGraphs, + *, + dtype: torch.dtype, + device: torch.device, +) -> Num[torch.Tensor, "B"]: # noqa: F722 + """Count nodes per graph via :func:`torch.bincount` (single kernel, no scratch).""" + minlength = int(num_graphs) if isinstance(num_graphs, int) else num_graphs + counts = torch.bincount(batch_idx, minlength=minlength) + return counts.to(device=device, dtype=dtype) + + +def per_graph_sum( + values: Float[torch.Tensor, "V ..."], # noqa: F722 + batch_idx: BatchIndices, + num_graphs: int | None = None, +) -> Float[torch.Tensor, "B ..."]: # noqa: F722 + """Sum per-node values into per-graph values via ``scatter_add_``. + + See the module docstring for ``batch_idx`` / ``num_graphs`` + semantics and error conditions. + + Returns + ------- + Float[torch.Tensor, "B ..."] + Per-graph sums of shape ``(num_graphs, *values.shape[1:])``. + """ + batch_idx, resolved = _prep_reduction(values, batch_idx, num_graphs, name="values") + return _per_graph_sum_resolved(values, batch_idx, resolved) + + +def per_graph_mean( + values: Float[torch.Tensor, "V ..."], # noqa: F722 + batch_idx: BatchIndices, + num_graphs: int | None = None, +) -> Float[torch.Tensor, "B ..."]: # noqa: F722 + r"""Mean of per-node values across each graph. + + This divides each graph's sum by that graph's node count and returns + one value per graph. It does not choose how those graph values are + weighted in a later scalar reduction. For per-graph values + :math:`x_i`, the graph-balanced scalar is :math:`B^{-1} \sum_i x_i`; + the atom-weighted scalar is + :math:`(\sum_i N_i x_i) / (\sum_i N_i)`. + + Empty graphs (zero nodes) are safe: their sum is zero and their + count is clamped to ``1`` before the division, so they yield zero. + See the module docstring for shared parameter / error semantics. + + Returns + ------- + Float[torch.Tensor, "B ..."] + Per-graph means. + """ + batch_idx, resolved = _prep_reduction(values, batch_idx, num_graphs, name="values") + totals = _per_graph_sum_resolved(values, batch_idx, resolved) + counts = _num_nodes_per_graph( + batch_idx, + resolved, + dtype=totals.dtype, + device=totals.device, + ).clamp_min(1.0) + # Broadcast counts across trailing dims of totals. + count_shape = [1] * (totals.ndim - 1) + counts = counts.view(-1, *count_shape) + return totals / counts + + +def frobenius_mse( + pred: Float[torch.Tensor, "B 3 3"], # noqa: F722 + target: Float[torch.Tensor, "B 3 3"], # noqa: F722 +) -> Float[torch.Tensor, "B"]: # noqa: F722 + """Per-graph Frobenius MSE over the trailing two matrix dims. + + Returns ``((pred - target) ** 2).mean(dim=(-2, -1))`` — the squared + Frobenius norm of the residual matrix, averaged over its entries. + Canonical use is on stress tensors of shape ``(B, 3, 3)``. + + Parameters + ---------- + pred, target + Same-shape per-graph matrix tensors. + + Returns + ------- + Float[torch.Tensor, "B"] + Per-graph Frobenius MSE. + + Raises + ------ + ValueError + If shapes differ or input is not at least a batched matrix + tensor (``ndim >= 3``). + """ + if pred.shape != target.shape: + raise ValueError( + f"pred shape {tuple(pred.shape)} must equal target shape " + f"{tuple(target.shape)}" + ) + if pred.ndim < 3: + raise ValueError( + f"frobenius_mse expects at least 3 dims (B, ..., M1, M2); " + f"got shape {tuple(pred.shape)}" + ) + return (pred - target).pow(2).mean(dim=(-2, -1)) diff --git a/nvalchemi/training/losses/schedules.py b/nvalchemi/training/losses/schedules.py new file mode 100644 index 00000000..73de73ff --- /dev/null +++ b/nvalchemi/training/losses/schedules.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Concrete weight schedules for loss functions. + +Four Pydantic-validated schedules are provided: :class:`ConstantWeight`, +:class:`LinearWeight`, :class:`CosineWeight`, and :class:`PiecewiseWeight`. +Each satisfies the runtime-checkable +:class:`~nvalchemi.training.losses.base.LossWeightSchedule` protocol and +can be supplied inside :class:`ComposedLossFunction`'s ``weights`` +sequence or on the left of ``schedule * leaf``. + +The concrete schedules always receive both the global step and epoch. +When ``per_epoch=False`` (the default), schedule windows and boundaries +advance by global step. When ``per_epoch=True``, they advance by epoch, +which lets loss weights follow optimizers or learning-rate schedulers +that update once per epoch. + +Serialization note +------------------ + +Schedules live in :class:`ComposedLossFunction`'s ``weights`` argument +rather than on leaves, and are reconstructed by the upstream +``TrainingStrategy`` from their ``(instance, spec)`` pair — the same +pattern used for models and optimizers (see +:mod:`nvalchemi.training._checkpoint`). A concrete schedule class still +round-trips standalone via :func:`~nvalchemi.training.create_model_spec`. + +Adding a new schedule +--------------------- + +You can write any callable ``(step: int, epoch: int) -> float`` with a +``per_epoch`` attribute and it will satisfy the +:class:`~nvalchemi.training.losses.base.LossWeightSchedule` protocol. + +Alternatively, subclass +:class:`~nvalchemi.training.losses.base._BaseWeightSchedule`: + +1. Inherit to pick up ``per_epoch`` and the frozen Pydantic config. +2. Implement ``__call__(step: int, epoch: int) -> float``; use + ``self._map_schedule_index(step, epoch)`` for schedules that advance + over a single training counter. +""" + +from __future__ import annotations + +import bisect +import math +from typing import Annotated, TypeAlias + +from pydantic import Field, model_validator + +from nvalchemi.training.losses.base import _BaseWeightSchedule + +_PositiveSteps: TypeAlias = Annotated[ + int, + Field( + gt=0, + description="Positive length of the schedule window in steps or epochs.", + ), +] + + +class ConstantWeight(_BaseWeightSchedule): + """Schedule that returns the same value for every update index.""" + + value: Annotated[float, Field(description="Constant weight value.")] + + def __call__(self, step: int, epoch: int) -> float: # noqa: ARG002 + """Return :attr:`value`, ignoring ``step`` and ``epoch``.""" + return float(self.value) + + +class _RampSchedule(_BaseWeightSchedule): + """Shared base for linear / cosine ramps from ``start`` to ``end``. + + Subclasses only differ in the curve applied to the clamped fraction + ``t in [0, 1]``. The index is the global step when ``per_epoch=False`` + and the epoch when ``per_epoch=True``. + """ + + start: Annotated[float, Field(description="Weight at schedule index 0.")] + end: Annotated[float, Field(description="Weight at schedule index `num_steps`.")] + num_steps: _PositiveSteps + + def _ramp_fraction(self, step: int, epoch: int) -> float | None: + """Return the clamped fraction ``t in [0, 1]`` or ``None`` outside the window. + + ``None`` means the caller should return the boundary value + (``start`` for ``idx <= 0``; ``end`` for ``idx >= num_steps``). + Otherwise the return is the raw linear fraction; subclasses apply + their own curve to it. + """ + idx = self._map_schedule_index(step, epoch) + if idx <= 0 or idx >= self.num_steps: + return None + return idx / self.num_steps + + +class LinearWeight(_RampSchedule): + """Linear ramp from ``start`` at index 0 to ``end`` at ``num_steps``. + + The schedule index is the global step when ``per_epoch=False`` and + the epoch when ``per_epoch=True``. Values are clamped to ``start`` + for index ``<= 0`` and to ``end`` for index ``>= num_steps``. + """ + + def __call__(self, step: int, epoch: int) -> float: + """Linear ramp from ``start`` to ``end``, clamped at both ends.""" + frac = self._ramp_fraction(step, epoch) + if frac is None: + return float( + self.start if self._map_schedule_index(step, epoch) <= 0 else self.end + ) + return float(self.start + (self.end - self.start) * frac) + + +class CosineWeight(_RampSchedule): + """Half-cosine interpolation from ``start`` to ``end`` over ``num_steps``. + + The schedule index is the global step when ``per_epoch=False`` and + the epoch when ``per_epoch=True``. At index ``0`` the value is + ``start``; at index ``num_steps`` it is ``end``; outside that window + the value is clamped. + """ + + def __call__(self, step: int, epoch: int) -> float: + """Half-cosine interpolation, clamped at both ends.""" + frac = self._ramp_fraction(step, epoch) + if frac is None: + return float( + self.start if self._map_schedule_index(step, epoch) <= 0 else self.end + ) + # Half-cosine: cos(0)=1 at index=0 -> start; cos(pi)=-1 at num_steps -> end. + curve = 0.5 * (1.0 - math.cos(math.pi * frac)) + return float(self.start + (self.end - self.start) * curve) + + +class PiecewiseWeight(_BaseWeightSchedule): + """Piecewise-constant schedule over strictly increasing boundaries. + + For ``boundaries = (b_0, ..., b_{k-1})`` and ``values = (v_0, ..., v_k)``, + returns ``v_0`` for schedule index ``< b_0``, ``v_1`` for + ``b_0 <= index < b_1``, and so on. The schedule index is the global + step when ``per_epoch=False`` and the epoch when ``per_epoch=True``. + Tuples (rather than lists) keep instances hashable under the frozen + model config. + """ + + boundaries: Annotated[ + tuple[int, ...], + Field( + description=( + "Strictly increasing, non-negative schedule-index boundaries." + ), + ), + ] + values: Annotated[ + tuple[float, ...], + Field(description="Values for each interval; length len(boundaries) + 1."), + ] + + @model_validator(mode="after") + def _check_boundaries_and_values(self) -> PiecewiseWeight: + """Enforce strictly-increasing non-negative boundaries and correct length.""" + if len(self.values) != len(self.boundaries) + 1: + raise ValueError( + f"values must have length len(boundaries) + 1; got " + f"len(values)={len(self.values)}, " + f"len(boundaries)={len(self.boundaries)}" + ) + prev = -1 + for b in self.boundaries: + if b < 0: + raise ValueError( + f"boundaries must be non-negative; got {self.boundaries}" + ) + if b <= prev: + raise ValueError( + f"boundaries must be strictly increasing; got {self.boundaries}" + ) + prev = b + return self + + def __call__(self, step: int, epoch: int) -> float: + """Return the value of the interval containing the schedule index. + + ``bisect_right`` gives the count of boundaries that the index has + reached or passed, which is the index into :attr:`values`. + """ + idx = bisect.bisect_right( + self.boundaries, self._map_schedule_index(step, epoch) + ) + return float(self.values[idx]) diff --git a/nvalchemi/training/losses/terms.py b/nvalchemi/training/losses/terms.py new file mode 100644 index 00000000..d362c130 --- /dev/null +++ b/nvalchemi/training/losses/terms.py @@ -0,0 +1,1014 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Concrete loss terms for energy, forces, and stress. + +All three accept prediction and target tensors directly. The configurable +``target_key`` / ``prediction_key`` names are used by +:class:`~nvalchemi.training.losses.composition.ComposedLossFunction` +when routing keyed prediction/target mappings into these tensor-first +loss terms. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeAlias + +import torch +from jaxtyping import Bool, Float, Integer +from plum import dispatch, overload + +from nvalchemi._typing import BatchIndices, Energy, Forces +from nvalchemi.training.losses.composition import ( + BaseLossFunction, + ReductionContext, +) +from nvalchemi.training.losses.reductions import per_graph_sum + +if TYPE_CHECKING: + from nvalchemi.data.batch import Batch + +_NodeCounts: TypeAlias = Integer[torch.Tensor, "B"] +_PaddedNodeMask: TypeAlias = Bool[torch.Tensor, "B V_max"] +_PaddedForces: TypeAlias = Float[torch.Tensor, "B V_max 3"] +_ForceTensor: TypeAlias = Forces | _PaddedForces +_DenseForceMask: TypeAlias = Bool[torch.Tensor, "V 3"] +_PaddedForceMask: TypeAlias = Bool[torch.Tensor, "B V_max 3"] +_PerGraphValues: TypeAlias = Float[torch.Tensor, "B"] + + +def _require_metadata(value: Any, name: str, *, loss_name: str) -> Any: + """Return required loss metadata or raise a focused error.""" + if value is None: + raise ValueError(f"{loss_name} requires {name}=... metadata.") + return value + + +def _node_counts( + num_nodes_per_graph: _NodeCounts | _PaddedNodeMask | None, + ref: Energy, +) -> Float[torch.Tensor, "B"]: + """Return per-graph node counts from counts or a padded node mask.""" + nodes = _require_metadata( + num_nodes_per_graph, + "num_nodes_per_graph", + loss_name="per-atom energy loss", + ).to(ref) + if nodes.ndim not in (1, 2): + raise ValueError( + "num_nodes_per_graph must be a 1-D count tensor or a 2-D padded node mask." + ) + if nodes.shape[0] != ref.shape[0]: + raise ValueError( + "num_nodes_per_graph leading dimension " + f"({nodes.shape[0]}) must match energy batch size ({ref.shape[0]})." + ) + if nodes.ndim == 1: + return nodes.clamp_min(1) + return nodes.sum(dim=-1).clamp_min(1) + + +def _padded_node_mask( + num_nodes_per_graph: _NodeCounts | _PaddedNodeMask | None, + ref: _PaddedForces, + max_nodes: int, +) -> _PaddedNodeMask: + """Return a padded node-validity mask for padded force tensors.""" + nodes = _require_metadata( + num_nodes_per_graph, "num_nodes_per_graph", loss_name="padded force loss" + ) + if nodes.ndim == 2: + mask = nodes.to(device=ref.device, dtype=torch.bool) + if mask.shape[0] != ref.shape[0]: + raise ValueError( + f"padded node mask batch dimension ({mask.shape[0]}) " + f"must match force batch size ({ref.shape[0]})." + ) + if mask.shape[1] != max_nodes: + raise ValueError( + f"padded node mask width ({mask.shape[1]}) must match " + f"force max nodes ({max_nodes}) for padded force tensors." + ) + return mask + if nodes.ndim != 1: + raise ValueError( + "num_nodes_per_graph must be a 1-D count tensor or a 2-D padded node mask." + ) + if nodes.shape[0] != ref.shape[0]: + raise ValueError( + f"num_nodes_per_graph length ({nodes.shape[0]}) " + f"must match force batch size ({ref.shape[0]})." + ) + counts = nodes.to(device=ref.device) + return torch.arange(max_nodes, device=ref.device).unsqueeze(0) < counts.unsqueeze( + -1 + ) + + +def _huber_loss(residual: torch.Tensor, delta: float) -> torch.Tensor: + """Return elementwise Huber loss for a residual tensor. + + Parameters + ---------- + residual : torch.Tensor + Prediction-minus-target residual. + delta : float + Positive transition point between quadratic and linear regimes. + + Returns + ------- + torch.Tensor + Elementwise Huber loss with the same shape as ``residual``. + """ + abs_residual = residual.abs() + return torch.where( + abs_residual < delta, + 0.5 * abs_residual.pow(2), + delta * (abs_residual - 0.5 * delta), + ) + + +class EnergyMSELoss(BaseLossFunction): + r"""Mean-squared-error loss on per-graph total energy. + + Energies enter this loss as one total-energy value per graph, with + canonical shape ``(B, 1)``. With ``per_atom=False`` the scalar is the + graph-balanced MSE of total-energy residuals, so every graph has equal + weight regardless of size. + + With ``per_atom=True``, prediction and target are first divided by + each graph's atom count. The squared residual is therefore measured in + energy-per-atom units, then reduced with atom-count weights: + + .. math:: + + L = \frac{\sum_i N_i + \left(\frac{E_i^\mathrm{pred} - E_i^\mathrm{target}}{N_i}\right)^2} + {\sum_i N_i}. + + This makes contributions proportional to graph size while + keeping the error quantity in per-atom energy units. Counts may be + supplied directly as ``(B,)`` or recovered from a padded node mask of + shape ``(B, V_max)``. + + Tensor Contract + --------------- + pred, target : Energy + Per-graph energy tensors of shape ``(B, 1)``. Shape validation + requires exact equality; ``(B, 1)`` and ``(B,)`` are rejected + even though they are broadcast-compatible. + num_nodes_per_graph : Integer[torch.Tensor, "B"] | Bool[torch.Tensor, "B V_max"], optional + Required only when ``per_atom=True``. May be explicit per-graph + counts or a padded node-validity mask. + + Parameters + ---------- + target_key : str, default "energy" + Target container key for the target tensor. + prediction_key : str, default "predicted_energy" + Prediction container key for the model output. + per_atom : bool, default False + Measure residuals in energy-per-atom units and reduce them with + atom-count weights: larger graphs contribute in proportion to + their atom counts. + ignore_nonfinite : bool, default False + When ``True``, target entries that are ``NaN`` or infinite are + excluded from both loss value and gradient using + :func:`torch.isfinite`. Intended for inputs where some samples + lack an energy label. Implemented with branch-free tensor ops + for ``torch.compile`` compatibility. When ``per_atom=True``, + atom-count weights for invalid targets are also excluded from + the denominator. When every target entry is non-finite the loss + is ``0.0``. + """ + + requires_eval_grad: bool = False + + def __init__( + self, + *, + target_key: str = "energy", + prediction_key: str = "predicted_energy", + per_atom: bool = False, + ignore_nonfinite: bool = False, + ) -> None: + """Configure attribute keys and energy reduction semantics.""" + super().__init__() + self.target_key = target_key + self.prediction_key = prediction_key + self.per_atom = per_atom + self.ignore_nonfinite = ignore_nonfinite + + def normalize( + self, + pred: torch.Tensor, + target: torch.Tensor, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, ReductionContext]: + """Divide by atom counts when ``per_atom=True``.""" + ctx = ReductionContext() + if not self.per_atom: + return pred, target, ctx + batch: Batch | None = kwargs.get("batch") + num_nodes_per_graph = kwargs.get("num_nodes_per_graph") + if batch is not None and num_nodes_per_graph is None: + num_nodes_per_graph = getattr(batch, "num_nodes_per_graph", None) + counts = _node_counts(num_nodes_per_graph, pred).unsqueeze(-1) + ctx["weights"] = counts + return pred / counts, target / counts, ctx + + def mask( + self, + pred: torch.Tensor, + target: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Exclude non-finite target entries when ``ignore_nonfinite=True``.""" + if self.ignore_nonfinite: + return torch.isfinite(target) + return torch.ones_like(target, dtype=torch.bool) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return squared residuals, zeroing invalid entries.""" + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return residual.pow(2) + + def extra_repr(self) -> str: + """Human-readable hyperparameter summary for :class:`nn.Module`'s repr.""" + return ( + f"target_key={self.target_key!r}, " + f"prediction_key={self.prediction_key!r}, " + f"per_atom={self.per_atom!r}, " + f"ignore_nonfinite={self.ignore_nonfinite!r}" + ) + + +class EnergyMAELoss(BaseLossFunction): + r"""Mean-absolute-error loss for per-graph energy targets. + + This loss operates on per-graph total energies with identical + prediction and target shapes, commonly ``(B, 1)`` or ``(B,)``. With + ``per_atom=True`` (default), prediction and target energies are first + divided by each graph's atom count, then absolute residuals are + reduced with atom-count weights so that larger graphs contribute + in proportion to their size: + + .. math:: + + L = \frac{\sum_i N_i + \left|\frac{E_i^\mathrm{pred} - E_i^\mathrm{target}}{N_i}\right|} + {\sum_i N_i}. + + Parameters + ---------- + target_key : str, default "energy" + Target container key for the target tensor. + prediction_key : str, default "predicted_energy" + Prediction container key for the model output. + per_atom : bool, default True + Divide prediction and target by ``num_nodes_per_graph`` before + computing absolute residuals. + ignore_nonfinite : bool, default True + When ``True``, target entries that are ``NaN`` or infinite are + excluded from both loss value and gradient using + :func:`torch.isfinite`. + """ + + def __init__( + self, + *, + target_key: str = "energy", + prediction_key: str = "predicted_energy", + per_atom: bool = True, + ignore_nonfinite: bool = True, + ) -> None: + """Configure attribute keys and energy MAE semantics.""" + super().__init__() + self.target_key = target_key + self.prediction_key = prediction_key + self.per_atom = per_atom + self.ignore_nonfinite = ignore_nonfinite + + def normalize( + self, + pred: torch.Tensor, + target: torch.Tensor, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, ReductionContext]: + """Divide by atom counts when ``per_atom=True``.""" + ctx = ReductionContext() + if not self.per_atom: + return pred, target, ctx + batch: Batch | None = kwargs.get("batch") + num_nodes_per_graph = kwargs.get("num_nodes_per_graph") + if batch is not None and num_nodes_per_graph is None: + num_nodes_per_graph = getattr(batch, "num_nodes_per_graph", None) + counts = _node_counts(num_nodes_per_graph, pred).reshape( + (-1,) + (1,) * (pred.ndim - 1) + ) + ctx["weights"] = counts + return pred / counts, target / counts, ctx + + def mask( + self, + pred: torch.Tensor, + target: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Exclude non-finite target entries when ``ignore_nonfinite=True``.""" + if self.ignore_nonfinite: + return torch.isfinite(target) + return torch.ones_like(target, dtype=torch.bool) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return absolute residuals, zeroing invalid entries.""" + return torch.where(valid, pred - target, torch.zeros_like(pred)).abs() + + def extra_repr(self) -> str: + """Human-readable hyperparameter summary for :class:`nn.Module`'s repr.""" + return ( + f"target_key={self.target_key!r}, " + f"prediction_key={self.prediction_key!r}, " + f"per_atom={self.per_atom!r}, " + f"ignore_nonfinite={self.ignore_nonfinite!r}" + ) + + +class EnergyHuberLoss(BaseLossFunction): + """Huber loss on total energy or energy per atom. + + With ``per_atom=True``, energies are divided by each graph's atom + count before the Huber loss is applied. The final reduction averages + labeled structures rather than atom-count weighting the per-graph + values. + + Parameters + ---------- + target_key : str, default "energy" + Target container key for the target tensor. + prediction_key : str, default "predicted_energy" + Prediction container key for the model output. + per_atom : bool, default True + Divide prediction and target by ``num_nodes_per_graph`` before + computing Huber residuals. + delta : float, default 0.01 + Positive transition point between quadratic and linear Huber regimes. + ignore_nonfinite : bool, default True + When ``True``, target entries that are ``NaN`` or infinite are + excluded from both loss value and gradient using + :func:`torch.isfinite`. + """ + + requires_eval_grad: bool = False + + def __init__( + self, + *, + target_key: str = "energy", + prediction_key: str = "predicted_energy", + per_atom: bool = True, + delta: float = 0.01, + ignore_nonfinite: bool = True, + ) -> None: + """Configure energy Huber loss keys and threshold.""" + super().__init__() + self.target_key = target_key + self.prediction_key = prediction_key + self.per_atom = per_atom + self.ignore_nonfinite = ignore_nonfinite + self.delta = float(delta) + + def normalize( + self, + pred: torch.Tensor, + target: torch.Tensor, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor, ReductionContext]: + """Divide by atom counts when ``per_atom=True``.""" + ctx = ReductionContext() + if not self.per_atom: + return pred, target, ctx + batch: Batch | None = kwargs.get("batch") + num_nodes_per_graph = kwargs.get("num_nodes_per_graph") + if batch is not None and num_nodes_per_graph is None: + num_nodes_per_graph = getattr(batch, "num_nodes_per_graph", None) + counts = _node_counts(num_nodes_per_graph, pred).reshape( + (-1,) + (1,) * (pred.ndim - 1) + ) + return pred / counts, target / counts, ctx + + def mask( + self, + pred: torch.Tensor, + target: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Exclude non-finite target entries when ``ignore_nonfinite=True``.""" + if self.ignore_nonfinite: + return torch.isfinite(target) + return torch.ones_like(target, dtype=torch.bool) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return elementwise Huber losses, zeroing invalid entries.""" + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return _huber_loss(residual, self.delta) + + def extra_repr(self) -> str: + """Human-readable hyperparameter summary for :class:`nn.Module`'s repr.""" + return ( + f"target_key={self.target_key!r}, " + f"prediction_key={self.prediction_key!r}, " + f"per_atom={self.per_atom!r}, " + f"ignore_nonfinite={self.ignore_nonfinite!r}, " + f"delta={self.delta!r}" + ) + + +class ForceMSELoss(BaseLossFunction): + """Mean-squared-error loss on per-atom forces. + + Forces enter this loss as per-atom vector quantities, unlike energy + totals. The ``normalize_by_atom_count`` flag therefore does not + convert total quantities into per-atom units; it controls how + per-atom force residuals are reduced across a mixed-size batch. + + Dense force tensors use shape ``(V, 3)``. Padded force tensors use + shape ``(B, V_max, 3)`` and ignore padding entries according to + ``num_nodes_per_graph`` supplied either as ``(B,)`` counts or + a ``(B, V_max)`` node mask. + + - ``normalize_by_atom_count=True`` (default): per-graph mean of + squared-component error, then mean over graphs. This is a + graph-balanced reduction: small and large graph contributions + are somewhat normalized. + - ``normalize_by_atom_count=False``: elementwise mean over all valid + force components. This is an atom/component-weighted reduction: + graph contributions are proportional to their number of valid + force components. + + Tensor Contract + --------------- + pred, target : Forces | Float[torch.Tensor, "B V_max 3"] + Dense per-node forces of shape ``(V, 3)`` or padded per-graph + forces of shape ``(B, V_max, 3)``. Shape validation requires + exact equality. + batch_idx : BatchIndices, optional + Required for dense ``(V, 3)`` forces when + ``normalize_by_atom_count=True``. Ignored for padded forces. + num_nodes_per_graph : Integer[torch.Tensor, "B"] | Bool[torch.Tensor, "B V_max"], optional + Required for padded ``(B, V_max, 3)`` forces. May be explicit + per-graph counts or a padded node-validity mask. + + Parameters + ---------- + target_key : str, default "forces" + Target container key for the target tensor. + prediction_key : str, default "predicted_forces" + Prediction container key for the model output. + normalize_by_atom_count : bool, default True + Control the batch reduction for already-per-atom force + residuals. ``True`` computes a graph-balanced mean by dividing + each graph's force-error sum by its valid component count before + averaging over graphs. ``False`` computes one global elementwise + mean over all valid force components. + ignore_nonfinite : bool, default False + When ``True``, target force components that are ``NaN`` or + infinite are excluded from both loss value and gradient using + :func:`torch.isfinite`. Intended for batches where some + atoms/graphs lack force labels. Implemented with branch-free + tensor ops for ``torch.compile`` compatibility. A graph whose + entire force tensor is non-finite contributes ``0.0`` to the + loss. + """ + + requires_eval_grad: bool = True + + def __init__( + self, + *, + target_key: str = "forces", + prediction_key: str = "predicted_forces", + normalize_by_atom_count: bool = True, + ignore_nonfinite: bool = False, + ) -> None: + """Configure attribute keys and per-graph normalization.""" + super().__init__() + self.target_key = target_key + self.prediction_key = prediction_key + self.normalize_by_atom_count = normalize_by_atom_count + self.ignore_nonfinite = ignore_nonfinite + + def mask( + self, + pred: torch.Tensor, + target: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Return component-level validity mask for dense or padded forces.""" + num_nodes_per_graph = kwargs.get("num_nodes_per_graph") + batch: Batch | None = kwargs.get("batch") + if batch is not None and pred.ndim == 3 and num_nodes_per_graph is None: + num_nodes_per_graph = getattr(batch, "num_nodes_per_graph", None) + return self._valid_force_components(pred, target, num_nodes_per_graph) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return squared force-component residuals, zeroing invalid entries.""" + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return residual.pow(2) + + def reduce( + self, + residual: torch.Tensor, + valid: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Reduce force-component residuals to a scalar loss.""" + valid_components = valid.to(dtype=residual.dtype) + batch: Batch | None = kwargs.get("batch") + batch_idx: BatchIndices | None = kwargs.get("batch_idx") + num_graphs: int | None = kwargs.get("num_graphs") + if batch is not None and self.normalize_by_atom_count and residual.ndim == 2: + if batch_idx is None: + batch_idx = getattr(batch, "batch_idx", None) + if num_graphs is None: + num_graphs = getattr(batch, "num_graphs", None) + if not self.normalize_by_atom_count: + if residual.ndim == 3: + per_graph_num = residual.sum(dim=(-2, -1)) + per_graph_den = valid_components.sum(dim=(-2, -1)) + self.per_sample_loss = ( + per_graph_num / per_graph_den.clamp_min(1.0) + ).detach() + return per_graph_num.sum() / per_graph_den.sum().clamp_min(1.0) + return residual.sum() / valid_components.sum().clamp_min(1.0) + per_graph_num, per_graph_den = self._per_graph_force_terms( + residual, valid_components, batch_idx, num_graphs + ) + per_sample = per_graph_num / per_graph_den.clamp_min(1.0) + self.per_sample_loss = per_sample.detach() + return per_sample.mean() + + @overload + def _valid_force_components( # noqa: F811 + self, + pred: Forces, # noqa: ARG002 + target: Forces, + num_nodes_per_graph: object, # noqa: ARG002 + ) -> _DenseForceMask: + """Return a valid-component mask for dense forces.""" + valid = torch.ones_like(target, dtype=torch.bool) + if self.ignore_nonfinite: + valid = valid & torch.isfinite(target) + return valid + + @overload + def _valid_force_components( # noqa: F811 + self, + pred: _PaddedForces, + target: _PaddedForces, + num_nodes_per_graph: _NodeCounts | _PaddedNodeMask | None, + ) -> _PaddedForceMask: + """Return a valid-component mask for padded forces.""" + node_mask = _padded_node_mask(num_nodes_per_graph, pred, pred.shape[1]) + valid = node_mask.unsqueeze(-1).expand_as(pred) + if self.ignore_nonfinite: + valid = valid & torch.isfinite(target) + return valid + + @dispatch + def _valid_force_components( # noqa: F811 + self, pred: object, target: object, num_nodes_per_graph: object + ) -> _DenseForceMask | _PaddedForceMask: + pass + + @overload + def _per_graph_force_terms( # noqa: F811 + self, + squared_error: Forces, + valid_components: Forces, + batch_idx: BatchIndices | None, + num_graphs: int | None, + ) -> tuple[_PerGraphValues, _PerGraphValues]: + """Return dense-force per-graph numerators and denominators.""" + batch_idx = _require_metadata(batch_idx, "batch_idx", loss_name="ForceMSELoss") + num_graphs = _require_metadata( + num_graphs, "num_graphs", loss_name="ForceMSELoss" + ) + per_atom_se = squared_error.sum(dim=-1) + per_atom_valid = valid_components.sum(dim=-1) + per_graph_se_sum = per_graph_sum(per_atom_se, batch_idx, num_graphs=num_graphs) + per_graph_valid = per_graph_sum( + per_atom_valid, batch_idx, num_graphs=num_graphs + ) + return per_graph_se_sum, per_graph_valid + + @overload + def _per_graph_force_terms( # noqa: F811 + self, + squared_error: _PaddedForces, + valid_components: _PaddedForces, + batch_idx: object, # noqa: ARG002 + num_graphs: object, # noqa: ARG002 + ) -> tuple[_PerGraphValues, _PerGraphValues]: + """Return padded-force per-graph numerators and denominators.""" + return squared_error.sum(dim=(-2, -1)), valid_components.sum(dim=(-2, -1)) + + @dispatch + def _per_graph_force_terms( # noqa: F811 + self, + squared_error: object, + valid_components: object, + batch_idx: object, + num_graphs: object, + ) -> tuple[_PerGraphValues, _PerGraphValues]: + pass + + def extra_repr(self) -> str: + """Human-readable hyperparameter summary for :class:`nn.Module`'s repr.""" + return ( + f"target_key={self.target_key!r}, " + f"prediction_key={self.prediction_key!r}, " + f"normalize_by_atom_count={self.normalize_by_atom_count!r}, " + f"ignore_nonfinite={self.ignore_nonfinite!r}" + ) + + +class ForceHuberLoss(ForceMSELoss): + """Huber loss on per-component force residuals. + + Inherits force masking and reduction from :class:`ForceMSELoss`. + + Parameters + ---------- + target_key : str, default "forces" + Target container key for the target tensor. + prediction_key : str, default "predicted_forces" + Prediction container key for the model output. + normalize_by_atom_count : bool, default False + Control the batch reduction for already-per-atom force + residuals. ``True`` computes a graph-balanced mean by dividing + each graph's force-error sum by its valid component count before + averaging over graphs. ``False`` computes one global elementwise + mean over all valid force components. + delta : float, default 0.01 + Positive transition point between quadratic and linear Huber regimes. + ignore_nonfinite : bool, default True + When ``True``, target force components that are ``NaN`` or + infinite are excluded from both loss value and gradient using + :func:`torch.isfinite`. + """ + + def __init__( + self, + *, + target_key: str = "forces", + prediction_key: str = "predicted_forces", + normalize_by_atom_count: bool = False, + delta: float = 0.01, + ignore_nonfinite: bool = True, + ) -> None: + """Configure force Huber loss keys, threshold, and reduction.""" + super().__init__( + target_key=target_key, + prediction_key=prediction_key, + normalize_by_atom_count=normalize_by_atom_count, + ignore_nonfinite=ignore_nonfinite, + ) + self.delta = float(delta) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return componentwise Huber force losses, zeroing invalid entries.""" + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return _huber_loss(residual, self.delta) + + def extra_repr(self) -> str: + """Human-readable hyperparameter summary for :class:`nn.Module`'s repr.""" + return f"{super().extra_repr()}, delta={self.delta!r}" + + +class ForceL2NormLoss(BaseLossFunction): + """Mean per-atom force-vector L2 loss. + + The per-atom residual is the vector norm + ``torch.linalg.vector_norm(pred - target, ord=2, dim=-1)``. Dense + ``(V, 3)`` inputs can be graph-balanced with ``batch_idx`` and + ``num_graphs``. Padded ``(B, V_max, 3)`` inputs require + ``num_nodes_per_graph`` counts or a node mask so padding can be + excluded from the atom-level reduction. + + Parameters + ---------- + target_key : str, default "forces" + Target container key for the target tensor. + prediction_key : str, default "predicted_forces" + Prediction container key for the model output. + normalize_by_atom_count : bool, default True + When ``True``, compute a mean atom L2 norm per graph, then mean + over graphs. When ``False``, compute one global mean over valid + atom L2 norms. + ignore_nonfinite : bool, default True + When ``True``, atoms whose target vector contains ``NaN`` or + infinity are excluded from both loss value and gradient using + :func:`torch.isfinite`. + """ + + def __init__( + self, + *, + target_key: str = "forces", + prediction_key: str = "predicted_forces", + normalize_by_atom_count: bool = True, + ignore_nonfinite: bool = True, + ) -> None: + """Configure attribute keys and force L2 semantics.""" + super().__init__() + self.target_key = target_key + self.prediction_key = prediction_key + self.normalize_by_atom_count = normalize_by_atom_count + self.ignore_nonfinite = ignore_nonfinite + + def mask( + self, + pred: torch.Tensor, + target: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Return atom-level validity mask (not component-level) for forces. + + The mask has shape ``(V,)`` for dense or ``(B, V_max)`` for + padded forces — one validity flag per atom, not per component. + """ + num_nodes_per_graph = kwargs.get("num_nodes_per_graph") + batch: Batch | None = kwargs.get("batch") + if batch is not None and pred.ndim == 3 and num_nodes_per_graph is None: + num_nodes_per_graph = getattr(batch, "num_nodes_per_graph", None) + return self._valid_force_atoms(pred, target, num_nodes_per_graph) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return per-atom L2 norm of force residuals, zeroing invalid atoms.""" + valid_vectors = valid.unsqueeze(-1) + residual = torch.where(valid_vectors, pred - target, torch.zeros_like(pred)) + return torch.linalg.vector_norm(residual, ord=2, dim=-1) + + def reduce( + self, + residual: torch.Tensor, + valid: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Reduce per-atom L2 norms to a scalar loss.""" + atom_weights = valid.to(dtype=residual.dtype) + batch: Batch | None = kwargs.get("batch") + batch_idx: BatchIndices | None = kwargs.get("batch_idx") + num_graphs: int | None = kwargs.get("num_graphs") + if batch is not None and self.normalize_by_atom_count and residual.ndim == 1: + if batch_idx is None: + batch_idx = getattr(batch, "batch_idx", None) + if num_graphs is None: + num_graphs = getattr(batch, "num_graphs", None) + if not self.normalize_by_atom_count: + if residual.ndim == 2: + per_graph_counts = atom_weights.sum(dim=-1).clamp_min(1.0) + self.per_sample_loss = ( + residual.sum(dim=-1) / per_graph_counts + ).detach() + return residual.sum() / atom_weights.sum().clamp_min(1.0) + per_graph_sum_l2, per_graph_counts = self._per_graph_atom_terms( + residual, atom_weights, batch_idx, num_graphs + ) + per_sample = per_graph_sum_l2 / per_graph_counts.clamp_min(1.0) + self.per_sample_loss = per_sample.detach() + return per_sample.mean() + + def _valid_force_atoms( + self, + pred: _ForceTensor, + target: _ForceTensor, + num_nodes_per_graph: _NodeCounts | _PaddedNodeMask | None, + ) -> Bool[torch.Tensor, "V"] | _PaddedNodeMask: + """Return atom-validity mask for dense or padded forces.""" + if pred.ndim == 2: + if self.ignore_nonfinite: + return torch.isfinite(target).all(dim=-1) + return torch.ones_like(target[..., 0], dtype=torch.bool) + node_mask = _padded_node_mask(num_nodes_per_graph, pred, pred.shape[1]) + if self.ignore_nonfinite: + return node_mask & torch.isfinite(target).all(dim=-1) + return node_mask + + def _per_graph_atom_terms( + self, + per_atom_values: Float[torch.Tensor, "..."], + atom_weights: Float[torch.Tensor, "..."], + batch_idx: BatchIndices | None, + num_graphs: int | None, + ) -> tuple[_PerGraphValues, _PerGraphValues]: + """Return per-graph atom-value sums and valid atom counts.""" + if per_atom_values.ndim == 1: + batch_idx = _require_metadata( + batch_idx, "batch_idx", loss_name="ForceL2NormLoss" + ) + num_graphs = _require_metadata( + num_graphs, "num_graphs", loss_name="ForceL2NormLoss" + ) + return ( + per_graph_sum(per_atom_values, batch_idx, num_graphs=num_graphs), + per_graph_sum(atom_weights, batch_idx, num_graphs=num_graphs), + ) + return per_atom_values.sum(dim=-1), atom_weights.sum(dim=-1) + + def extra_repr(self) -> str: + """Human-readable hyperparameter summary for :class:`nn.Module`'s repr.""" + return ( + f"target_key={self.target_key!r}, " + f"prediction_key={self.prediction_key!r}, " + f"normalize_by_atom_count={self.normalize_by_atom_count!r}, " + f"ignore_nonfinite={self.ignore_nonfinite!r}" + ) + + +class StressMSELoss(BaseLossFunction): + """Mean-squared-error loss on the per-graph stress tensor. + + Both pred and target are shape ``(B, 3, 3)``. The loss is the mean + of the per-graph squared-Frobenius residual, computed via + :func:`~nvalchemi.training.losses.reductions.frobenius_mse`. + + Tensor Contract + --------------- + pred, target : Stress + Per-graph stress tensors of shape ``(B, 3, 3)``. Shape + validation requires exact equality. + + Parameters + ---------- + target_key : str, default "stress" + Target container key for the target tensor. + prediction_key : str, default "predicted_stress" + Prediction container key for the model output. + ignore_nonfinite : bool, default False + When ``True``, target stress components that are ``NaN`` or + infinite are excluded from both loss value and gradient using + :func:`torch.isfinite`. Intended for inputs that mix samples + with and without stress labels. Implemented with branch-free + tensor ops for ``torch.compile`` compatibility. A graph whose + entire stress tensor is non-finite contributes ``0.0`` to the + loss. + """ + + requires_eval_grad: bool = True + + def __init__( + self, + *, + target_key: str = "stress", + prediction_key: str = "predicted_stress", + ignore_nonfinite: bool = False, + ) -> None: + """Configure attribute keys for target and prediction.""" + super().__init__() + self.target_key = target_key + self.prediction_key = prediction_key + self.ignore_nonfinite = ignore_nonfinite + + def mask( + self, + pred: torch.Tensor, + target: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Exclude non-finite stress components when ``ignore_nonfinite=True``.""" + if self.ignore_nonfinite: + return torch.isfinite(target) + return torch.ones_like(target, dtype=torch.bool) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return squared stress residuals, zeroing invalid entries.""" + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return residual.pow(2) + + def reduce( + self, + residual: torch.Tensor, + valid: torch.Tensor, + ctx: ReductionContext, + **kwargs: Any, + ) -> torch.Tensor: + """Reduce per-component stress residuals to a per-graph mean scalar.""" + per_graph_num = residual.sum(dim=(-2, -1)) + per_graph_den = valid.to(dtype=residual.dtype).sum(dim=(-2, -1)).clamp_min(1.0) + per_sample = per_graph_num / per_graph_den + self.per_sample_loss = per_sample.detach() + return per_sample.mean() + + def extra_repr(self) -> str: + """Human-readable hyperparameter summary for :class:`nn.Module`'s repr.""" + return ( + f"target_key={self.target_key!r}, " + f"prediction_key={self.prediction_key!r}, " + f"ignore_nonfinite={self.ignore_nonfinite!r}" + ) + + +class StressHuberLoss(StressMSELoss): + """Huber loss on per-graph stress tensors. + + Inherits stress masking and reduction from :class:`StressMSELoss`. + + Parameters + ---------- + target_key : str, default "stress" + Target container key for the target tensor. + prediction_key : str, default "predicted_stress" + Prediction container key for the model output. + delta : float, default 0.01 + Positive transition point between quadratic and linear Huber regimes. + ignore_nonfinite : bool, default True + When ``True``, target stress components that are ``NaN`` or + infinite are excluded from both loss value and gradient using + :func:`torch.isfinite`. + """ + + def __init__( + self, + *, + target_key: str = "stress", + prediction_key: str = "predicted_stress", + delta: float = 0.01, + ignore_nonfinite: bool = True, + ) -> None: + """Configure stress Huber loss keys and threshold.""" + super().__init__( + target_key=target_key, + prediction_key=prediction_key, + ignore_nonfinite=ignore_nonfinite, + ) + self.delta = float(delta) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, + ) -> torch.Tensor: + """Return componentwise Huber stress losses, zeroing invalid entries.""" + residual = torch.where(valid, pred - target, torch.zeros_like(pred)) + return _huber_loss(residual, self.delta) + + def extra_repr(self) -> str: + """Human-readable hyperparameter summary for :class:`nn.Module`'s repr.""" + return f"{super().extra_repr()}, delta={self.delta!r}" diff --git a/nvalchemi/training/optimizers.py b/nvalchemi/training/optimizers.py new file mode 100644 index 00000000..aed867f3 --- /dev/null +++ b/nvalchemi/training/optimizers.py @@ -0,0 +1,407 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Optimizer configuration and stepping helpers for training strategies.""" + +from __future__ import annotations + +import inspect +from collections.abc import Callable, Iterable, Mapping +from typing import Any, TypeAlias + +import torch +from pydantic import ( + BaseModel, + ConfigDict, + Field, + model_validator, +) +from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau + +from nvalchemi._serialization import SerializableClass, SerializableOptionalClass +from nvalchemi.training._spec import ( + BaseSpec, + create_model_spec, +) + +OptSchedPair: TypeAlias = tuple[torch.optim.Optimizer, LRScheduler | None] +SchedulerMetricAdapter: TypeAlias = Callable[[dict[str, Any]], float] | str | None + +_DEFAULT_METRIC_KEY = "total_loss" + +__all__ = [ + "OptSchedPair", + "OptimizerConfig", + "SchedulerMetricAdapter", + "setup_optimizers", + "step_lr_schedulers", + "step_metric_schedulers", + "step_optimizers", + "zero_gradients", +] + + +def _check_kwargs(cls: type, kwargs: Mapping[str, Any], label: str) -> None: + """Raise ``ValueError`` if ``kwargs`` are not accepted by ``cls.__init__``.""" + try: + sig = inspect.signature(cls.__init__) + except (TypeError, ValueError): + return + try: + sig.bind_partial(None, None, **kwargs) + except TypeError as exc: + accepted = { + name + for name, param in sig.parameters.items() + if param.kind + not in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + } + } + invalid = sorted(set(kwargs) - accepted) + if not invalid: + raise ValueError( + f"Invalid {label} kwargs for {cls.__name__}: {exc}" + ) from None + raise ValueError( + f"Invalid {label} kwargs for {cls.__name__}: {invalid}" + ) from None + + +def _normalize_optimizer_configs( + value: Any, + *, + single_model_input: bool, +) -> Any: + """Normalize accepted optimizer config inputs to named lists.""" + if isinstance(value, OptimizerConfig): + if not single_model_input and value is not None: + raise ValueError( + "Unkeyed optimizer_configs require single-model input; pass " + "{'model_name': [OptimizerConfig(...)]} for named models." + ) + return {"main": [value]} + if isinstance(value, list): + if not single_model_input: + raise ValueError( + "Unkeyed optimizer_configs require single-model input; pass " + "{'model_name': [...]} for named models." + ) + return {"main": value} + if isinstance(value, dict): + if set(value) == {0}: + return {"main": value[0]} + return value + return value + + +class OptimizerConfig(BaseModel): + """Declarative optimizer + optional LR-scheduler bundle. + + Kwargs are validated against each class's ``__init__`` at construction + time so mistakes surface before training starts. Build the concrete + ``(optimizer, scheduler)`` pair via :meth:`build`. + + Attributes + ---------- + optimizer_cls : type[torch.optim.Optimizer] + Optimizer class; ``optimizer_kwargs`` must match its signature. + optimizer_kwargs : dict[str, Any] + scheduler_cls : type | None + Optional LR scheduler. Time-based schedulers (``StepLR``, + ``CosineAnnealingLR``, etc.) step every optimizer step. + Metric-driven schedulers (``ReduceLROnPlateau`` and subclasses) + step only at validation checkpoints via + :func:`step_metric_schedulers`. + scheduler_kwargs : dict[str, Any] + Must be empty unless ``scheduler_cls`` is set. + scheduler_metric_adapter : Callable[[dict], float] | str | None + How a metric-driven scheduler (``ReduceLROnPlateau``) extracts + its scalar metric from the validation summary dict. A ``str`` + is treated as a key lookup into the summary; a callable + receives the whole summary dict and returns a ``float``; + ``None`` uses the default extractor (see + :func:`_extract_scheduler_metric`). + + Examples + -------- + >>> import torch + >>> cfg = OptimizerConfig( + ... optimizer_cls=torch.optim.Adam, + ... optimizer_kwargs={"lr": 1e-3}, + ... scheduler_cls=torch.optim.lr_scheduler.StepLR, + ... scheduler_kwargs={"step_size": 10, "gamma": 0.1}, + ... ) + """ + + optimizer_cls: SerializableClass + optimizer_kwargs: dict[str, Any] = Field(default_factory=dict) + scheduler_cls: SerializableOptionalClass = None + scheduler_kwargs: dict[str, Any] = Field(default_factory=dict) + scheduler_metric_adapter: SchedulerMetricAdapter = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @model_validator(mode="after") + def _validate_kwargs(self) -> OptimizerConfig: + """Validate optimizer/scheduler kwargs against their __init__ signatures.""" + _check_kwargs(self.optimizer_cls, self.optimizer_kwargs, "optimizer") + if self.scheduler_cls is None: + if self.scheduler_kwargs: + raise ValueError( + "scheduler_kwargs provided but scheduler_cls is None; " + "set scheduler_cls or remove scheduler_kwargs. " + f"Got: {sorted(self.scheduler_kwargs)}" + ) + if self.scheduler_metric_adapter is not None: + raise ValueError( + "scheduler_metric_adapter provided but scheduler_cls is None." + ) + else: + _check_kwargs(self.scheduler_cls, self.scheduler_kwargs, "scheduler") + return self + + def build(self, params: Iterable[torch.nn.Parameter]) -> OptSchedPair: + """Instantiate the optimizer and optional scheduler for ``params``. + + Parameters + ---------- + params : Iterable[torch.nn.Parameter] + + Returns + ------- + tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler | None] + """ + optimizer = self.optimizer_cls(params, **self.optimizer_kwargs) + scheduler = ( + self.scheduler_cls(optimizer, **self.scheduler_kwargs) + if self.scheduler_cls is not None + else None + ) + return optimizer, scheduler + + def to_spec(self) -> BaseSpec: + """Serialize to a :class:`BaseSpec` via :func:`create_model_spec`. + + Returns + ------- + BaseSpec + A spec instance that rebuilds the original :class:`OptimizerConfig`. + """ + return create_model_spec(type(self), **self.model_dump()) + + @classmethod + def from_spec(cls, spec: BaseSpec) -> OptimizerConfig: + """Rebuild an :class:`OptimizerConfig` from a :class:`BaseSpec`. + + Parameters + ---------- + spec : BaseSpec + A spec produced by :meth:`to_spec`. + + Returns + ------- + OptimizerConfig + A freshly validated instance equivalent to the original. + + Raises + ------ + TypeError + If ``spec`` does not build an :class:`OptimizerConfig`. + """ + instance = spec.build() + if not isinstance(instance, cls): + raise TypeError( + f"Spec at {spec.cls_path!r} built {type(instance).__name__}, " + f"expected {cls.__name__}." + ) + return instance + + +def setup_optimizers( + models: torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict, + optimizer_configs: OptimizerConfig + | list[OptimizerConfig] + | dict[str, list[OptimizerConfig]], +) -> dict[str, list[OptSchedPair]]: + """Build optimizers and schedulers for configured model names. + + Parameters + ---------- + models : torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict + optimizer_configs : OptimizerConfig | list[OptimizerConfig] | dict[str, list[OptimizerConfig]] + + Returns + ------- + dict[str, list[tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler | None]]] + + Raises + ------ + ValueError + If a config key is not present in ``models`` or a configured model has + no trainable parameters. + """ + named_model_input = isinstance(models, (dict, torch.nn.ModuleDict)) + named_models = dict(models.items()) if named_model_input else {"main": models} + configs = _normalize_optimizer_configs( + optimizer_configs, single_model_input=not named_model_input + ) + result: dict[str, list[OptSchedPair]] = {} + for key, cfgs in configs.items(): + if key not in named_models: + raise ValueError( + f"optimizer_configs key {key!r} is not present in models; " + f"available model keys: {sorted(named_models)}." + ) + trainable = [p for p in named_models[key].parameters() if p.requires_grad] + if not trainable: + raise ValueError( + f"Configured model {key!r} has no trainable parameters " + "(requires_grad=True)." + ) + result[key] = [cfg.build(trainable) for cfg in cfgs] + return result + + +def zero_gradients(opts: Iterable[torch.optim.Optimizer]) -> None: + """Call ``zero_grad(set_to_none=True)`` on each optimizer. + + Parameters + ---------- + opts : Iterable[torch.optim.Optimizer] + """ + for opt in opts: + opt.zero_grad(set_to_none=True) + + +def step_optimizers(opts: Iterable[torch.optim.Optimizer]) -> None: + """Call ``step()`` on each optimizer. + + Parameters + ---------- + opts : Iterable[torch.optim.Optimizer] + """ + for opt in opts: + opt.step() + + +def _is_metric_driven( + scheduler: LRScheduler | ReduceLROnPlateau | None, +) -> bool: + """Return whether ``scheduler`` is a metric-driven LR scheduler. + + Metric-driven schedulers (``ReduceLROnPlateau`` and subclasses) + require a scalar metric argument for each ``step()`` call and + are therefore stepped only at validation checkpoints, not on + every optimizer step. + + Parameters + ---------- + scheduler : LRScheduler | ReduceLROnPlateau | None + Scheduler instance to check. + + Returns + ------- + bool + ``True`` when ``scheduler`` is an instance of + ``ReduceLROnPlateau``. + """ + return isinstance(scheduler, ReduceLROnPlateau) + + +def _extract_scheduler_metric( + summary: dict[str, Any], + adapter: SchedulerMetricAdapter, +) -> float: + """Extract a scalar metric from a validation summary for a metric-driven scheduler. + + Parameters + ---------- + summary : dict[str, Any] + Validation summary dictionary produced by + :meth:`~nvalchemi.training._validation._LossAccumulator.summary`. + adapter : SchedulerMetricAdapter + Extraction strategy. A callable receives the full summary and + returns a float. A ``str`` is used as a direct key lookup. When + ``None``, the default key ``"total_loss"`` is used (the + aggregate/total validation loss). + + Returns + ------- + float + Scalar metric value. + + Raises + ------ + KeyError + When a string adapter (or the default key) is not present in + ``summary``. + """ + if callable(adapter): + return float(adapter(summary)) + key = adapter if isinstance(adapter, str) else _DEFAULT_METRIC_KEY + if key not in summary: + available = sorted(summary.keys()) + raise KeyError( + f"Scheduler metric key {key!r} not found in validation summary; " + f"available keys: {available}" + ) + return float(summary[key]) + + +def step_lr_schedulers( + schedulers: Iterable[LRScheduler | ReduceLROnPlateau | None], +) -> None: + """Call ``step()`` on each non-``None`` time-based scheduler. + + Metric-driven schedulers (``ReduceLROnPlateau``) are skipped here; + they step at validation checkpoints via :func:`step_metric_schedulers`. + + Parameters + ---------- + schedulers : Iterable[LRScheduler | ReduceLROnPlateau | None] + """ + for scheduler in schedulers: + if scheduler is None or _is_metric_driven(scheduler): + continue + scheduler.step() + + +def step_metric_schedulers( + schedulers: Iterable[LRScheduler | ReduceLROnPlateau | None], + adapters: Iterable[SchedulerMetricAdapter], + summary: dict[str, Any], +) -> None: + """Step metric-driven schedulers using a validation summary. + + Zips ``schedulers`` with ``adapters`` (positional correspondence + must match) and calls ``scheduler.step(metric)`` for each + metric-driven scheduler. Non-metric-driven and ``None`` + schedulers are skipped. + + Parameters + ---------- + schedulers : Iterable[LRScheduler | ReduceLROnPlateau | None] + Flat list of schedulers in the same positional order as + ``adapters``. + adapters : Iterable[SchedulerMetricAdapter] + Per-scheduler metric extraction adapters. + summary : dict[str, Any] + Validation summary dictionary. + """ + for scheduler, adapter in zip(schedulers, adapters, strict=True): + if scheduler is None or not _is_metric_driven(scheduler): + continue + scheduler.step(_extract_scheduler_metric(summary, adapter)) diff --git a/nvalchemi/training/runtime.py b/nvalchemi/training/runtime.py new file mode 100644 index 00000000..e92807d3 --- /dev/null +++ b/nvalchemi/training/runtime.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Runtime helpers for dataloading, device placement, and parallelism setup.""" + +from __future__ import annotations + +from collections.abc import Callable, Iterator, Mapping, Sequence +from contextlib import contextmanager +from typing import Any + +import torch +from torch.utils.data import DataLoader + +__all__ = [ + "configure_dataloader", + "configure_parallelism", + "freeze_unconfigured_models", + "move_to_devices", +] + + +@contextmanager +def freeze_unconfigured_models( + models: dict[str, torch.nn.Module] | torch.nn.ModuleDict, + optimizer_configs: Mapping[str, object], +) -> Iterator[None]: + """Temporarily eval/freeze models omitted from optimizer configs. + + Parameters + ---------- + models : dict[str, torch.nn.Module] | torch.nn.ModuleDict + Named models participating in a training run. + optimizer_configs : Mapping[str, object] + Optimizer configuration keyed by model name. Models absent from this + mapping are temporarily switched to eval mode and have all parameters + marked ``requires_grad=False``. + + Yields + ------ + None + Control while omitted models are frozen. + """ + state: dict[str, tuple[bool, list[tuple[torch.nn.Parameter, bool]]]] = {} + for key, model in models.items(): + if key in optimizer_configs: + continue + param_states: list[tuple[torch.nn.Parameter, bool]] = [] + for param in model.parameters(): + param_states.append((param, param.requires_grad)) + param.requires_grad_(False) + state[key] = (model.training, param_states) + model.eval() + try: + yield + finally: + for key, (training, param_states) in state.items(): + models[key].train(training) + for param, requires_grad in param_states: + param.requires_grad_(requires_grad) + + +def move_to_devices( + models: torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict, + devices: Sequence[torch.device], + *, + non_blocking: bool = False, +) -> torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict: + """Move one model or named models to device(s), preserving input shape. + + Parameters + ---------- + models : torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict + Single module or named modules. Named modules are assigned devices in + insertion order. + devices : Sequence[torch.device] + One device broadcasts to all models; otherwise length must match the + number of models. + non_blocking : bool, optional + Forwarded to :meth:`torch.nn.Module.to`. + + Returns + ------- + torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict + The same input shape after in-place ``.to(...)`` calls. + + Raises + ------ + ValueError + If ``devices`` has length other than ``1`` or the number of models. + """ + if isinstance(models, (dict, torch.nn.ModuleDict)): + if len(devices) not in (1, len(models)): + raise ValueError( + f"devices must have length 1 or len(models)={len(models)}; " + f"got {len(devices)}." + ) + expanded = list(devices) if len(devices) != 1 else list(devices) * len(models) + for model, device in zip(models.values(), expanded, strict=True): + model.to(device, non_blocking=non_blocking) + return models + if len(devices) != 1: + raise ValueError( + f"single-model device assignment requires exactly one device; " + f"got {len(devices)}." + ) + return models.to(devices[0], non_blocking=non_blocking) + + +def configure_dataloader( + dataset: Any, + *, + batch_size: int, + shuffle: bool | None = None, + sampler: Any = None, + batch_sampler: Any = None, + collate_fn: Callable | None = None, + **dl_kwargs: Any, +) -> DataLoader: + """Thin wrapper around :class:`~torch.utils.data.DataLoader`. + + Parameters + ---------- + dataset : Any + batch_size : int + shuffle : bool | None, optional + Defaults to ``False`` when ``sampler`` is provided and ``True`` + otherwise. Passing ``True`` with ``sampler`` raises ``ValueError``. + sampler : Any, optional + Optional sample-ordering object forwarded to ``DataLoader``. + batch_sampler : Any, optional + Optional batch sampler forwarded to ``DataLoader``. + collate_fn : Callable | None, optional + **dl_kwargs : Any + Forwarded to ``DataLoader``. + + Returns + ------- + torch.utils.data.DataLoader + + Raises + ------ + ValueError + If ``shuffle=True`` and ``sampler`` are both provided. + """ + if shuffle is True and sampler is not None: + raise ValueError("shuffle=True is incompatible with sampler.") + resolved_shuffle = sampler is None if shuffle is None else shuffle + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=resolved_shuffle, + sampler=sampler, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + **dl_kwargs, + ) + + +def configure_parallelism( + models: torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict, + *, + strategy: str = "none", +) -> torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict: + """Configure model parallelism, preserving input shape. + + Parameters + ---------- + models : torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict + strategy : str, optional + + Returns + ------- + torch.nn.Module | dict[str, torch.nn.Module] | torch.nn.ModuleDict + + Raises + ------ + NotImplementedError + For any strategy other than ``"none"``. + """ + if strategy == "none": + return models + raise NotImplementedError( + f"Unsupported parallelism strategy: {strategy!r}; " + "supported strategies: ['none']" + ) diff --git a/nvalchemi/training/strategy.py b/nvalchemi/training/strategy.py new file mode 100644 index 00000000..59a4345f --- /dev/null +++ b/nvalchemi/training/strategy.py @@ -0,0 +1,1642 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Training strategy lifecycle and default forward-pass helper. + +``TrainingStrategy`` wires one named model (``"main"``) or a dictionary-like +collection of named models through a user-supplied ``training_fn``. +Single-model strategies call ``training_fn(model, batch)``; named-model +strategies call ``training_fn(models, batch)`` for distillation or multi-model +workflows. +Models omitted from optimizer configs are temporarily set to eval mode and +frozen during ``run``. Named-model training functions that use omitted models as +teacher/auxiliary networks must run those forward passes under +``torch.no_grad()`` or detach returned tensors unless autograd through those +outputs is intentionally required. + +Loss hooks see live autograd-connected losses from ``AFTER_LOSS`` through +``BEFORE_BACKWARD``. From ``AFTER_BACKWARD`` onward the hook context carries +detached loss tensors so logging hooks do not accidentally retain graphs. +""" + +from __future__ import annotations + +import dataclasses +import itertools +import math +import warnings +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING, Annotated, Any + +import torch +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + SkipValidation, + field_validator, + model_validator, +) +from torch import nn +from torch.optim.lr_scheduler import LRScheduler + +from nvalchemi._serialization import _import_cls +from nvalchemi._typing import ModelOutputs +from nvalchemi.distributed import DistributedManager +from nvalchemi.hooks._context import TrainContext +from nvalchemi.hooks._protocol import Hook +from nvalchemi.hooks._registry import HookRegistryMixin +from nvalchemi.models.base import BaseModelMixin +from nvalchemi.training import _spec_utils as strategy_spec +from nvalchemi.training import _strategy_validation as strategy_validation +from nvalchemi.training import _validation +from nvalchemi.training._spec import create_model_spec +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training._validation import ValidationConfig +from nvalchemi.training.distributed import get_rank as get_distributed_rank +from nvalchemi.training.hooks import TrainingUpdateHook, TrainingUpdateOrchestrator +from nvalchemi.training.hooks.mixed_precision import MixedPrecisionHook +from nvalchemi.training.hooks.update import ( + _fold_training_update_hooks, + _hook_claims_stage, +) +from nvalchemi.training.losses.composition import ( + ComposedLossFunction, + ComposedLossOutput, + _ProductWeight, + as_composed_loss, + compute_supervised_loss, + loss_component_to_spec, + loss_target_keys, +) +from nvalchemi.training.optimizers import ( + OptimizerConfig, + SchedulerMetricAdapter, + _normalize_optimizer_configs, + setup_optimizers, + step_lr_schedulers, + step_metric_schedulers, + step_optimizers, + zero_gradients, +) +from nvalchemi.training.runtime import freeze_unconfigured_models, move_to_devices + +if TYPE_CHECKING: + from nvalchemi.data.batch import Batch + from nvalchemi.training._checkpoint import CheckpointValidator + +__all__ = ["TrainingStrategy", "default_training_fn"] + +_RESTART_COUNTER_FIELDS = ( + "step_count", + "batch_count", + "epoch_count", + "epoch_step_count", +) + + +@dataclasses.dataclass(frozen=True) +class _RuntimeOptimizer: + """Bind an optimizer to its scheduler and metric adapter as one unit. + + Users pass aligned ``optimizer_configs`` and the strategy keeps the + derived optimizer, scheduler, and scheduler-metric adapter together + in a single record so the three can never drift out of positional + correspondence internally. + + Attributes + ---------- + optimizer : torch.optim.Optimizer + The built optimizer. + scheduler : LRScheduler | None + The built LR scheduler, or ``None`` when the config declared no + scheduler. + adapter : SchedulerMetricAdapter + The metric adapter (callable, summary-key string, or ``None``) + used to extract a scalar for a metric-driven scheduler. + """ + + optimizer: torch.optim.Optimizer + scheduler: LRScheduler | None + adapter: SchedulerMetricAdapter + + +def _loss_weight_to_spec(weight: Any) -> Any: + """Serialize a composed-loss weight schedule while leaving scalars unchanged.""" + if isinstance(weight, _ProductWeight): + return create_model_spec( + type(weight), + left=_loss_weight_to_spec(weight.left), + right=_loss_weight_to_spec(weight.right), + ) + if hasattr(weight, "model_dump"): + return create_model_spec(type(weight), **weight.model_dump()) + return weight + + +def _validate_single_do_claimants( + hooks: Sequence[Hook], + *, + extra_hook: Hook | None = None, + extra_stage: TrainingStage | None = None, +) -> None: + """Raise if more than one hook claims a DO update stage.""" + candidates: list[Hook] = list(hooks) + if extra_hook is not None and all(h is not extra_hook for h in candidates): + candidates.append(extra_hook) + for do_stage in (TrainingStage.DO_BACKWARD, TrainingStage.DO_OPTIMIZER_STEP): + claimants = [ + h + for h in candidates + if _hook_claims_stage(h, do_stage) + or (h is extra_hook and extra_stage == do_stage) + ] + if len(claimants) > 1: + names = ", ".join(type(h).__name__ for h in claimants) + migration_hint = ( + " If one claimant is a plain DO-stage hook that should compose " + "with update policies, implement it as TrainingUpdateHook so it " + "runs inside the TrainingUpdateOrchestrator." + if any(isinstance(h, TrainingUpdateOrchestrator) for h in claimants) + else " Compose claim semantics are reserved for a future feature." + ) + raise ValueError( + f"At most one hook may claim {do_stage.name}; got " + f"{len(claimants)}: {names}.{migration_hint}" + ) + + +def _hook_needs_prior_update_orchestrator(hook: Hook, stage: TrainingStage) -> bool: + """Return whether ``hook`` requires the update orchestrator before ``stage``.""" + check = getattr(hook, "_requires_update_orchestrator_before_stage", None) + return bool(check is not None and check(stage)) + + +def _order_update_orchestrator_before_dependent_hooks( + hooks: Sequence[Hook | TrainingUpdateOrchestrator], +) -> list[Hook | TrainingUpdateOrchestrator]: + """Move the update orchestrator before hooks that observe its post-step state.""" + result = list(hooks) + orchestrator_index = next( + ( + index + for index, hook in enumerate(result) + if isinstance(hook, TrainingUpdateOrchestrator) + ), + None, + ) + if orchestrator_index is None: + return result + first_dependent_index = next( + ( + index + for index, hook in enumerate(result[:orchestrator_index]) + if _hook_needs_prior_update_orchestrator( + hook, TrainingStage.AFTER_OPTIMIZER_STEP + ) + ), + None, + ) + if first_dependent_index is None: + return result + orchestrator = result.pop(orchestrator_index) + result.insert(first_dependent_index, orchestrator) + return result + + +def _validate_hook_dependencies( + hooks: Sequence[Hook | TrainingUpdateOrchestrator], +) -> None: + """Ask hooks to validate dependencies against the full registered set.""" + for hook in hooks: + validate = getattr(hook, "_validate_registered_hooks", None) + if validate is not None: + validate(hooks) + + +def _iter_registered_hooks( + hooks: Iterable[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator], +) -> Iterator[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator]: + """Yield registered hooks and children nested in update orchestrators.""" + for hook in hooks: + yield hook + if isinstance(hook, TrainingUpdateOrchestrator): + yield from _iter_registered_hooks(hook.iter_hooks()) + + +def default_training_fn(model: BaseModelMixin, batch: Batch) -> dict[str, torch.Tensor]: + """Run a forward pass and prefix output keys with ``predicted_``. + + Parameters + ---------- + model : BaseModelMixin + A wrapped MLIP whose ``__call__`` returns model outputs. + batch : Batch + Input batch of atomic graphs. + + Returns + ------- + dict[str, torch.Tensor] + Predictions keyed by ``predicted_`` with ``None`` outputs + omitted. + """ + outputs: ModelOutputs = model(batch) + return { + f"predicted_{key}": value for key, value in outputs.items() if value is not None + } + + +class TrainingStrategy(BaseModel, HookRegistryMixin): + """Pydantic-driven supervised training loop for MLIP models. + + Attributes + ---------- + models : dict[str, BaseModelMixin] + Named models visible to ``training_fn`` and hooks. Single-model inputs + are stored under ``"main"``; :class:`torch.nn.ModuleDict` inputs are + accepted and normalized to a plain ``dict``. + optimizer_configs : dict[str, list[OptimizerConfig]] + Optimizer/scheduler configs keyed by model name. Keys may target a + subset of ``models``; omitted models are frozen/eval during ``run``. + num_epochs : int | None + Epoch count; mutually exclusive with ``num_steps``. At runtime, + epochs are converted into a target step count from the dataloader + length and ``epoch_step_modifier``. + num_steps : int | None + Target step count; mutually exclusive with ``num_epochs``. + epoch_step_modifier : float + Positive multiplier applied when converting ``num_epochs`` to a + target step count. Hooks may inspect this value through + ``ctx.workflow``. + hooks : list[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator] + Hooks executed at the stages declared by :class:`TrainingStage`. + Bare :class:`TrainingUpdateHook` instances are auto-wrapped into a + single :class:`TrainingUpdateOrchestrator` (see Notes). Duplicate + hook object instances are rejected, and the list is **not** + expected to be mutated once the ``TrainingStrategy`` context + manager has been entered. + training_fn : Callable[..., Mapping[str, torch.Tensor]] + Explicit forward-pass callable. Single-model strategies call + ``(model, batch)``; named-model strategies call ``(models, batch)``. + loss_fn : ComposedLossFunction + Composed loss whose components drive target collection. Leaf losses are + accepted and normalized to one-component composed losses. + devices : list[torch.device] + One device shared by all models, or one device per model for helper + placement. Named-model ``run`` currently supports one device only. + distributed_manager : DistributedManager | None + Optional external distributed manager. The strategy passes this through + hook contexts for distributed-aware hooks. + step_count : int + Runtime optimizer-step counter, excluded from specs. Batches whose + optimizer step is skipped by update hooks do not advance this counter. + batch_count : int + Runtime batch counter, excluded from specs. This advances for every + completed batch, including batches whose optimizer step is skipped. + epoch_count : int + Runtime epoch counter, excluded from specs. + epoch_step_count : int + Runtime counter for batches consumed within the current epoch, + excluded from specs. + + Notes + ----- + Use :meth:`to_spec_dict` / :meth:`from_spec_dict` for JSON-based save/load. + Optimizer configs, loss specs, devices, importable training functions, and + best-effort model specs are serialized. Runtime ``models`` and + ``training_fn`` overrides passed to :meth:`from_spec_dict` take precedence; + the serialized model call mode is used only when no runtime model override + is supplied. ``hooks``, ``step_count``, ``batch_count``, ``epoch_count``, + and ``epoch_step_count`` remain runtime-only. + + Bare :class:`TrainingUpdateHook` instances are auto-wrapped into a single + :class:`TrainingUpdateOrchestrator` on registration; the orchestrator owns + the ``zero_gradients`` / ``backward`` / ``optimizer.step`` / + ``scheduler.step`` calls that the strategy otherwise issues by default. + Construction-time hook validation errors surface as + :class:`pydantic.ValidationError`; :meth:`register_hook` raises + :class:`ValueError` directly. + """ + + models: dict[str, BaseModelMixin] + optimizer_configs: dict[str, list[OptimizerConfig]] = Field(default_factory=dict) + num_epochs: int | None = Field(default=None, ge=1) + num_steps: int | None = Field(default=None, ge=1) + epoch_step_modifier: float = Field(default=1.0, gt=0, allow_inf_nan=False) + hooks: list[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator] = Field( + default_factory=list, + description=( + "Hooks to run at training stages. Accepts ``Hook`` Protocol " + "instances, bare ``TrainingUpdateHook`` instances (auto-wrapped " + "into a single ``TrainingUpdateOrchestrator``), or an explicit " + "``TrainingUpdateOrchestrator``. Example: " + "``hooks=[CheckpointHook(...), MyClipGradHook()]``." + ), + ) + training_fn: Callable[..., Mapping[str, torch.Tensor]] | None = None + loss_fn: ComposedLossFunction + devices: list[torch.device] = Field(default_factory=lambda: [torch.device("cpu")]) + distributed_manager: Annotated[DistributedManager | None, SkipValidation()] = Field( + default=None, + exclude=True, + ) + step_count: int = Field(default=0, ge=0, exclude=True) + batch_count: int = Field(default=0, ge=0, exclude=True) + epoch_count: int = Field(default=0, ge=0, exclude=True) + epoch_step_count: int = Field(default=0, ge=0, exclude=True) + single_model_input: bool = Field(default=False, exclude=True) + last_validation: dict[str, Any] | None = Field(default=None, exclude=True) + inference_model: nn.Module | nn.ModuleDict | None = Field( + default=None, exclude=True + ) + validation_config: ValidationConfig | None = Field(default=None, exclude=True) + + _context_depth: int = PrivateAttr(default=0) + _ctx: TrainContext | None = PrivateAttr(default=None) + _has_do_backward_claim: bool = PrivateAttr(default=False) + _has_do_optimizer_step_claim: bool = PrivateAttr(default=False) + _has_update_orchestrator: bool = PrivateAttr(default=False) + _resume_optimizer_state: bool = PrivateAttr(default=False) + _runtime_optimizers: list[_RuntimeOptimizer] = PrivateAttr(default_factory=list) + + _active_dataloader: Any = PrivateAttr(default=None) + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + # To minimize overhead, validation is only performed at the + # initial construction + validate_assignment=False, + revalidate_instances="never", + ) + + _stage_type = TrainingStage + + @property + def epoch(self) -> int: + """Backward-compatible alias for :attr:`epoch_count`.""" + return self.epoch_count + + @epoch.setter + def epoch(self, value: int) -> None: + self.epoch_count = value + + @property + def active_dataloader(self) -> Any: + """Return the dataloader currently owned by the training workflow.""" + return self._active_dataloader + + @active_dataloader.setter + def active_dataloader(self, dataloader: Any) -> None: + """Set the dataloader currently owned by the training workflow.""" + self._active_dataloader = dataloader + + @model_validator(mode="before") + @classmethod + def _normalize_inputs(cls, data: Any) -> Any: + """Normalize model and optimizer input shapes before field validation.""" + if not isinstance(data, dict): + return data + normalized = dict(data) + raw_models = normalized.get("models") + single_model_input = isinstance(raw_models, BaseModelMixin) + if "models" in normalized: + normalized["models"] = strategy_validation._normalize_models(raw_models) + if "optimizer_configs" in normalized: + normalized["optimizer_configs"] = _normalize_optimizer_configs( + normalized["optimizer_configs"], single_model_input=single_model_input + ) + if "epoch" in normalized and "epoch_count" not in normalized: + normalized["epoch_count"] = normalized.pop("epoch") + normalized["single_model_input"] = single_model_input + return normalized + + @field_validator("loss_fn", mode="before") + @classmethod + def _normalize_loss_fn(cls, value: Any) -> Any: + """Normalize a leaf loss into a one-component composed loss.""" + try: + return as_composed_loss(value) + except TypeError as exc: + raise RuntimeError( + "Only loss functions that inherit `BaseLossFunction` or" + " a composition of loss functions is accepted." + ) from exc + + @field_validator("training_fn", mode="before") + @classmethod + def _resolve_training_fn(cls, value: Any) -> Any: + """Resolve a dotted-path string to a callable, or accept a callable as-is.""" + if isinstance(value, str): + value = strategy_spec._resolve_dotted_callable(value) + if value is None: + raise ValueError(strategy_validation._TRAINING_FN_REQUIRED_MESSAGE) + if not callable(value): + raise ValueError( + f"training_fn must be callable or a dotted path string, got " + f"{type(value).__name__}." + ) + return value + + @field_validator("hooks", mode="before") + @classmethod + def _autowrap_update_hooks(cls, value: Any) -> Any: + """Fold bare ``TrainingUpdateHook`` instances into a single orchestrator.""" + if isinstance(value, (str, bytes)) or not isinstance(value, Sequence): + return value + return _order_update_orchestrator_before_dependent_hooks( + _fold_training_update_hooks(value) + ) + + @model_validator(mode="after") + def _validate_strategy(self) -> TrainingStrategy: + """Enforce model, duration, optimizer, and device consistency.""" + have_epochs = self.num_epochs is not None + have_steps = self.num_steps is not None + if have_epochs == have_steps: + raise ValueError( + "Exactly one of num_epochs or num_steps must be set; " + f"got num_epochs={self.num_epochs!r}, num_steps={self.num_steps!r}." + ) + if not self.models: + raise ValueError("models must contain at least one BaseModelMixin.") + if not self.optimizer_configs: + raise ValueError( + "optimizer_configs must configure at least one model; " + "got an empty mapping." + ) + for idx, cfgs in self.optimizer_configs.items(): + if idx not in self.models: + raise ValueError( + f"optimizer_configs key {idx!r} is not present in models; " + f"available model keys: {sorted(self.models)}." + ) + if not cfgs: + raise ValueError( + f"optimizer_configs[{idx!r}] must contain at least one " + "OptimizerConfig." + ) + if not self.devices: + raise ValueError("devices must contain at least one torch.device.") + n_devices = len(self.devices) + if n_devices not in (1, len(self.models)): + raise ValueError( + f"devices must have length 1 or len(models)={len(self.models)}; " + f"got {n_devices}." + ) + if self.training_fn is None: + raise ValueError(strategy_validation._TRAINING_FN_REQUIRED_MESSAGE) + strategy_validation._validate_training_fn_call_shape( + self.training_fn, single_model_input=self.single_model_input + ) + hook_ids = [id(hook) for hook in self.hooks] + if len(hook_ids) != len(set(hook_ids)): + raise ValueError( + "hooks must not contain duplicate hook instances; pass distinct " + "hook objects instead." + ) + _validate_single_do_claimants(self.hooks) + _validate_hook_dependencies(self.hooks) + return self + + def model_post_init(self, __context: Any) -> None: + """Initialize hook storage, per-run counters, and cached target keys.""" + self._init_hooks(list(self.hooks)) + self._refresh_hook_claim_flags() + self._last_batch: Batch | None = None + self._last_losses: ComposedLossOutput | None = None + self._last_loss: torch.Tensor | None = None + self._optimizers: list[torch.optim.Optimizer] = [] + self._lr_schedulers: list[LRScheduler | None] = [] + self._context_depth = 0 + self._ctx = None + self._target_keys: tuple[str, ...] = loss_target_keys(self.loss_fn) + + def _refresh_hook_claim_flags(self) -> None: + """Recompute cached DO-stage claim and orchestrator-presence flags.""" + self._has_do_backward_claim = ( + sum( + 1 + for hook in self.hooks + if _hook_claims_stage(hook, TrainingStage.DO_BACKWARD) + ) + == 1 + ) + self._has_do_optimizer_step_claim = ( + sum( + 1 + for hook in self.hooks + if _hook_claims_stage(hook, TrainingStage.DO_OPTIMIZER_STEP) + ) + == 1 + ) + self._has_update_orchestrator = any( + isinstance(hook, TrainingUpdateOrchestrator) for hook in self.hooks + ) + + def _replace_hooks_with_registry_validation(self, hooks: Sequence[Hook]) -> None: + """Replace hook storage after validating each hook through the base registry.""" + previous_hooks = self.hooks + self.hooks = [] + try: + for hook in hooks: + HookRegistryMixin.register_hook(self, hook) + except Exception: + self.hooks = previous_hooks + raise + + def register_hook( + self, + hook: Hook | TrainingUpdateHook | TrainingUpdateOrchestrator, + stage: TrainingStage | None = None, + ) -> None: + """Register a hook, auto-wrapping bare update hooks when needed.""" + is_update = isinstance(hook, (TrainingUpdateHook, TrainingUpdateOrchestrator)) + if is_update and stage is not None: + raise ValueError( + "stage= is not supported for TrainingUpdateHook or " + "TrainingUpdateOrchestrator registration. Update hooks declare " + "their stages through _runs_on_stage and are auto-wrapped into " + "one TrainingUpdateOrchestrator." + ) + if not is_update: + _validate_single_do_claimants( + self.hooks, extra_hook=hook, extra_stage=stage + ) + previous_hooks = list(self.hooks) + try: + super().register_hook(hook, stage=stage) + _validate_hook_dependencies(self.hooks) + except Exception: + self.hooks = previous_hooks + raise + self._refresh_hook_claim_flags() + return + folded = _order_update_orchestrator_before_dependent_hooks( + _fold_training_update_hooks([*self.hooks, hook]) + ) + _validate_single_do_claimants(folded) + _validate_hook_dependencies(folded) + self._replace_hooks_with_registry_validation(folded) + self._refresh_hook_claim_flags() + + def _build_context(self, batch: Batch | None) -> TrainContext: + """Build a TrainContext, reusing the per-batch cache when populated.""" + if self._ctx is not None: + return self._ctx + global_rank = get_distributed_rank(self.distributed_manager) + return TrainContext( + batch=batch, + model=self.models.get("main"), + global_rank=global_rank, + workflow=self, + step_count=self.step_count, + batch_count=self.batch_count, + epoch_step_count=self.epoch_step_count, + models=self.models, + epoch=self.epoch_count, + loss=self._last_loss, + losses=self._last_losses, + optimizers=self._optimizers, + lr_schedulers=self._lr_schedulers, + validation=self.last_validation, + ) + + def _run_hooks(self, stage: TrainingStage, batch: Batch) -> None: + """Dispatch hooks for ``stage`` with an early-return fast path.""" + if not self.hooks: + return + self._call_hooks(stage, batch) + + def _refresh_hook_counters(self) -> None: + """Mirror current strategy counters into the cached hook context.""" + if self._ctx is None: + return + self._ctx.step_count = self.step_count + self._ctx.batch_count = self.batch_count + self._ctx.epoch_step_count = self.epoch_step_count + self._ctx.epoch = self.epoch_count + self._ctx.validation = self.last_validation + + def __enter__(self) -> TrainingStrategy: + """Enter hook context managers registered on this strategy.""" + if self._context_depth > 0: + self._context_depth += 1 + return self + for hook in self.hooks: + if hasattr(hook, "__enter__"): + hook.__enter__() + self._context_depth = 1 + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + """Exit or close hook contexts registered on this strategy.""" + if self._context_depth == 0: + return + self._context_depth -= 1 + if self._context_depth > 0: + return + for hook in reversed(self.hooks): + if hasattr(hook, "__exit__"): + hook.__exit__(exc_type, exc, tb) + elif hasattr(hook, "close"): + hook.close() + + def _prepare_setup_hooks(self) -> None: + """Allow hooks to prepare runtime state before device placement.""" + for hook in self.hooks: + prepare = getattr(hook, "prepare_strategy", None) + if callable(prepare): + prepare(self) + + def _run_setup_hooks(self, dataloader: Any = None) -> Any: + """Run setup-stage hooks and return the active dataloader.""" + if not self.hooks: + return dataloader + self.active_dataloader = dataloader + ctx = self._build_context(None) + for hook in self.hooks: + should_run_setup = _hook_claims_stage( + hook, TrainingStage.SETUP + ) or isinstance(hook, TrainingUpdateOrchestrator) + if not should_run_setup: + continue + if self.step_count % hook.frequency != 0: + continue + hook(ctx, TrainingStage.SETUP) + return self.active_dataloader + + def _validate_runtime_devices(self) -> None: + """Raise for runtime device layouts that cannot be executed.""" + if not self.single_model_input and len(self.devices) > 1: + raise ValueError( + "Named-model training with multiple devices is unsupported: " + "training_fn(models, batch) receives one batch on one device. " + "Use a single shared device or pass models=model for " + "single-model behavior." + ) + + def _setup_runtime_optimizers( + self, *, rebuild: bool = False + ) -> tuple[list[torch.optim.Optimizer], list[LRScheduler | None]]: + """Build or reuse flattened runtime optimizer/scheduler lists.""" + if not rebuild and self._runtime_optimizers: + return self._optimizers, self._lr_schedulers + + records: list[_RuntimeOptimizer] = [] + built = setup_optimizers(self.models, self.optimizer_configs) + # Iterate configs in the same key order and list order as + # setup_optimizers to bind each optimizer to its scheduler and + # metric adapter as a single _RuntimeOptimizer record. Building + # the records here (rather than three parallel lists) is the one + # place positional correspondence is established; the flat lists + # below are derived views that cannot drift from each other. + for key, cfgs in _normalize_optimizer_configs( + self.optimizer_configs, single_model_input=self.single_model_input + ).items(): + pairs = built[key] + for cfg, (opt, sched) in zip(cfgs, pairs, strict=True): + records.append( + _RuntimeOptimizer( + optimizer=opt, + scheduler=sched, + adapter=cfg.scheduler_metric_adapter, + ) + ) + self._runtime_optimizers = records + self._optimizers = [record.optimizer for record in records] + self._lr_schedulers = [record.scheduler for record in records] + return self._optimizers, self._lr_schedulers + + def train_batch(self, batch: Batch) -> None: + """Train on a single batch using the configured training flow. + + This public one-batch API is intended for interactive workflows and + tests where the caller already has a batch in hand. It runs the + per-batch stages from ``BEFORE_BATCH`` through ``AFTER_BATCH``, but it + does not run the outer ``BEFORE_TRAINING``/``AFTER_TRAINING`` or + epoch-level hooks and does not enforce ``num_epochs``/``num_steps``. + It still advances runtime counters: ``batch_count`` and + ``epoch_step_count`` advance for every completed batch, while + ``step_count`` advances only when the optimizer step executes. + + Optimizers and schedulers are built from ``optimizer_configs`` on first + use and then reused by subsequent ``train_batch`` calls. Full + :meth:`run` calls continue to rebuild optimizer state at the start of + the run. + + Parameters + ---------- + batch : Batch + Batch to train on. + """ + strategy_context = nullcontext(self) if self._context_depth > 0 else self + with strategy_context: + self._prepare_setup_hooks() + self._validate_runtime_devices() + self.models = move_to_devices(self.models, self.devices) + self._run_setup_hooks() + flat_opts, flat_scheds = self._setup_runtime_optimizers() + batch = batch.to(self.devices[0], non_blocking=True) + self._update_hook_snapshot(batch=batch, loss_out=None) + + with freeze_unconfigured_models(self.models, self.optimizer_configs): + self._train_batch_with_optimizers(batch, flat_opts, flat_scheds) + + def _train_batch_with_optimizers( + self, + batch: Batch, + flat_opts: list[torch.optim.Optimizer], + flat_scheds: list[LRScheduler | None], + ) -> None: + """Forward-backward-optimize a single batch with hook dispatch.""" + self._optimizers = flat_opts + self._lr_schedulers = flat_scheds + self._ctx = self._build_context(batch) if self.hooks else None + + try: + self._run_hooks(TrainingStage.BEFORE_BATCH, batch) + if not self._has_update_orchestrator: + zero_gradients(flat_opts) + self._run_hooks(TrainingStage.BEFORE_FORWARD, batch) + model_arg = self.models["main"] if self.single_model_input else self.models + predictions = self.training_fn(model_arg, batch) + self._run_hooks(TrainingStage.AFTER_FORWARD, batch) + + self._run_hooks(TrainingStage.BEFORE_LOSS, batch) + loss_out = self._compute_losses( + predictions, + batch, + step=self.step_count, + epoch=self.epoch_count, + ) + self._update_hook_snapshot(loss_out=loss_out) + self._run_hooks(TrainingStage.AFTER_LOSS, batch) + + self._run_hooks(TrainingStage.BEFORE_BACKWARD, batch) + if self._has_do_backward_claim: + self._run_hooks(TrainingStage.DO_BACKWARD, batch) + elif self._ctx is not None and self._ctx.loss is not None: + self._ctx.loss.backward() + else: + loss_out["total_loss"].backward() + self._run_backward_completion(batch, loss_out) + optimizer_step_ran = self._run_optimizer_step_phase( + batch, flat_opts, flat_scheds + ) + + self.batch_count += 1 + self.epoch_step_count += 1 + if optimizer_step_ran: + self.step_count += 1 + self._refresh_hook_counters() + self._run_hooks(TrainingStage.AFTER_BATCH, batch) + finally: + self._ctx = None + + def _run_backward_completion( + self, batch: Batch, loss_out: ComposedLossOutput + ) -> None: + """Publish detached losses, then fire the gradient-available stage.""" + if self.hooks: + self._update_hook_snapshot(loss_out=loss_out, detach=True) + self._run_hooks(TrainingStage.AFTER_BACKWARD, batch) + + def _run_optimizer_step_phase( + self, + batch: Batch, + flat_opts: list[torch.optim.Optimizer], + flat_scheds: list[LRScheduler | None], + ) -> bool: + """Run the last pre-step hook, step owner, and step-aware post hook.""" + self._run_hooks(TrainingStage.BEFORE_OPTIMIZER_STEP, batch) + if self._has_do_optimizer_step_claim: + self._run_hooks(TrainingStage.DO_OPTIMIZER_STEP, batch) + optimizer_step_ran = self._optimizer_step_ran_after_do_stage() + else: + step_optimizers(flat_opts) + step_lr_schedulers(flat_scheds) + optimizer_step_ran = True + self._run_hooks(TrainingStage.AFTER_OPTIMIZER_STEP, batch) + return optimizer_step_ran + + def _optimizer_step_ran_after_do_stage(self) -> bool: + """Return whether the DO optimizer-step owner reported an executed step.""" + for hook in self.hooks: + if isinstance(hook, TrainingUpdateOrchestrator): + return not hook.optimizer_step_skipped + return True + + def _compute_losses( + self, + predictions: Mapping[str, torch.Tensor], + batch: Batch, + *, + step: int, + epoch: int, + ) -> ComposedLossOutput: + """Run ``loss_fn`` with graph metadata threaded as keyword kwargs.""" + return compute_supervised_loss( + self.loss_fn, + predictions, + batch, + step=step, + epoch=epoch, + target_keys=self._target_keys, + ) + + def _update_hook_snapshot( + self, + *, + batch: Batch | None = None, + loss_out: ComposedLossOutput | None = None, + detach: bool = False, + ) -> None: + """Single mutation point for hook-visible transient state.""" + if batch is not None: + self._last_batch = batch + if loss_out is None: + self._last_loss = None + self._last_losses = None + elif detach: + self._last_loss = loss_out["total_loss"].detach() + self._last_losses = { + "total_loss": loss_out["total_loss"].detach(), + "per_component_unweighted": { + k: v.detach() + for k, v in loss_out["per_component_unweighted"].items() + }, + "per_component_weight": dict(loss_out["per_component_weight"]), + "per_component_raw_weight": dict(loss_out["per_component_raw_weight"]), + "per_component_sample": { + k: v.detach() for k, v in loss_out["per_component_sample"].items() + }, + } + else: + self._last_loss = loss_out["total_loss"] + self._last_losses = loss_out + if self._ctx is not None: + if batch is not None: + self._ctx.batch = batch + self._ctx.loss = self._last_loss + self._ctx.losses = self._last_losses + self._refresh_hook_counters() + + def _dataloader_length(self, dataloader: Iterable[Batch]) -> int | None: + """Return ``len(dataloader)`` when available without iterating it.""" + try: + return len(dataloader) # type: ignore[arg-type] + except TypeError: + return None + + def _resolve_target_step_count(self, batches_per_epoch: int | None) -> int: + """Resolve ``num_steps``/``num_epochs`` to an absolute step target.""" + if self.num_steps is not None: + return self.num_steps + + if batches_per_epoch is None: + raise ValueError( + "num_epochs requires a sized dataloader so epochs can be " + "converted to a target step count. Use num_steps for unsized " + "iterables." + ) + + if batches_per_epoch <= 0: + raise ValueError( + "dataloader must contain at least one batch when num_epochs " + "is configured." + ) + if self.num_epochs is None: + raise RuntimeError("TrainingStrategy has neither num_epochs nor num_steps.") + return math.ceil(self.num_epochs * batches_per_epoch * self.epoch_step_modifier) + + def _set_sampler_epoch(self, dataloader: Iterable[Batch]) -> None: + """Set distributed/data-parallel sampler epoch when supported.""" + candidates = ( + getattr(dataloader, "sampler", None), + getattr(getattr(dataloader, "batch_sampler", None), "sampler", None), + ) + seen: set[int] = set() + for sampler in candidates: + if sampler is None or id(sampler) in seen: + continue + seen.add(id(sampler)) + set_epoch = getattr(sampler, "set_epoch", None) + if callable(set_epoch): + set_epoch(self.epoch_count) + return + + def _prepare_epoch_step_count(self, batches_per_epoch: int | None) -> None: + """Infer or normalize intra-epoch progress for restartable runs.""" + if batches_per_epoch is None or batches_per_epoch <= 0: + return + if self.epoch_step_count >= batches_per_epoch: + extra_epochs, self.epoch_step_count = divmod( + self.epoch_step_count, batches_per_epoch + ) + self.epoch_count += extra_epochs + + completed_epoch_batches = self.epoch_count * batches_per_epoch + raw_progress = self.batch_count or self.step_count + if self.epoch_step_count: + expected_progress = completed_epoch_batches + self.epoch_step_count + if raw_progress and raw_progress != expected_progress: + raise ValueError( + "restart counters are inconsistent: batch_count or " + "step_count does not match epoch_count * len(dataloader) " + "+ epoch_step_count." + ) + self.batch_count = max(self.batch_count, expected_progress) + return + + if raw_progress < completed_epoch_batches: + raise ValueError( + "restart counters are inconsistent: batch_count or step_count " + "is smaller " + "than epoch_count * len(dataloader)." + ) + elapsed_epoch_steps = raw_progress - completed_epoch_batches + extra_epochs, self.epoch_step_count = divmod( + elapsed_epoch_steps, batches_per_epoch + ) + self.epoch_count += extra_epochs + self.batch_count = max(self.batch_count, raw_progress) + + def run( + self, + dataloader: Iterable[Batch], + ) -> None: + """Execute the training loop over ``dataloader``. + + Parameters + ---------- + dataloader : Iterable[Batch] + Any iterable of batches; need not be a ``DataLoader``. + The configured duration targets effective optimizer/scheduler + steps. Batches whose optimizer step is skipped still advance the + dataloader-position counters. + + Raises + ------ + ValueError + If named-model training is configured with multiple devices, or if + the dataloader produces no batches before the configured target + step count is reached. + """ + training_started = False + strategy_context = nullcontext(self) if self._context_depth > 0 else self + with strategy_context: + # --- Setup phase: prepare hooks, devices, dataloader, targets --- + self._prepare_setup_hooks() + self._validate_runtime_devices() + self.models = move_to_devices(self.models, self.devices) + dataloader = self._run_setup_hooks(dataloader) + batches_per_epoch = self._dataloader_length(dataloader) + target_step_count = self._resolve_target_step_count(batches_per_epoch) + if self.step_count >= target_step_count: + return + self._prepare_epoch_step_count(batches_per_epoch) + + primary_device = self.devices[0] + flat_opts, flat_scheds = self._setup_runtime_optimizers( + rebuild=not self._resume_optimizer_state + ) + + with freeze_unconfigured_models(self.models, self.optimizer_configs): + # --- Epoch loop: recycles the dataloader until target reached --- + for _epoch_idx in itertools.count(): + self._set_sampler_epoch(dataloader) + processed_epoch_batch = False + exhausted_dataloader = True + # --- Batch loop --- + for batch_idx, batch in enumerate(dataloader): + # Skip batches already consumed on a resumed epoch. + if batch_idx < self.epoch_step_count: + continue + if self.step_count >= target_step_count: + exhausted_dataloader = False + break + batch = batch.to(primary_device, non_blocking=True) + self._update_hook_snapshot(batch=batch, loss_out=None) + # BEFORE_TRAINING: fires once, on the first batch overall. + if not training_started: + self._run_hooks(TrainingStage.BEFORE_TRAINING, batch) + training_started = True + # BEFORE_EPOCH: fires at the start of each epoch. + if self.epoch_step_count == 0: + self._run_hooks(TrainingStage.BEFORE_EPOCH, batch) + + # Per-batch train: BEFORE_BATCH..AFTER_OPTIMIZER_STEP..AFTER_BATCH. + self._train_batch_with_optimizers(batch, flat_opts, flat_scheds) + # Step-cadence validation checkpoint (every_n_steps); runs + # after the completed step so EMA weights are current. + self._validation_checkpoint(TrainingStage.AFTER_OPTIMIZER_STEP) + processed_epoch_batch = True + # End the epoch once the per-epoch batch budget is hit. + if ( + batches_per_epoch is not None + and self.epoch_step_count >= batches_per_epoch + ): + exhausted_dataloader = True + break + if self.step_count >= target_step_count: + exhausted_dataloader = False + break + + if ( + not processed_epoch_batch + and self.step_count < target_step_count + ): + raise ValueError( + "dataloader produced no batches before reaching " + "the target step count; ensure the dataloader is " + "non-empty, re-iterable, and compatible with the " + "restored epoch_step_count." + ) + + # --- Epoch boundary: advance counters then fire AFTER_EPOCH --- + if exhausted_dataloader: + self.epoch_count += 1 + self.epoch_step_count = 0 + self._refresh_hook_counters() + self._run_hooks(TrainingStage.AFTER_EPOCH, self._last_batch) + # Epoch-cadence validation checkpoint (every_n_epochs). + self._validation_checkpoint(TrainingStage.AFTER_EPOCH) + if self.step_count >= target_step_count: + break + + # --- End of training: AFTER_TRAINING, then a final validation pass --- + if self._last_batch is not None: + self._update_hook_snapshot(loss_out=None) + self._run_hooks(TrainingStage.AFTER_TRAINING, self._last_batch) + # Always validate once at the end when configured (no cadence + # gate); metric-driven LR schedulers then consume the summary. + if self.validation_config is not None: + self.validate() + self._step_metric_schedulers() + + def to_spec_dict(self) -> dict[str, Any]: + """Serialize declarative training knobs to a JSON-ready dict. + + Returns + ------- + dict[str, Any] + JSON-ready bundle suitable for :func:`json.dumps`. + """ + component_specs = [ + loss_component_to_spec(comp) for comp in self.loss_fn.components + ] + loss_fn_spec = create_model_spec( + type(self.loss_fn), + components=component_specs, + weights=[_loss_weight_to_spec(weight) for weight in self.loss_fn._weights], + normalize_weights=self.loss_fn.normalize_weights, + ) + spec = { + "optimizer_configs": { + key: [cfg.to_spec().model_dump() for cfg in cfgs] + for key, cfgs in self.optimizer_configs.items() + }, + "num_epochs": self.num_epochs, + "num_steps": self.num_steps, + "epoch_step_modifier": self.epoch_step_modifier, + "devices": [str(device) for device in self.devices], + "loss_fn_spec": loss_fn_spec.model_dump(), + "model_specs": strategy_spec._model_specs_from_models(self.models), + "single_model_input": self.single_model_input, + } + try: + spec["training_fn"] = strategy_spec._callable_dotted_path(self.training_fn) + except ValueError as exc: + warnings.warn( + f"Omitting non-importable training_fn from spec: {exc}", + UserWarning, + stacklevel=2, + ) + return spec + + def to_checkpoint_dict(self) -> dict[str, Any]: + """Serialize strategy recipe and restart counters for checkpoints. + + Returns + ------- + dict[str, Any] + JSON-ready checkpoint metadata. Model weights and optimizer state + remain outside this payload in checkpoint ``state_dict`` files. + """ + runtime_state = {key: getattr(self, key) for key in _RESTART_COUNTER_FIELDS} + return { + **self.to_spec_dict(), + "strategy_cls": f"{type(self).__module__}.{type(self).__qualname__}", + "runtime_state": runtime_state, + } + + def save_checkpoint( + self, + root_folder: Path | str, + *, + checkpoint_index: int = -1, + ) -> int: + """Save this strategy as a restartable checkpoint. + + Parameters + ---------- + root_folder : Path | str + Root directory for checkpoint files. + checkpoint_index : int, optional + Checkpoint index to write. ``-1`` auto-increments from the latest + manifest index, or starts at ``0`` when no manifest exists. + + Returns + ------- + int + The checkpoint index that was written. + """ + from nvalchemi.training._checkpoint import save_checkpoint + + return save_checkpoint( + root_folder, + checkpoint_index=checkpoint_index, + strategy=self, + ) + + def restore_checkpoint( + self, + root_folder: Path | str, + checkpoint_index: int = -1, + map_location: str | torch.device | None = None, + *, + validators: Sequence[CheckpointValidator] | None = None, + ) -> Mapping[str, Any]: + """Restore checkpoint state into this already-constructed strategy. + + Parameters + ---------- + root_folder : Path | str + Root directory containing checkpoint files. + checkpoint_index : int, optional + Checkpoint index to load. ``-1`` loads the latest manifest index. + map_location : str | torch.device | None, optional + Device override passed through to :func:`torch.load`. + validators : Sequence[CheckpointValidator] | None, optional + Optional loaded-checkpoint validators forwarded to the lower-level + loader. + + Returns + ------- + Mapping[str, Any] + Loaded checkpoint payload from :func:`nvalchemi.training.load_checkpoint`. + """ + from nvalchemi.training._checkpoint import load_checkpoint + + loaded = load_checkpoint( + root_folder, + checkpoint_index=checkpoint_index, + map_location=map_location, + validators=validators, + strategy=self, + ) + if not isinstance(loaded, Mapping) or loaded.get("strategy") is not self: + raise ValueError( + "TrainingStrategy.restore_checkpoint could not restore into " + "this strategy." + ) + return loaded + + @classmethod + def load_checkpoint( + cls, + root_folder: Path | str, + checkpoint_index: int = -1, + map_location: str | torch.device | None = None, + *, + hooks: Sequence[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator] + | None = None, + training_fn: Callable[..., Mapping[str, torch.Tensor]] | str | None = None, + validators: Sequence[CheckpointValidator] | None = None, + ) -> TrainingStrategy: + """Load a restartable strategy checkpoint. + + This is the strategy-focused convenience wrapper around + :func:`nvalchemi.training.load_checkpoint`. Use the module-level + function when callers need the full manifest, component dictionaries, + partial component loads, or foreign checkpoint adapters. + + Parameters + ---------- + root_folder : Path | str + Root directory containing checkpoint files. + checkpoint_index : int, optional + Checkpoint index to load. ``-1`` loads the latest manifest index. + map_location : str | torch.device | None, optional + Device override passed through to :func:`torch.load` and the + restored strategy metadata. + hooks : Sequence[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator] | None, optional + Runtime hooks to attach to the restored strategy. + training_fn : Callable[..., Mapping[str, torch.Tensor]] | str | None, optional + Runtime training function override. This is required when the saved + strategy used a local or otherwise non-importable training + function. + validators : Sequence[CheckpointValidator] | None, optional + Optional loaded-checkpoint validators forwarded to the lower-level + loader. + + Returns + ------- + TrainingStrategy + Restored strategy with model, optimizer, scheduler, and runtime + counters loaded. + + Raises + ------ + ValueError + If the checkpoint does not contain restartable strategy metadata. + TypeError + If the restored strategy is not an instance of ``cls``. + """ + from nvalchemi.training._checkpoint import load_checkpoint + + loaded = load_checkpoint( + root_folder, + checkpoint_index=checkpoint_index, + map_location=map_location, + hooks=hooks, + training_fn=training_fn, + validators=validators, + ) + if not isinstance(loaded, Mapping) or loaded.get("strategy") is None: + raise ValueError( + "TrainingStrategy.load_checkpoint requires a checkpoint saved " + "from a TrainingStrategy. Use nvalchemi.training.load_checkpoint " + "for component-only checkpoints." + ) + strategy = loaded["strategy"] + if not isinstance(strategy, cls): + raise TypeError( + f"Loaded strategy has type {type(strategy).__name__}, expected " + f"{cls.__name__}." + ) + return strategy + + @classmethod + def from_spec_dict( + cls, + spec: Mapping[str, Any], + *, + models: strategy_validation.ModelInput | None = None, + hooks: Sequence[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator] + | None = None, + training_fn: Callable[..., Mapping[str, torch.Tensor]] | str | None = None, + ) -> TrainingStrategy: + """Rebuild a :class:`TrainingStrategy` from a :meth:`to_spec_dict` bundle. + + Parameters + ---------- + spec : Mapping[str, Any] + A dict produced by :meth:`to_spec_dict`, optionally after a JSON round-trip. + models : BaseModelMixin | dict[str, BaseModelMixin] | torch.nn.ModuleDict | None, optional + Runtime model override(s). + hooks : Sequence[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator] | None, optional + Runtime hooks; defaults to an empty list. Bare update hooks are + auto-wrapped into a single orchestrator. + training_fn : Callable[..., Mapping[str, torch.Tensor]] | str | None, optional + Runtime callable or dotted-path override. + + Returns + ------- + TrainingStrategy + A freshly validated strategy ready to :meth:`run`. + """ + required = ("optimizer_configs", "devices", "loss_fn_spec") + missing = [k for k in required if k not in spec] + if missing: + raise ValueError( + f"from_spec_dict: spec is missing required key(s) {missing}. " + f"Expected keys: {list(required)}." + ) + model_input = strategy_spec._models_from_spec_and_overrides( + spec.get("model_specs", {}), + models, + single_model_input=strategy_spec._single_model_input_from_spec( + spec.get("single_model_input") + ), + ) + return cls( + models=model_input, + optimizer_configs=strategy_spec._optimizer_configs_from_spec( + spec["optimizer_configs"] + ), + num_epochs=spec.get("num_epochs"), + num_steps=spec.get("num_steps"), + epoch_step_modifier=spec.get("epoch_step_modifier", 1.0), + hooks=list(hooks) if hooks is not None else [], + training_fn=strategy_spec._training_fn_from_spec(spec, training_fn), + loss_fn=strategy_spec._loss_fn_from_spec(spec["loss_fn_spec"]), + devices=strategy_spec._devices_from_spec(spec["devices"]), + ) + + @classmethod + def from_checkpoint_dict( + cls, + spec: Mapping[str, Any], + *, + models: strategy_validation.ModelInput | None = None, + hooks: Sequence[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator] + | None = None, + training_fn: Callable[..., Mapping[str, torch.Tensor]] | str | None = None, + ) -> TrainingStrategy: + """Rebuild a strategy from checkpoint metadata. + + Parameters + ---------- + spec : Mapping[str, Any] + A dict produced by :meth:`to_checkpoint_dict`. + models : BaseModelMixin | dict[str, BaseModelMixin] | torch.nn.ModuleDict | None, optional + Runtime model override(s), normally the models loaded from the + checkpoint weight files. + hooks : Sequence[Hook | TrainingUpdateHook | TrainingUpdateOrchestrator] | None, optional + Runtime hooks appended by the caller. + training_fn : Callable[..., Mapping[str, torch.Tensor]] | str | None, optional + Runtime callable or dotted-path override. + + Returns + ------- + TrainingStrategy + A strategy with declarative fields and restart counters restored. + """ + strategy_cls = cls + raw_strategy_cls = spec.get("strategy_cls") + if raw_strategy_cls is not None: + if not isinstance(raw_strategy_cls, str): + raise ValueError( + "from_checkpoint_dict: 'strategy_cls' must be a dotted " + f"class path string; got {type(raw_strategy_cls).__name__}." + ) + imported = _import_cls(raw_strategy_cls) + if not issubclass(imported, cls): + raise ValueError( + f"from_checkpoint_dict: {raw_strategy_cls!r} must resolve " + f"to a {cls.__name__} subclass." + ) + strategy_cls = imported + + strategy = strategy_cls.from_spec_dict( + spec, + models=models, + hooks=hooks, + training_fn=training_fn, + ) + runtime_state = spec.get("runtime_state", {}) + if runtime_state is None: + runtime_state = {} + if not isinstance(runtime_state, Mapping): + raise ValueError( + "from_checkpoint_dict: 'runtime_state' must be a mapping when " + f"present; got {type(runtime_state).__name__}." + ) + for key in _RESTART_COUNTER_FIELDS: + if key in runtime_state: + value = int(runtime_state[key]) + if value < 0: + raise ValueError( + "from_checkpoint_dict: runtime counter " + f"{key!r} must be non-negative; got {value}." + ) + setattr(strategy, key, value) + return strategy + + def _inference_autocast( + self, device: torch.device + ) -> tuple[Callable[[], AbstractContextManager[None]], str]: + """Return validation autocast context factory and precision label. + + Scans registered hooks for a :class:`MixedPrecisionHook` and + returns an autocast context factory and a precision label string. + + Parameters + ---------- + device : torch.device + Primary workflow device for the validation pass. + + Returns + ------- + tuple[Callable[[], AbstractContextManager[None]], str] + A ``(context_factory, precision_label)`` pair. The factory is + called once per validation pass to enter/exit the autocast + region. + + Raises + ------ + RuntimeError + When ``use_mixed_precision='always'`` but no + :class:`MixedPrecisionHook` is registered. + """ + use_mixed_precision = ( + self.validation_config.use_mixed_precision + if self.validation_config is not None + else "auto" + ) + if use_mixed_precision == "never": + return nullcontext, "float32" + for hook in _iter_registered_hooks(self.hooks): + if isinstance(hook, MixedPrecisionHook): + precision = str(hook.precision).removeprefix("torch.") + return lambda: hook.inference_autocast(device), precision + if use_mixed_precision == "always": + raise RuntimeError( + "ValidationConfig use_mixed_precision='always' requires a " + "registered MixedPrecisionHook." + ) + return nullcontext, "float32" + + # ------------------------------------------------------------------ + # Inference-model write interface (Phase C) + # ------------------------------------------------------------------ + + def set_inference_model( + self, module: nn.Module, *, model_key: str | None = None + ) -> None: + """Publish a module into the strategy's inference-model slot. + + EMA hooks (and future SWA/distillation hooks) call this after + updating their averaged weights so that + :meth:`validate` reads current inference weights. + + Parameters + ---------- + module : nn.Module + The averaged / inference-ready module to publish. + model_key : str | None + Identifies the target model in named-model strategies. + Ignored for single-model strategies, which always store + a bare :class:`nn.Module`. + + Notes + ----- + The published module is moved to the strategy's primary device before + it is stored so validation can safely pair it with batches moved to + the same device. + For single-model strategies (``single_model_input=True``), + ``model_key`` is ignored and the slot stores a bare + :class:`nn.Module`. For named-model strategies with a + ``model_key``, the slot is promoted to an + :class:`nn.ModuleDict` so that multiple hooks can each + write their own key. + """ + module.to(self.devices[0], non_blocking=True) + if model_key is None or self.single_model_input: + self.inference_model = module + return + if not isinstance(self.inference_model, nn.ModuleDict): + self.inference_model = nn.ModuleDict() + self.inference_model[model_key] = module + + # ------------------------------------------------------------------ + # Validation schedule predicates (Phase C) + # ------------------------------------------------------------------ + + def _should_validate(self, stage: TrainingStage) -> bool: + """Return whether a schedule-triggered validation should fire now. + + Parameters + ---------- + stage : TrainingStage + The lifecycle stage being evaluated. + + Returns + ------- + bool + ``True`` when the current counters match the configured + ``every_n_steps`` or ``every_n_epochs`` cadence. + """ + if self.validation_config is None: + return False + cfg = self.validation_config + if cfg.every_n_steps is not None: + # Vetoed optimizer steps (accumulation, spike skipping) leave + # step_count parked on a multiple; fire only when the step ran. + return ( + stage is TrainingStage.AFTER_OPTIMIZER_STEP + and self.step_count > 0 + and self.step_count % cfg.every_n_steps == 0 + and self._optimizer_step_ran_after_do_stage() + ) + if cfg.every_n_epochs is not None: + return ( + stage is TrainingStage.AFTER_EPOCH + and self.epoch_count % cfg.every_n_epochs == 0 + ) + return False + + def _validation_checkpoint(self, stage: TrainingStage) -> bool: + """Run validation if scheduled and return whether it fired. + + Centralizes the validation-trigger logic for both step and + epoch cadences. After a successful validation pass, any + metric-driven LR schedulers are stepped with the fresh + validation summary and the gate is consumed. + + Parameters + ---------- + stage : TrainingStage + The lifecycle stage that triggered this checkpoint. + + Returns + ------- + bool + ``True`` if a validation pass ran at this checkpoint, + ``False`` otherwise. + """ + if self.validation_config is None: + return False + if not self._should_validate(stage): + return False + self.validate() + self._step_metric_schedulers() + return True + + def _step_metric_schedulers(self) -> None: + """Step metric-driven schedulers with the last validation summary. + + Consumes :attr:`last_validation` after stepping so that + subsequent non-validation iterations do not re-step the + metric-driven schedulers. This implements the + ``last_validation`` gate/consume pattern: the field is set by + :meth:`validate` and cleared here after metric schedulers + have consumed the summary. The gate is only consumed when at + least one metric-driven scheduler is present; time-based-only + workflows preserve the summary for downstream consumers. + """ + if self.last_validation is None: + return + from nvalchemi.training.optimizers import _is_metric_driven + + has_metric = any( + _is_metric_driven(record.scheduler) for record in self._runtime_optimizers + ) + if not has_metric: + return + step_metric_schedulers( + [record.scheduler for record in self._runtime_optimizers], + [record.adapter for record in self._runtime_optimizers], + self.last_validation, + ) + self.last_validation = None + + # ------------------------------------------------------------------ + # Validation execution (Phase B) + # ------------------------------------------------------------------ + + def validate(self) -> dict[str, Any] | None: + """Run a validation pass using the strategy's :attr:`validation_config`. + + Delegates to :class:`~nvalchemi.training._validation.ValidationLoop` + to evaluate the model on the configured validation data and loss + function. Uses the strategy's own counters (``step_count``, + ``epoch_count``) for loss-schedule evaluation and sink metadata. + + Returns + ------- + dict[str, Any] | None + The validation summary dictionary on rank 0, or ``None`` on + non-publishing ranks. The summary is also stored on + :attr:`last_validation`. + + Raises + ------ + RuntimeError + When ``validation_config`` is ``None`` or when required hooks + (e.g. :class:`MixedPrecisionHook`) are missing. + """ + if self.validation_config is None: + raise RuntimeError( + "TrainingStrategy.validate() requires a validation_config." + ) + with _validation.ValidationLoop.from_training_strategy(self) as loop: + self.last_validation = loop.execute() + # Fire AFTER_VALIDATION while the summary is still live, before any + # metric-driven LR schedulers consume (and clear) last_validation. + if self._last_batch is not None: + self._refresh_hook_counters() + self._run_hooks(TrainingStage.AFTER_VALIDATION, self._last_batch) + return self.last_validation diff --git a/pyproject.toml b/pyproject.toml index 415f20c9..0c110dce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,8 @@ dependencies = [ "zarr>=3", "periodictable==2.0.2", "rich>=13.0.0", - "nvidia-physicsnemo>=2.0.0", + "plotext", + "nvidia-physicsnemo>=2.1.0", ] keywords = [ "machine learning", @@ -65,16 +66,18 @@ ase = [ ] cu12 = [ "nvalchemi-toolkit-ops[torch-cu12]>=0.3.1; sys_platform != 'darwin'", - "nvidia-physicsnemo[cu12]>=2.0.0; sys_platform != 'darwin'", + "nvidia-physicsnemo[cu12]>=2.1.0; sys_platform != 'darwin'", "cuml-cu12>=25.6.0; sys_platform != 'darwin'", "torch; sys_platform != 'darwin'", + "torchvision; sys_platform != 'darwin'", "cuequivariance-ops-torch-cu12>=0.8.0; sys_platform != 'darwin'", ] cu13 = [ "nvalchemi-toolkit-ops[torch-cu13]>=0.3.1; sys_platform != 'darwin'", - "nvidia-physicsnemo[cu13]>=2.0.0; sys_platform != 'darwin'", + "nvidia-physicsnemo[cu13]>=2.1.0; sys_platform != 'darwin'", "cuml-cu13>=25.6.0; sys_platform != 'darwin'", "torch; sys_platform != 'darwin'", + "torchvision; sys_platform != 'darwin'", "cuequivariance-ops-torch-cu13>=0.8.0; sys_platform != 'darwin'", ] pymatgen = [ @@ -84,6 +87,9 @@ mace = [ "cuequivariance-torch>=0.8.0", "mace-torch==0.3.15", ] +tensorboard = [ + "tensorboard", +] [tool.hatch.build.targets.wheel] packages = ["nvalchemi"] @@ -119,6 +125,10 @@ torch = [ { index = "pytorch-cu126", extra = "cu12", marker = "sys_platform != 'darwin'" }, { index = "pytorch-cu130", extra = "cu13", marker = "sys_platform != 'darwin'" }, ] +torchvision = [ + { index = "pytorch-cu126", extra = "cu12", marker = "sys_platform != 'darwin'" }, + { index = "pytorch-cu130", extra = "cu13", marker = "sys_platform != 'darwin'" }, +] # these are intended to be developer facing [dependency-groups] diff --git a/test/conftest.py b/test/conftest.py index 89489087..d3647882 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -32,3 +32,9 @@ def gpu_device(request) -> str: if not torch.cuda.is_available(): pytest.skip("No CUDA device available for GPU test.") return request.param + + +@pytest.fixture +def fixed_torch_seed() -> None: + """Set a fixed PyTorch RNG seed for tests that compare random tensors.""" + torch.manual_seed(0) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index 8b75ed04..34b62e19 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -1167,3 +1167,153 @@ def capture_td_irecv( assert send_tags_after_meta == recv_tags, ( f"Tag mismatch: send (after meta)={send_tags_after_meta}, recv={recv_tags}" ) + + +class TestBatchFromRawDicts: + """Tests for Batch.from_raw_dicts (validation-free batch construction).""" + + def test_matches_from_data_list(self): + """from_raw_dicts produces identical tensors to from_data_list.""" + data_list = [_atomic_data_with_edges_and_system(3, 4) for _ in range(5)] + ref = Batch.from_data_list(data_list, skip_validation=True) + + raw_dicts = [ + { + "positions": d.positions, + "atomic_numbers": d.atomic_numbers, + "neighbor_list": d.neighbor_list, + "energy": d.energy, + } + for d in data_list + ] + result = Batch.from_raw_dicts(raw_dicts) + + assert result.num_graphs == ref.num_graphs + assert result.num_nodes == ref.num_nodes + assert result.num_edges == ref.num_edges + torch.testing.assert_close(result.positions, ref.positions) + torch.testing.assert_close(result.atomic_numbers, ref.atomic_numbers) + torch.testing.assert_close(result.neighbor_list, ref.neighbor_list) + torch.testing.assert_close(result.energy, ref.energy) + + def test_empty_raises(self): + with pytest.raises(ValueError, match="empty data list"): + Batch.from_raw_dicts([]) + + def test_node_offset_applied_to_neighbor_list(self): + """neighbor_list indices are offset by cumulative node count.""" + d0 = { + "atomic_numbers": torch.tensor([1, 2]), + "neighbor_list": torch.tensor([[0, 1]]), + } + d1 = { + "atomic_numbers": torch.tensor([3]), + "neighbor_list": torch.tensor([[0, 0]]), + } + batch = Batch.from_raw_dicts([d0, d1]) + # d1's neighbor_list should be offset by 2 (num_nodes in d0) + assert batch.neighbor_list[-1, 0].item() == 2 + assert batch.neighbor_list[-1, 1].item() == 2 + + def test_keys_tracking(self): + """Batch.keys correctly reports node/edge/system sets.""" + raw = [ + { + "positions": torch.randn(2, 3), + "atomic_numbers": torch.tensor([1, 2]), + "energy": torch.tensor([[0.5]]), + "neighbor_list": torch.zeros(1, 2, dtype=torch.long), + } + ] + batch = Batch.from_raw_dicts(raw) + assert "positions" in batch.keys["node"] + assert "neighbor_list" in batch.keys["edge"] + assert "energy" in batch.keys["system"] + + def test_segment_lengths(self): + """Per-graph node/edge counts are correct.""" + d0 = { + "atomic_numbers": torch.tensor([1, 2, 3]), + "positions": torch.randn(3, 3), + "neighbor_list": torch.zeros(2, 2, dtype=torch.long), + } + d1 = { + "atomic_numbers": torch.tensor([4]), + "positions": torch.randn(1, 3), + "neighbor_list": torch.zeros(5, 2, dtype=torch.long), + } + batch = Batch.from_raw_dicts([d0, d1]) + assert batch.num_nodes_list == [3, 1] + assert batch.num_edges_list == [2, 5] + + def test_custom_key_preserved_as_system(self) -> None: + """Keys not in _default_*_keys are preserved as system-level.""" + d0 = { + "atomic_numbers": torch.tensor([1, 2]), + "positions": torch.randn(2, 3), + "my_custom_scalar": torch.tensor([42.0]), + } + d1 = { + "atomic_numbers": torch.tensor([3]), + "positions": torch.randn(1, 3), + "my_custom_scalar": torch.tensor([99.0]), + } + batch = Batch.from_raw_dicts([d0, d1]) + assert "my_custom_scalar" in batch.keys["system"] + assert batch.my_custom_scalar.shape == (2,) + assert batch.my_custom_scalar[0].item() == 42.0 + assert batch.my_custom_scalar[1].item() == 99.0 + + def test_field_levels_classifies_custom_atom_key(self) -> None: + """field_levels routes custom per-atom tensors to atom level.""" + d0 = { + "atomic_numbers": torch.tensor([1, 2, 3]), + "positions": torch.randn(3, 3), + "partial_charges": torch.tensor([0.1, 0.2, 0.3]), + } + d1 = { + "atomic_numbers": torch.tensor([4, 5]), + "positions": torch.randn(2, 3), + "partial_charges": torch.tensor([0.4, 0.5]), + } + batch = Batch.from_raw_dicts([d0, d1], field_levels={"partial_charges": "atom"}) + assert "partial_charges" in batch.keys["node"] + assert batch.partial_charges.shape == (5,) + torch.testing.assert_close( + batch.partial_charges, + torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + ) + + def test_field_levels_classifies_custom_edge_key(self) -> None: + """field_levels routes custom per-edge tensors to edge level.""" + d0 = { + "atomic_numbers": torch.tensor([1, 2]), + "positions": torch.randn(2, 3), + "neighbor_list": torch.tensor([[0, 1], [1, 0]]), + "edge_weights": torch.tensor([1.0, 2.0]), + } + d1 = { + "atomic_numbers": torch.tensor([3]), + "positions": torch.randn(1, 3), + "neighbor_list": torch.tensor([[0, 0]]), + "edge_weights": torch.tensor([3.0]), + } + batch = Batch.from_raw_dicts([d0, d1], field_levels={"edge_weights": "edge"}) + assert "edge_weights" in batch.keys["edge"] + assert batch.edge_weights.shape == (3,) + + def test_field_levels_fallback_still_system(self) -> None: + """Keys absent from both default sets and field_levels fall back to system.""" + d0 = { + "atomic_numbers": torch.tensor([1, 2]), + "positions": torch.randn(2, 3), + "unknown_scalar": torch.tensor([1.0]), + } + d1 = { + "atomic_numbers": torch.tensor([3]), + "positions": torch.randn(1, 3), + "unknown_scalar": torch.tensor([2.0]), + } + # field_levels is provided but doesn't mention unknown_scalar + batch = Batch.from_raw_dicts([d0, d1], field_levels={"some_other_key": "atom"}) + assert "unknown_scalar" in batch.keys["system"] diff --git a/test/data/test_io_test.py b/test/data/test_io_test.py new file mode 100644 index 00000000..aae33b24 --- /dev/null +++ b/test/data/test_io_test.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the nvalchemi I/O benchmark CLI helpers.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from nvalchemi.data.io_test import ( + _build_read_indices, + _expand_read_modes, + _make_atomic_data, + _run_benchmark, + _run_read_benchmark, +) + + +def test_expand_read_modes_defaults_to_batch() -> None: + """No explicit read mode uses the batch readback fast path.""" + assert _expand_read_modes(()) == ("batch",) + + +def test_expand_read_modes_supports_both() -> None: + """The convenience mode expands into batch and single readback paths.""" + assert _expand_read_modes(("both",)) == ("batch", "single") + + +def test_build_read_indices_supports_sequential_order() -> None: + """Sequential read order preserves logical storage order.""" + assert _build_read_indices(5, "sequential", seed=123, read_order_block_size=2) == [ + 0, + 1, + 2, + 3, + 4, + ] + + +def test_build_read_indices_supports_full_shuffle() -> None: + """Shuffle read order randomizes individual sample indices.""" + indices = _build_read_indices(8, "shuffle", seed=123, read_order_block_size=4) + + assert sorted(indices) == list(range(8)) + assert indices != list(range(8)) + + +def test_build_read_indices_supports_block_shuffle() -> None: + """Block shuffle preserves locality inside shuffled contiguous blocks.""" + indices = _build_read_indices(8, "block-shuffle", seed=123, read_order_block_size=2) + blocks = [indices[start : start + 2] for start in range(0, len(indices), 2)] + + assert sorted(indices) == list(range(8)) + assert indices != list(range(8)) + assert all(block[1] == block[0] + 1 for block in blocks) + + +def test_make_atomic_data_generates_edge_rows() -> None: + """Generated edge tensors use edge-major row layout.""" + data = _make_atomic_data(num_atoms=4, num_edges=7) + + assert data.neighbor_list.shape == (7, 2) + assert data.shifts.shape == (7, 3) + + +def test_run_benchmark_profiles_readback(tmp_path: Path) -> None: + """Benchmark results include a timed full-store readback.""" + results = _run_benchmark( + num_systems_list=[2], + min_atoms=3, + max_atoms=4, + seed=42, + config=None, + store_dir=tmp_path, + ) + + result = results[0] + assert result["read_mode"] == "batch" + assert result["read_order"] == "sequential" + assert result["batch_size"] == 64 + assert result["prefetch_factor"] == 16 + assert result["effective_read_window"] == 1024 + assert result["read_bytes"] >= result["raw_bytes"] + assert result["read_time"] >= 0 + assert result["profile_time"] == pytest.approx( + result["write_time"] + result["read_time"] + ) + assert result["read_throughput"] >= 0 + assert result["profile_throughput"] >= 0 + + +def test_run_benchmark_can_compare_batch_and_single_readback(tmp_path: Path) -> None: + """Benchmark can report batch and single-sample readback rows.""" + results = _run_benchmark( + num_systems_list=[2], + min_atoms=3, + max_atoms=4, + seed=42, + config=None, + store_dir=tmp_path, + read_modes=("batch", "single"), + batch_size=2, + prefetch_factor=3, + read_order="shuffle", + read_seed=123, + ) + + assert [result["read_mode"] for result in results] == ["batch", "single"] + assert [result["read_order"] for result in results] == ["shuffle", "shuffle"] + assert [result["batch_size"] for result in results] == [2, 1] + assert [result["prefetch_factor"] for result in results] == [3, 0] + assert [result["effective_read_window"] for result in results] == [6, 1] + assert {result["num_systems"] for result in results} == {2} + + +def test_run_benchmark_records_block_shuffle_settings(tmp_path: Path) -> None: + """Benchmark rows record block-shuffle readback settings.""" + results = _run_benchmark( + num_systems_list=[4], + min_atoms=3, + max_atoms=4, + seed=42, + config=None, + store_dir=tmp_path, + read_order="block-shuffle", + read_seed=123, + read_order_block_size=2, + batch_size=2, + prefetch_factor=2, + ) + + result = results[0] + assert result["read_order"] == "block-shuffle" + assert result["read_order_block_size"] == 2 + + +@pytest.fixture() +def small_zarr_store(tmp_path: Path) -> Path: + """Write a 4-system Zarr store for read-only benchmarking.""" + from nvalchemi.data.datapipes.backends.zarr import AtomicDataZarrWriter + + store_path = tmp_path / "small.zarr" + data_list = [_make_atomic_data(num_atoms=5, num_edges=8) for _ in range(4)] + writer = AtomicDataZarrWriter(store_path) + writer.write(data_list) + return store_path + + +def test_run_read_benchmark_reads_existing_store(small_zarr_store: Path) -> None: + """Read benchmark discovers sample count and reports read throughput.""" + results = _run_read_benchmark(store_path=small_zarr_store) + + assert len(results) == 1 + result = results[0] + assert result["num_systems"] == 4 + assert result["read_mode"] == "batch" + assert result["read_order"] == "sequential" + assert result["read_time"] >= 0 + assert result["read_bytes"] > 0 + assert result["read_throughput"] >= 0 + assert result["store_path"] == str(small_zarr_store) + + +def test_run_read_benchmark_supports_shuffle(small_zarr_store: Path) -> None: + """Read benchmark works with shuffled access order.""" + results = _run_read_benchmark( + store_path=small_zarr_store, + read_order="shuffle", + read_seed=42, + ) + + result = results[0] + assert result["read_order"] == "shuffle" + assert result["read_order_block_size"] is None + assert result["read_bytes"] > 0 + + +def test_run_read_benchmark_compares_batch_and_single(small_zarr_store: Path) -> None: + """Read benchmark can report both batch and single readback modes.""" + results = _run_read_benchmark( + store_path=small_zarr_store, + read_modes=("batch", "single"), + batch_size=2, + prefetch_factor=3, + ) + + assert [r["read_mode"] for r in results] == ["batch", "single"] + assert [r["batch_size"] for r in results] == [2, 1] + assert [r["prefetch_factor"] for r in results] == [3, 0] + assert [r["effective_read_window"] for r in results] == [6, 1] + assert all(r["num_systems"] == 4 for r in results) + assert all(r["read_bytes"] > 0 for r in results) diff --git a/test/data/test_multidataset_samplers.py b/test/data/test_multidataset_samplers.py new file mode 100644 index 00000000..d4d7bd2d --- /dev/null +++ b/test/data/test_multidataset_samplers.py @@ -0,0 +1,346 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for multidataset samplers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import pytest +import torch +from torch.utils.data import DistributedSampler + +from nvalchemi.data.atomic_data import AtomicData +from nvalchemi.data.datapipes import ( + DataLoader, + Dataset, + DistributedSamplerProtocol, + MultiDataset, + MultiDatasetBatchSampler, + MultiDatasetSampler, +) + + +def _make_ordered_atomic_data(label: int) -> AtomicData: + """Create one-atom AtomicData with an order-identifying atomic number.""" + return AtomicData( + atomic_numbers=torch.tensor([label], dtype=torch.long), + positions=torch.tensor([[float(label), 0.0, 0.0]]), + cell=torch.eye(3).unsqueeze(0), + pbc=torch.tensor([[True, True, True]]), + ) + + +class _OrderedReadManyReader: + """Minimal reader that records read_many calls for DataLoader tests.""" + + def __init__(self, n: int = 5) -> None: + self._n = n + self.pin_memory = False + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + return _make_ordered_atomic_data(index + 1).to_dict() + + @property + def field_names(self) -> list[str]: + return list(self._load_sample(0)) if self._n > 0 else [] + + def read_many( + self, indices: Sequence[int] + ) -> list[tuple[dict[str, torch.Tensor], dict[str, int]]]: + return [(self._load_sample(index), {"src_index": index}) for index in indices] + + def __len__(self) -> int: + return self._n + + def close(self) -> None: + """Release reader resources.""" + + +class _FakeDistributedManager: + """Structural distributed manager for sampler tests.""" + + def __init__(self, *, world_size: int, rank: int) -> None: + self.world_size = world_size + self.rank = rank + self.initialized = True + + def is_initialized(self) -> bool: + """Return whether the manager is initialized.""" + return self.initialized + + +def test_torch_distributed_sampler_satisfies_protocol() -> None: + """Verify native PyTorch distributed samplers satisfy the shared protocol.""" + sampler = DistributedSampler(range(4), num_replicas=2, rank=0) + + assert isinstance(sampler, DistributedSamplerProtocol) + + +def test_multidataset_sampler_shards_across_distributed_ranks() -> None: + """Verify regular multi-dataset sampling emits a rank-local shard.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=3), device="cpu"), + Dataset(_OrderedReadManyReader(n=3), device="cpu"), + ) + rank0 = MultiDatasetSampler( + dataset, + num_replicas=2, + rank=0, + replacement=False, + shuffle=False, + ) + rank1 = MultiDatasetSampler( + dataset, + num_replicas=2, + rank=1, + replacement=False, + shuffle=False, + ) + + assert isinstance(rank0, DistributedSamplerProtocol) + assert len(rank0) == 3 + assert len(rank1) == 3 + assert list(rank0) == [0, 2, 4] + assert list(rank1) == [1, 3, 5] + + +def test_multidataset_sampler_infers_rank_from_distributed_manager() -> None: + """Verify distributed manager metadata configures sampler sharding.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=3), device="cpu"), + Dataset(_OrderedReadManyReader(n=3), device="cpu"), + ) + manager = _FakeDistributedManager(world_size=2, rank=1) + sampler = MultiDatasetSampler( + dataset, + distributed_manager=manager, + replacement=False, + shuffle=False, + ) + + assert sampler.num_replicas == 2 + assert sampler.rank == 1 + assert list(sampler) == [1, 3, 5] + + +def test_multidataset_sampler_set_epoch_changes_owned_shuffle() -> None: + """Verify set_epoch changes deterministic shuffling when no generator is passed.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=8), device="cpu"), + Dataset(_OrderedReadManyReader(n=8), device="cpu"), + ) + sampler = MultiDatasetSampler( + dataset, + num_samples=12, + replacement=False, + shuffle=True, + seed=17, + ) + + epoch0 = list(sampler) + assert epoch0 == list(sampler) + + sampler.set_epoch(1) + + assert list(sampler) != epoch0 + + +def test_multidataset_batch_sampler_shards_batches_across_distributed_ranks() -> None: + """Verify multi-dataset batch sampling shards whole batches by rank.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=6), device="cpu"), + Dataset(_OrderedReadManyReader(n=6), device="cpu"), + ) + rank0 = MultiDatasetBatchSampler( + dataset, + batch_size=4, + samples_per_dataset=[2, 2], + num_batches=3, + num_replicas=2, + rank=0, + replacement=False, + shuffle=False, + ) + rank1 = MultiDatasetBatchSampler( + dataset, + batch_size=4, + samples_per_dataset=[2, 2], + num_batches=3, + num_replicas=2, + rank=1, + replacement=False, + shuffle=False, + ) + + assert isinstance(rank0, DistributedSamplerProtocol) + assert len(rank0) == 2 + assert len(rank1) == 2 + assert list(rank0) == [[0, 1, 6, 7], [4, 5, 10, 11]] + assert list(rank1) == [[2, 3, 8, 9], [0, 1, 6, 7]] + + +def test_multidataset_sampler_uses_custom_rates_without_replacement() -> None: + """Verify regular MultiDataset sampling emits global indices at given rates.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=3), device="cpu"), + Dataset(_OrderedReadManyReader(n=8), device="cpu"), + ) + sampler = MultiDatasetSampler( + dataset, + weights=[1.0, 3.0], + num_samples=8, + replacement=False, + shuffle=False, + ) + + indices = list(sampler) + + assert indices == [0, 1, 3, 4, 5, 6, 7, 8] + assert [dataset.to_local_index(index)[0] for index in indices] == [ + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + ] + + +def test_balanced_multidataset_batch_sampler_forms_balanced_batches() -> None: + """Verify balanced batches include equal samples from each child dataset.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=4), device="cpu"), + Dataset(_OrderedReadManyReader(n=6), device="cpu"), + ) + sampler = MultiDatasetBatchSampler.balanced( + dataset, + batch_size=4, + num_batches=2, + replacement=False, + shuffle=False, + ) + + assert list(sampler) == [[0, 1, 4, 5], [2, 3, 6, 7]] + + loader = DataLoader( + dataset, + batch_sampler=sampler, + prefetch_factor=0, + use_streams=False, + ) + batches = list(loader) + + assert [batch.atomic_numbers.tolist() for batch in batches] == [ + [1, 2, 1, 2], + [3, 4, 3, 4], + ] + + +def test_weighted_multidataset_batch_sampler_uses_dataset_rates() -> None: + """Verify weighted batch sampling allocates batch slots by dataset rate.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=8), device="cpu"), + Dataset(_OrderedReadManyReader(n=8), device="cpu"), + ) + sampler = MultiDatasetBatchSampler( + dataset, + batch_size=5, + weights=[4.0, 1.0], + num_batches=2, + replacement=False, + shuffle=False, + ) + + assert sampler.samples_per_dataset == [4, 1] + assert list(sampler) == [[0, 1, 2, 3, 8], [4, 5, 6, 7, 9]] + + +def test_samples_per_dataset_floats_are_relative_rates() -> None: + """Verify float samples_per_dataset entries allocate by relative ratio.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=8), device="cpu"), + Dataset(_OrderedReadManyReader(n=8), device="cpu"), + ) + sampler = MultiDatasetBatchSampler( + dataset, + batch_size=8, + samples_per_dataset=[1.0, 3.0], + num_batches=1, + replacement=False, + shuffle=False, + ) + + assert sampler.samples_per_dataset == [2, 6] + assert list(sampler) == [[0, 1, 8, 9, 10, 11, 12, 13]] + + +def test_batch_sampler_min_size_epoch_policy_stops_at_smallest_dataset() -> None: + """Verify min_size avoids oversampling smaller contributing datasets.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=2), device="cpu"), + Dataset(_OrderedReadManyReader(n=6), device="cpu"), + ) + sampler = MultiDatasetBatchSampler.balanced( + dataset, + batch_size=4, + epoch_policy="min_size", + replacement=True, + shuffle=False, + ) + + assert len(sampler) == 1 + assert list(sampler) == [[0, 1, 2, 3]] + + +def test_batch_sampler_max_size_epoch_policy_oversamples_smaller_dataset() -> None: + """Verify max_size can balance batches across the largest dataset span.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=2), device="cpu"), + Dataset(_OrderedReadManyReader(n=6), device="cpu"), + ) + sampler = MultiDatasetBatchSampler.balanced( + dataset, + batch_size=4, + epoch_policy="max_size", + replacement=True, + shuffle=False, + ) + + assert len(sampler) == 3 + assert list(sampler) == [ + [0, 1, 2, 3], + [0, 1, 4, 5], + [0, 1, 6, 7], + ] + + +def test_batch_sampler_max_size_epoch_policy_requires_replacement() -> None: + """Verify max_size fails without replacement when oversampling is required.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=2), device="cpu"), + Dataset(_OrderedReadManyReader(n=6), device="cpu"), + ) + + with pytest.raises(ValueError, match="replacement=True"): + MultiDatasetBatchSampler.balanced( + dataset, + batch_size=4, + epoch_policy="max_size", + replacement=False, + shuffle=False, + ) diff --git a/test/data/test_reader_base.py b/test/data/test_reader_base.py index af893d56..1cdd0cc4 100644 --- a/test/data/test_reader_base.py +++ b/test/data/test_reader_base.py @@ -24,6 +24,7 @@ import pytest import torch +from physicsnemo.datapipes.readers.base import Reader as PhysicsNeMoReader from nvalchemi.data.datapipes.backends.base import Reader @@ -60,6 +61,29 @@ def __len__(self) -> int: return len(self._data) +class ManyOnlyReader(Reader): + """Reader implementation that only supports batch-oriented raw loading.""" + + def __init__(self) -> None: + super().__init__() + self._data = [{"x": torch.tensor([float(i)])} for i in range(3)] + self.calls: list[list[int]] = [] + + def _load_many_samples(self, indices) -> list[dict[str, torch.Tensor]]: + self.calls.append(list(indices)) + return [self._data[index] for index in indices] + + def __len__(self) -> int: + return len(self._data) + + +class NoLoadHookReader(Reader): + """Reader implementation without a raw loading hook.""" + + def __len__(self) -> int: + return 1 + + class FailingReader(MinimalReader): """Reader that raises ``ValueError`` when a specific index is loaded.""" @@ -190,6 +214,31 @@ def test_pin_memory_true_pins_tensors(self): assert data_dict["x"].is_pinned() +class TestReaderLoadHooks: + """Tests for optional single-sample and multi-sample loading hooks.""" + + def test_read_uses_many_sample_hook_when_available(self): + """A reader can implement only _load_many_samples.""" + reader = ManyOnlyReader() + data_dict, metadata = reader.read(1) + assert torch.allclose(data_dict["x"], torch.tensor([1.0])) + assert metadata["index"] == 1 + assert reader.calls == [[1]] + + def test_read_many_uses_many_sample_hook_once(self): + """read_many delegates one request to _load_many_samples.""" + reader = ManyOnlyReader() + samples = reader.read_many([2, -3]) + assert [metadata["index"] for _, metadata in samples] == [2, -3] + assert reader.calls == [[2, -3]] + + def test_reader_without_load_hook_raises_not_implemented(self): + """A concrete reader still needs at least one raw loading hook.""" + reader = NoLoadHookReader() + with pytest.raises(NotImplementedError): + reader.read(0) + + # --------------------------------------------------------------------------- # TestReaderGetSampleMetadata # --------------------------------------------------------------------------- @@ -282,3 +331,11 @@ def test_repr_contains_pin_memory(self): reader_pin = MinimalReader(pin_memory=True) assert "False" in repr(reader_no_pin) assert "True" in repr(reader_pin) + + +class TestReaderPhysicsNeMoInheritance: + """Tests for PhysicsNeMo reader compatibility.""" + + def test_reader_is_physicsnemo_reader(self): + """nvalchemi readers inherit from the PhysicsNeMo reader base.""" + assert isinstance(MinimalReader(), PhysicsNeMoReader) diff --git a/test/data/test_zarr_datapipe.py b/test/data/test_zarr_datapipe.py index b44f371d..3f1010e6 100644 --- a/test/data/test_zarr_datapipe.py +++ b/test/data/test_zarr_datapipe.py @@ -17,7 +17,7 @@ from __future__ import annotations import random -from collections.abc import Generator +from collections.abc import Generator, Iterator, Sequence from math import floor from pathlib import Path from unittest.mock import MagicMock, patch @@ -25,6 +25,11 @@ import pytest import torch import zarr +from physicsnemo.datapipes.dataloader import DataLoader as PhysicsNeMoDataLoader +from physicsnemo.datapipes.dataset import Dataset as PhysicsNeMoDataset +from physicsnemo.datapipes.multi_dataset import MultiDataset as PhysicsNeMoMultiDataset +from torch.utils.data import Sampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler from nvalchemi.data.atomic_data import AtomicData from nvalchemi.data.batch import Batch @@ -33,7 +38,9 @@ AtomicDataZarrWriter, DataLoader, Dataset, + MultiDataset, ) +from nvalchemi.data.datapipes.backends.base import Reader from nvalchemi.data.datapipes.backends.zarr import ( ZarrArrayConfig, ZarrWriteConfig, @@ -86,6 +93,16 @@ def _data_generator(num_samples: int, seed: int = 5136) -> Generator: yield _make_atomic_data(num_atoms, num_edges) +def _make_ordered_atomic_data(label: int) -> AtomicData: + """Create one-atom AtomicData with an order-identifying atomic number.""" + return AtomicData( + atomic_numbers=torch.tensor([label], dtype=torch.long), + positions=torch.tensor([[float(label), 0.0, 0.0]]), + cell=torch.eye(3).unsqueeze(0), + pbc=torch.tensor([[True, True, True]]), + ) + + class TestAtomicDataZarrWriter: """Tests for AtomicDataZarrWriter.""" @@ -904,6 +921,91 @@ def test_reader_full_roundtrip(tmp_path: Path) -> None: ) +def test_reader_read_many_matches_single_sample_reads(tmp_path: Path) -> None: + """Verify read_many preserves per-sample reader semantics and order.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + indices = [2, 0, 3] + many = reader.read_many(indices) + singles = [reader.read(index) for index in indices] + + assert len(many) == len(indices) + for (many_data, many_metadata), (single_data, single_metadata) in zip( + many, singles, strict=True + ): + assert many_metadata["physical_index"] == single_metadata["physical_index"] + for key, many_tensor in many_data.items(): + single_tensor = single_data[key] + if many_tensor.dtype.is_floating_point: + assert torch.allclose(many_tensor, single_tensor), key + else: + assert torch.equal(many_tensor, single_tensor), key + + +def test_reader_read_many_skips_deleted_and_supports_negative_indices( + tmp_path: Path, +) -> None: + """Verify read_many maps logical indices through the active sample mask.""" + data_list = [_make_ordered_atomic_data(i + 1) for i in range(5)] + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + writer.delete([1]) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + samples = reader.read_many([2, 0, -1]) + + labels = [data["atomic_numbers"].item() for data, _ in samples] + physical_indices = [metadata["physical_index"] for _, metadata in samples] + logical_indices = [metadata["index"] for _, metadata in samples] + + assert labels == [4, 1, 5] + assert physical_indices == ["3", "0", "4"] + assert logical_indices == [2, 0, 3] + + +def test_reader_read_many_empty_returns_empty(tmp_path: Path) -> None: + """Verify read_many([]) returns an empty list.""" + data_list = list(_data_generator(3)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + result = reader.read_many([]) + assert result == [] + + +def test_reader_read_many_single_element(tmp_path: Path) -> None: + """Verify read_many([i]) matches reader.read(i).""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + many = reader.read_many([2]) + single = reader.read(2) + + many_data, many_meta = many[0] + single_data, single_meta = single + assert many_meta["physical_index"] == single_meta["physical_index"] + for key in many_data: + assert torch.equal(many_data[key], single_data[key]), key + + +def test_dataset_metadata_delegates_to_zarr_reader_pointers(tmp_path: Path) -> None: + """Verify Zarr metadata lookup does not load full samples.""" + data_list = [_make_atomic_data(3, 2), _make_atomic_data(5, 7)] + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device="cpu") + with patch.object(reader, "_load_sample", side_effect=AssertionError): + assert dataset.get_metadata(1) == (5, 7) + + def test_reader_optional_fields_only(tmp_path: Path) -> None: """Verify minimal AtomicData loads without error. @@ -1235,6 +1337,327 @@ def test_dataset_roundtrip_values(tmp_path: Path) -> None: assert torch.allclose(loaded.shifts, original.shifts) +class _OrderedReadManyReader: + """Minimal reader that records read_many calls for DataLoader tests.""" + + def __init__(self, n: int = 5) -> None: + self._n = n + self.read_many_calls: list[list[int]] = [] + self.pin_memory = False + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + return _make_ordered_atomic_data(index + 1).to_dict() + + def _get_sample_metadata(self, index: int) -> dict[str, int]: + return {"src_index": index} + + @property + def field_names(self) -> list[str]: + return list(self._load_sample(0)) if self._n > 0 else [] + + def read_many( + self, indices: Sequence[int] + ) -> list[tuple[dict[str, torch.Tensor], dict[str, int]]]: + self.read_many_calls.append(list(indices)) + return [ + (self._load_sample(index), self._get_sample_metadata(index)) + for index in indices + ] + + def __len__(self) -> int: + return self._n + + def close(self) -> None: + pass + + +def test_dataset_load_batches_uses_reader_read_many() -> None: + """Verify Dataset.load_batches delegates batch reads to the reader.""" + reader = _OrderedReadManyReader() + dataset = Dataset(reader, device="cpu") + + batch = dataset.load_batches([[3, 1]])[0] + + assert reader.read_many_calls == [[3, 1]] + assert batch.atomic_numbers.tolist() == [4, 2] + + +def test_dataset_and_dataloader_are_physicsnemo_subclasses() -> None: + """Verify nvalchemi datapipes inherit PhysicsNeMo datapipes.""" + dataset = Dataset(_OrderedReadManyReader(), device="cpu") + loader = DataLoader(dataset, batch_size=1, use_streams=False) + multidataset = MultiDataset(dataset) + + assert isinstance(dataset, PhysicsNeMoDataset) + assert isinstance(loader, PhysicsNeMoDataLoader) + assert isinstance(multidataset, PhysicsNeMoMultiDataset) + + +def test_dataloader_fused_prefetches_sampler_batches_without_streams() -> None: + """Verify DataLoader fuses sampler batches even without CUDA streams.""" + reader = _OrderedReadManyReader() + dataset = Dataset(reader, device="cpu") + + class FixedSampler(Sampler[int]): + """Sampler with a deterministic non-sequential order.""" + + def __iter__(self) -> Iterator[int]: + return iter([4, 2, 0]) + + def __len__(self) -> int: + return 3 + + loader = DataLoader( + dataset, + batch_size=2, + sampler=FixedSampler(), + use_streams=False, + ) + + batches = list(loader) + + assert reader.read_many_calls == [[4, 2, 0]] + assert [batch.atomic_numbers.tolist() for batch in batches] == [[5, 3], [1]] + + +def test_dataloader_fused_prefetches_batch_sampler_without_streams() -> None: + """Verify pre-batched indices are fused by default without CUDA streams.""" + reader = _OrderedReadManyReader() + dataset = Dataset(reader, device="cpu") + + class FixedBatchSampler(Sampler[list[int]]): + """Sampler that yields pre-batched indices.""" + + def __iter__(self) -> Iterator[list[int]]: + return iter([[3, 1], [0, 2]]) + + def __len__(self) -> int: + return 2 + + loader = DataLoader(dataset, batch_sampler=FixedBatchSampler(), use_streams=False) + + batches = list(loader) + + assert len(loader) == 2 + assert reader.read_many_calls == [[3, 1, 0, 2]] + assert [batch.atomic_numbers.tolist() for batch in batches] == [[4, 2], [1, 3]] + + +def test_dataloader_prefetch_factor_zero_uses_simple_batches() -> None: + """Verify prefetch_factor=0 preserves one read_many call per batch.""" + reader = _OrderedReadManyReader() + dataset = Dataset(reader, device="cpu") + + loader = DataLoader( + dataset, + batch_size=2, + use_streams=False, + prefetch_factor=0, + ) + + batches = list(loader) + + assert reader.read_many_calls == [[0, 1], [2, 3], [4]] + assert [batch.atomic_numbers.tolist() for batch in batches] == [[1, 2], [3, 4], [5]] + + +def test_dataloader_prefetch_factor_controls_read_window() -> None: + """Verify prefetch_factor controls the fused read_many window.""" + reader = _OrderedReadManyReader(n=10) + dataset = Dataset(reader, device="cpu") + loader = DataLoader( + dataset, + batch_size=2, + prefetch_factor=3, + use_streams=False, + ) + + batches = list(loader) + + assert loader.effective_read_window == 6 + assert reader.read_many_calls == [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]] + assert [batch.atomic_numbers.tolist() for batch in batches] == [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + ] + + +def test_multidataset_getitem_enriches_metadata() -> None: + """Verify MultiDataset sample access reports source dataset metadata.""" + reader_a = _OrderedReadManyReader(n=3) + reader_b = _OrderedReadManyReader(n=4) + dataset_a = Dataset(reader_a, device="cpu") + dataset_b = Dataset(reader_b, device="cpu") + dataset = MultiDataset(dataset_a, dataset_b) + + data, metadata = dataset[4] + + assert reader_a.read_many_calls == [] + assert reader_b.read_many_calls == [[1]] + assert data.atomic_numbers.item() == 2 + assert metadata["dataset_index"] == 1 + assert metadata["src_index"] == 1 + + +def test_multidataset_load_batches_routes_mixed_indices_to_child_batches() -> None: + """Verify mixed MultiDataset batches route through child load_batches methods.""" + dataset_a = Dataset(_OrderedReadManyReader(n=3), device="cpu") + dataset_b = Dataset(_OrderedReadManyReader(n=4), device="cpu") + dataset = MultiDataset(dataset_a, dataset_b) + + with ( + patch.object(dataset_a, "load_batches", wraps=dataset_a.load_batches) as load_a, + patch.object(dataset_b, "load_batches", wraps=dataset_b.load_batches) as load_b, + ): + batch = dataset.load_batches([[0, 3, 2, 6]])[0] + + assert [ + [list(batch_indices) for batch_indices in call.args[0]] + for call in load_a.call_args_list + ] == [[[0, 2]]] + assert [ + [list(batch_indices) for batch_indices in call.args[0]] + for call in load_b.call_args_list + ] == [[[0, 3]]] + assert batch.atomic_numbers.tolist() == [1, 1, 3, 4] + + +def test_multidataset_dataloader_delegates_single_child_fused_prefetch() -> None: + """Verify same-child fused chunks use the child fused read path.""" + reader_a = _OrderedReadManyReader(n=6) + reader_b = _OrderedReadManyReader(n=4) + dataset = MultiDataset( + Dataset(reader_a, device="cpu"), + Dataset(reader_b, device="cpu"), + ) + loader = DataLoader( + dataset, + batch_size=2, + prefetch_factor=2, + sampler=SequentialSampler(range(4)), + use_streams=False, + ) + + batches = list(loader) + + assert reader_a.read_many_calls == [[0, 1, 2, 3]] + assert reader_b.read_many_calls == [] + assert [batch.atomic_numbers.tolist() for batch in batches] == [[1, 2], [3, 4]] + + +def test_multidataset_dataloader_groups_mixed_fused_prefetch_by_child() -> None: + """Verify mixed fused chunks still issue one read_many per child.""" + reader_a = _OrderedReadManyReader(n=3) + reader_b = _OrderedReadManyReader(n=4) + dataset_a = Dataset(reader_a, device="cpu") + dataset_b = Dataset(reader_b, device="cpu") + dataset = MultiDataset( + dataset_a, + dataset_b, + ) + + class MixedSampler(Sampler[int]): + """Sampler that alternates between child datasets.""" + + def __iter__(self) -> Iterator[int]: + return iter([0, 3, 2, 6]) + + def __len__(self) -> int: + return 4 + + with ( + patch.object(dataset_a, "load_batches", wraps=dataset_a.load_batches) as load_a, + patch.object(dataset_b, "load_batches", wraps=dataset_b.load_batches) as load_b, + ): + loader = DataLoader( + dataset, + batch_size=2, + prefetch_factor=2, + sampler=MixedSampler(), + use_streams=False, + ) + batches = list(loader) + + assert [ + [list(batch_indices) for batch_indices in call.args[0]] + for call in load_a.call_args_list + ] == [[[0], [2]]] + assert [ + [list(batch_indices) for batch_indices in call.args[0]] + for call in load_b.call_args_list + ] == [[[0], [3]]] + assert reader_a.read_many_calls == [[0, 2]] + assert reader_b.read_many_calls == [[0, 3]] + assert [batch.atomic_numbers.tolist() for batch in batches] == [[1, 1], [3, 4]] + + +def test_dataloader_rejects_negative_prefetch_factor() -> None: + """Verify negative prefetch factors fail instead of disabling prefetching.""" + reader = _OrderedReadManyReader() + dataset = Dataset(reader, device="cpu") + + with pytest.raises(ValueError, match="prefetch_factor"): + DataLoader(dataset, prefetch_factor=-1) + + +def test_dataloader_pin_memory_enables_reader_pin_memory() -> None: + """Verify DataLoader pin_memory requests pinned reads from the reader.""" + reader = _OrderedReadManyReader() + dataset = Dataset(reader, device="cpu") + + loader = DataLoader(dataset, batch_size=2, use_streams=False, pin_memory=True) + + assert loader.pin_memory is True + assert dataset.pin_memory is True + assert reader.pin_memory is True + + +def test_dataloader_pin_memory_does_not_mutate_multidataset_children() -> None: + """Verify MultiDataset leaves child pin-memory policy to each dataset.""" + reader_a = _OrderedReadManyReader() + reader_b = _OrderedReadManyReader() + dataset = MultiDataset( + Dataset(reader_a, device="cpu"), + Dataset(reader_b, device="cpu"), + ) + + loader = DataLoader(dataset, batch_size=2, use_streams=False, pin_memory=True) + + assert loader.pin_memory is True + assert reader_a.pin_memory is False + assert reader_b.pin_memory is False + + +def test_multidataset_output_strict_uses_first_nonempty_field_names() -> None: + """Verify strict field validation ignores empty leading datasets.""" + empty_dataset = Dataset(_OrderedReadManyReader(n=0), device="cpu") + nonempty_dataset = Dataset(_OrderedReadManyReader(n=2), device="cpu") + + dataset = MultiDataset(empty_dataset, nonempty_dataset, output_strict=True) + + assert dataset.field_names == nonempty_dataset.field_names + assert dataset.validate_field_names() == nonempty_dataset.field_names + + +def test_multidataset_cancel_prefetch_canonicalizes_negative_indices() -> None: + """Verify cancel_prefetch(-1) clears delegated child sample prefetch.""" + dataset = MultiDataset( + Dataset(_OrderedReadManyReader(n=3), device="cpu"), + Dataset(_OrderedReadManyReader(n=2), device="cpu"), + ) + + dataset.prefetch(-1) + assert dataset.prefetch_count == 1 + + dataset.cancel_prefetch(-1) + + assert dataset.prefetch_count == 0 + dataset.close() + + @pytest.mark.parametrize("batch_size", [1, 4, 8, 16, 32]) @pytest.mark.parametrize("sample_scale", [0.9, 1.0, 1.1]) def test_dataloader_yields_batch( @@ -1303,6 +1726,61 @@ def test_dataloader_shuffle(tmp_path: Path) -> None: assert order1 != order2, "Shuffle should produce different order across loaders" +def test_dataloader_custom_sampler(tmp_path: Path) -> None: + """Verify DataLoader respects a minimal custom sampler order.""" + + class ReverseOddSampler(Sampler[int]): + """Yield a fixed non-sequential subset to exercise custom sampling.""" + + def __init__(self, indices: list[int]) -> None: + self.indices = indices + + def __iter__(self) -> Iterator[int]: + return iter(self.indices) + + def __len__(self) -> int: + return len(self.indices) + + data_list = [_make_ordered_atomic_data(i + 1) for i in range(5)] + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + sampler = ReverseOddSampler([4, 2, 0]) + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device="cpu") + loader = DataLoader(dataset, batch_size=2, sampler=sampler, use_streams=False) + + with patch.object(reader, "read_many", wraps=reader.read_many) as read_many: + batches = list(loader) + + assert [batch.atomic_numbers.tolist() for batch in batches] == [[5, 3], [1]] + assert [list(call.args[0]) for call in read_many.call_args_list] == [[4, 2, 0]] + + +def test_dataloader_distributed_sampler(tmp_path: Path) -> None: + """Verify DataLoader works with PyTorch's DistributedSampler.""" + data_list = [_make_ordered_atomic_data(i + 1) for i in range(6)] + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device="cpu") + sampler = DistributedSampler( + dataset, + num_replicas=2, + rank=1, + shuffle=False, + drop_last=False, + ) + loader = DataLoader(dataset, batch_size=2, sampler=sampler, use_streams=False) + + with patch.object(reader, "read_many", wraps=reader.read_many) as read_many: + batches = list(loader) + + assert [batch.atomic_numbers.tolist() for batch in batches] == [[2, 4], [6]] + assert [list(call.args[0]) for call in read_many.call_args_list] == [[1, 3, 5]] + + class TestDatasetPrefetch: """Tests for Dataset prefetch mechanics (CPU thread-pool path).""" @@ -1529,9 +2007,9 @@ def test_load_and_transform_captures_error(self, tmp_path: Path) -> None: with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: dataset = Dataset(reader, device="cpu") - # Mock reader._load_sample to raise an error + # Mock the raw batch hook to raise an error with patch.object( - reader, "_load_sample", side_effect=RuntimeError("test error") + reader, "_load_many_samples", side_effect=RuntimeError("test error") ): result = dataset._load_and_transform(0) @@ -1550,8 +2028,10 @@ def test_prefetch_error_propagation(self, tmp_path: Path) -> None: with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: dataset = Dataset(reader, device="cpu") - # Mock reader._load_sample to raise an error - with patch.object(reader, "_load_sample", side_effect=RuntimeError("boom")): + # Mock the raw batch hook to raise an error + with patch.object( + reader, "_load_many_samples", side_effect=RuntimeError("boom") + ): # Prefetch the sample (error will be captured) dataset.prefetch(0) @@ -1580,6 +2060,408 @@ def test_dataset_close_with_inflight_prefetch(self, tmp_path: Path) -> None: assert reader._root is None +class TestFusedBatchPrefetch: + """Tests for fused multi-batch prefetch (prefetch_fused_batches / get_fused_batches).""" + + def test_fused_prefetch_yields_correct_batches( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Fused read yields the same batches as individual reads.""" + data_list = list(_data_generator(12)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device) + + # Read synchronously for reference + ref_b0, ref_b1, ref_b2 = dataset.load_batches( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + ) + + # Read via fused prefetch + dataset.prefetch_fused_batches([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + fused_batches = list(dataset.get_fused_batches()) + + assert len(fused_batches) == 3 + for fused, ref in zip(fused_batches, [ref_b0, ref_b1, ref_b2], strict=True): + assert fused.num_graphs == ref.num_graphs + torch.testing.assert_close(fused.positions, ref.positions) + torch.testing.assert_close(fused.atomic_numbers, ref.atomic_numbers) + + def test_fused_prefetch_variable_batch_sizes( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Handles sub-batches of different sizes (e.g. last batch is smaller).""" + data_list = list(_data_generator(7)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device) + + dataset.prefetch_fused_batches([[0, 1, 2], [3, 4, 5], [6]]) + fused_batches = list(dataset.get_fused_batches()) + + assert len(fused_batches) == 3 + assert fused_batches[0].num_graphs == 3 + assert fused_batches[1].num_graphs == 3 + assert fused_batches[2].num_graphs == 1 + + def test_fused_prefetch_raises_without_pending( + self, tmp_path: Path, gpu_device: str + ) -> None: + """get_fused_batches raises RuntimeError when no prefetch is pending.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device) + with pytest.raises(RuntimeError, match="No fused batch prefetch pending"): + list(dataset.get_fused_batches()) + + def test_fused_prefetch_queues_two(self, tmp_path: Path, gpu_device: str) -> None: + """Two prefetch_fused_batches calls queue; a third fails explicitly.""" + data_list = list(_data_generator(12)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device) + + dataset.prefetch_fused_batches([[0, 1], [2, 3]]) + dataset.prefetch_fused_batches([[4, 5], [6, 7]]) + with pytest.raises(RuntimeError, match="queue is full"): + dataset.prefetch_fused_batches([[8, 9], [10, 11]]) + + # First get_fused_batches returns chunk 1 + batches_1 = list(dataset.get_fused_batches()) + assert len(batches_1) == 2 + assert batches_1[0].num_graphs == 2 + + # Second get_fused_batches returns chunk 2 + batches_2 = list(dataset.get_fused_batches()) + assert len(batches_2) == 2 + assert batches_2[0].num_graphs == 2 + + def test_cancel_clears_fused_prefetch( + self, tmp_path: Path, gpu_device: str + ) -> None: + """cancel_prefetch clears the fused prefetch future.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device) + + dataset.prefetch_fused_batches([[0, 1], [2, 3]]) + dataset.cancel_prefetch() + + # Should now raise because the future was cleared + with pytest.raises(RuntimeError, match="No fused batch prefetch pending"): + list(dataset.get_fused_batches()) + + def test_dataloader_amortized_completeness( + self, tmp_path: Path, gpu_device: str + ) -> None: + """DataLoader with amortized prefetch yields all samples.""" + num_samples = 20 + data_list = list(_data_generator(num_samples)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device) + loader = DataLoader( + dataset, + batch_size=3, + prefetch_factor=4, + use_streams=True, + ) + + total = sum(batch.num_graphs for batch in loader) + assert total == num_samples + + def test_dataloader_amortized_shuffle( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Shuffled DataLoader with amortized prefetch yields all samples.""" + num_samples = 16 + data_list = list(_data_generator(num_samples)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device) + loader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + prefetch_factor=3, + use_streams=True, + ) + + total = sum(batch.num_graphs for batch in loader) + assert total == num_samples + + def test_skip_validation_matches_validated( + self, tmp_path: Path, gpu_device: str + ) -> None: + """skip_validation=True fused prefetch yields same data as validated path.""" + data_list = list(_data_generator(12)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + batch_lists = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + ds_val = Dataset(reader, device=gpu_device, skip_validation=False) + ds_val.prefetch_fused_batches(batch_lists) + ref_batches = list(ds_val.get_fused_batches()) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + ds_raw = Dataset(reader, device=gpu_device, skip_validation=True) + ds_raw.prefetch_fused_batches(batch_lists) + raw_batches = list(ds_raw.get_fused_batches()) + + assert len(raw_batches) == len(ref_batches) + for raw, ref in zip(raw_batches, ref_batches, strict=True): + assert raw.num_graphs == ref.num_graphs + assert raw.num_nodes == ref.num_nodes + assert raw.num_edges == ref.num_edges + torch.testing.assert_close(raw.positions, ref.positions) + torch.testing.assert_close(raw.atomic_numbers, ref.atomic_numbers) + + def test_skip_validation_dataloader_completeness( + self, tmp_path: Path, gpu_device: str + ) -> None: + """DataLoader with skip_validation yields all samples.""" + num_samples = 20 + data_list = list(_data_generator(num_samples)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device, skip_validation=True) + loader = DataLoader( + dataset, + batch_size=3, + prefetch_factor=4, + use_streams=True, + ) + total = sum(batch.num_graphs for batch in loader) + assert total == num_samples + + def test_skip_validation_dataloader_shuffle( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Shuffled DataLoader with skip_validation yields all samples.""" + num_samples = 16 + data_list = list(_data_generator(num_samples)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device, skip_validation=True) + loader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + prefetch_factor=3, + use_streams=True, + ) + total = sum(batch.num_graphs for batch in loader) + assert total == num_samples + + def test_fused_prefetch_error_propagation( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Verify background read errors propagate through get_fused_batches.""" + data_list = list(_data_generator(6)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device) + # Patch _read_raw_samples to raise + with patch.object( + dataset, "_read_raw_samples", side_effect=RuntimeError("boom") + ): + dataset.prefetch_fused_batches([[0, 1], [2, 3]]) + with pytest.raises(RuntimeError, match="boom"): + list(dataset.get_fused_batches()) + + def test_skip_validation_custom_key_roundtrip( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Custom Zarr fields survive the skip_validation + from_raw_dicts path.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + # Add a custom system-level field + custom = torch.arange(4, dtype=torch.float32).unsqueeze(1) + writer.add_custom("my_flag", custom, "system") + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device, skip_validation=True) + batch = dataset.load_batches([list(range(4))])[0] + + assert "my_flag" in batch.keys["system"] + assert batch.my_flag.shape[0] == 4 + + def test_skip_validation_custom_atom_key_roundtrip( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Custom atom-level Zarr fields are classified correctly with skip_validation. + + Reproduces the bug where from_raw_dicts misclassified custom + per-atom tensors as system-level, causing a shape crash in + UniformLevelStorage. + """ + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + # Add a custom atom-level field (variable size per sample). + total_atoms = sum(d.num_nodes for d in data_list) + embeddings = torch.randn(total_atoms, 8) + writer.add_custom("atom_embedding", embeddings, "atom") + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + # Verify reader exposes field_levels with the custom field. + assert reader.field_levels.get("atom_embedding") == "atom" + + dataset = Dataset(reader, device=gpu_device, skip_validation=True) + batch = dataset.load_batches([list(range(4))])[0] + + assert "atom_embedding" in batch.keys["node"] + assert batch.atom_embedding.shape == (total_atoms, 8) + + def test_skip_validation_custom_edge_key_roundtrip( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Custom edge-level Zarr fields survive skip_validation path.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + total_edges = sum(d.num_edges for d in data_list) + distances = torch.randn(total_edges) + writer.add_custom("pair_distance", distances, "edge") + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + assert reader.field_levels.get("pair_distance") == "edge" + + dataset = Dataset(reader, device=gpu_device, skip_validation=True) + batch = dataset.load_batches([list(range(4))])[0] + + assert "pair_distance" in batch.keys["edge"] + assert batch.pair_distance.shape == (total_edges,) + + def test_validated_custom_key_roundtrip( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Custom system-level Zarr fields survive validated batching.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + custom = torch.arange(4, dtype=torch.float32).unsqueeze(1) + writer.add_custom("my_flag", custom, "system") + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device, skip_validation=False) + batch = dataset.get_batch(list(range(4))) + + assert "my_flag" in batch.keys["system"] + assert batch.my_flag.shape[0] == 4 + + def test_validated_custom_atom_key_roundtrip( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Custom atom-level Zarr fields are classified in validated batches.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + total_atoms = sum(d.num_nodes for d in data_list) + embeddings = torch.randn(total_atoms, 8) + writer.add_custom("atom_embedding", embeddings, "atom") + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + assert reader.field_levels.get("atom_embedding") == "atom" + + dataset = Dataset(reader, device=gpu_device, skip_validation=False) + batch = dataset.get_batch(list(range(4))) + + assert "atom_embedding" in batch.keys["node"] + assert batch.atom_embedding.shape == (total_atoms, 8) + + def test_validated_prefetch_custom_atom_key_roundtrip( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Custom atom-level fields survive validated get_batch prefetch.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + total_atoms = sum(d.num_nodes for d in data_list) + embeddings = torch.randn(total_atoms, 8) + writer.add_custom("atom_embedding", embeddings, "atom") + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device, skip_validation=False) + dataset.prefetch_many(list(range(4))) + batch = dataset.get_batch(list(range(4))) + + assert "atom_embedding" in batch.keys["node"] + assert batch.atom_embedding.shape == (total_atoms, 8) + + def test_validated_custom_edge_key_roundtrip( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Custom edge-level Zarr fields survive validated batching.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + total_edges = sum(d.num_edges for d in data_list) + distances = torch.randn(total_edges) + writer.add_custom("pair_distance", distances, "edge") + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + assert reader.field_levels.get("pair_distance") == "edge" + + dataset = Dataset(reader, device=gpu_device, skip_validation=False) + batch = dataset.get_batch(list(range(4))) + + assert "pair_distance" in batch.keys["edge"] + assert batch.pair_distance.shape == (total_edges,) + + def test_validated_fused_prefetch_custom_atom_key_roundtrip( + self, tmp_path: Path, gpu_device: str + ) -> None: + """Custom atom-level fields survive validated fused prefetch.""" + data_list = list(_data_generator(4)) + writer = AtomicDataZarrWriter(tmp_path / "test.zarr") + writer.write(data_list) + + total_atoms = sum(d.num_nodes for d in data_list) + embeddings = torch.randn(total_atoms, 8) + writer.add_custom("atom_embedding", embeddings, "atom") + + with AtomicDataZarrReader(tmp_path / "test.zarr") as reader: + dataset = Dataset(reader, device=gpu_device, skip_validation=False) + dataset.prefetch_fused_batches([list(range(4))]) + batches = list(dataset.get_fused_batches()) + + assert len(batches) == 1 + batch = batches[0] + assert "atom_embedding" in batch.keys["node"] + assert batch.atom_embedding.shape == (total_atoms, 8) + + class TestDataLoaderPrefetch: """Tests for DataLoader prefetch iteration path.""" @@ -1688,7 +2570,16 @@ def test_prefetch_pipeline_completeness( def test_prefetch_consumes_batches_lazily( self, tmp_path: Path, gpu_device: str ) -> None: - """Generator is not fully materialised; only the fill window is consumed.""" + """Generator is not fully materialised; only the pipeline window is consumed. + + True double-buffering primes two queue slots, then refills + one slot after consuming the oldest chunk. By the first + yield, at most ``3 * prefetch_factor`` batch-index lists have + been pulled from the sampler: + - chunk_a (pf) — primed and consumed + - chunk_b (pf) — primed, still in flight + - next_chunk (pf) — collected and submitted after chunk_a + """ data_list = list(_data_generator(20)) writer = AtomicDataZarrWriter(tmp_path / "test.zarr") writer.write(data_list) @@ -1719,7 +2610,9 @@ def _counting_generate(): gen = loader._iter_prefetch() next(gen) - assert batches_pulled <= prefetch_factor + # True double-buffer: 2 primed chunks + 1 refill after + # consuming the oldest = 3 * prefetch_factor pulled. + assert batches_pulled <= 3 * prefetch_factor gen.close() @@ -2037,10 +2930,11 @@ def test_append_dispatches_to_memory_store_set(self, num_samples: int) -> None: # --------------------------------------------------------------------------- -class _SimpleReader: +class _SimpleReader(Reader): """Minimal duck-typed reader for Dataset tests (no zarr required).""" def __init__(self, n: int = 3) -> None: + super().__init__() self._n = n def _load_sample(self, index: int) -> dict: @@ -2090,6 +2984,8 @@ def test_default_device_is_set_when_none_given(self): @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_getitem_returns_atomic_data_and_metadata(self, device: str): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("No CUDA device available.") reader = _SimpleReader() ds = Dataset(reader, device=device) data, meta = ds[0] @@ -2098,6 +2994,8 @@ def test_getitem_returns_atomic_data_and_metadata(self, device: str): @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_getitem_transfers_to_target_device(self, device: str): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("No CUDA device available.") reader = _SimpleReader() ds = Dataset(reader, device=device) data, _ = ds[0] diff --git a/test/dynamics/test_profiler_hook.py b/test/dynamics/test_profiler_hook.py index 7620dc85..55435020 100644 --- a/test/dynamics/test_profiler_hook.py +++ b/test/dynamics/test_profiler_hook.py @@ -12,334 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for ProfilerHook.""" +"""Compatibility tests for removed dynamics ProfilerHook imports.""" from __future__ import annotations -import csv -from unittest.mock import MagicMock, patch - import pytest -import torch - -from nvalchemi.data import AtomicData, Batch -from nvalchemi.dynamics.base import DynamicsStage -from nvalchemi.dynamics.demo import DemoDynamics -from nvalchemi.dynamics.hooks.profiling import ProfilerHook -from nvalchemi.models.demo import DemoModel, DemoModelWrapper - - -def _make_batch( - n_graphs: int = 2, atoms_per_graph: int = 3, device: str = "cpu" -) -> Batch: - data_list = [ - AtomicData( - atomic_numbers=torch.tensor([6] * atoms_per_graph, dtype=torch.long), - positions=torch.randn(atoms_per_graph, 3), - ) - for _ in range(n_graphs) - ] - batch = Batch.from_data_list(data_list).to(device) - batch.__dict__["forces"] = torch.randn(batch.num_nodes, 3, device=device) - batch.__dict__["energy"] = torch.randn(batch.num_graphs, 1, device=device) - batch.__dict__["velocities"] = torch.randn(batch.num_nodes, 3, device=device) * 0.01 - batch.__dict__["atomic_masses"] = torch.full( - (batch.num_nodes,), 12.0, device=device - ) - return batch - - -def _make_dynamics(hooks=None, n_steps: int = 5, device: str = "cpu") -> DemoDynamics: - model = DemoModelWrapper(DemoModel()) - if device != "cpu": - model = model.to(device) - return DemoDynamics( - model=model, n_steps=n_steps, dt=1.0, hooks=hooks, device_type=device - ) - - -# ------------------------------------------------------------------ -# Construction / presets -# ------------------------------------------------------------------ - - -class TestConstruction: - def test_step_preset(self) -> None: - profiler = ProfilerHook("step") - assert set(profiler._profiled_stages) == { - DynamicsStage.BEFORE_STEP, - DynamicsStage.AFTER_STEP, - } - - def test_detailed_preset(self) -> None: - profiler = ProfilerHook("detailed") - expected = { - DynamicsStage.BEFORE_STEP, - DynamicsStage.BEFORE_PRE_UPDATE, - DynamicsStage.AFTER_PRE_UPDATE, - DynamicsStage.BEFORE_COMPUTE, - DynamicsStage.AFTER_COMPUTE, - DynamicsStage.BEFORE_POST_UPDATE, - DynamicsStage.AFTER_POST_UPDATE, - DynamicsStage.AFTER_STEP, - } - assert set(profiler._profiled_stages) == expected - - def test_all_preset(self) -> None: - profiler = ProfilerHook("all") - assert DynamicsStage.ON_CONVERGE not in profiler._profiled_stages - assert len(profiler._profiled_stages) == len(DynamicsStage) - 1 - - def test_custom_stages(self) -> None: - S = DynamicsStage - custom = {S.BEFORE_COMPUTE, S.AFTER_COMPUTE} - profiler = ProfilerHook(custom) - assert set(profiler._profiled_stages) == custom - - def test_unknown_preset_raises(self) -> None: - with pytest.raises(ValueError, match="Unknown stages preset"): - ProfilerHook("bogus") # type: ignore[arg-type] - - def test_single_stage_raises(self) -> None: - with pytest.raises(ValueError, match="At least two stages"): - ProfilerHook({DynamicsStage.BEFORE_STEP}) - - def test_stages_sorted_by_execution_order(self) -> None: - profiler = ProfilerHook("detailed") - values = [s.value for s in profiler._profiled_stages] - assert values == sorted(values) - - -# ------------------------------------------------------------------ -# Registration -# ------------------------------------------------------------------ - - -class TestRegistration: - def test_registers_at_all_stages(self, device: str) -> None: - profiler = ProfilerHook("step") - dynamics = _make_dynamics(hooks=[profiler], n_steps=1, device=device) - assert profiler in dynamics.hooks - # Verify _runs_on_stage covers the expected stages - assert profiler._runs_on_stage(DynamicsStage.BEFORE_STEP) - assert profiler._runs_on_stage(DynamicsStage.AFTER_STEP) - - def test_does_not_register_at_other_stages(self, device: str) -> None: - profiler = ProfilerHook("step") - _make_dynamics(hooks=[profiler], n_steps=1, device=device) - assert not profiler._runs_on_stage(DynamicsStage.BEFORE_COMPUTE) - - def test_composable_with_other_hooks(self, device: str) -> None: - from nvalchemi.dynamics.hooks.safety import NaNDetectorHook - - profiler = ProfilerHook("step") - nan_hook = NaNDetectorHook() - dynamics = _make_dynamics(hooks=[profiler, nan_hook], n_steps=3, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - assert len(profiler.summary()) > 0 - - -# ------------------------------------------------------------------ -# CPU timing -# ------------------------------------------------------------------ - - -class TestTiming: - def test_records_values(self, device: str) -> None: - profiler = ProfilerHook("step") - dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - summary = profiler.summary() - key = "BEFORE_STEP->AFTER_STEP" - assert key in summary - assert summary[key]["n_samples"] == 5 - - def test_summary_keys(self, device: str) -> None: - profiler = ProfilerHook("step") - dynamics = _make_dynamics(hooks=[profiler], n_steps=3, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - summary = profiler.summary() - key = next(iter(summary)) - expected_keys = {"mean_s", "std_s", "min_s", "max_s", "total_s", "n_samples"} - assert set(summary[key].keys()) == expected_keys - - def test_positive_deltas(self, device: str) -> None: - profiler = ProfilerHook("step") - dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - for stats in profiler.summary().values(): - assert stats["mean_s"] >= 0 - assert stats["min_s"] >= 0 - - def test_frequency_gating(self, device: str) -> None: - profiler = ProfilerHook("step", frequency=3) - dynamics = _make_dynamics(hooks=[profiler], n_steps=9, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - summary = profiler.summary() - assert summary["BEFORE_STEP->AFTER_STEP"]["n_samples"] == 3 - - def test_detailed_timing(self, device: str) -> None: - profiler = ProfilerHook("detailed") - dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - summary = profiler.summary() - # 8 stages -> 7 transitions. - assert len(summary) == 7 - for stats in summary.values(): - assert stats["n_samples"] == 5 - - def test_reset(self, device: str) -> None: - profiler = ProfilerHook("step") - dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - assert len(profiler.summary()) > 0 - profiler.reset() - assert profiler.summary() == {} - - -# ------------------------------------------------------------------ -# Auto backend -# ------------------------------------------------------------------ - - -class TestAutoBackend: - def test_auto_selects_perf_counter_on_cpu(self) -> None: - profiler = ProfilerHook("step", timer_backend="auto") - dynamics = _make_dynamics(hooks=[profiler], n_steps=1, device="cpu") - batch = _make_batch(device="cpu") - dynamics.run(batch) - assert profiler._backend_resolved == "perf_counter" - - def test_auto_selects_cuda_event_on_gpu(self, gpu_device: str) -> None: - profiler = ProfilerHook("step", timer_backend="auto") - dynamics = _make_dynamics(hooks=[profiler], n_steps=1, device=gpu_device) - batch = _make_batch(device=gpu_device) - dynamics.run(batch) - assert profiler._backend_resolved == "cuda_event" - - -# ------------------------------------------------------------------ -# NVTX -# ------------------------------------------------------------------ - - -class TestNVTX: - def test_nvtx_push_pop_called(self) -> None: - try: - import nvtx # noqa: F401 - except ImportError: - pytest.skip("nvtx not available") - - with patch("nvalchemi.dynamics.hooks.profiling.nvtx") as mock_nvtx: - mock_nvtx.push_range = MagicMock() - mock_nvtx.pop_range = MagicMock() - - profiler = ProfilerHook("step", enable_nvtx=True) - dynamics = _make_dynamics(hooks=[profiler], n_steps=1) - batch = _make_batch() - dynamics.run(batch) - - assert mock_nvtx.push_range.call_count >= 1 - assert mock_nvtx.pop_range.call_count >= 1 - - def test_nvtx_disabled(self) -> None: - with patch("nvalchemi.dynamics.hooks.profiling.nvtx") as mock_nvtx: - mock_nvtx.push_range = MagicMock() - mock_nvtx.pop_range = MagicMock() - - profiler = ProfilerHook("step", enable_nvtx=False) - dynamics = _make_dynamics(hooks=[profiler], n_steps=1) - batch = _make_batch() - dynamics.run(batch) - - mock_nvtx.push_range.assert_not_called() - mock_nvtx.pop_range.assert_not_called() - - -# ------------------------------------------------------------------ -# CSV logging -# ------------------------------------------------------------------ - - -class TestCSVLogging: - def test_writes_csv(self, tmp_path, device: str) -> None: - log_file = tmp_path / "profiler.csv" - profiler = ProfilerHook("step", log_path=log_file) - dynamics = _make_dynamics(hooks=[profiler], n_steps=3, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - profiler.close() - - with open(log_file) as f: - rows = list(csv.DictReader(f)) - assert len(rows) == 3 - assert "rank" in rows[0] - assert "step" in rows[0] - assert "stage" in rows[0] - assert "t_since_init_s" in rows[0] - assert "delta_s" in rows[0] - - def test_detailed_csv_rows(self, tmp_path, device: str) -> None: - log_file = tmp_path / "detailed.csv" - profiler = ProfilerHook("detailed", log_path=log_file) - dynamics = _make_dynamics(hooks=[profiler], n_steps=2, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - profiler.close() - - with open(log_file) as f: - rows = list(csv.DictReader(f)) - # 8 stages -> 7 transitions per step, 2 steps -> 14 rows. - assert len(rows) == 14 - - -# ------------------------------------------------------------------ -# Console output -# ------------------------------------------------------------------ -class TestConsoleOutput: - def test_show_console(self, device: str) -> None: - with patch("nvalchemi.dynamics.hooks.profiling.logger") as mock_logger: - profiler = ProfilerHook("step", show_console=True) - dynamics = _make_dynamics(hooks=[profiler], n_steps=2, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - assert mock_logger.info.call_count == 2 +def test_removed_profiler_hook_import_from_package_points_to_replacements() -> None: + """Old package-level ProfilerHook imports raise a targeted migration error.""" + with pytest.raises(ImportError, match="TorchProfilerHook.*StageTimingHook"): + from nvalchemi.dynamics.hooks import ProfilerHook # noqa: F401 - def test_console_frequency(self, device: str) -> None: - with patch("nvalchemi.dynamics.hooks.profiling.logger") as mock_logger: - profiler = ProfilerHook( - "step", - show_console=True, - console_frequency=3, - ) - dynamics = _make_dynamics(hooks=[profiler], n_steps=9, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - assert mock_logger.info.call_count == 3 +def test_removed_profiler_hook_import_from_module_points_to_replacements() -> None: + """Old module-level ProfilerHook imports raise a targeted migration error.""" + with pytest.raises(ImportError, match="TorchProfilerHook.*StageTimingHook"): + from nvalchemi.dynamics.hooks.profiling import ProfilerHook # noqa: F401 -# ------------------------------------------------------------------ -# Integration -# ------------------------------------------------------------------ +def test_stage_timing_hook_still_imports_from_dynamics_package() -> None: + """StageTimingHook remains discoverable next to dynamics hooks.""" + from nvalchemi.dynamics.hooks import StageTimingHook + from nvalchemi.hooks import StageTimingHook as SharedStageTimingHook -class TestIntegration: - def test_full_loop(self, device: str) -> None: - profiler = ProfilerHook("step") - dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) - batch = _make_batch(device=device) - dynamics.run(batch) - summary = profiler.summary() - assert len(summary) > 0 - for stats in summary.values(): - assert "mean_s" in stats - assert stats["n_samples"] == 5 + assert StageTimingHook is SharedStageTimingHook diff --git a/test/hooks/test_context.py b/test/hooks/test_context.py index 4c14be47..0ca0d954 100644 --- a/test/hooks/test_context.py +++ b/test/hooks/test_context.py @@ -103,10 +103,13 @@ def test_create_with_training_fields(self): mock_optimizer = MagicMock() mock_scheduler = MagicMock() mock_gradients = {"param": torch.tensor([1.0, 2.0])} + mock_scaler = MagicMock(spec=torch.amp.GradScaler) ctx = TrainContext( batch=mock_batch, step_count=42, + batch_count=43, + epoch_step_count=2, epoch=5, loss=mock_loss, losses=mock_losses, @@ -114,11 +117,14 @@ def test_create_with_training_fields(self): optimizers=[mock_optimizer], lr_schedulers=[mock_scheduler], gradients=mock_gradients, + grad_scaler=mock_scaler, global_rank=2, ) assert ctx.batch is mock_batch assert ctx.step_count == 42 + assert ctx.batch_count == 43 + assert ctx.epoch_step_count == 2 assert ctx.epoch == 5 assert ctx.loss is mock_loss assert ctx.losses is mock_losses @@ -126,6 +132,7 @@ def test_create_with_training_fields(self): assert ctx.optimizers == [mock_optimizer] assert ctx.lr_schedulers == [mock_scheduler] assert ctx.gradients is mock_gradients + assert ctx.grad_scaler is mock_scaler assert ctx.global_rank == 2 def test_default_values_for_training_fields(self): @@ -133,10 +140,19 @@ def test_default_values_for_training_fields(self): ctx = TrainContext(batch=mock_batch) assert ctx.step_count == 0 + assert ctx.batch_count == 0 + assert ctx.epoch_step_count == 0 assert ctx.epoch == 0 assert ctx.loss is None assert ctx.losses is None assert ctx.models is None - assert ctx.optimizers is None - assert ctx.lr_schedulers is None + assert ctx.optimizers == [] + assert ctx.lr_schedulers == [] assert ctx.gradients is None + assert ctx.grad_scaler is None + + def test_optimizers_default_is_independent_per_instance(self): + ctx_a = TrainContext(batch=MagicMock()) + ctx_b = TrainContext(batch=MagicMock()) + ctx_a.optimizers.append(MagicMock()) + assert ctx_b.optimizers == [] diff --git a/test/hooks/test_physicsnemo_profiler_hook.py b/test/hooks/test_physicsnemo_profiler_hook.py new file mode 100644 index 00000000..35269647 --- /dev/null +++ b/test/hooks/test_physicsnemo_profiler_hook.py @@ -0,0 +1,350 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the PhysicsNeMo PyTorch profiler hook.""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any + +import pytest +import torch + +from nvalchemi import distributed as distributed_module +from nvalchemi.dynamics.base import DynamicsStage +from nvalchemi.hooks import HookContext, TorchProfilerHook +from nvalchemi.hooks import physicsnemo_profiling as profiling_module +from nvalchemi.training import TrainingStage + + +class _CustomStage(Enum): + """Custom enum with names that should not be claimed by the hook.""" + + BEFORE_TRAINING = 0 + AFTER_BATCH = 1 + BEFORE_STEP = 2 + AFTER_STEP = 3 + + +@dataclass +class _FakeConfig: + """Minimal stand-in for PhysicsNeMo TorchProfilerConfig.""" + + name: str = "torch" + torch_prof_activities: tuple[Any, ...] | None = None + record_shapes: bool = True + with_stack: bool = False + profile_memory: bool = True + with_flops: bool = True + schedule: Any = None + on_trace_ready_path: Path | None = None + + +class _FakeTorchProfileWrapper: + """Minimal stand-in for PhysicsNeMo TorchProfileWrapper.""" + + last_instance: _FakeTorchProfileWrapper | None = None + + def __init__(self, config: _FakeConfig) -> None: + self.config = config + self.enabled = False + type(self).last_instance = self + + +class _FakeProfiler: + """Minimal stand-in for PhysicsNeMo Profiler.""" + + def __init__(self) -> None: + self.initialized = False + self.enabled = False + self.output_path: Path | None = None + self.wrapper: _FakeTorchProfileWrapper | None = None + self.enter_count = 0 + self.exit_count = 0 + self.step_count = 0 + self.finalize_count = 0 + + def enable( + self, wrapper: _FakeTorchProfileWrapper | str + ) -> _FakeTorchProfileWrapper: + self.enabled = True + if isinstance(wrapper, str): + resolved = _FakeTorchProfileWrapper.last_instance + if resolved is None: + raise RuntimeError("Fake torch profiler was not configured.") + wrapper = resolved + wrapper.enabled = True + self.wrapper = wrapper + return wrapper + + def __enter__(self) -> _FakeProfiler: + self.initialized = True + self.enter_count += 1 + return self + + def __exit__(self, *exc: object) -> None: + self.exit_count += 1 + + def step(self) -> None: + self.step_count += 1 + + def finalize(self) -> None: + self.finalize_count += 1 + + +@dataclass +class _FakeManager: + """Structural distributed manager for rank layout tests.""" + + world_size: int = 1 + + +@dataclass +class _FakeWorkflow: + """Workflow object carrying a distributed manager.""" + + distributed_manager: _FakeManager | None = None + + +def _ctx(rank: int = 0, world_size: int = 1) -> HookContext: + """Build a base hook context for profiler tests.""" + return HookContext( + batch=None, + global_rank=rank, + workflow=_FakeWorkflow(_FakeManager(world_size=world_size)), + ) + + +@pytest.fixture() +def fake_profiler(monkeypatch: pytest.MonkeyPatch) -> _FakeProfiler: + """Patch PhysicsNeMo profiler classes with fakes.""" + profiler = _FakeProfiler() + + monkeypatch.setattr(profiling_module, "Profiler", lambda: profiler) + monkeypatch.setattr( + profiling_module, "TorchProfileWrapper", _FakeTorchProfileWrapper + ) + monkeypatch.setattr(profiling_module, "TorchProfilerConfig", _FakeConfig) + return profiler + + +def _reset_physicsnemo_profiler_state() -> None: + """Reset PhysicsNeMo profiler singleton state for smoke tests.""" + try: + from physicsnemo.utils.profiling import Profiler, TorchProfileWrapper + except ImportError: + return + Profiler._profilers.clear() + Profiler._decoration_registry.clear() + Profiler._output_top = Path("./physicsnemo_profiling_outputs/") + Profiler._initialized = False + Profiler._clear_instance() + TorchProfileWrapper._clear_instance() + + +@pytest.fixture(autouse=True) +def reset_physicsnemo_profiler_state() -> Iterator[None]: + """Keep real PhysicsNeMo profiler singletons isolated between tests.""" + _reset_physicsnemo_profiler_state() + yield + _reset_physicsnemo_profiler_state() + + +class TestTorchProfilerHookConstruction: + """TorchProfilerHook construction and stage dispatch.""" + + def test_activity_aliases_are_normalized(self, tmp_path: Path) -> None: + """String activities are normalized to PyTorch profiler enums.""" + hook = TorchProfilerHook(output_dir=tmp_path, activities=("cpu",)) + assert hook.activities == (torch.profiler.ProfilerActivity.CPU,) + + def test_unknown_activity_raises(self, tmp_path: Path) -> None: + """Unknown activity strings fail validation.""" + with pytest.raises(ValueError, match="Unknown profiler activity"): + TorchProfilerHook(output_dir=tmp_path, activities=("bogus",)) + + def test_runs_on_training_and_dynamics_stages(self, tmp_path: Path) -> None: + """The hook claims training and dynamics profiler stages only.""" + hook = TorchProfilerHook(output_dir=tmp_path) + assert hook._runs_on_stage(TrainingStage.BEFORE_TRAINING) + assert hook._runs_on_stage(TrainingStage.BEFORE_BATCH) + assert hook._runs_on_stage(TrainingStage.AFTER_BATCH) + assert hook._runs_on_stage(TrainingStage.AFTER_TRAINING) + assert hook._runs_on_stage(DynamicsStage.BEFORE_STEP) + assert hook._runs_on_stage(DynamicsStage.AFTER_STEP) + assert not hook._runs_on_stage(DynamicsStage.BEFORE_COMPUTE) + assert not hook._runs_on_stage(_CustomStage.AFTER_BATCH) + assert not hook._runs_on_stage(_CustomStage.AFTER_STEP) + + +class TestTorchProfilerHookLifecycle: + """TorchProfilerHook lifecycle behavior with fake PhysicsNeMo objects.""" + + def test_training_lifecycle_starts_steps_and_finalizes( + self, tmp_path: Path, fake_profiler: _FakeProfiler + ) -> None: + """Training stages drive start, step, and finalization.""" + hook = TorchProfilerHook(output_dir=tmp_path) + ctx = _ctx(rank=1, world_size=2) + + hook(ctx, TrainingStage.BEFORE_TRAINING) + assert fake_profiler.enter_count == 1 + assert fake_profiler.output_path == tmp_path / "rank_1" + + hook(ctx, TrainingStage.AFTER_BATCH) + assert fake_profiler.step_count == 1 + + hook(ctx, TrainingStage.AFTER_TRAINING) + hook.close() + assert fake_profiler.exit_count == 1 + assert fake_profiler.finalize_count == 1 + + def test_train_batch_fallback_starts_on_before_batch( + self, tmp_path: Path, fake_profiler: _FakeProfiler + ) -> None: + """Standalone train_batch calls can start without BEFORE_TRAINING.""" + hook = TorchProfilerHook(output_dir=tmp_path) + ctx = _ctx() + + with hook: + hook(ctx, TrainingStage.BEFORE_BATCH) + hook(ctx, TrainingStage.AFTER_BATCH) + + assert fake_profiler.enter_count == 1 + assert fake_profiler.step_count == 1 + assert fake_profiler.finalize_count == 1 + + def test_dynamics_lifecycle_starts_steps_and_finalizes( + self, tmp_path: Path, fake_profiler: _FakeProfiler + ) -> None: + """Dynamics stages drive start, step, and context finalization.""" + hook = TorchProfilerHook(output_dir=tmp_path) + ctx = _ctx(rank=0, world_size=1) + + with hook: + hook(ctx, DynamicsStage.BEFORE_STEP) + hook(ctx, DynamicsStage.AFTER_STEP) + + assert fake_profiler.output_path == tmp_path / "rank_0" + assert fake_profiler.step_count == 1 + assert fake_profiler.exit_count == 1 + assert fake_profiler.finalize_count == 1 + + def test_trace_ready_path_is_rank_suffixed( + self, tmp_path: Path, fake_profiler: _FakeProfiler + ) -> None: + """TensorBoard trace handler paths get an explicit rank directory.""" + hook = TorchProfilerHook( + output_dir=tmp_path / "out", + on_trace_ready_path=tmp_path / "traces", + activities=("cpu",), + ) + hook(_ctx(rank=2, world_size=4), DynamicsStage.BEFORE_STEP) + + assert fake_profiler.wrapper is not None + assert fake_profiler.wrapper.config.on_trace_ready_path == ( + tmp_path / "traces" / "rank_2" + ) + assert fake_profiler.output_path == tmp_path / "out" / "rank_2" + + def test_rank_subdirs_can_be_disabled_for_single_process( + self, tmp_path: Path, fake_profiler: _FakeProfiler + ) -> None: + """Single-process runs can write directly under output_dir.""" + hook = TorchProfilerHook(output_dir=tmp_path, rank_subdirs=False) + + hook(_ctx(rank=0, world_size=1), DynamicsStage.BEFORE_STEP) + + assert fake_profiler.output_path == tmp_path + + def test_rank_subdirs_disabled_still_suffixes_distributed_runs( + self, + tmp_path: Path, + fake_profiler: _FakeProfiler, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Distributed runs keep rank suffixing even when single-rank layout opts out.""" + monkeypatch.setenv("WORLD_SIZE", "2") + hook = TorchProfilerHook(output_dir=tmp_path, rank_subdirs=False) + + hook(_ctx(rank=1, world_size=2), DynamicsStage.BEFORE_STEP) + + assert fake_profiler.output_path == tmp_path / "rank_1" + + def test_physicsnemo_single_process_manager_uses_base_output_dir( + self, + tmp_path: Path, + fake_profiler: _FakeProfiler, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """When PhysicsNeMo is initialized but not distributed, use output_dir.""" + + class _InitializedManager: + distributed = False + world_size = 1 + rank = 0 + + @classmethod + def is_initialized(cls) -> bool: + return True + + monkeypatch.setattr(profiling_module, "DistributedManager", _InitializedManager) + monkeypatch.setattr( + distributed_module, "DistributedManager", _InitializedManager + ) + hook = TorchProfilerHook(output_dir=tmp_path) + + hook(_ctx(rank=1, world_size=2), DynamicsStage.BEFORE_STEP) + + assert fake_profiler.output_path == tmp_path + + def test_already_initialized_physicsnemo_profiler_raises( + self, tmp_path: Path, fake_profiler: _FakeProfiler + ) -> None: + """The hook refuses to reconfigure an active PhysicsNeMo profiler.""" + fake_profiler.initialized = True + hook = TorchProfilerHook(output_dir=tmp_path) + + with pytest.raises(RuntimeError, match="already initialized or enabled"): + hook(_ctx(), DynamicsStage.BEFORE_STEP) + + +class TestTorchProfilerHookSmoke: + """Smoke tests with the real PhysicsNeMo profiler.""" + + def test_cpu_trace_is_written(self, tmp_path: Path) -> None: + """A CPU-only profile writes PhysicsNeMo torch outputs.""" + pytest.importorskip("physicsnemo") + hook = TorchProfilerHook( + output_dir=tmp_path, + activities=("cpu",), + record_shapes=False, + profile_memory=False, + with_flops=False, + ) + ctx = _ctx() + + with hook: + hook(ctx, TrainingStage.BEFORE_TRAINING) + with torch.profiler.record_function("nvalchemi_profiler_smoke"): + (torch.ones(4) + 1).sum().item() + hook(ctx, TrainingStage.AFTER_BATCH) + + out_dir = tmp_path / "rank_0" / "torch" + assert (out_dir / "trace.json").is_file() + assert (out_dir / "cpu_time.txt").is_file() diff --git a/test/hooks/test_reporting.py b/test/hooks/test_reporting.py new file mode 100644 index 00000000..f9223122 --- /dev/null +++ b/test/hooks/test_reporting.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for hook-native reporting orchestration.""" + +from __future__ import annotations + +import warnings +from enum import Enum, auto +from typing import Any + +import pytest + +from nvalchemi.hooks import HookContext, HookRegistryMixin, TrainContext +from nvalchemi.hooks.reporting import ( + Reporter, + ReportingErrorPolicy, + ReportingOrchestrator, + ReportingState, +) + + +class _ReportStage(Enum): + AFTER_OPTIMIZER_STEP = auto() + AFTER_STEP = auto() + BEFORE_STEP = auto() + EXACT = auto() + + +class _Reporter: + def __init__( + self, + name: str = "reporter", + events: list[str] | None = None, + *, + rank_zero_only: bool = False, + requires_all_ranks: bool = False, + fail_report: bool = False, + ) -> None: + self.name = name + self.events = events + self.rank_zero_only = rank_zero_only + self.requires_all_ranks = requires_all_ranks + self.fail_report = fail_report + self.calls: list[tuple[HookContext, Enum, ReportingState]] = [] + + def report(self, ctx: HookContext, stage: Enum, state: ReportingState) -> None: + if self.fail_report: + raise RuntimeError("report failed") + self.calls.append((ctx, stage, state)) + if self.events is not None: + self.events.append(f"report:{self.name}:{stage.name}:{state.event_count}") + + +class _ContextReporter: + def __init__( + self, + name: str, + events: list[str], + *, + rank_zero_only: bool = False, + requires_all_ranks: bool = False, + fail_enter: bool = False, + fail_exit: bool = False, + ) -> None: + self.name = name + self.events = events + self.rank_zero_only = rank_zero_only + self.requires_all_ranks = requires_all_ranks + self.fail_enter = fail_enter + self.fail_exit = fail_exit + + def __enter__(self) -> _ContextReporter: + self.events.append(f"enter:{self.name}") + if self.fail_enter: + raise RuntimeError("enter failed") + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any, + ) -> None: + self.events.append(f"exit:{self.name}") + if self.fail_exit: + raise RuntimeError("exit failed") + + def close(self) -> None: + self.events.append(f"close:{self.name}") + + def report(self, ctx: HookContext, stage: Enum, state: ReportingState) -> None: + self.events.append(f"report:{self.name}") + + +class _CloseOnlyReporter: + def __init__( + self, + name: str, + events: list[str], + *, + fail_close: bool = False, + ) -> None: + self.name = name + self.events = events + self.fail_close = fail_close + + def close(self) -> None: + self.events.append(f"close:{self.name}") + if self.fail_close: + raise RuntimeError("close failed") + + def report(self, ctx: HookContext, stage: Enum, state: ReportingState) -> None: + self.events.append(f"report:{self.name}") + + +class _Engine(HookRegistryMixin): + def __init__(self, hooks: list[Reporter]) -> None: + self.step_count = 0 + self._init_hooks(hooks) + + def _build_context(self, batch: object) -> HookContext: + return _ctx(step_count=self.step_count) + + +class _RankedReportingOrchestrator(ReportingOrchestrator): + def __init__( + self, + reporters: list[Reporter], + *, + global_rank: int, + **kwargs: Any, + ) -> None: + super().__init__(reporters, **kwargs) + self._global_rank = global_rank + + @property + def global_rank(self) -> int: + return self._global_rank + + +def _ctx(*, global_rank: int = 0, step_count: int = 7) -> TrainContext: + return TrainContext( + batch=object(), + global_rank=global_rank, + step_count=step_count, + ) + + +class TestReportingOrchestratorDispatch: + def test_default_stages_cover_training_and_dynamics(self) -> None: + hook = ReportingOrchestrator([]) + + assert hook._runs_on_stage(_ReportStage.AFTER_OPTIMIZER_STEP) + assert hook._runs_on_stage(_ReportStage.AFTER_STEP) + assert not hook._runs_on_stage(_ReportStage.BEFORE_STEP) + + def test_exact_enum_stages_override_defaults(self) -> None: + hook = ReportingOrchestrator([], stages={_ReportStage.EXACT}) + + assert hook._runs_on_stage(_ReportStage.EXACT) + assert not hook._runs_on_stage(_ReportStage.AFTER_STEP) + + def test_stage_name_strings_match_enum_names(self) -> None: + hook = ReportingOrchestrator([], stages={"EXACT"}) + + assert hook._runs_on_stage(_ReportStage.EXACT) + assert not hook._runs_on_stage(_ReportStage.AFTER_STEP) + + def test_reporters_receive_original_context_stage_and_shared_state(self) -> None: + events: list[str] = [] + first = _Reporter("first", events) + second = _Reporter("second", events) + hook = ReportingOrchestrator([first, second]) + ctx = _ctx(step_count=11) + + hook(ctx, _ReportStage.AFTER_STEP) + + assert events == [ + "report:first:AFTER_STEP:1", + "report:second:AFTER_STEP:1", + ] + assert first.calls == [(ctx, _ReportStage.AFTER_STEP, hook.state)] + assert second.calls == [(ctx, _ReportStage.AFTER_STEP, hook.state)] + assert hook.state.last_stage == "AFTER_STEP" + assert hook.state.last_step_count == 11 + + def test_frequency_gating_comes_from_hook_registry(self) -> None: + reporter = _Reporter("reporter") + hook = ReportingOrchestrator([reporter], frequency=2) + engine = _Engine([hook]) + + engine.step_count = 1 + engine._call_hooks(_ReportStage.AFTER_STEP, object()) + engine.step_count = 2 + engine._call_hooks(_ReportStage.AFTER_STEP, object()) + + assert len(reporter.calls) == 1 + assert reporter.calls[0][0].step_count == 2 + + def test_orchestrator_rank_zero_only_skips_state_and_reporters(self) -> None: + reporter = _Reporter("reporter") + nonzero = _RankedReportingOrchestrator( + [reporter], + global_rank=1, + rank_zero_only=True, + ) + + nonzero(_ctx(global_rank=0), _ReportStage.AFTER_STEP) + assert nonzero.state.event_count == 0 + assert reporter.calls == [] + + rank_zero = _RankedReportingOrchestrator( + [reporter], + global_rank=0, + rank_zero_only=True, + ) + rank_zero(_ctx(global_rank=1), _ReportStage.AFTER_STEP) + assert rank_zero.state.event_count == 1 + assert len(reporter.calls) == 1 + + def test_orchestrator_rank_zero_only_dispatches_all_rank_reporters(self) -> None: + gated = _Reporter("gated") + collective = _Reporter("collective", requires_all_ranks=True) + hook = _RankedReportingOrchestrator( + [gated, collective], + global_rank=1, + rank_zero_only=True, + ) + + hook(_ctx(global_rank=1), _ReportStage.AFTER_STEP) + + assert gated.calls == [] + assert len(collective.calls) == 1 + assert hook.state.event_count == 1 + + def test_reporter_rank_zero_only_skips_only_that_reporter(self) -> None: + gated = _Reporter("gated", rank_zero_only=True) + ungated = _Reporter("ungated") + hook = _RankedReportingOrchestrator([gated, ungated], global_rank=1) + + hook(_ctx(global_rank=0), _ReportStage.AFTER_STEP) + + assert gated.calls == [] + assert len(ungated.calls) == 1 + assert hook.state.event_count == 1 + + +class TestReportingOrchestratorFailures: + @pytest.mark.parametrize( + ("policy", "expected_later_calls"), + [ + (ReportingErrorPolicy.RAISE, 0), + (ReportingErrorPolicy.WARN, 1), + (ReportingErrorPolicy.IGNORE, 1), + ], + ) + def test_report_failure_policy_records_message_and_controls_fanout( + self, + policy: ReportingErrorPolicy, + expected_later_calls: int, + ) -> None: + later = _Reporter("later") + hook = ReportingOrchestrator( + [_Reporter(fail_report=True), later], + error_policy=policy, + ) + ctx = _ctx(global_rank=2) + + if policy == ReportingErrorPolicy.RAISE: + with pytest.raises(RuntimeError, match="report failed"): + hook(ctx, _ReportStage.AFTER_STEP) + elif policy == ReportingErrorPolicy.WARN: + with pytest.warns(UserWarning, match="failed during report"): + hook(ctx, _ReportStage.AFTER_STEP) + else: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + hook(ctx, _ReportStage.AFTER_STEP) + assert caught == [] + + assert len(later.calls) == expected_later_calls + assert len(hook.state.messages) == 1 + message = hook.state.messages[0] + assert message.message.startswith("_Reporter failed during report") + assert message.stage == "AFTER_STEP" + assert message.step_count == 7 + assert message.global_rank == 2 + + +class TestReportingOrchestratorLifecycle: + def test_nested_context_enters_and_exits_once(self) -> None: + events: list[str] = [] + hook = ReportingOrchestrator([_ContextReporter("reporter", events)]) + + with hook: + with hook: + assert events == ["enter:reporter"] + + assert events == ["enter:reporter", "exit:reporter"] + + def test_context_exits_in_reverse_order_and_prefers_exit_over_close(self) -> None: + events: list[str] = [] + hook = ReportingOrchestrator( + [ + _ContextReporter("first", events), + _ContextReporter("second", events), + ] + ) + + with hook: + pass + + assert events == [ + "enter:first", + "enter:second", + "exit:second", + "exit:first", + ] + + def test_close_only_reporters_close_in_reverse_order_once(self) -> None: + events: list[str] = [] + hook = ReportingOrchestrator( + [ + _CloseOnlyReporter("first", events), + _CloseOnlyReporter("second", events), + ] + ) + + hook.close() + hook.close() + + assert events == ["close:second", "close:first"] + + def test_close_inside_context_prevents_double_exit(self) -> None: + events: list[str] = [] + hook = ReportingOrchestrator([_ContextReporter("reporter", events)]) + + with hook: + hook.close() + + assert events == ["enter:reporter", "exit:reporter"] + + def test_enter_failure_unwinds_already_entered_reporters(self) -> None: + events: list[str] = [] + hook = ReportingOrchestrator( + [ + _ContextReporter("first", events), + _ContextReporter("second", events, fail_enter=True), + ] + ) + + with pytest.raises(RuntimeError, match="enter failed"): + with hook: + pass + + assert events == ["enter:first", "enter:second", "exit:first"] + assert hook.state.messages[-1].message.startswith( + "_ContextReporter failed during enter" + ) + + def test_close_failure_still_attempts_remaining_reporters(self) -> None: + events: list[str] = [] + hook = ReportingOrchestrator( + [ + _CloseOnlyReporter("first", events), + _CloseOnlyReporter("second", events, fail_close=True), + ] + ) + + with pytest.raises(RuntimeError, match="close failed"): + hook.close() + + assert events == ["close:second", "close:first"] + assert hook.state.messages[-1].message.startswith( + "_CloseOnlyReporter failed during close" + ) + + def test_cleanup_failure_warns_without_replacing_workflow_exception(self) -> None: + events: list[str] = [] + hook = ReportingOrchestrator( + [_ContextReporter("reporter", events, fail_exit=True)] + ) + + with pytest.warns(UserWarning, match="failed during close"): + with pytest.raises(ValueError, match="workflow failed"): + with hook: + raise ValueError("workflow failed") + + assert events == ["enter:reporter", "exit:reporter"] + + def test_failed_enter_reporter_is_disabled_under_non_raising_policy(self) -> None: + events: list[str] = [] + failed = _ContextReporter("failed", events, fail_enter=True) + active = _Reporter("active", events) + hook = ReportingOrchestrator( + [failed, active], + error_policy=ReportingErrorPolicy.WARN, + ) + + with pytest.warns(UserWarning, match="failed during enter"): + with hook: + hook(_ctx(), _ReportStage.AFTER_STEP) + + assert events == [ + "enter:failed", + "report:active:AFTER_STEP:1", + ] + + def test_rank_zero_only_orchestrator_skips_lifecycle_on_nonzero_rank( + self, + ) -> None: + events: list[str] = [] + hook = _RankedReportingOrchestrator( + [_ContextReporter("reporter", events)], + global_rank=1, + rank_zero_only=True, + ) + + with hook: + pass + hook.close() + + assert events == [] + + def test_rank_zero_only_orchestrator_enters_all_rank_reporters_on_nonzero_rank( + self, + ) -> None: + events: list[str] = [] + hook = _RankedReportingOrchestrator( + [_ContextReporter("reporter", events, requires_all_ranks=True)], + global_rank=1, + rank_zero_only=True, + ) + + with hook: + hook(_ctx(global_rank=1), _ReportStage.AFTER_STEP) + + assert events == ["enter:reporter", "report:reporter", "exit:reporter"] + + def test_rank_zero_only_reporter_skips_lifecycle_on_nonzero_rank( + self, + ) -> None: + events: list[str] = [] + hook = _RankedReportingOrchestrator( + [_ContextReporter("reporter", events, rank_zero_only=True)], + global_rank=1, + ) + + with hook: + pass + hook.close() + + assert events == [] + + +class TestReportingState: + def test_state_tracks_event_metadata_and_bounds_messages(self) -> None: + state = ReportingState(max_messages=2) + ctx = _ctx(global_rank=3, step_count=19) + + state.mark_event(ctx, _ReportStage.AFTER_STEP) + + assert state.event_count == 1 + assert state.last_stage == "AFTER_STEP" + assert state.last_step_count == 19 + assert state.last_global_rank == 3 + + state.add_message("info", "first", ctx=ctx, stage=_ReportStage.AFTER_STEP) + state.add_message("warning", "second", ctx=ctx, stage=_ReportStage.BEFORE_STEP) + state.add_message("error", "third", ctx=ctx, stage=_ReportStage.EXACT) + + assert [message.message for message in state.messages] == ["second", "third"] + assert state.messages[-1].stage == "EXACT" + assert state.messages[-1].step_count == 19 + assert state.messages[-1].global_rank == 3 diff --git a/test/hooks/test_reporting_rich.py b/test/hooks/test_reporting_rich.py new file mode 100644 index 00000000..be9236f9 --- /dev/null +++ b/test/hooks/test_reporting_rich.py @@ -0,0 +1,460 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Rich reporting.""" + +from __future__ import annotations + +from enum import Enum, auto +from io import StringIO +from types import SimpleNamespace + +import pytest +import torch +from rich.console import Console + +from nvalchemi.hooks import DynamicsContext, HookContext, TrainContext +from nvalchemi.hooks.reporting import ( + BaseRichLayout, + DynamicsRichLayout, + ReportingState, + RichReporter, + TrainingRichLayout, +) + + +class _ReportStage(Enum): + AFTER_OPTIMIZER_STEP = auto() + AFTER_STEP = auto() + OTHER = auto() + + +class _RecordingLive: + def __init__(self) -> None: + self.refresh_values: list[bool] = [] + + def update(self, renderable: object, *, refresh: bool = False) -> None: + self.refresh_values.append(refresh) + + +def _ctx( + *, + global_rank: int = 0, + loss: torch.Tensor | None = None, + workflow: object | None = None, +) -> TrainContext: + return TrainContext( + batch=object(), + global_rank=global_rank, + workflow=workflow, + step_count=17, + batch_count=19, + epoch_step_count=3, + epoch=5, + loss=loss, + ) + + +def _state( + ctx: DynamicsContext | HookContext | TrainContext, + stage: _ReportStage = _ReportStage.AFTER_OPTIMIZER_STEP, +) -> ReportingState: + state = ReportingState() + state.mark_event(ctx, stage) + return state + + +def _dynamics_ctx(*, global_rank: int = 0) -> DynamicsContext: + batch = SimpleNamespace( + num_graphs=2, + energy=torch.tensor([[-1.0], [-3.0]]), + forces=torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + ] + ), + velocities=torch.zeros(3, 3), + atomic_masses=torch.ones(3), + batch_idx=torch.tensor([0, 0, 1]), + num_nodes_per_graph=torch.tensor([2, 1]), + status=torch.tensor([[0], [1]]), + ) + return DynamicsContext( + batch=batch, + global_rank=global_rank, + step_count=23, + converged_mask=torch.tensor([False, True]), + workflow=SimpleNamespace(exit_status=1, n_steps=50), + ) + + +def _console(buffer: StringIO) -> Console: + return Console( + file=buffer, + force_terminal=False, + color_system=None, + width=120, + ) + + +def test_rich_reporter_prints_live_dashboard() -> None: + buffer = StringIO() + ctx = _ctx(loss=torch.tensor(2.5)) + reporter = RichReporter( + custom_scalars={"metric": lambda context, stage: 9.0}, # noqa: ARG005 + title="training", + console=_console(buffer), + ) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + output = buffer.getvalue() + assert "training" in output + assert "AFTER_OPTIMIZER_STEP" in output + assert "step 17" in output + assert "loss/total" in output + assert "2.5" in output + assert "metric" in output + assert "9" in output + assert "rank" in output + assert "event" in output + assert "Progress" in output + assert "Messages" in output + assert "Training Curves" in output + assert reporter.history["loss/total"] == ((17, 2.5),) + + +def test_rich_reporter_defaults_to_rank_zero_only() -> None: + buffer = StringIO() + ctx = _ctx(global_rank=1, loss=torch.tensor(2.5)) + reporter = RichReporter(console=_console(buffer)) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + assert reporter.rank_zero_only is True + assert buffer.getvalue() == "" + + +def test_rich_reporter_reduction_uses_all_rank_dispatch_and_rank_zero_write() -> None: + buffer = StringIO() + ctx = _ctx(loss=torch.tensor(2.5)) + reporter = RichReporter( + rank_reduction="mean", + console=_console(buffer), + ) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + assert reporter.rank_zero_only is False + assert reporter.requires_all_ranks is True + assert "loss/total" in buffer.getvalue() + + +def test_rich_reporter_reduction_skips_nonzero_rank_write() -> None: + buffer = StringIO() + ctx = _ctx(global_rank=1, loss=torch.tensor(2.5)) + reporter = RichReporter( + rank_reduction="mean", + console=_console(buffer), + ) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + assert buffer.getvalue() == "" + + +def test_rich_reporter_reduction_context_starts_live_only_on_rank_zero() -> None: + buffer = StringIO() + reporter = RichReporter( + rank_reduction="mean", + console=_console(buffer), + transient=True, + ) + + with reporter: + assert reporter.rank_zero_only is False + assert reporter._live is None + + nonzero_ctx = _ctx(global_rank=1, loss=torch.tensor(2.5)) + reporter.report( + nonzero_ctx, + _ReportStage.AFTER_OPTIMIZER_STEP, + _state(nonzero_ctx), + ) + assert reporter._live is None + assert buffer.getvalue() == "" + + rank_zero_ctx = _ctx(loss=torch.tensor(2.5)) + reporter.report( + rank_zero_ctx, + _ReportStage.AFTER_OPTIMIZER_STEP, + _state(rank_zero_ctx), + ) + assert reporter._live is not None + + assert reporter._live is None + + +def test_rich_reporter_live_update_uses_configured_refresh_cadence() -> None: + ctx = _ctx(loss=torch.tensor(2.5)) + live = _RecordingLive() + reporter = RichReporter() + reporter._live = live + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + assert live.refresh_values == [False] + + +def test_rich_reporter_max_scalars_truncates_output() -> None: + buffer = StringIO() + ctx = _ctx(loss=torch.tensor(2.5)) + reporter = RichReporter( + custom_scalars={ + "first": lambda context, stage: 1.0, # noqa: ARG005 + "second": lambda context, stage: 2.0, # noqa: ARG005 + }, + max_scalars=1, + console=_console(buffer), + ) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + output = buffer.getvalue() + assert "omitted" in output + + +def test_rich_reporter_seed_history_supports_preview_data() -> None: + buffer = StringIO() + reporter = RichReporter(title="preview", console=_console(buffer)) + + snapshot = reporter.seed_history( + { + "loss/total": [1.0, 0.5, 0.25], + "optimizer/lr": [1e-3, 5e-4, 1e-4], + }, + steps=[10, 20, 30], + epoch=2, + batch_count=64, + ) + reporter.console.print(reporter.renderable()) + + output = buffer.getvalue() + assert snapshot.scalars == {"loss/total": 0.25, "optimizer/lr": 1e-4} + assert reporter.history["loss/total"] == ((10, 1.0), (20, 0.5), (30, 0.25)) + assert "preview" in output + assert "loss/total" in output + assert "optimizer/lr" in output + + +def test_rich_reporter_preview_renders_default_dashboard() -> None: + buffer = StringIO() + + RichReporter.preview(console=_console(buffer), title="preview") + + output = buffer.getvalue() + assert "preview" in output + assert "loss/total" in output + assert "optimizer/lr" in output + + +def test_rich_reporter_layout_names_resolve_to_layouts() -> None: + training = RichReporter(layout="training") + dynamics = RichReporter(layout="dynamics") + custom = DynamicsRichLayout() + + custom_reporter = RichReporter(layout=custom) + + assert isinstance(training.layout, TrainingRichLayout) + assert isinstance(dynamics.layout, DynamicsRichLayout) + assert custom_reporter.layout is custom + assert isinstance(training.layout, BaseRichLayout) + + +def test_rich_reporter_auto_selects_layout_from_context() -> None: + train_buffer = StringIO() + train_ctx = _ctx(loss=torch.tensor(2.5)) + train_reporter = RichReporter(console=_console(train_buffer)) + + train_reporter.report( + train_ctx, + _ReportStage.AFTER_OPTIMIZER_STEP, + _state(train_ctx), + ) + + assert isinstance(train_reporter.layout, TrainingRichLayout) + assert "training" in train_buffer.getvalue() + + dynamics_buffer = StringIO() + dynamics_ctx = _dynamics_ctx() + dynamics_reporter = RichReporter(console=_console(dynamics_buffer), max_plots=0) + + dynamics_reporter.report( + dynamics_ctx, + _ReportStage.AFTER_STEP, + _state(dynamics_ctx, _ReportStage.AFTER_STEP), + ) + + assert isinstance(dynamics_reporter.layout, DynamicsRichLayout) + assert "dynamics" in dynamics_buffer.getvalue() + + +def test_rich_reporter_auto_layout_ignores_unknown_context_by_default() -> None: + buffer = StringIO() + reporter = RichReporter(console=_console(buffer)) + ctx = HookContext(batch=object()) + + reporter.report(ctx, _ReportStage.OTHER, _state(ctx)) + + assert buffer.getvalue() == "" + + +def test_rich_reporter_strict_auto_layout_rejects_unknown_context() -> None: + reporter = RichReporter(strict_layout=True) + ctx = HookContext(batch=object()) + + with pytest.raises(ValueError, match="could not select a layout"): + reporter.report(ctx, _ReportStage.OTHER, _state(ctx)) + + +def test_rich_layouts_are_available_from_workflow_submodules() -> None: + from nvalchemi.hooks.reporting.layouts.dynamics import ( + DynamicsRichLayout as Dynamics, + ) + from nvalchemi.hooks.reporting.layouts.train import TrainingRichLayout as Training + + assert isinstance(Training(), TrainingRichLayout) + assert isinstance(Dynamics(), DynamicsRichLayout) + + +def test_rich_reporter_dynamics_preview_uses_dynamics_metrics() -> None: + buffer = StringIO() + + RichReporter.preview(console=_console(buffer), layout="dynamics", title="preview") + + output = buffer.getvalue() + assert "preview" in output + assert "dynamics" in output + assert "AFTER_STEP" in output + assert "AFTER_OPTIMIZER_STEP" not in output + assert "fmax" in output + assert "temperature" in output + assert "converged_fraction" in output + assert "loss/total" not in output + assert "epoch=" not in output + assert "batch=" not in output + + +def test_rich_reporter_dynamics_layout_collects_default_metrics() -> None: + buffer = StringIO() + ctx = _dynamics_ctx() + reporter = RichReporter( + layout="dynamics", + console=_console(buffer), + max_plots=0, + ) + + reporter.report(ctx, _ReportStage.AFTER_STEP, _state(ctx, _ReportStage.AFTER_STEP)) + + output = buffer.getvalue() + assert "dynamics" in output + assert "Observables" in output + assert "Convergence" in output + assert "Pipeline" in output + assert "Messages" in output + assert "Dynamics Traces" in output + assert "energy" in output + assert "fmax" in output + assert "temperature" in output + assert "converged_fraction" in output + assert "active_fraction" in output + assert "graduated" in output + assert "converged" in output + assert "status 0" in output + assert reporter.history["energy"] == ((23, -2.0),) + assert reporter.history["fmax"] == ((23, 3.0),) + assert reporter.history["temperature"] == ((23, 0.0),) + assert reporter.history["converged_fraction"] == ((23, 0.5),) + assert reporter.history["dynamics/converged_count"] == ((23, 1.0),) + assert reporter.history["active_fraction"] == ((23, 0.5),) + + +def test_rich_reporter_rejects_unknown_layout() -> None: + class PartialLayout: + def default_preview_history(self) -> dict[str, list[float]]: + return {"metric": [1.0]} + + def render(self, *args: object, **kwargs: object) -> object: + return object() + + with pytest.raises(ValueError, match="layout"): + RichReporter(layout="unknown") + with pytest.raises(TypeError, match="layout objects"): + RichReporter(layout=object()) + with pytest.raises(TypeError, match="default_preview_stage"): + RichReporter(layout=PartialLayout()) + + +def test_rich_reporter_live_context_updates_and_closes() -> None: + buffer = StringIO() + ctx = _ctx(loss=torch.tensor(2.5)) + reporter = RichReporter(console=_console(buffer), transient=True) + + with reporter: + assert reporter._live is None + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + assert reporter._live is not None + + assert reporter._live is None + assert reporter.history["loss/total"] == ((17, 2.5),) + + +def test_rich_reporter_renders_recent_messages() -> None: + buffer = StringIO() + ctx = _ctx(loss=torch.tensor(2.5)) + state = _state(ctx) + state.add_message( + "warning", + "scheduler stepped before optimizer", + ctx=ctx, + stage=_ReportStage.AFTER_OPTIMIZER_STEP, + ) + reporter = RichReporter(console=_console(buffer)) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, state) + + output = buffer.getvalue() + assert "Messages" in output + assert "scheduler stepped before optimizer" in output + + +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ({"precision": -1}, "precision"), + ({"max_scalars": 0}, "max_scalars"), + ({"history_size": 0}, "history_size"), + ({"max_plots": -1}, "max_plots"), + ({"plot_height": 3}, "plot_height"), + ({"refresh_per_second": 0}, "refresh_per_second"), + ], +) +def test_rich_reporter_validates_formatting_options( + kwargs: dict[str, int], + message: str, +) -> None: + with pytest.raises(ValueError, match=message): + RichReporter(**kwargs) diff --git a/test/hooks/test_reporting_scalars.py b/test/hooks/test_reporting_scalars.py new file mode 100644 index 00000000..b9c9f782 --- /dev/null +++ b/test/hooks/test_reporting_scalars.py @@ -0,0 +1,420 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for reporting scalar extraction and reduction helpers.""" + +from __future__ import annotations + +import json +import sys +import time +from datetime import timedelta +from enum import Enum, auto +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest +import torch +from torch import distributed as dist +from torch import multiprocessing as mp + +import nvalchemi.hooks.reporting._distributed as reporting_distributed +from nvalchemi.hooks import TrainContext +from nvalchemi.hooks.reporting import ( + ReportingState, + ScalarSnapshot, + collect_scalars, + extract_loss_scalars, + extract_scalars, +) +from nvalchemi.hooks.reporting._distributed import reduce_scalar_snapshot + + +class _ReportStage(Enum): + AFTER_OPTIMIZER_STEP = auto() + + +def _ctx( + *, + global_rank: int = 2, + loss: torch.Tensor | None = None, + losses: dict[str, object] | None = None, + optimizers: list[torch.optim.Optimizer] | None = None, + lr_schedulers: list[torch.optim.lr_scheduler.LRScheduler | None] | None = None, +) -> TrainContext: + return TrainContext( + batch=object(), + global_rank=global_rank, + step_count=17, + batch_count=19, + epoch_step_count=3, + epoch=5, + loss=loss, + losses=losses, + optimizers=optimizers or [], + lr_schedulers=lr_schedulers or [], + ) + + +def _install_fake_physicsnemo_manager( + monkeypatch: pytest.MonkeyPatch | None = None, + *, + device: str | torch.device = "cpu", + initialized: bool = True, +) -> None: + physicsnemo_module = ModuleType("physicsnemo") + distributed_module = ModuleType("physicsnemo.distributed") + + class FakeDistributedManager: + @classmethod + def is_initialized(cls) -> bool: + return initialized + + def __init__(self) -> None: + self.device = torch.device(device) + + distributed_module.DistributedManager = FakeDistributedManager + physicsnemo_module.distributed = distributed_module + if monkeypatch is None: + sys.modules["physicsnemo"] = physicsnemo_module + sys.modules["physicsnemo.distributed"] = distributed_module + else: + monkeypatch.setitem(sys.modules, "physicsnemo", physicsnemo_module) + monkeypatch.setitem( + sys.modules, + "physicsnemo.distributed", + distributed_module, + ) + + +def _distributed_reduce_worker(rank: int, init_file: str, output_dir: str) -> None: + _install_fake_physicsnemo_manager() + dist.init_process_group( + "gloo", + init_method=f"file://{init_file}", + world_size=2, + rank=rank, + timeout=timedelta(seconds=30), + ) + try: + snapshot = ScalarSnapshot( + stage="AFTER_OPTIMIZER_STEP", + scalars={ + "loss/total": float(rank + 1), + "metric": float((rank + 1) * 10), + }, + global_rank=rank, + ) + results: dict[str, object] = {} + for reduction in ( + "mean", + dist.ReduceOp.SUM, + dist.ReduceOp.MIN, + dist.ReduceOp.MAX, + ): + reduced = reduce_scalar_snapshot( + snapshot, + reduction, + reporter_name="TestReporter", + ) + name = reduction if isinstance(reduction, str) else str(reduction).lower() + results[name.rsplit(".", maxsplit=1)[-1]] = reduced.scalars + + mismatched_snapshot = ScalarSnapshot( + stage="AFTER_OPTIMIZER_STEP", + scalars={f"rank/{rank}": float(rank)}, + global_rank=rank, + ) + try: + reduce_scalar_snapshot( + mismatched_snapshot, + dist.ReduceOp.SUM, + reporter_name="TestReporter", + ) + except ValueError as exc: + results["mismatch"] = str(exc) + else: + results["mismatch"] = "missing-error" + + output_path = Path(output_dir) / f"rank-{rank}.json" + output_path.write_text(json.dumps(results, sort_keys=True), encoding="utf-8") + finally: + dist.destroy_process_group() + + +def test_extract_loss_scalars_handles_simple_training_losses() -> None: + ctx = _ctx( + loss=torch.tensor(1.5), + losses={ + "energy": torch.tensor(0.4), + "force": torch.tensor(0.1), + }, + ) + + scalars = extract_loss_scalars(ctx) + + assert scalars == pytest.approx( + { + "loss/total": 1.5, + "loss/energy": 0.4, + "loss/force": 0.1, + } + ) + + +def test_extract_loss_scalars_handles_composed_loss_output() -> None: + ctx = _ctx( + loss=torch.tensor(99.0), + losses={ + "total_loss": torch.tensor(3.0), + "per_component_unweighted": { + "energy": torch.tensor(1.0), + "force": torch.tensor([2.0]), + }, + "per_component_weight": {"energy": 0.25, "force": 0.75}, + "per_component_raw_weight": {"energy": 1.0, "force": 3.0}, + "per_component_sample": { + "energy": torch.tensor([1.0, 3.0]), + "force": torch.tensor([2.0, 6.0]), + }, + }, + ) + + scalars = extract_loss_scalars(ctx) + + assert scalars == pytest.approx( + { + "loss/total": 3.0, + "loss/energy/unweighted": 1.0, + "loss/force/unweighted": 2.0, + "loss/energy/weight": 0.25, + "loss/force/weight": 0.75, + "loss/energy/raw_weight": 1.0, + "loss/force/raw_weight": 3.0, + "loss/energy/sample_mean": 2.0, + "loss/force/sample_mean": 4.0, + } + ) + + +def test_extract_scalars_flattens_nested_mapping() -> None: + scalars = extract_scalars( + { + "outer": { + "inner": torch.tensor(2.0), + "flag": True, + }, + "plain": 3, + }, + prefix="custom", + ) + + assert scalars == { + "custom/outer/inner": 2.0, + "custom/outer/flag": 1.0, + "custom/plain": 3.0, + } + + +def test_extract_scalars_rejects_non_scalar_tensor() -> None: + with pytest.raises(ValueError, match="'vector' must be scalar"): + extract_scalars({"vector": torch.tensor([1.0, 2.0])}) + + +def test_collect_scalars_includes_metadata_custom_scalars_and_lrs() -> None: + parameter = torch.nn.Parameter(torch.tensor(1.0)) + optimizer = torch.optim.SGD([parameter], lr=0.125) + ctx = _ctx(loss=torch.tensor(2.5), optimizers=[optimizer]) + state = ReportingState() + state.mark_event(ctx, _ReportStage.AFTER_OPTIMIZER_STEP) + + snapshot = collect_scalars( + ctx, + _ReportStage.AFTER_OPTIMIZER_STEP, + state, + custom_scalars={ + "metric": lambda context, stage: torch.tensor(4.5), # noqa: ARG005 + "nested": lambda context, stage: {"value": 6.0}, # noqa: ARG005 + }, + ) + + assert snapshot.stage == "AFTER_OPTIMIZER_STEP" + assert snapshot.event_count == 1 + assert snapshot.step_count == 17 + assert snapshot.batch_count == 19 + assert snapshot.epoch_step_count == 3 + assert snapshot.epoch == 5 + assert snapshot.global_rank == 2 + assert snapshot.elapsed_s is not None + assert snapshot.scalars == pytest.approx( + { + "loss/total": 2.5, + "optimizer/lr": 0.125, + "metric": 4.5, + "nested/value": 6.0, + } + ) + + +def test_collect_scalars_extracts_scheduler_lrs() -> None: + parameter = torch.nn.Parameter(torch.tensor(1.0)) + optimizer = torch.optim.SGD([parameter], lr=0.125) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5) + ctx = _ctx(optimizers=[optimizer], lr_schedulers=[scheduler]) + + snapshot = collect_scalars(ctx, _ReportStage.AFTER_OPTIMIZER_STEP) + + assert snapshot.scalars == pytest.approx( + { + "optimizer/lr": 0.125, + "scheduler/lr": 0.125, + } + ) + + +def test_collect_scalars_preserves_scheduler_slot_indices() -> None: + first_parameter = torch.nn.Parameter(torch.tensor(1.0)) + second_parameter = torch.nn.Parameter(torch.tensor(2.0)) + first_optimizer = torch.optim.SGD([first_parameter], lr=0.125) + second_optimizer = torch.optim.SGD([second_parameter], lr=0.25) + scheduler = torch.optim.lr_scheduler.StepLR( + second_optimizer, + step_size=1, + gamma=0.5, + ) + ctx = _ctx( + optimizers=[first_optimizer, second_optimizer], + lr_schedulers=[None, scheduler], + ) + + snapshot = collect_scalars(ctx, _ReportStage.AFTER_OPTIMIZER_STEP) + + assert snapshot.scalars == pytest.approx( + { + "optimizer/0/lr": 0.125, + "optimizer/1/lr": 0.25, + "scheduler/1/lr": 0.25, + } + ) + + +@pytest.mark.skipif( + not dist.is_available() or not dist.is_gloo_available(), + reason="torch.distributed gloo backend is required", +) +def test_reduce_scalar_snapshot_uses_initialized_process_group(tmp_path) -> None: + output_dir = tmp_path / "distributed-results" + output_dir.mkdir() + + mp.spawn( + _distributed_reduce_worker, + args=(str(tmp_path / "distributed-init"), str(output_dir)), + nprocs=2, + join=True, + ) + + rank_results = [ + json.loads((output_dir / f"rank-{rank}.json").read_text(encoding="utf-8")) + for rank in range(2) + ] + expected = { + "mean": {"loss/total": 1.5, "metric": 15.0}, + "sum": {"loss/total": 3.0, "metric": 30.0}, + "min": {"loss/total": 1.0, "metric": 10.0}, + "max": {"loss/total": 2.0, "metric": 20.0}, + } + for results in rank_results: + for reduction, expected_scalars in expected.items(): + assert results[reduction] == pytest.approx(expected_scalars) + assert "same scalar keys" in results["mismatch"] + + +def test_reduce_scalar_snapshot_batches_scalar_collective(monkeypatch) -> None: + all_reduce_sizes: list[int] = [] + + def fake_all_gather_object( + gathered_keys: list[tuple[str, ...]], + keys: tuple[str, ...], + ) -> None: + gathered_keys[:] = [keys, keys] + + def fake_all_reduce(values: torch.Tensor, op: dist.ReduceOp) -> None: + all_reduce_sizes.append(values.numel()) + values.mul_(2.0) + + monkeypatch.setattr(reporting_distributed.dist, "is_available", lambda: True) + monkeypatch.setattr(reporting_distributed.dist, "is_initialized", lambda: True) + monkeypatch.setattr(reporting_distributed.dist, "get_world_size", lambda: 2) + monkeypatch.setattr( + reporting_distributed.dist, + "all_gather_object", + fake_all_gather_object, + ) + monkeypatch.setattr(reporting_distributed.dist, "all_reduce", fake_all_reduce) + monkeypatch.setattr( + reporting_distributed, + "_collective_device", + lambda: torch.device("cpu"), + ) + snapshot = ScalarSnapshot( + stage="AFTER_OPTIMIZER_STEP", + scalars={"a": 1.0, "b": 2.0, "c": 3.0}, + ) + + reduced = reduce_scalar_snapshot( + snapshot, + dist.ReduceOp.SUM, + reporter_name="TestReporter", + ) + + assert all_reduce_sizes == [3] + assert reduced.scalars == pytest.approx({"a": 2.0, "b": 4.0, "c": 6.0}) + + +def test_collective_device_uses_physicsnemo_distributed_manager(monkeypatch) -> None: + _install_fake_physicsnemo_manager(monkeypatch, device="cpu") + + assert reporting_distributed._collective_device() == torch.device("cpu") + + +def test_collective_device_requires_initialized_physicsnemo_manager( + monkeypatch, +) -> None: + _install_fake_physicsnemo_manager( + monkeypatch, + initialized=False, + ) + + with pytest.raises(RuntimeError, match="DistributedManager to be initialized"): + reporting_distributed._collective_device() + + +def test_collect_scalars_can_include_training_progress() -> None: + ctx = _ctx(loss=torch.tensor(2.5)) + ctx.workflow = SimpleNamespace(num_steps=20, num_epochs=10) + state = ReportingState(started_at_s=time.monotonic() - 10.0) + state.mark_event(ctx, _ReportStage.AFTER_OPTIMIZER_STEP) + + snapshot = collect_scalars( + ctx, + _ReportStage.AFTER_OPTIMIZER_STEP, + state, + include_progress=True, + ) + + assert snapshot.scalars["training/progress_fraction"] == pytest.approx(17 / 20) + assert snapshot.scalars["training/remaining_steps"] == pytest.approx(3.0) + assert snapshot.scalars["training/target_epochs"] == pytest.approx(10.0) + assert snapshot.scalars["training/steps_per_s"] > 0 + assert snapshot.scalars["training/eta_s"] > 0 diff --git a/test/hooks/test_reporting_tensorboard.py b/test/hooks/test_reporting_tensorboard.py new file mode 100644 index 00000000..a8056899 --- /dev/null +++ b/test/hooks/test_reporting_tensorboard.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for TensorBoard reporting.""" + +from __future__ import annotations + +from enum import Enum, auto + +import pytest +import torch + +from nvalchemi._optional import OptionalDependency, OptionalDependencyError +from nvalchemi.hooks import TrainContext +from nvalchemi.hooks.reporting import ( + ReportingState, + TensorBoardReporter, +) + + +class _ReportStage(Enum): + AFTER_OPTIMIZER_STEP = auto() + + +class _RecordingWriter: + def __init__(self) -> None: + self.scalars: list[tuple[str, float, int | None]] = [] + self.flushed = 0 + self.closed = 0 + + def add_scalar( + self, + tag: str, + scalar_value: float, + global_step: int | None = None, + ) -> None: + self.scalars.append((tag, scalar_value, global_step)) + + def flush(self) -> None: + self.flushed += 1 + + def close(self) -> None: + self.closed += 1 + + +def _ctx(*, global_rank: int = 0, loss: torch.Tensor | None = None) -> TrainContext: + return TrainContext( + batch=object(), + global_rank=global_rank, + step_count=17, + batch_count=19, + epoch_step_count=3, + epoch=5, + loss=loss, + ) + + +def _state(ctx: TrainContext) -> ReportingState: + state = ReportingState() + state.mark_event(ctx, _ReportStage.AFTER_OPTIMIZER_STEP) + return state + + +@pytest.fixture(autouse=True) +def _tensorboard_available(monkeypatch) -> None: + dep = OptionalDependency.TENSORBOARD + monkeypatch.setattr(dep, "_available", True) + monkeypatch.setattr(dep, "_import_error", None) + + +def test_tensorboard_reporter_writes_scalar_tags_with_step(tmp_path) -> None: + writer = _RecordingWriter() + ctx = _ctx(loss=torch.tensor(2.5)) + reporter = TensorBoardReporter( + tmp_path / "runs", + custom_scalars={"metric": lambda context, stage: 9.0}, # noqa: ARG005 + tag_prefix="train", + writer=writer, + ) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + assert writer.scalars == [ + ("train/loss/total", 2.5, 17), + ("train/metric", 9.0, 17), + ] + assert writer.flushed == 1 + + +def test_tensorboard_reporter_defaults_to_rank_zero_only(tmp_path) -> None: + writer = _RecordingWriter() + ctx = _ctx(global_rank=1, loss=torch.tensor(2.5)) + reporter = TensorBoardReporter(tmp_path / "runs", writer=writer) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + assert reporter.rank_zero_only is True + assert writer.scalars == [] + + +def test_tensorboard_reporter_requires_rank_token_for_all_rank_writes( + tmp_path, +) -> None: + with pytest.raises(ValueError, match="must contain '\\{rank\\}'"): + TensorBoardReporter(tmp_path / "runs", rank_zero_only=False) + + +def test_tensorboard_reporter_all_rank_write_accepts_rank_safe_log_dir( + tmp_path, +) -> None: + writer = _RecordingWriter() + ctx = _ctx(global_rank=3, loss=torch.tensor(2.5)) + reporter = TensorBoardReporter( + tmp_path / "runs-rank-{rank}", + rank_zero_only=False, + writer=writer, + ) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + assert writer.scalars == [("loss/total", 2.5, 17)] + + +def test_tensorboard_reduction_uses_all_rank_dispatch_and_rank_zero_write( + tmp_path, +) -> None: + writer = _RecordingWriter() + ctx = _ctx(global_rank=0, loss=torch.tensor(2.5)) + reporter = TensorBoardReporter( + tmp_path / "runs", + rank_reduction="mean", + writer=writer, + ) + + reporter.report(ctx, _ReportStage.AFTER_OPTIMIZER_STEP, _state(ctx)) + + assert reporter.rank_zero_only is False + assert writer.scalars == [("loss/total", 2.5, 17)] + + +def test_tensorboard_close_closes_writer(tmp_path) -> None: + writer = _RecordingWriter() + reporter = TensorBoardReporter(tmp_path / "runs", writer=writer) + + reporter.close() + + assert writer.closed == 1 + + +def test_tensorboard_missing_extra_uses_optional_dependency_error( + tmp_path, + monkeypatch, +) -> None: + dep = OptionalDependency.TENSORBOARD + monkeypatch.setattr(dep, "_available", False) + monkeypatch.setattr(dep, "_import_error", ImportError("missing tensorboard")) + + with pytest.raises( + OptionalDependencyError, match="nvalchemi-toolkit\\[tensorboard\\]" + ): + TensorBoardReporter(tmp_path / "runs") diff --git a/test/hooks/test_shared_hooks.py b/test/hooks/test_shared_hooks.py index ee46101b..ce421a1e 100644 --- a/test/hooks/test_shared_hooks.py +++ b/test/hooks/test_shared_hooks.py @@ -21,7 +21,6 @@ from nvalchemi.data import AtomicData, Batch from nvalchemi.dynamics.base import DynamicsStage from nvalchemi.dynamics.hooks.logging import LoggingHook -from nvalchemi.dynamics.hooks.profiling import ProfilerHook from nvalchemi.dynamics.hooks.safety import MaxForceClampHook, NaNDetectorHook from nvalchemi.dynamics.hooks.snapshot import SnapshotHook from nvalchemi.dynamics.sinks import HostMemory @@ -29,6 +28,7 @@ BiasedPotentialHook, DynamicsContext, NeighborListHook, + StageTimingHook, WrapPeriodicHook, ) from nvalchemi.models.base import NeighborConfig @@ -203,6 +203,27 @@ def test_dynamics_stage_wraps(self) -> None: assert batch.positions[1, 0].item() > 0.0 +# =========================================================================== +# StageTimingHook +# =========================================================================== + + +class TestStageTimingHook: + """StageTimingHook records transitions under DynamicsStage.""" + + def test_dynamics_stage_records_transition(self) -> None: + """A shared StageTimingHook can time dynamics stages.""" + hook = StageTimingHook( + {DynamicsStage.BEFORE_STEP, DynamicsStage.AFTER_STEP}, + enable_nvtx=False, + ) + batch = _make_batch() + ctx = _make_ctx(batch) + hook(ctx, DynamicsStage.BEFORE_STEP) + hook(ctx, DynamicsStage.AFTER_STEP) + assert hook.summary()["BEFORE_STEP->AFTER_STEP"]["n_samples"] == 1 + + # =========================================================================== # NeighborListHook # =========================================================================== @@ -230,23 +251,3 @@ def test_dynamics_stage_builds_neighbors(self) -> None: ctx = _make_ctx(batch) hook(ctx, DynamicsStage.BEFORE_COMPUTE) assert batch.neighbor_list is not None - - -# =========================================================================== -# ProfilerHook -# =========================================================================== - - -class TestProfilerHook: - """ProfilerHook records timing under DynamicsStage.""" - - def test_dynamics_stage_records(self) -> None: - """Timing is recorded under DynamicsStage.""" - profiler = ProfilerHook({DynamicsStage.BEFORE_STEP, DynamicsStage.AFTER_STEP}) - batch = _make_batch() - ctx = _make_ctx(batch, step_count=0) - profiler(ctx, DynamicsStage.BEFORE_STEP) - profiler(ctx, DynamicsStage.AFTER_STEP) - summary = profiler.summary() - assert "BEFORE_STEP->AFTER_STEP" in summary - assert summary["BEFORE_STEP->AFTER_STEP"]["n_samples"] == 1 diff --git a/test/hooks/test_stage_timing_hook.py b/test/hooks/test_stage_timing_hook.py new file mode 100644 index 00000000..a35ea9cb --- /dev/null +++ b/test/hooks/test_stage_timing_hook.py @@ -0,0 +1,358 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for StageTimingHook.""" + +from __future__ import annotations + +import csv +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from nvalchemi.data import AtomicData, Batch +from nvalchemi.dynamics.base import DynamicsStage +from nvalchemi.dynamics.demo import DemoDynamics +from nvalchemi.hooks import StageTimingHook, TrainContext +from nvalchemi.models.demo import DemoModel, DemoModelWrapper +from nvalchemi.training import TrainingStage + + +def _make_batch( + n_graphs: int = 2, atoms_per_graph: int = 3, device: str = "cpu" +) -> Batch: + data_list = [ + AtomicData( + atomic_numbers=torch.tensor([6] * atoms_per_graph, dtype=torch.long), + positions=torch.randn(atoms_per_graph, 3), + ) + for _ in range(n_graphs) + ] + batch = Batch.from_data_list(data_list).to(device) + batch.__dict__["forces"] = torch.randn(batch.num_nodes, 3, device=device) + batch.__dict__["energy"] = torch.randn(batch.num_graphs, 1, device=device) + batch.__dict__["velocities"] = torch.randn(batch.num_nodes, 3, device=device) * 0.01 + batch.__dict__["atomic_masses"] = torch.full( + (batch.num_nodes,), 12.0, device=device + ) + return batch + + +def _make_dynamics(hooks=None, n_steps: int = 5, device: str = "cpu") -> DemoDynamics: + model = DemoModelWrapper(DemoModel()) + if device != "cpu": + model = model.to(device) + return DemoDynamics( + model=model, n_steps=n_steps, dt=1.0, hooks=hooks, device_type=device + ) + + +# ------------------------------------------------------------------ +# Construction / presets +# ------------------------------------------------------------------ + + +class TestConstruction: + def test_step_preset(self) -> None: + profiler = StageTimingHook("step") + assert set(profiler._profiled_stages) == { + DynamicsStage.BEFORE_STEP, + DynamicsStage.AFTER_STEP, + } + + def test_detailed_preset(self) -> None: + profiler = StageTimingHook("detailed") + expected = { + DynamicsStage.BEFORE_STEP, + DynamicsStage.BEFORE_PRE_UPDATE, + DynamicsStage.AFTER_PRE_UPDATE, + DynamicsStage.BEFORE_COMPUTE, + DynamicsStage.AFTER_COMPUTE, + DynamicsStage.BEFORE_POST_UPDATE, + DynamicsStage.AFTER_POST_UPDATE, + DynamicsStage.AFTER_STEP, + } + assert set(profiler._profiled_stages) == expected + + def test_all_preset(self) -> None: + profiler = StageTimingHook("all") + assert DynamicsStage.ON_CONVERGE not in profiler._profiled_stages + assert len(profiler._profiled_stages) == len(DynamicsStage) - 1 + + def test_custom_stages(self) -> None: + S = DynamicsStage + custom = {S.BEFORE_COMPUTE, S.AFTER_COMPUTE} + profiler = StageTimingHook(custom) + assert set(profiler._profiled_stages) == custom + + def test_unknown_preset_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown stages preset"): + StageTimingHook("bogus") # type: ignore[arg-type] + + def test_single_stage_raises(self) -> None: + with pytest.raises(ValueError, match="At least two stages"): + StageTimingHook({DynamicsStage.BEFORE_STEP}) + + def test_stages_sorted_by_execution_order(self) -> None: + profiler = StageTimingHook("detailed") + values = [s.value for s in profiler._profiled_stages] + assert values == sorted(values) + + def test_training_stages_are_supported_explicitly(self) -> None: + """Explicit training stages can be timed through TrainContext.""" + hook = StageTimingHook( + {TrainingStage.BEFORE_BATCH, TrainingStage.AFTER_BATCH}, + enable_nvtx=False, + ) + batch = _make_batch() + ctx = TrainContext(batch=batch, step_count=7) + hook(ctx, TrainingStage.BEFORE_BATCH) + hook(ctx, TrainingStage.AFTER_BATCH) + assert hook.summary()["BEFORE_BATCH->AFTER_BATCH"]["n_samples"] == 1 + + +# ------------------------------------------------------------------ +# Registration +# ------------------------------------------------------------------ + + +class TestRegistration: + def test_registers_at_all_stages(self, device: str) -> None: + profiler = StageTimingHook("step") + dynamics = _make_dynamics(hooks=[profiler], n_steps=1, device=device) + assert profiler in dynamics.hooks + # Verify _runs_on_stage covers the expected stages + assert profiler._runs_on_stage(DynamicsStage.BEFORE_STEP) + assert profiler._runs_on_stage(DynamicsStage.AFTER_STEP) + + def test_does_not_register_at_other_stages(self, device: str) -> None: + profiler = StageTimingHook("step") + _make_dynamics(hooks=[profiler], n_steps=1, device=device) + assert not profiler._runs_on_stage(DynamicsStage.BEFORE_COMPUTE) + + def test_composable_with_other_hooks(self, device: str) -> None: + from nvalchemi.dynamics.hooks.safety import NaNDetectorHook + + profiler = StageTimingHook("step") + nan_hook = NaNDetectorHook() + dynamics = _make_dynamics(hooks=[profiler, nan_hook], n_steps=3, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + assert len(profiler.summary()) > 0 + + +# ------------------------------------------------------------------ +# CPU timing +# ------------------------------------------------------------------ + + +class TestTiming: + def test_records_values(self, device: str) -> None: + profiler = StageTimingHook("step") + dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + summary = profiler.summary() + key = "BEFORE_STEP->AFTER_STEP" + assert key in summary + assert summary[key]["n_samples"] == 5 + + def test_summary_keys(self, device: str) -> None: + profiler = StageTimingHook("step") + dynamics = _make_dynamics(hooks=[profiler], n_steps=3, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + summary = profiler.summary() + key = next(iter(summary)) + expected_keys = {"mean_s", "std_s", "min_s", "max_s", "total_s", "n_samples"} + assert set(summary[key].keys()) == expected_keys + + def test_positive_deltas(self, device: str) -> None: + profiler = StageTimingHook("step") + dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + for stats in profiler.summary().values(): + assert stats["mean_s"] >= 0 + assert stats["min_s"] >= 0 + + def test_frequency_gating(self, device: str) -> None: + profiler = StageTimingHook("step", frequency=3) + dynamics = _make_dynamics(hooks=[profiler], n_steps=9, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + summary = profiler.summary() + assert summary["BEFORE_STEP->AFTER_STEP"]["n_samples"] == 3 + + def test_detailed_timing(self, device: str) -> None: + profiler = StageTimingHook("detailed") + dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + summary = profiler.summary() + # 8 stages -> 7 transitions. + assert len(summary) == 7 + for stats in summary.values(): + assert stats["n_samples"] == 5 + + def test_reset(self, device: str) -> None: + profiler = StageTimingHook("step") + dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + assert len(profiler.summary()) > 0 + profiler.reset() + assert profiler.summary() == {} + + +# ------------------------------------------------------------------ +# Auto backend +# ------------------------------------------------------------------ + + +class TestAutoBackend: + def test_auto_selects_perf_counter_on_cpu(self) -> None: + profiler = StageTimingHook("step", timer_backend="auto") + dynamics = _make_dynamics(hooks=[profiler], n_steps=1, device="cpu") + batch = _make_batch(device="cpu") + dynamics.run(batch) + assert profiler._backend_resolved == "perf_counter" + + def test_auto_selects_cuda_event_on_gpu(self, gpu_device: str) -> None: + profiler = StageTimingHook("step", timer_backend="auto") + dynamics = _make_dynamics(hooks=[profiler], n_steps=1, device=gpu_device) + batch = _make_batch(device=gpu_device) + dynamics.run(batch) + assert profiler._backend_resolved == "cuda_event" + + +# ------------------------------------------------------------------ +# NVTX +# ------------------------------------------------------------------ + + +class TestNVTX: + def test_nvtx_push_pop_called(self) -> None: + try: + import nvtx # noqa: F401 + except ImportError: + pytest.skip("nvtx not available") + + with patch("nvalchemi.hooks.stage_timing.nvtx") as mock_nvtx: + mock_nvtx.push_range = MagicMock() + mock_nvtx.pop_range = MagicMock() + + profiler = StageTimingHook("step", enable_nvtx=True) + dynamics = _make_dynamics(hooks=[profiler], n_steps=1) + batch = _make_batch() + dynamics.run(batch) + + assert mock_nvtx.push_range.call_count >= 1 + assert mock_nvtx.pop_range.call_count >= 1 + + def test_nvtx_disabled(self) -> None: + with patch("nvalchemi.hooks.stage_timing.nvtx") as mock_nvtx: + mock_nvtx.push_range = MagicMock() + mock_nvtx.pop_range = MagicMock() + + profiler = StageTimingHook("step", enable_nvtx=False) + dynamics = _make_dynamics(hooks=[profiler], n_steps=1) + batch = _make_batch() + dynamics.run(batch) + + mock_nvtx.push_range.assert_not_called() + mock_nvtx.pop_range.assert_not_called() + + +# ------------------------------------------------------------------ +# CSV logging +# ------------------------------------------------------------------ + + +class TestCSVLogging: + def test_writes_csv(self, tmp_path, device: str) -> None: + log_file = tmp_path / "profiler.csv" + profiler = StageTimingHook("step", log_path=log_file) + dynamics = _make_dynamics(hooks=[profiler], n_steps=3, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + profiler.close() + + with open(log_file) as f: + rows = list(csv.DictReader(f)) + assert len(rows) == 3 + assert "rank" in rows[0] + assert "step" in rows[0] + assert "stage" in rows[0] + assert "t_since_init_s" in rows[0] + assert "delta_s" in rows[0] + + def test_detailed_csv_rows(self, tmp_path, device: str) -> None: + log_file = tmp_path / "detailed.csv" + profiler = StageTimingHook("detailed", log_path=log_file) + dynamics = _make_dynamics(hooks=[profiler], n_steps=2, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + profiler.close() + + with open(log_file) as f: + rows = list(csv.DictReader(f)) + # 8 stages -> 7 transitions per step, 2 steps -> 14 rows. + assert len(rows) == 14 + + +# ------------------------------------------------------------------ +# Console output +# ------------------------------------------------------------------ + + +class TestConsoleOutput: + def test_show_console(self, device: str) -> None: + with patch("nvalchemi.hooks.stage_timing.logger") as mock_logger: + profiler = StageTimingHook("step", show_console=True) + dynamics = _make_dynamics(hooks=[profiler], n_steps=2, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + assert mock_logger.info.call_count == 2 + + def test_console_frequency(self, device: str) -> None: + with patch("nvalchemi.hooks.stage_timing.logger") as mock_logger: + profiler = StageTimingHook( + "step", + show_console=True, + console_frequency=3, + ) + dynamics = _make_dynamics(hooks=[profiler], n_steps=9, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + assert mock_logger.info.call_count == 3 + + +# ------------------------------------------------------------------ +# Integration +# ------------------------------------------------------------------ + + +class TestIntegration: + def test_full_loop(self, device: str) -> None: + profiler = StageTimingHook("step") + dynamics = _make_dynamics(hooks=[profiler], n_steps=5, device=device) + batch = _make_batch(device=device) + dynamics.run(batch) + summary = profiler.summary() + assert len(summary) > 0 + for stats in summary.values(): + assert "mean_s" in stats + assert stats["n_samples"] == 5 diff --git a/test/models/test_mace.py b/test/models/test_mace.py index df11a722..733ee4fc 100644 --- a/test/models/test_mace.py +++ b/test/models/test_mace.py @@ -22,6 +22,8 @@ from __future__ import annotations +from types import SimpleNamespace + import pytest import torch @@ -31,6 +33,14 @@ from nvalchemi.data import AtomicData, Batch # noqa: E402 from nvalchemi.models.base import NeighborListFormat # noqa: E402 from nvalchemi.models.mace import MACEWrapper # noqa: E402 +from nvalchemi.training import EnergyMSELoss, ValidationConfig # noqa: E402 +from nvalchemi.training._stages import TrainingStage # noqa: E402 +from nvalchemi.training.hooks import EMAHook # noqa: E402 +from nvalchemi.training.optimizers import OptimizerConfig # noqa: E402 +from nvalchemi.training.strategy import ( # noqa: E402 + TrainingStrategy, + default_training_fn, +) # --------------------------------------------------------------------------- # Shared constants @@ -171,6 +181,20 @@ def _make_single_atom(device: str = "cpu") -> AtomicData: ) +def _make_ema_ctx( + model: torch.nn.Module, + *, + step_count: int, +) -> SimpleNamespace: + """Build the minimal TrainContext surface EMAHook reads.""" + return SimpleNamespace( + models={"main": model}, + step_count=step_count, + loss=None, + workflow=object(), + ) + + def _make_pbc_water(device: str = "cpu") -> AtomicData: """H2O in a periodic cubic box with integer neighbor_list_shifts on edges.""" positions = torch.tensor( @@ -635,6 +659,79 @@ def test_exported_model_matches_wrapper(self, wrapper, tmp_path): ) +# --------------------------------------------------------------------------- +# EMA checkpointing +# --------------------------------------------------------------------------- + + +class TestEMAIntegration: + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_strategy_checkpoint_round_trip_restores_ema_cuda_wrapper( + self, + tmp_path, + ) -> None: + """Strategy checkpoints restore EMA state into the runtime hook. + + User checkpoint restarts go through ``TrainingStrategy`` and + ``load_checkpoint`` rather than saving an EMA hook directly. This test + saves a strategy checkpoint with pending EMA weights loaded on CPU, then + restores them into a caller-supplied runtime hook and materializes the + averaged ``MACEWrapper`` against the restored CUDA model. + """ + device = torch.device("cuda", torch.cuda.current_device()) + source = MACEWrapper(MockMACEModel().to(device)) + batch = Batch.from_data_list([_make_water(device="cuda")]) + + ema = EMAHook(model_key="main", decay=0.0) + strategy = TrainingStrategy( + models=source, + optimizer_configs=OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + ), + loss_fn=EnergyMSELoss(), + num_steps=1, + devices=[device], + training_fn=default_training_fn, + hooks=[ema], + ) + + # Seed EMA before saving so the checkpoint contains hook-owned tensor state. + ema( + _make_ema_ctx(source, step_count=0), + TrainingStage.AFTER_OPTIMIZER_STEP, + ) + strategy.save_checkpoint(tmp_path) + + # Users restore strategy checkpoints through the strategy convenience API; + # hooks are runtime objects supplied fresh and hydrated by the loader. + restored_ema = EMAHook(model_key="main", decay=0.0) + restored_strategy = TrainingStrategy.load_checkpoint( + tmp_path, + map_location=device, + hooks=[restored_ema], + training_fn=default_training_fn, + ) + assert restored_ema._averaged_model is None + assert restored_ema._pending_averaged_state is not None + + # The pending EMA state is materialized lazily once the restored model exists. + restored_ema( + _make_ema_ctx(restored_strategy.models["main"], step_count=1), + TrainingStage.AFTER_OPTIMIZER_STEP, + ) + + averaged = restored_ema.get_averaged_model().module + assert isinstance(averaged, MACEWrapper) + assert averaged._node_emb.device == device + assert next(averaged.parameters()).device == device + + expected = source.forward(batch) + actual = averaged.forward(batch) + torch.testing.assert_close(actual["energy"], expected["energy"]) + torch.testing.assert_close(actual["forces"], expected["forces"]) + + # --------------------------------------------------------------------------- # from_checkpoint error path (no network required) # --------------------------------------------------------------------------- @@ -696,6 +793,15 @@ def _water_batch(dtype: torch.dtype = torch.float64, device: str = "cpu") -> Bat return Batch.from_data_list([data]) +def _water_batch_with_energy( + dtype: torch.dtype = torch.float32, device: str = "cpu" +) -> Batch: + """Return a water batch with a supervised energy target for training tests.""" + batch = _water_batch(dtype=dtype, device=device) + batch.energy = torch.zeros(1, 1, dtype=dtype, device=device) + return batch + + @pytest.fixture(scope="session") def real_wrapper_cpu(): """Load the MACE-MP small checkpoint once per session (requires network). @@ -848,6 +954,155 @@ def test_cueq_conversion(self): assert out["energy"].shape == (1, 1) assert out["forces"].shape == (3, 3) + def test_cueq_strategy_ema_checkpoint_round_trip(self, tmp_path): + """Strategy checkpoints restore MACE + cuEq models and EMA hook state. + + This follows the documented user restart path: a strategy owns a + MACEWrapper loaded from an existing checkpoint with cuEquivariance + enabled, saves a restartable checkpoint, and is reconstructed through + ``TrainingStrategy.load_checkpoint`` with a fresh runtime EMA hook. + """ + pytest.importorskip( + "cuequivariance", reason="cuequivariance not installed; skipping cuEq test" + ) + if not torch.cuda.is_available(): + pytest.skip("CUDA required for cuEquivariance EMA checkpoint test") + device = torch.device("cuda", torch.cuda.current_device()) + try: + source = MACEWrapper.from_checkpoint( + "small-0b", + device=device, + dtype=torch.float32, + enable_cueq=True, + ) + except Exception as e: + pytest.skip(f"Checkpoint unavailable or cuEq failed: {e}") + + ema = EMAHook(model_key="main", decay=0.0) + strategy = TrainingStrategy( + models=source, + optimizer_configs=OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + ), + loss_fn=EnergyMSELoss(), + num_steps=1, + devices=[device], + training_fn=default_training_fn, + hooks=[ema], + ) + ema( + _make_ema_ctx(source, step_count=0), + TrainingStage.AFTER_OPTIMIZER_STEP, + ) + strategy.save_checkpoint(tmp_path) + + restored_ema = EMAHook(model_key="main", decay=0.0) + restored = TrainingStrategy.load_checkpoint( + tmp_path, + map_location=device, + hooks=[restored_ema], + training_fn=default_training_fn, + ) + assert restored_ema._averaged_model is None + assert restored_ema._pending_averaged_state is not None + + restored_model = restored.models["main"] + restored_ema( + _make_ema_ctx(restored_model, step_count=restored.step_count + 1), + TrainingStage.AFTER_OPTIMIZER_STEP, + ) + + averaged = restored_ema.get_averaged_model().module + cpu_buffers = [name for name, buf in averaged.named_buffers() if buf.is_cpu] + assert cpu_buffers == [] + + batch = _water_batch(dtype=torch.float32, device="cuda") + expected = source.forward(batch) + actual = averaged.forward(batch) + torch.testing.assert_close(actual["energy"], expected["energy"]) + torch.testing.assert_close(actual["forces"], expected["forces"]) + + def test_cueq_strategy_ema_checkpoint_round_trip_after_optimizer_step( + self, tmp_path + ): + """Reloaded MACE + cuEq checkpoints validate through post-step EMA weights. + + The reported failure happens after a real training update, when + validation switches from the live cuEq model to the EMA-published + cuEq model. This test exercises that full lifecycle instead of + manually seeding the EMA hook state before checkpointing. + """ + pytest.importorskip( + "cuequivariance", reason="cuequivariance not installed; skipping cuEq test" + ) + if not torch.cuda.is_available(): + pytest.skip("CUDA required for cuEquivariance EMA checkpoint test") + device = torch.device("cuda", torch.cuda.current_device()) + try: + source = MACEWrapper.from_checkpoint( + "small-0b", + device=device, + dtype=torch.float32, + enable_cueq=True, + ) + except Exception as e: + pytest.skip(f"Checkpoint unavailable or cuEq failed: {e}") + + train_batch = _water_batch_with_energy(dtype=torch.float32, device="cuda") + val_batch = _water_batch_with_energy(dtype=torch.float32, device="cuda") + loss = EnergyMSELoss() + ema = EMAHook(model_key="main", decay=0.0) + strategy = TrainingStrategy( + models=source, + optimizer_configs=OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + ), + loss_fn=loss, + validation_config=ValidationConfig( + validation_data=[val_batch], + loss_fn=loss, + use_ema="always", + grad_mode="enabled", + ), + num_steps=1, + devices=[device], + training_fn=default_training_fn, + hooks=[ema], + ) + + # Run the real training lifecycle so optimizer state, updated model + # weights, EMA publication, and validation all happen in strategy order. + strategy.run([train_batch]) + assert strategy.inference_model is not None + assert strategy.last_validation is not None + assert strategy.last_validation["model_source"] == "ema" + strategy.save_checkpoint(tmp_path) + + restored_ema = EMAHook(model_key="main", decay=0.0) + restored = TrainingStrategy.load_checkpoint( + tmp_path, + map_location=device, + hooks=[restored_ema], + training_fn=default_training_fn, + ) + restored.validation_config = ValidationConfig( + validation_data=[val_batch], + loss_fn=loss, + use_ema="always", + grad_mode="enabled", + ) + + # The restored strategy has already reached num_steps=1, so run() + # executes SETUP hooks and returns before another optimizer step. That + # setup pass should materialize/publish restored EMA state for validation. + restored.run([train_batch]) + + summary = restored.validate() + assert summary is not None + assert summary["model_source"] == "ema" + def test_energy_and_forces_match_ase_calculator(self, real_wrapper_cpu, tmp_path): """MACEWrapper E+F must agree with the MACE ASE MACECalculator. diff --git a/test/models/test_pipeline.py b/test/models/test_pipeline.py index da7bb157..83803975 100644 --- a/test/models/test_pipeline.py +++ b/test/models/test_pipeline.py @@ -25,6 +25,7 @@ from __future__ import annotations +import copy from collections import OrderedDict import pytest @@ -572,6 +573,29 @@ def test_autograd_does_not_mutate_sub_model_config(self, simple_batch): assert a.model_config.active_outputs == {"energy", "forces"} assert b.model_config.active_outputs == {"energy", "forces"} + def test_deepcopy_refreshes_step_caches_on_forward(self, simple_batch): + """Deep-copied pipelines rebuild id-keyed step caches on first forward.""" + model = MockAutogradEnergyModel() + model.model_config.active_outputs = {"energy", "forces"} + pipe = PipelineModelWrapper( + groups=[PipelineGroup(steps=[model], use_autograd=True)] + ) + original_step_id = id(pipe.groups[0].steps[0]) + assert original_step_id in pipe._step_needs_neighbor_adapt + assert "forces" not in pipe._step_active_overrides[original_step_id] + + ema_pipe = copy.deepcopy(pipe) + copied_step_id = id(ema_pipe.groups[0].steps[0]) + assert copied_step_id != original_step_id + assert copied_step_id not in ema_pipe._step_needs_neighbor_adapt + + ema_pipe.model_config.active_outputs = {"energy", "forces"} + ema_pipe(simple_batch) + + assert copied_step_id in ema_pipe._step_needs_neighbor_adapt + assert "forces" not in ema_pipe._step_active_overrides[copied_step_id] + assert model.model_config.active_outputs == {"energy", "forces"} + class TestPipelineDependentAutograd: """Case 2b: A predicts charges+energy, B uses charges for energy. @@ -988,6 +1012,24 @@ def test_tighter_cutoff_filters_matrix(self): f"atom {atom_idx} should not see atom 3 at cutoff 4" ) + def test_deepcopy_preserves_neighbor_adaptation(self): + """Deep-copied pipelines still filter neighbors per sub-model cutoff.""" + wide = _MatrixModel10() + tight = _MatrixModel4() + pipe = PipelineModelWrapper(groups=[PipelineGroup(steps=[wide, tight])]) + tight_step_id = id(pipe.groups[0].steps[1]) + assert pipe._step_needs_neighbor_adapt[tight_step_id] is True + + copied = copy.deepcopy(pipe) + tight_copy = copied.groups[0].steps[1].model + assert id(copied.groups[0].steps[1]) not in copied._step_needs_neighbor_adapt + + batch = _make_neighbor_batch() + copied(batch) + + expected_nn = torch.tensor([2, 2, 2, 0], dtype=torch.int32) + torch.testing.assert_close(tight_copy.captured_num_neighbors, expected_nn) + def test_matrix_to_coo_conversion(self): """COO model in a MATRIX pipeline receives converted neighbor list.""" matrix_model = _MatrixModel10() diff --git a/test/training/__init__.py b/test/training/__init__.py new file mode 100644 index 00000000..6b377ba5 --- /dev/null +++ b/test/training/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations diff --git a/test/training/conftest.py b/test/training/conftest.py new file mode 100644 index 00000000..499bd9b4 --- /dev/null +++ b/test/training/conftest.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared fixtures and builders for ``test/training/``. + +Fixtures are pure-value — they return built objects, not callables. +Tests that need non-default variants either import the ``_build_*`` +helpers directly or construct their objects inline. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest +import torch + +from nvalchemi.data import AtomicData, Batch +from nvalchemi.training import EnergyMSELoss, ForceMSELoss +from nvalchemi.training.optimizers import OptimizerConfig +from nvalchemi.training.strategy import TrainingStrategy + +if TYPE_CHECKING: + from nvalchemi.models.base import BaseModelMixin + + +@pytest.fixture(autouse=True) +def _seed_torch() -> None: + """Seed ``torch`` (and CUDA, when visible) to ``0`` before every test.""" + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + +def _build_atomic_data(n_atoms: int = 3, seed: int = 0) -> AtomicData: + g = torch.Generator().manual_seed(seed) + positions = torch.randn(n_atoms, 3, generator=g) + atomic_numbers = torch.randint(1, 10, (n_atoms,), dtype=torch.long, generator=g) + energy = torch.randn(1, 1, generator=g) + forces = torch.randn(n_atoms, 3, generator=g) + return AtomicData( + positions=positions, + atomic_numbers=atomic_numbers, + atomic_masses=torch.ones(n_atoms), + energy=energy, + forces=forces, + ) + + +def _build_batch(n_systems: int = 2, n_atoms_each: int = 3, seed: int = 0) -> Batch: + data_list = [ + _build_atomic_data(n_atoms_each, seed=seed + i) for i in range(n_systems) + ] + return Batch.from_data_list(data_list) + + +def _build_dataset( + n_batches: int = 3, + n_systems: int = 2, + n_atoms_each: int = 3, + base_seed: int = 100, +) -> list[Batch]: + return [ + _build_batch( + n_systems=n_systems, + n_atoms_each=n_atoms_each, + seed=base_seed + i * 10, + ) + for i in range(n_batches) + ] + + +def _build_demo_model() -> Any: + from nvalchemi.models.demo import DemoModel, DemoModelWrapper + + torch.manual_seed(0) + return DemoModelWrapper(DemoModel(num_atom_types=20, hidden_dim=8)) + + +def _build_adam_optimizer_configs( + lr: float = 1e-3, +) -> dict[str, list[OptimizerConfig]]: + return { + "main": [ + OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": lr}, + ) + ] + } + + +def _build_baseline_strategy_kwargs( + models: BaseModelMixin | dict[str, BaseModelMixin] | None = None, +) -> dict[str, Any]: + # Import locally so identity is preserved for spec round-trip tests. + from test.training.test_strategy import demo_training_fn + + if models is None: + models = _build_demo_model() + return { + "models": models, + "optimizer_configs": OptimizerConfig(optimizer_cls=torch.optim.Adam), + "num_epochs": 1, + "training_fn": demo_training_fn, + "loss_fn": EnergyMSELoss() + ForceMSELoss(normalize_by_atom_count=True), + } + + +@pytest.fixture +def atomic_data() -> AtomicData: + """Return a default :class:`AtomicData` — 3 atoms, ``seed=0``.""" + return _build_atomic_data() + + +@pytest.fixture +def batch() -> Batch: + """Return a default :class:`Batch` — 2 systems, 3 atoms each, ``seed=0``.""" + return _build_batch() + + +@pytest.fixture +def dataset() -> list[Batch]: + """Return a default dataset of 3 batches (``base_seed=100``).""" + return _build_dataset() + + +@pytest.fixture +def demo_model() -> Any: + """Return a freshly-seeded :class:`DemoModelWrapper`.""" + return _build_demo_model() + + +@pytest.fixture +def adam_optimizer_configs() -> dict[str, list[OptimizerConfig]]: + """Return a default Adam :class:`OptimizerConfig` mapping keyed by ``main``.""" + return _build_adam_optimizer_configs() + + +@pytest.fixture +def baseline_strategy_kwargs(demo_model: Any) -> dict[str, Any]: + """Return default kwargs suitable for ``TrainingStrategy(**kwargs)``.""" + return _build_baseline_strategy_kwargs(models=demo_model) + + +@pytest.fixture +def strategy(baseline_strategy_kwargs: dict[str, Any]) -> TrainingStrategy: + """Return a default :class:`TrainingStrategy` built from ``baseline_strategy_kwargs``.""" + return TrainingStrategy(**baseline_strategy_kwargs) diff --git a/test/training/test_checkpoint.py b/test/training/test_checkpoint.py new file mode 100644 index 00000000..dd7420bc --- /dev/null +++ b/test/training/test_checkpoint.py @@ -0,0 +1,1559 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for :mod:`nvalchemi.training._checkpoint`.""" + +from __future__ import annotations + +import ast +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn + +from nvalchemi.data import AtomicData, Batch +from nvalchemi.models.base import BaseModelMixin +from nvalchemi.training import EnergyMSELoss, TrainingStage +from nvalchemi.training._checkpoint import ( + CheckpointManifest, + load_checkpoint, + save_checkpoint, +) +from nvalchemi.training._spec import create_model_spec +from nvalchemi.training.optimizers import OptimizerConfig +from nvalchemi.training.strategy import TrainingStrategy, default_training_fn + +# --------------------------------------------------------------------------- +# Helper classes (reused from original file) +# --------------------------------------------------------------------------- + + +class SwiGLU(nn.Module): + """Custom activation: SwiGLU-style gated activation with a learnable scale. + + Splits the input channel dimension in half, applies SiLU to one half and + multiplies by the other, then scales by a learnable parameter. Exercises + :func:`create_model_spec`/:func:`save_checkpoint` against a module that + owns its own :class:`~torch.nn.Parameter` rather than delegating to a + stock layer. + """ + + def __init__(self, init_scale: float = 1.0) -> None: + super().__init__() + self.scale = nn.Parameter(torch.tensor(float(init_scale))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a, b = x.chunk(2, dim=-1) + return self.scale * (a * torch.nn.functional.silu(b)) + + +class CustomMLPBlock(nn.Module): + """Pre-norm MLP with a custom activation, a SwiGLU expansion, and dropout. + + The block is deliberately non-trivial: it stacks :class:`LayerNorm`, an + expansion :class:`Linear` that feeds :class:`SwiGLU` (which halves the + channel dimension), a projection :class:`Linear`, and :class:`Dropout`, + with an optional residual connection. It stress tests the spec layer in + three ways: + + 1. ``__init__`` takes a mix of ints, floats, booleans, and a + :class:`torch.dtype` --- the latter routed through the custom type + serializer registry. + 2. The module owns parameters at multiple nesting depths (top-level + :class:`Linear` weights, plus the :class:`SwiGLU` scale parameter). + 3. The forward pass is stateless w.r.t. shape up to the channel + dimension, so round-tripped weights must reproduce outputs exactly + (modulo dropout, which we disable by calling ``.eval()``). + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + dropout: float = 0.1, + eps: float = 1e-5, + activation_scale: float = 1.0, + use_residual: bool = True, + dtype: torch.dtype = torch.float32, + ) -> None: + super().__init__() + if hidden_features % 2 != 0: + raise ValueError( + f"hidden_features must be even for SwiGLU, got {hidden_features}" + ) + self.in_features = in_features + self.hidden_features = hidden_features + self.use_residual = use_residual + + self.norm = nn.LayerNorm(in_features, eps=eps, dtype=dtype) + self.expand = nn.Linear(in_features, hidden_features, dtype=dtype) + self.activation = SwiGLU(init_scale=activation_scale) + self.project = nn.Linear(hidden_features // 2, in_features, dtype=dtype) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm(x) + h = self.expand(h) + h = self.activation(h) + h = self.project(h) + h = self.dropout(h) + return x + h if self.use_residual else h + + +@dataclass +class NotModule: + """Non-module class used to verify load rejects non-nn.Module specs.""" + + arg_a: int + arg_b: str + + +def checkpoint_training_fn( + model: BaseModelMixin, + batch: Batch, +) -> dict[str, torch.Tensor]: + """Importable training function used by strategy checkpoint tests.""" + return default_training_fn(model, batch) + + +class _NoOpCheckpointHook: + """Simple observer hook used by checkpoint restart tests.""" + + stage = TrainingStage.BEFORE_BATCH + frequency = 1 + + def __call__(self, ctx: Any, stage: TrainingStage) -> None: + """Observe a training stage without mutating state.""" + del ctx, stage + + +def _make_checkpoint_batch(n_atoms: int = 3, seed: int = 0) -> Batch: + """Build a small batch with energy targets for strategy checkpoint tests.""" + generator = torch.Generator().manual_seed(seed) + data = AtomicData( + positions=torch.randn(n_atoms, 3, generator=generator), + atomic_numbers=torch.randint( + 1, 10, (n_atoms,), dtype=torch.long, generator=generator + ), + atomic_masses=torch.ones(n_atoms), + energy=torch.randn(1, 1, generator=generator), + ) + return Batch.from_data_list([data]) + + +def _make_checkpoint_strategy(num_steps: int = 4) -> TrainingStrategy: + """Create a serializable demo training strategy for checkpoint tests.""" + from nvalchemi.models.demo import DemoModel, DemoModelWrapper + + torch.manual_seed(0) + model = DemoModelWrapper(DemoModel(num_atom_types=20, hidden_dim=8)) + return TrainingStrategy( + models=model, + optimizer_configs=OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + scheduler_cls=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 2, "gamma": 0.5}, + ), + num_steps=num_steps, + training_fn=checkpoint_training_fn, + loss_fn=EnergyMSELoss(), + devices=[torch.device("cpu")], + ) + + +def _make_multi_optimizer_checkpoint_strategy() -> TrainingStrategy: + """Create a serializable strategy with two optimizer/scheduler pairs.""" + from nvalchemi.models.demo import DemoModel, DemoModelWrapper + + torch.manual_seed(0) + model = DemoModelWrapper(DemoModel(num_atom_types=20, hidden_dim=8)) + return TrainingStrategy( + models=model, + optimizer_configs=[ + OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + scheduler_cls=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 2, "gamma": 0.5}, + ), + OptimizerConfig( + optimizer_cls=torch.optim.SGD, + optimizer_kwargs={"lr": 1e-2}, + scheduler_cls=torch.optim.lr_scheduler.ExponentialLR, + scheduler_kwargs={"gamma": 0.9}, + ), + ], + num_steps=1, + training_fn=checkpoint_training_fn, + loss_fn=EnergyMSELoss(), + devices=[torch.device("cpu")], + ) + + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + + +class TestSaveCheckpointSingleModel: + """Basic dict-based single-model checkpoint save/load behavior.""" + + def test_save_creates_manifest_and_model_dir(self, tmp_path: Path) -> None: + """Directory layout: manifest.json + models/{name}/ created on save.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + assert (tmp_path / "manifest.json").is_file() + assert (tmp_path / "models" / "main" / "spec.json").is_file() + assert (tmp_path / "models" / "main" / "checkpoints" / "0.pt").is_file() + + def test_save_load_basic_roundtrip(self, tmp_path: Path) -> None: + """Save one model, load it, verify weights match.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + idx = save_checkpoint(tmp_path, models={"main": (model, spec)}) + assert idx == 0 + + result = load_checkpoint(tmp_path) + assert isinstance(result, CheckpointManifest) + assert "main" in result.models + reloaded, reloaded_spec = result.models["main"] + assert isinstance(reloaded, nn.Linear) + assert torch.equal(reloaded.weight, model.weight) + assert torch.equal(reloaded.bias, model.bias) + assert reloaded_spec.timestamp == spec.timestamp + + def test_autoincrement_from_manifest(self, tmp_path: Path) -> None: + """Three sequential saves auto-increment indices 0, 1, 2.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + idx0 = save_checkpoint(tmp_path, models={"main": (model, spec)}) + idx1 = save_checkpoint(tmp_path, models={"main": (model, spec)}) + idx2 = save_checkpoint(tmp_path, models={"main": (model, spec)}) + assert (idx0, idx1, idx2) == (0, 1, 2) + + def test_explicit_index(self, tmp_path: Path) -> None: + """Explicit ``checkpoint_index=5`` writes to ``checkpoints/5.pt``.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + idx = save_checkpoint( + tmp_path, models={"main": (model, spec)}, checkpoint_index=5 + ) + assert idx == 5 + assert (tmp_path / "models" / "main" / "checkpoints" / "5.pt").is_file() + + def test_spec_consistency_check_raises_on_mismatch(self, tmp_path: Path) -> None: + """Saving a different spec under the same model name raises ValueError.""" + model_a = nn.Linear(4, 2) + spec_a = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model_a, spec_a)}) + + model_b = nn.Linear(8, 2) + spec_b = create_model_spec(nn.Linear, in_features=8, out_features=2) + with pytest.raises(ValueError, match="in_features"): + save_checkpoint(tmp_path, models={"main": (model_b, spec_b)}) + + def test_load_latest_from_manifest(self, tmp_path: Path) -> None: + """Default load returns the latest checkpoint from the manifest.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + for i in (0, 2, 5): + save_checkpoint( + tmp_path, models={"main": (model, spec)}, checkpoint_index=i + ) + + # Overwrite index 5 with mutated weights. + mutated = nn.Linear(4, 2) + with torch.no_grad(): + mutated.weight.copy_(mutated.weight + 100.0) + save_checkpoint(tmp_path, models={"main": (mutated, spec)}, checkpoint_index=5) + + result = load_checkpoint(tmp_path) + reloaded, _ = result.models["main"] + assert torch.allclose(reloaded.weight, mutated.weight) + + def test_load_explicit_index(self, tmp_path: Path) -> None: + """Loading specific checkpoint indices returns the correct weights.""" + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + + model_a = nn.Linear(4, 2) + model_b = nn.Linear(4, 2) + with torch.no_grad(): + model_b.weight.copy_(model_b.weight + 10.0) + + save_checkpoint(tmp_path, models={"main": (model_a, spec)}, checkpoint_index=1) + save_checkpoint(tmp_path, models={"main": (model_b, spec)}, checkpoint_index=2) + + loaded_a = load_checkpoint(tmp_path, checkpoint_index=1).models["main"][0] + loaded_b = load_checkpoint(tmp_path, checkpoint_index=2).models["main"][0] + assert torch.allclose(loaded_a.weight, model_a.weight) + assert torch.allclose(loaded_b.weight, model_b.weight) + + def test_load_missing_manifest_raises(self, tmp_path: Path) -> None: + """FileNotFoundError when no manifest.json exists.""" + with pytest.raises(FileNotFoundError, match="manifest.json"): + load_checkpoint(tmp_path) + + def test_non_module_build_raises(self, tmp_path: Path) -> None: + """Spec for a non-nn.Module class raises RuntimeError on load.""" + spec = create_model_spec(NotModule, arg_a=5, arg_b="hello") + + # Manually stage the directory layout so load_checkpoint can parse it. + model_dir = tmp_path / "models" / "main" + ckpt_dir = model_dir / "checkpoints" + ckpt_dir.mkdir(parents=True) + (model_dir / "spec.json").write_text(spec.model_dump_json(indent=2)) + torch.save({}, ckpt_dir / "0.pt") + manifest = { + "schema_version": 1, + "checkpoint_index": 0, + "models": ["main"], + "optimizers": [], + "schedulers": [], + "associations": {}, + } + (tmp_path / "manifest.json").write_text(json.dumps(manifest)) + + with pytest.raises(RuntimeError, match="expected nn.Module"): + load_checkpoint(tmp_path) + + def test_load_weights_only_true(self, tmp_path: Path) -> None: + """Every ``torch.load`` call uses ``weights_only=True``.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + import nvalchemi.training._checkpoint as ckpt_mod + + real_load = ckpt_mod.torch.load + with patch.object(ckpt_mod.torch, "load", wraps=real_load) as mock_load: + load_checkpoint(tmp_path) + + assert mock_load.call_count >= 1 + for call in mock_load.call_args_list: + assert call.kwargs.get("weights_only") is True, ( + f"torch.load called without weights_only=True: {call}" + ) + + +class TestMultiModel: + """Two or more named models in a single checkpoint.""" + + @staticmethod + def _save_student_teacher( + tmp_path: Path, + ) -> tuple[nn.Module, nn.Module]: + """Save a student/teacher pair and return the original modules.""" + student = nn.Linear(4, 2) + teacher = nn.Linear(4, 2) + with torch.no_grad(): + teacher.weight.copy_(teacher.weight + 5.0) + s_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + t_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint( + tmp_path, + models={"student": (student, s_spec), "teacher": (teacher, t_spec)}, + ) + return student, teacher + + def test_save_load_two_models(self, tmp_path: Path) -> None: + """Both student and teacher round-trip correctly.""" + student, teacher = self._save_student_teacher(tmp_path) + result = load_checkpoint(tmp_path) + assert set(result.models) == {"student", "teacher"} + + loaded_student, _ = result.models["student"] + loaded_teacher, _ = result.models["teacher"] + assert torch.equal(loaded_student.weight, student.weight) + assert torch.equal(loaded_teacher.weight, teacher.weight) + + def test_models_have_independent_weights(self, tmp_path: Path) -> None: + """Perturbed models are distinguishable after load.""" + student, teacher = self._save_student_teacher(tmp_path) + result = load_checkpoint(tmp_path) + loaded_s = result.models["student"][0] + loaded_t = result.models["teacher"][0] + # The teacher was shifted by +5.0 so they should differ. + assert not torch.equal(loaded_s.weight, loaded_t.weight) + + def test_model_names_in_manifest(self, tmp_path: Path) -> None: + """Manifest lists both model names.""" + self._save_student_teacher(tmp_path) + manifest = json.loads((tmp_path / "manifest.json").read_text()) + assert sorted(manifest["models"]) == ["student", "teacher"] + + def test_model_subdirectories_exist(self, tmp_path: Path) -> None: + """Per-model subdirectories are created under ``models/``.""" + self._save_student_teacher(tmp_path) + assert (tmp_path / "models" / "student").is_dir() + assert (tmp_path / "models" / "teacher").is_dir() + + +class TestOptimizerCheckpoint: + """Optimizer state round-trip through the checkpoint layer.""" + + @staticmethod + def _train_steps( + model: nn.Module, optimizer: torch.optim.Optimizer, n_steps: int + ) -> None: + """Run *n_steps* fake training steps to build up optimizer state.""" + for _ in range(n_steps): + x = torch.randn(2, model.in_features) + loss = model(x).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + def test_save_load_optimizer_state_dict(self, tmp_path: Path) -> None: + """Optimizer state_dict round-trips through save/load.""" + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + opt_spec = create_model_spec(torch.optim.SGD, lr=0.01, momentum=0.9) + + self._train_steps(model, optimizer, 3) + original_state = optimizer.state_dict() + + save_checkpoint( + tmp_path, + models={"main": (model, m_spec)}, + optimizers={"opt": (optimizer, opt_spec)}, + ) + result = load_checkpoint(tmp_path) + loaded_opt, _ = result.optimizers["opt"] + loaded_state = loaded_opt.state_dict() + + # Compare param_groups (excluding 'params' which are tensor ids). + for orig_pg, loaded_pg in zip( + original_state["param_groups"], loaded_state["param_groups"] + ): + for key in ("lr", "momentum", "weight_decay"): + assert orig_pg[key] == loaded_pg[key] + + def test_optimizer_param_groups_preserved(self, tmp_path: Path) -> None: + """LR, momentum, weight_decay survive round-trip.""" + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD( + model.parameters(), lr=0.05, momentum=0.8, weight_decay=1e-4 + ) + opt_spec = create_model_spec( + torch.optim.SGD, lr=0.05, momentum=0.8, weight_decay=1e-4 + ) + self._train_steps(model, optimizer, 1) + + save_checkpoint( + tmp_path, + models={"main": (model, m_spec)}, + optimizers={"opt": (optimizer, opt_spec)}, + ) + result = load_checkpoint(tmp_path) + loaded_pg = result.optimizers["opt"][0].param_groups[0] + assert loaded_pg["lr"] == pytest.approx(0.05) + assert loaded_pg["momentum"] == pytest.approx(0.8) + assert loaded_pg["weight_decay"] == pytest.approx(1e-4) + + def test_optimizer_step_state_preserved(self, tmp_path: Path) -> None: + """Momentum buffers match: original and reloaded produce same results.""" + torch.manual_seed(42) + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + opt_spec = create_model_spec(torch.optim.SGD, lr=0.01, momentum=0.9) + + self._train_steps(model, optimizer, 5) + + save_checkpoint( + tmp_path, + models={"main": (model, m_spec)}, + optimizers={"opt": (optimizer, opt_spec)}, + ) + result = load_checkpoint(tmp_path) + loaded_model, _ = result.models["main"] + loaded_opt, _ = result.optimizers["opt"] + + # Run M more steps on both and verify weights converge identically. + torch.manual_seed(99) + inputs = [torch.randn(2, 4) for _ in range(3)] + + for x in inputs: + loss = model(x).sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + for x in inputs: + loss = loaded_model(x).sum() + loss.backward() + loaded_opt.step() + loaded_opt.zero_grad() + + for p_orig, p_loaded in zip(model.parameters(), loaded_model.parameters()): + assert torch.allclose(p_orig, p_loaded, atol=1e-6) + + +class TestSchedulerCheckpoint: + """Scheduler state round-trip through the checkpoint layer.""" + + def test_save_load_scheduler_state_dict(self, tmp_path: Path) -> None: + """CosineAnnealingLR state_dict round-trips.""" + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + opt_spec = create_model_spec(torch.optim.SGD, lr=0.1) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + sched_spec = create_model_spec( + torch.optim.lr_scheduler.CosineAnnealingLR, T_max=10 + ) + + for _ in range(5): + scheduler.step() + original_state = scheduler.state_dict() + + associations = {"main": {"optimizers": ["opt"], "schedulers": ["sched"]}} + save_checkpoint( + tmp_path, + models={"main": (model, m_spec)}, + optimizers={"opt": (optimizer, opt_spec)}, + schedulers={"sched": (scheduler, sched_spec)}, + associations=associations, + ) + result = load_checkpoint(tmp_path) + loaded_sched, _ = result.schedulers["sched"] + loaded_state = loaded_sched.state_dict() + + for key in original_state: + assert original_state[key] == loaded_state[key], ( + f"scheduler state key {key!r} differs" + ) + + def test_lr_trajectory_preserved(self, tmp_path: Path) -> None: + """LR trajectory matches after reload: step N, save, reload, step M more.""" + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + opt_spec = create_model_spec(torch.optim.SGD, lr=0.1) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) + sched_spec = create_model_spec( + torch.optim.lr_scheduler.CosineAnnealingLR, T_max=20 + ) + + # Step N=5 times before save. + for _ in range(5): + scheduler.step() + + associations = {"main": {"optimizers": ["opt"], "schedulers": ["sched"]}} + save_checkpoint( + tmp_path, + models={"main": (model, m_spec)}, + optimizers={"opt": (optimizer, opt_spec)}, + schedulers={"sched": (scheduler, sched_spec)}, + associations=associations, + ) + result = load_checkpoint(tmp_path) + loaded_sched, _ = result.schedulers["sched"] + + # Step M=10 more times on both; LR must match at every step. + for step in range(10): + scheduler.step() + loaded_sched.step() + lr_orig = scheduler.get_last_lr()[0] + lr_loaded = loaded_sched.get_last_lr()[0] + assert lr_orig == pytest.approx(lr_loaded), ( + f"LR mismatch at step {step}: {lr_orig} vs {lr_loaded}" + ) + + +class TestAssociations: + """Model-to-optimizer-to-scheduler linkage via associations.""" + + def test_associations_stored_in_manifest(self, tmp_path: Path) -> None: + """Associations dict appears in manifest.json.""" + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + opt_spec = create_model_spec(torch.optim.SGD, lr=0.01) + + associations = {"main": {"optimizers": ["opt"], "schedulers": []}} + save_checkpoint( + tmp_path, + models={"main": (model, m_spec)}, + optimizers={"opt": (optimizer, opt_spec)}, + associations=associations, + ) + manifest = json.loads((tmp_path / "manifest.json").read_text()) + assert manifest["associations"] == associations + + def test_load_wires_optimizer_to_correct_model(self, tmp_path: Path) -> None: + """Optimizer param groups reference the associated model's parameters.""" + student = nn.Linear(4, 2) + teacher = nn.Linear(4, 2) + s_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + t_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD(student.parameters(), lr=0.01) + opt_spec = create_model_spec(torch.optim.SGD, lr=0.01) + + associations = {"student": {"optimizers": ["s_opt"], "schedulers": []}} + save_checkpoint( + tmp_path, + models={ + "student": (student, s_spec), + "teacher": (teacher, t_spec), + }, + optimizers={"s_opt": (optimizer, opt_spec)}, + associations=associations, + ) + result = load_checkpoint(tmp_path) + loaded_opt, _ = result.optimizers["s_opt"] + loaded_student, _ = result.models["student"] + + # The optimizer's param groups should reference the loaded student's + # parameters (same data pointers). + opt_param_ids = {id(p) for pg in loaded_opt.param_groups for p in pg["params"]} + student_param_ids = {id(p) for p in loaded_student.parameters()} + assert opt_param_ids == student_param_ids + + def test_load_wires_scheduler_to_correct_optimizer(self, tmp_path: Path) -> None: + """Scheduler references the correct optimizer on load.""" + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + opt_spec = create_model_spec(torch.optim.SGD, lr=0.1) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + sched_spec = create_model_spec( + torch.optim.lr_scheduler.CosineAnnealingLR, T_max=10 + ) + + associations = {"main": {"optimizers": ["opt"], "schedulers": ["sched"]}} + save_checkpoint( + tmp_path, + models={"main": (model, m_spec)}, + optimizers={"opt": (optimizer, opt_spec)}, + schedulers={"sched": (scheduler, sched_spec)}, + associations=associations, + ) + result = load_checkpoint(tmp_path) + loaded_sched, _ = result.schedulers["sched"] + loaded_opt, _ = result.optimizers["opt"] + + # The scheduler's internal optimizer should be the loaded one. + assert loaded_sched.optimizer is loaded_opt + + def test_single_model_fallback_no_associations(self, tmp_path: Path) -> None: + """One model + one optimizer, no associations: load succeeds via fallback.""" + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + opt_spec = create_model_spec(torch.optim.SGD, lr=0.01) + + # No associations at all. + save_checkpoint( + tmp_path, + models={"main": (model, m_spec)}, + optimizers={"opt": (optimizer, opt_spec)}, + ) + result = load_checkpoint(tmp_path) + assert "opt" in result.optimizers + + def test_scheduler_attaches_to_second_optimizer(self, tmp_path: Path) -> None: + """Regression: scheduler must wrap the optimizer it was saved with. + + With multiple optimizers on the same model, the scheduler must attach + to the specific optimizer it was constructed with, not simply the + first optimizer in the manifest's association list. Auto-inference + cannot disambiguate two optimizers sharing parameters, so explicit + associations are required to pin the scheduler to the SGD optimizer. + """ + model = nn.Linear(4, 2) + m_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + + # Two optimizers with different classes to make identity unambiguous. + opt_adam = torch.optim.Adam(model.parameters(), lr=0.01) + opt_sgd = torch.optim.SGD(model.parameters(), lr=0.1) + adam_spec = create_model_spec(torch.optim.Adam, lr=0.01) + sgd_spec = create_model_spec(torch.optim.SGD, lr=0.1) + + # Scheduler on the SGD optimizer (NOT Adam). + scheduler = torch.optim.lr_scheduler.StepLR(opt_sgd, step_size=10, gamma=0.5) + sched_spec = create_model_spec( + torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5 + ) + + # Explicit associations: the scheduler must wrap ``sgd``. Listing + # ``sgd`` first among optimizers ensures the load-time wiring logic + # picks it as the scheduler's optimizer. + associations = { + "m": {"optimizers": ["sgd", "adam"], "schedulers": ["step"]}, + } + save_checkpoint( + tmp_path, + models={"m": (model, m_spec)}, + optimizers={"adam": (opt_adam, adam_spec), "sgd": (opt_sgd, sgd_spec)}, + schedulers={"step": (scheduler, sched_spec)}, + associations=associations, + ) + + result = load_checkpoint(tmp_path) + loaded_scheduler, _ = result.schedulers["step"] + loaded_sgd, _ = result.optimizers["sgd"] + loaded_adam, _ = result.optimizers["adam"] + + assert loaded_scheduler.optimizer is loaded_sgd, ( + "scheduler should wrap the SGD optimizer it was saved with" + ) + assert loaded_scheduler.optimizer is not loaded_adam + + # Verify the manifest recorded the association correctly. + assert "step" in result.associations.get("m", {}).get("schedulers", []) + + +class TestCheckpointCustomMLPBlock: + """Stress tests: serialize a non-trivial custom block end-to-end. + + The target is :class:`CustomMLPBlock` --- a pre-norm MLP wrapping a custom + :class:`SwiGLU` activation, an expansion/projection :class:`Linear` pair, + :class:`LayerNorm`, and :class:`Dropout`. These tests exercise the spec + + checkpoint pipeline against a module that mixes several param types + (weights, bias, learnable activation scale, LayerNorm affine params) and + several kwarg types (int, float, bool, :class:`torch.dtype`). + """ + + @staticmethod + def _make_spec(**overrides: object) -> tuple[dict[str, object], object]: + """Build default CustomMLPBlock kwargs + spec with optional overrides.""" + kwargs: dict[str, object] = { + "in_features": 8, + "hidden_features": 16, + "dropout": 0.25, + "eps": 1e-6, + "activation_scale": 0.5, + "use_residual": True, + "dtype": torch.float32, + } + kwargs.update(overrides) + return kwargs, create_model_spec(CustomMLPBlock, **kwargs) + + def test_roundtrip_preserves_all_params(self, tmp_path: Path) -> None: + """All named parameters survive a save/load round-trip bit-exactly.""" + kwargs, spec = self._make_spec() + model = CustomMLPBlock(**kwargs) + with torch.no_grad(): + for p in model.parameters(): + p.add_(torch.randn_like(p) * 0.1) + + save_checkpoint(tmp_path, models={"main": (model, spec)}) + result = load_checkpoint(tmp_path) + reloaded, reloaded_spec = result.models["main"] + assert isinstance(reloaded, CustomMLPBlock) + assert reloaded_spec.timestamp == spec.timestamp + + original_params = dict(model.named_parameters()) + reloaded_params = dict(reloaded.named_parameters()) + assert set(original_params) == set(reloaded_params) + for name, tensor in original_params.items(): + assert torch.equal(reloaded_params[name], tensor), ( + f"parameter {name!r} differs after round-trip" + ) + + def test_roundtrip_preserves_forward_output(self, tmp_path: Path) -> None: + """Forward pass output is identical after round-trip.""" + kwargs, spec = self._make_spec(dropout=0.5) + model = CustomMLPBlock(**kwargs).eval() + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + result = load_checkpoint(tmp_path) + reloaded, _ = result.models["main"] + reloaded.eval() + + x = torch.randn(4, kwargs["in_features"]) + with torch.no_grad(): + y_original = model(x) + y_reloaded = reloaded(x) + assert torch.equal(y_original, y_reloaded) + + def test_spec_json_is_pure_json(self, tmp_path: Path) -> None: + """spec.json contains only JSON-native types; no pickled blobs.""" + kwargs, spec = self._make_spec() + model = CustomMLPBlock(**kwargs) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + raw = (tmp_path / "models" / "main" / "spec.json").read_text() + parsed = json.loads(raw) + assert parsed["dtype"] == "torch.float32" + for key in ("cls_path", "timestamp"): + assert key in parsed + assert parsed["cls_path"].endswith(".CustomMLPBlock") + + def test_dtype_kwarg_round_trips(self, tmp_path: Path) -> None: + """torch.float64 dtype kwarg survives JSON round-trip.""" + kwargs, spec = self._make_spec(dtype=torch.float64) + model = CustomMLPBlock(**kwargs) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + result = load_checkpoint(tmp_path) + _, reloaded_spec = result.models["main"] + reloaded = result.models["main"][0] + assert reloaded_spec.dtype is torch.float64 + assert reloaded.expand.weight.dtype is torch.float64 + assert reloaded.norm.weight.dtype is torch.float64 + + def test_activation_parameter_checkpointed(self, tmp_path: Path) -> None: + """Learnable ``SwiGLU.scale`` survives the round-trip.""" + kwargs, spec = self._make_spec(activation_scale=0.5) + model = CustomMLPBlock(**kwargs) + with torch.no_grad(): + model.activation.scale.fill_(7.5) + + save_checkpoint(tmp_path, models={"main": (model, spec)}) + result = load_checkpoint(tmp_path) + reloaded, _ = result.models["main"] + assert torch.equal(reloaded.activation.scale, torch.tensor(7.5)) + + def test_autoincrement_multiple_checkpoints(self, tmp_path: Path) -> None: + """Autoincrement + per-index reload preserves correct weights.""" + kwargs, spec = self._make_spec() + model = CustomMLPBlock(**kwargs) + + snapshots: list[torch.Tensor] = [] + for step in range(3): + with torch.no_grad(): + model.project.weight.add_(float(step) + 1.0) + snapshots.append(model.project.weight.detach().clone()) + idx = save_checkpoint(tmp_path, models={"main": (model, spec)}) + assert idx == step + + for step, snapshot in enumerate(snapshots): + result = load_checkpoint(tmp_path, checkpoint_index=step) + reloaded, _ = result.models["main"] + assert torch.equal(reloaded.project.weight, snapshot), ( + f"checkpoint {step} did not reload its own weights" + ) + + def test_hyperparameter_mismatch_raises(self, tmp_path: Path) -> None: + """Saving a second spec with different hyperparameters must fail.""" + kwargs_a, spec_a = self._make_spec(hidden_features=16) + model_a = CustomMLPBlock(**kwargs_a) + save_checkpoint(tmp_path, models={"main": (model_a, spec_a)}) + + kwargs_b, spec_b = self._make_spec(hidden_features=32) + model_b = CustomMLPBlock(**kwargs_b) + with pytest.raises(ValueError, match="hidden_features"): + save_checkpoint(tmp_path, models={"main": (model_b, spec_b)}) + + +class TestSecurityAST: + """AST-level security invariants for ``_checkpoint.py``.""" + + _CHECKPOINT_PATH = ( + Path(__file__).resolve().parents[2] + / "nvalchemi" + / "training" + / "_checkpoint.py" + ) + _FORBIDDEN_MODULES = frozenset({"pickle", "cloudpickle", "dill", "marshal"}) + + def _tree(self) -> ast.AST: + return ast.parse(self._CHECKPOINT_PATH.read_text()) + + def test_no_pickle_imports(self) -> None: + """No imports of pickle, cloudpickle, dill, or marshal.""" + tree = self._tree() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + root = alias.name.split(".")[0] + assert root not in self._FORBIDDEN_MODULES, ( + f"_checkpoint.py:{node.lineno} imports forbidden " + f"module {alias.name!r}" + ) + elif isinstance(node, ast.ImportFrom): + if node.module is None: + continue + root = node.module.split(".")[0] + assert root not in self._FORBIDDEN_MODULES, ( + f"_checkpoint.py:{node.lineno} imports from forbidden " + f"module {node.module!r}" + ) + + def test_torch_load_always_weights_only(self) -> None: + """Every ``torch.load(...)`` call has ``weights_only=True``.""" + tree = self._tree() + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if not ( + isinstance(func, ast.Attribute) + and func.attr == "load" + and isinstance(func.value, ast.Name) + and func.value.id == "torch" + ): + continue + kw = {k.arg: k.value for k in node.keywords if k.arg is not None} + assert "weights_only" in kw, ( + f"_checkpoint.py:{node.lineno} torch.load() missing weights_only= kwarg" + ) + val = kw["weights_only"] + assert isinstance(val, ast.Constant) and val.value is True, ( + f"_checkpoint.py:{node.lineno} torch.load(weights_only=...) " + f"must be literal True, got {ast.dump(val)}" + ) + + def test_torch_save_uses_state_dict(self) -> None: + """``torch.save`` never receives a raw module/optimizer object. + + The implementation extracts ``state_dict()`` in the caller and + passes the dict to ``_save_component``, which calls ``torch.save`` + with a plain variable. We verify that no ``torch.save`` call has + a first argument that is a bare attribute access on ``self`` + (e.g., ``torch.save(model, ...)`` or ``torch.save(self.model, ...)``) + --- only plain names (like ``state_dict``) or subscripts are + acceptable. + """ + tree = self._tree() + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if not ( + isinstance(func, ast.Attribute) + and func.attr == "save" + and isinstance(func.value, ast.Name) + and func.value.id == "torch" + ): + continue + assert node.args, ( + f"_checkpoint.py:{node.lineno} torch.save() called with no args" + ) + first = node.args[0] + # Must NOT be a bare attribute on self (which would indicate + # saving a raw object). Acceptable: Name, Subscript, or Call. + if isinstance(first, ast.Attribute): + assert not ( + isinstance(first.value, ast.Name) and first.value.id == "self" + ), ( + f"_checkpoint.py:{node.lineno} torch.save() first arg " + f"appears to be a raw object (self.{first.attr}), " + f"expected a state_dict result" + ) + + +class TestSchemaVersion: + """Manifest schema versioning and forward-compatibility guard.""" + + def test_save_writes_schema_version(self, tmp_path: Path) -> None: + """``manifest.json`` contains ``schema_version`` after save.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + manifest = json.loads((tmp_path / "manifest.json").read_text()) + assert "schema_version" in manifest + assert manifest["schema_version"] == 1 + + def test_load_v0_manifest_without_schema_key(self, tmp_path: Path) -> None: + """A manifest missing ``schema_version`` (v0) loads successfully.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + # Strip schema_version to simulate a v0 manifest. + manifest_path = tmp_path / "manifest.json" + manifest = json.loads(manifest_path.read_text()) + manifest.pop("schema_version") + manifest_path.write_text(json.dumps(manifest, indent=2)) + + result = load_checkpoint(tmp_path) + assert "main" in result.models + reloaded, _ = result.models["main"] + assert torch.equal(reloaded.weight, model.weight) + + def test_future_schema_version_raises(self, tmp_path: Path) -> None: + """A manifest with a newer schema version raises ``ValueError``.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + # Bump to a future version. + manifest_path = tmp_path / "manifest.json" + manifest = json.loads(manifest_path.read_text()) + manifest["schema_version"] = 999 + manifest_path.write_text(json.dumps(manifest, indent=2)) + + with pytest.raises(ValueError, match="newer than supported"): + load_checkpoint(tmp_path) + + def test_schema_version_preserved_across_saves(self, tmp_path: Path) -> None: + """Successive saves always write the current schema version.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + for _ in range(3): + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + manifest = json.loads((tmp_path / "manifest.json").read_text()) + assert manifest["schema_version"] == 1 + + def test_manifest_pydantic_validation(self, tmp_path: Path) -> None: + """Malformed manifest.json triggers Pydantic ``ValidationError``.""" + from pydantic import ValidationError + + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + # Corrupt the manifest: models should be list[str], not a string. + manifest_path = tmp_path / "manifest.json" + raw = json.loads(manifest_path.read_text()) + raw["models"] = "not-a-list" + manifest_path.write_text(json.dumps(raw)) + + with pytest.raises(ValidationError): + load_checkpoint(tmp_path) + + def test_manifest_model_dump_roundtrip(self) -> None: + """``CheckpointManifest`` round-trips through JSON serialization.""" + original = CheckpointManifest( + checkpoint_index=3, + models=["a", "b"], + optimizers=["opt_a"], + schedulers=[], + associations={"a": {"optimizers": ["opt_a"], "schedulers": []}}, + ) + dumped = original.model_dump_json() + restored = CheckpointManifest.model_validate_json(dumped) + assert restored == original + + +class TestCheckpointGPU: + """GPU-specific checkpoint round-trip tests.""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_save_load_model_on_gpu(self, tmp_path: Path) -> None: + """Round-trip a model whose parameters live on CUDA.""" + device = torch.device("cuda") + model = nn.Linear(4, 2).to(device) + with torch.no_grad(): + model.weight.add_(1.25) + model.bias.add_(-0.5) + + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + idx = save_checkpoint(tmp_path, models={"main": (model, spec)}) + assert idx == 0 + + # Verify saved tensors are CUDA-resident. + saved = torch.load( + tmp_path / "models" / "main" / "checkpoints" / "0.pt", + weights_only=True, + ) + assert saved["weight"].is_cuda + assert saved["bias"].is_cuda + + result = load_checkpoint(tmp_path) + reloaded, _ = result.models["main"] + assert torch.allclose(reloaded.weight.cpu(), model.weight.cpu()) + assert torch.allclose(reloaded.bias.cpu(), model.bias.cpu()) + + +class TestDtypeRoundtrip: + """Regression: ``torch.dtype`` kwargs rehydrate as real dtype objects.""" + + def test_dtype_kwarg_roundtrip(self, tmp_path: Path) -> None: + """A spec with a ``torch.dtype`` kwarg round-trips bit-exactly. + + Atul's concern A3: when a spec carries a ``torch.dtype`` (e.g. + ``torch.float32``), the saved ``spec.json`` uses a string + representation, but on load the field must rehydrate to the + actual :class:`torch.dtype` so that ``spec.build()`` can hand it + to modules expecting a dtype (not a string). + """ + spec = create_model_spec( + nn.Linear, in_features=4, out_features=2, dtype=torch.float32 + ) + model = spec.build() + assert model.weight.dtype == torch.float32 + + save_checkpoint(tmp_path, models={"m": (model, spec)}) + result = load_checkpoint(tmp_path) + loaded_model, loaded_spec = result.models["m"] + + assert loaded_model.weight.dtype == torch.float32 + # The spec's dtype field must rehydrate as a torch.dtype (not a string). + assert loaded_spec.dtype == torch.float32 + assert isinstance(loaded_spec.dtype, torch.dtype) + + +class TestEMACheckpoint: + """Tests for round-tripping EMA (``AveragedModel``) wrappers. + + ``torch.optim.swa_utils.AveragedModel`` takes the base model as a + positional ``__init__`` argument; :func:`create_model_spec` only + accepts kwargs. The supported workflow is therefore to save the base + model (and optionally the inner averaged module) as ordinary + :class:`nn.Module`\\ s, then reconstruct the ``AveragedModel`` wrapper + in user code after loading. + """ + + def test_ema_base_model_roundtrip(self, tmp_path: Path) -> None: + """Save base + EMA inner module, reconstruct EMA wrapper on load.""" + from torch.optim.swa_utils import AveragedModel + + base = nn.Linear(4, 2) + ema = AveragedModel(base) + # Simulate training: perturb base weights, then update EMA. + for _ in range(3): + with torch.no_grad(): + base.weight.add_(torch.randn_like(base.weight) * 0.1) + ema.update_parameters(base) + + base_spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + + # Save both the base and the EMA's inner module. ``ema.module`` is + # a plain ``nn.Linear`` with the averaged state_dict. + save_checkpoint( + tmp_path, + models={ + "base": (base, base_spec), + "ema_inner": (ema.module, base_spec), + }, + ) + + result = load_checkpoint(tmp_path) + loaded_base, _ = result.models["base"] + loaded_ema_inner, _ = result.models["ema_inner"] + + # Base and EMA inner weights round-trip. + assert torch.allclose(loaded_base.weight, base.weight) + assert torch.allclose(loaded_ema_inner.weight, ema.module.weight) + + # Reconstruct the EMA wrapper: copy averaged weights into the new + # ``AveragedModel``'s inner module so forward-pass output matches. + reconstructed_ema = AveragedModel(loaded_base) + reconstructed_ema.module.load_state_dict(loaded_ema_inner.state_dict()) + x = torch.randn(1, 4) + with torch.no_grad(): + assert torch.allclose(reconstructed_ema(x), ema(x)) + + +class TestLoadCheckpointKwargs: + """Tests for the ``map_location`` and ``model_names`` kwargs.""" + + def test_map_location_cpu(self, tmp_path: Path) -> None: + """``map_location='cpu'`` places the loaded model on CPU.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"m": (model, spec)}) + + result = load_checkpoint(tmp_path, map_location="cpu") + loaded, _ = result.models["m"] + assert loaded.weight.device.type == "cpu" + # State-dict tensors must also be on CPU. + for v in loaded.state_dict().values(): + assert v.device.type == "cpu" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_map_location_cuda(self, tmp_path: Path) -> None: + """``map_location='cuda'`` places the loaded model on CUDA.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"m": (model, spec)}) + + result = load_checkpoint(tmp_path, map_location="cuda") + loaded, _ = result.models["m"] + assert loaded.weight.device.type == "cuda" + + def test_model_names_loads_only_specified_model(self, tmp_path: Path) -> None: + """``model_names`` restricts loading to the selected models only.""" + m1 = nn.Linear(4, 2) + m2 = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint( + tmp_path, + models={"student": (m1, spec), "teacher": (m2, spec)}, + ) + + result = load_checkpoint(tmp_path, model_names={"teacher"}) + assert list(result.models.keys()) == ["teacher"] + assert result.optimizers == {} + assert result.schedulers == {} + # Associations on the result remain informational (reflect on-disk state). + assert isinstance(result.associations, dict) + + def test_model_names_multi_select(self, tmp_path: Path) -> None: + """``model_names`` with multiple names loads all of them and their + associated optimizers/schedulers (union).""" + m1 = nn.Linear(4, 2) + m2 = nn.Linear(4, 2) + m3 = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + + opt1 = torch.optim.Adam(m1.parameters(), lr=0.01) + opt2 = torch.optim.Adam(m2.parameters(), lr=0.02) + opt_spec_1 = create_model_spec(torch.optim.Adam, lr=0.01) + opt_spec_2 = create_model_spec(torch.optim.Adam, lr=0.02) + + save_checkpoint( + tmp_path, + models={"a": (m1, spec), "b": (m2, spec), "c": (m3, spec)}, + optimizers={"a_opt": (opt1, opt_spec_1), "b_opt": (opt2, opt_spec_2)}, + ) + + result = load_checkpoint(tmp_path, model_names={"a", "b"}) + assert set(result.models.keys()) == {"a", "b"} + assert "c" not in result.models + # Both associated optimizers come along, c's (nonexistent) does not. + assert set(result.optimizers.keys()) == {"a_opt", "b_opt"} + + def test_model_names_includes_associated_components(self, tmp_path: Path) -> None: + """``model_names`` also loads the associated optimizers/schedulers.""" + m1 = nn.Linear(4, 2) + m2 = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + + # Give the student an optimizer; teacher gets nothing. + opt1 = torch.optim.Adam(m1.parameters(), lr=0.01) + opt_spec = create_model_spec(torch.optim.Adam, lr=0.01) + + save_checkpoint( + tmp_path, + models={"student": (m1, spec), "teacher": (m2, spec)}, + optimizers={"s_opt": (opt1, opt_spec)}, + ) + + # Loading student pulls in its associated optimizer. + result = load_checkpoint(tmp_path, model_names={"student"}) + assert list(result.models.keys()) == ["student"] + assert "s_opt" in result.optimizers + + # Loading teacher picks up no optimizer (none associated). + result = load_checkpoint(tmp_path, model_names={"teacher"}) + assert list(result.models.keys()) == ["teacher"] + assert result.optimizers == {} + + def test_model_names_unknown_raises_keyerror(self, tmp_path: Path) -> None: + """Unknown names in ``model_names`` raise :class:`KeyError` listing them.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint( + tmp_path, + models={"student": (model, spec), "teacher": (model, spec)}, + ) + + with pytest.raises(KeyError, match="nonexistent"): + load_checkpoint(tmp_path, model_names={"nonexistent"}) + + # Multiple unknowns — both should be reported. + with pytest.raises(KeyError) as excinfo: + load_checkpoint(tmp_path, model_names={"nonexistent", "ghost"}) + msg = str(excinfo.value) + assert "nonexistent" in msg + assert "ghost" in msg + # The error message must list the available model names. + assert "student" in msg + assert "teacher" in msg + + +class TestStrategyCheckpoint: + """High-level strategy checkpoint round-trip behavior.""" + + def test_strategy_checkpoint_loads_builtin_result_and_restart_state( + self, tmp_path: Path + ) -> None: + """Saving a strategy returns a restartable high-level load dict.""" + strategy = _make_checkpoint_strategy(num_steps=4) + strategy.train_batch(_make_checkpoint_batch(seed=1)) + assert strategy.step_count == 1 + assert strategy.batch_count == 1 + + idx = save_checkpoint(tmp_path, strategy=strategy) + assert idx == 0 + assert (tmp_path / "strategy.json").is_file() + + loaded = load_checkpoint(tmp_path) + assert isinstance(loaded, dict) + assert loaded["checkpoint_index"] == 0 + assert loaded["source"]["format"] == "native" + + restored = loaded["strategy"] + assert isinstance(restored, TrainingStrategy) + assert restored.step_count == 1 + assert restored.batch_count == 1 + assert restored.epoch_count == strategy.epoch_count + assert restored.epoch_step_count == strategy.epoch_step_count + assert restored.single_model_input is True + + main_entry = loaded["models"]["main"] + assert main_entry["model"] is restored.models["main"] + assert set(main_entry["optimizers"]) == {"main_optimizer"} + assert set(main_entry["schedulers"]) == {"main_scheduler"} + assert restored._resume_optimizer_state is True + assert restored._optimizers[0].state_dict()["state"] + assert restored._lr_schedulers[0].state_dict()["last_epoch"] == 1 + + restored.run( + [ + _make_checkpoint_batch(seed=2), + _make_checkpoint_batch(seed=3), + _make_checkpoint_batch(seed=4), + ] + ) + assert restored.step_count == 4 + + def test_restored_strategy_can_save_next_checkpoint(self, tmp_path: Path) -> None: + """A resumed strategy can continue writing checkpoints in the same root.""" + strategy = _make_checkpoint_strategy(num_steps=4) + strategy.train_batch(_make_checkpoint_batch(seed=1)) + save_checkpoint(tmp_path, strategy=strategy) + + loaded = load_checkpoint(tmp_path) + restored = loaded["strategy"] + restored.run( + [ + _make_checkpoint_batch(seed=2), + _make_checkpoint_batch(seed=3), + _make_checkpoint_batch(seed=4), + ] + ) + + idx = save_checkpoint(tmp_path, strategy=restored) + assert idx == 1 + assert (tmp_path / "models" / "main" / "checkpoints" / "1.pt").is_file() + assert ( + tmp_path / "optimizers" / "main_optimizer" / "checkpoints" / "1.pt" + ).is_file() + assert ( + tmp_path / "schedulers" / "main_scheduler" / "checkpoints" / "1.pt" + ).is_file() + + manifest = json.loads((tmp_path / "manifest.json").read_text()) + assert manifest["checkpoint_index"] == 1 + strategy_metadata = json.loads((tmp_path / "strategy.json").read_text()) + assert strategy_metadata["runtime_state"]["step_count"] == 4 + indexed_metadata = json.loads( + (tmp_path / "strategy" / "checkpoints" / "1.json").read_text() + ) + assert indexed_metadata["runtime_state"]["step_count"] == 4 + + reloaded = load_checkpoint(tmp_path, checkpoint_index=1) + assert reloaded["strategy"].step_count == 4 + + def test_strategy_metadata_is_loaded_by_checkpoint_index( + self, tmp_path: Path + ) -> None: + """Explicit checkpoint indices restore matching strategy counters.""" + strategy = _make_checkpoint_strategy(num_steps=4) + strategy.train_batch(_make_checkpoint_batch(seed=1)) + save_checkpoint(tmp_path, strategy=strategy) + + loaded = load_checkpoint(tmp_path) + restored = loaded["strategy"] + restored.run( + [ + _make_checkpoint_batch(seed=2), + _make_checkpoint_batch(seed=3), + _make_checkpoint_batch(seed=4), + ] + ) + save_checkpoint(tmp_path, strategy=restored) + + loaded_first = load_checkpoint(tmp_path, checkpoint_index=0) + loaded_second = load_checkpoint(tmp_path, checkpoint_index=1) + + assert loaded_first["strategy"].step_count == 1 + assert loaded_second["strategy"].step_count == 4 + assert (tmp_path / "strategy" / "checkpoints" / "0.json").is_file() + assert (tmp_path / "strategy" / "checkpoints" / "1.json").is_file() + + def test_multi_optimizer_strategy_schedulers_keep_optimizer_edges( + self, tmp_path: Path + ) -> None: + """Schedulers in multi-optimizer strategies reload on the right optimizer.""" + strategy = _make_multi_optimizer_checkpoint_strategy() + save_checkpoint(tmp_path, strategy=strategy) + + loaded = load_checkpoint(tmp_path) + restored = loaded["strategy"] + + assert set(loaded["models"]["main"]["optimizers"]) == { + "main_optimizer_0", + "main_optimizer_1", + } + assert set(loaded["models"]["main"]["schedulers"]) == { + "main_scheduler_0", + "main_scheduler_1", + } + assert restored._lr_schedulers[0] is not None + assert restored._lr_schedulers[1] is not None + assert restored._lr_schedulers[0].optimizer is restored._optimizers[0] + assert restored._lr_schedulers[1].optimizer is restored._optimizers[1] + + def test_register_hook_preserves_loaded_optimizer_state( + self, tmp_path: Path + ) -> None: + """Adding observer hooks after load does not discard optimizer state.""" + strategy = _make_checkpoint_strategy(num_steps=2) + strategy.train_batch(_make_checkpoint_batch(seed=1)) + save_checkpoint(tmp_path, strategy=strategy) + + restored = load_checkpoint(tmp_path)["strategy"] + assert restored._resume_optimizer_state is True + + restored.register_hook(_NoOpCheckpointHook()) + + assert restored._resume_optimizer_state is True + + def test_map_location_overrides_restored_strategy_device( + self, tmp_path: Path + ) -> None: + """A CPU map_location keeps the restored strategy runnable on CPU.""" + strategy = _make_checkpoint_strategy(num_steps=2) + strategy.train_batch(_make_checkpoint_batch(seed=1)) + save_checkpoint(tmp_path, strategy=strategy) + + for metadata_path in ( + tmp_path / "strategy.json", + tmp_path / "strategy" / "checkpoints" / "0.json", + ): + metadata = json.loads(metadata_path.read_text()) + metadata["devices"] = ["cuda"] + metadata_path.write_text(json.dumps(metadata)) + + restored = load_checkpoint(tmp_path, map_location="cpu")["strategy"] + + assert restored.devices == [torch.device("cpu")] + assert all( + parameter.device.type == "cpu" + for parameter in restored.models["main"].parameters() + ) + optimizer_params = restored._optimizers[0].param_groups[0]["params"] + assert all(parameter.device.type == "cpu" for parameter in optimizer_params) + + def test_strategy_can_be_second_positional_argument(self, tmp_path: Path) -> None: + """``save_checkpoint(path, strategy)`` is accepted as high-level UX.""" + strategy = _make_checkpoint_strategy(num_steps=2) + idx = save_checkpoint(tmp_path, strategy) + assert idx == 0 + loaded = load_checkpoint(tmp_path) + assert isinstance(loaded["strategy"], TrainingStrategy) + + def test_strategy_methods_save_and_load_restartable_checkpoint( + self, tmp_path: Path + ) -> None: + """``TrainingStrategy`` exposes one-off save/load checkpoint helpers.""" + strategy = _make_checkpoint_strategy(num_steps=3) + strategy.train_batch(_make_checkpoint_batch(seed=1)) + + idx = strategy.save_checkpoint(tmp_path) + + assert idx == 0 + restored = TrainingStrategy.load_checkpoint(tmp_path, map_location="cpu") + assert isinstance(restored, TrainingStrategy) + assert restored.step_count == 1 + assert restored.batch_count == 1 + assert restored._resume_optimizer_state is True + + restored.run( + [ + _make_checkpoint_batch(seed=2), + _make_checkpoint_batch(seed=3), + ] + ) + assert restored.step_count == 3 + + idx = restored.save_checkpoint(tmp_path) + assert idx == 1 + reloaded = TrainingStrategy.load_checkpoint(tmp_path, checkpoint_index=1) + assert reloaded.step_count == 3 + assert reloaded.batch_count == 3 + + def test_strategy_load_checkpoint_requires_strategy_metadata( + self, tmp_path: Path + ) -> None: + """The strategy convenience loader rejects component-only checkpoints.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + with pytest.raises( + ValueError, match="checkpoint saved from a TrainingStrategy" + ): + TrainingStrategy.load_checkpoint(tmp_path) + + def test_validator_callback_wraps_failures(self, tmp_path: Path) -> None: + """Validators receive model entries and errors name the checkpoint/model.""" + model = nn.Linear(4, 2) + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + + seen: list[str] = [] + + def passing_validator( + model_name: str, + entry: dict[str, Any], + loaded: dict[str, Any], + ) -> None: + seen.append(model_name) + assert isinstance(entry["model"], nn.Linear) + assert loaded["source"]["path"] == str(tmp_path) + + loaded = load_checkpoint(tmp_path, validators=[passing_validator]) + assert isinstance(loaded, CheckpointManifest) + assert seen == ["main"] + + def failing_validator( + model_name: str, + entry: dict[str, Any], + loaded: dict[str, Any], + ) -> None: + del model_name, entry, loaded + raise RuntimeError("unsupported elements") + + with pytest.raises(ValueError, match="unsupported elements"): + load_checkpoint(tmp_path, validators=[failing_validator]) + + def test_mace_adapter_uses_training_safe_defaults( + self, + tmp_path: Path, + ) -> None: + """MACE adapter loads local files with cuEq/compile disabled by default.""" + checkpoint_path = tmp_path / "mace.pt" + checkpoint_path.write_bytes(b"placeholder") + wrapper = nn.Linear(1, 1) + + with patch( + "nvalchemi.models.mace.MACEWrapper.from_checkpoint", + return_value=wrapper, + ) as mocked: + with pytest.warns(UserWarning, match="trusted"): + loaded = load_checkpoint( + checkpoint_path, + adapter="mace", + adapter_kwargs={"dtype": torch.float32}, + ) + + assert loaded["strategy"] is None + assert loaded["models"]["main"]["model"] is wrapper + assert loaded["models"]["main"]["spec"] is None + assert loaded["source"]["format"] == "mace" + mocked.assert_called_once() + kwargs = mocked.call_args.kwargs + assert kwargs["enable_cueq"] is False + assert kwargs["compile_model"] is False + assert kwargs["dtype"] is torch.float32 diff --git a/test/training/test_checkpoint_hook.py b/test/training/test_checkpoint_hook.py new file mode 100644 index 00000000..827dc96a --- /dev/null +++ b/test/training/test_checkpoint_hook.py @@ -0,0 +1,408 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for periodic training checkpoint hooks.""" + +from __future__ import annotations + +import json +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +import pytest +import torch +from torch import distributed as dist + +from nvalchemi.training import ( + CheckpointHook, + EMAHook, + OptimizerConfig, + TrainingStage, + TrainingStrategy, + load_checkpoint, +) +from test.training.conftest import _build_baseline_strategy_kwargs + + +def _model_parameter_vector(strategy: TrainingStrategy) -> torch.Tensor: + """Return a detached flat parameter vector for the strategy's main model.""" + return torch.cat( + [ + param.detach().cpu().reshape(-1) + for param in strategy.models["main"].parameters() + ] + ) + + +def _ema_state_dict(hook: EMAHook) -> dict[str, Any]: + """Return a detached CPU snapshot of an initialized EMA wrapper.""" + return { + key: value.detach().cpu().clone() if isinstance(value, torch.Tensor) else value + for key, value in hook.get_averaged_model().state_dict().items() + } + + +def _assert_state_dict_close( + actual: Mapping[str, Any], + expected: Mapping[str, Any], +) -> None: + """Assert two state dictionaries contain equal scalar and tensor values.""" + assert actual.keys() == expected.keys() + for key, value in actual.items(): + if isinstance(value, torch.Tensor): + torch.testing.assert_close(value, expected[key], msg=f"state {key!r}") + else: + assert value == expected[key] + + +def _ema_restart_strategy_kwargs() -> dict[str, Any]: + """Return deterministic strategy kwargs for interrupted-run comparisons.""" + return { + **_build_baseline_strategy_kwargs(), + "optimizer_configs": OptimizerConfig( + optimizer_cls=torch.optim.SGD, + optimizer_kwargs={"lr": 1e-3}, + ), + } + + +def _init_single_process_group(tmp_path: Path) -> None: + """Initialize a single-rank process group for CPU DDP tests.""" + init_file = tmp_path / "ddp_init" + dist.init_process_group( + "gloo", + init_method=f"file://{init_file}", + rank=0, + world_size=1, + ) + + +class TestCheckpointHookConstruction: + """Validate checkpoint hook configuration.""" + + def test_without_interval_claims_no_stage(self, tmp_path: Path) -> None: + """A checkpoint hook without a cadence is a no-op observer.""" + hook = CheckpointHook(tmp_path) + assert not hook._runs_on_stage(TrainingStage.AFTER_BATCH) + assert not hook._runs_on_stage(TrainingStage.AFTER_EPOCH) + + def test_rejects_step_and_epoch_interval_together(self, tmp_path: Path) -> None: + """A single checkpoint hook owns one cadence policy.""" + with pytest.raises(ValueError, match="exactly one"): + CheckpointHook(tmp_path, step_interval=10, epoch_interval=1) + + @pytest.mark.parametrize("field", ["step_interval", "epoch_interval"]) + def test_interval_must_be_positive(self, tmp_path: Path, field: str) -> None: + """Configured checkpoint cadences must be positive.""" + with pytest.raises(ValueError, match="greater than 0"): + CheckpointHook(tmp_path, **{field: 0}) + + +class TestCheckpointHookCadence: + """Verify periodic checkpoint saves from a running strategy.""" + + def test_step_interval_saves_restartable_checkpoints( + self, + tmp_path: Path, + baseline_strategy_kwargs: dict[str, Any], + dataset: list[Any], + ) -> None: + """Step cadence writes restart checkpoints at completed optimizer steps.""" + hook = CheckpointHook(tmp_path, step_interval=2, async_save=False) + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "num_epochs": None, + "num_steps": 4, + "hooks": [hook], + } + ) + + strategy.run(dataset) + + assert hook.last_checkpoint_index == 1 + assert (tmp_path / "models" / "main" / "checkpoints" / "0.pt").is_file() + assert (tmp_path / "models" / "main" / "checkpoints" / "1.pt").is_file() + first = load_checkpoint(tmp_path, checkpoint_index=0)["strategy"] + second = load_checkpoint(tmp_path, checkpoint_index=1)["strategy"] + assert first.step_count == 2 + assert second.step_count == 4 + + def test_epoch_interval_saves_completed_epoch_state( + self, + tmp_path: Path, + baseline_strategy_kwargs: dict[str, Any], + dataset: list[Any], + ) -> None: + """Epoch cadence saves after epoch counters have advanced.""" + hook = CheckpointHook(tmp_path, epoch_interval=1, async_save=False) + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "num_epochs": 2, + "hooks": [hook], + } + ) + + strategy.run(dataset) + + assert hook.last_checkpoint_index == 1 + first_metadata = json.loads( + (tmp_path / "strategy" / "checkpoints" / "0.json").read_text() + ) + second_metadata = json.loads( + (tmp_path / "strategy" / "checkpoints" / "1.json").read_text() + ) + assert first_metadata["runtime_state"]["epoch_count"] == 1 + assert first_metadata["runtime_state"]["epoch_step_count"] == 0 + assert second_metadata["runtime_state"]["epoch_count"] == 2 + assert second_metadata["runtime_state"]["epoch_step_count"] == 0 + + def test_async_save_flushes_on_strategy_exit( + self, + tmp_path: Path, + baseline_strategy_kwargs: dict[str, Any], + dataset: list[Any], + ) -> None: + """Async checkpoint writes finish before ``TrainingStrategy.run`` returns.""" + hook = CheckpointHook(tmp_path, step_interval=1) + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "num_epochs": None, + "num_steps": 1, + "hooks": [hook], + } + ) + + strategy.run(dataset) + + assert hook.last_checkpoint_index == 0 + restored = load_checkpoint(tmp_path)["strategy"] + assert restored.step_count == 1 + + def test_restarted_strategy_continues_periodic_checkpoint_round_trip( + self, + tmp_path: Path, + baseline_strategy_kwargs: dict[str, Any], + dataset: list[Any], + ) -> None: + """Repeated save-load cycles preserve updated restart state.""" + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "num_epochs": None, + "num_steps": 1, + "hooks": [ + CheckpointHook(tmp_path, step_interval=1, async_save=False), + ], + } + ) + previous_params = _model_parameter_vector(strategy) + + for checkpoint_index in range(3): + strategy.num_steps = strategy.step_count + 1 + strategy.run([dataset[checkpoint_index]]) + + current_params = _model_parameter_vector(strategy) + assert not torch.allclose(current_params, previous_params) + + loaded = load_checkpoint( + tmp_path, + checkpoint_index=checkpoint_index, + hooks=[ + CheckpointHook(tmp_path, step_interval=1, async_save=False), + ], + ) + restored = loaded["strategy"] + + assert loaded["checkpoint_index"] == checkpoint_index + assert restored.step_count == checkpoint_index + 1 + assert restored.batch_count == checkpoint_index + 1 + assert restored._resume_optimizer_state is True + assert restored._optimizers[0].state_dict()["state"] + torch.testing.assert_close( + _model_parameter_vector(restored), + current_params, + ) + + strategy = restored + previous_params = current_params + + def test_periodic_checkpoint_restores_ema_hook_state( + self, + tmp_path: Path, + baseline_strategy_kwargs: dict[str, Any], + dataset: list[Any], + ) -> None: + """Periodic checkpoints restore checkpointable EMA hook state.""" + ema = EMAHook(model_key="main", decay=0.5) + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "num_epochs": None, + "num_steps": 2, + "hooks": [ + ema, + CheckpointHook(tmp_path, step_interval=1, async_save=False), + ], + } + ) + + strategy.run(dataset) + saved_state = ema.state_dict() + + restored_ema = EMAHook(model_key="main", decay=0.5) + restored = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "num_epochs": None, + "num_steps": 2, + "hooks": [ + restored_ema, + CheckpointHook(tmp_path, step_interval=1, async_save=False), + ], + } + ) + restored.restore_checkpoint(tmp_path, checkpoint_index=1) + + assert restored.step_count == 2 + assert restored_ema.num_updates == ema.num_updates + assert restored_ema._averaged_model is None + assert restored_ema._pending_averaged_state is not None + + saved_average = saved_state["averaged_model_state"] + for key, value in restored_ema._pending_averaged_state.items(): + torch.testing.assert_close(value, saved_average[key]) + + def test_restarted_training_matches_uninterrupted_ema_average( + self, + tmp_path: Path, + dataset: list[Any], + ) -> None: + """EMA average continues exactly across a strategy checkpoint restart.""" + full_dataset = dataset[:3] + decay = 0.5 + + reference_ema = EMAHook(model_key="main", decay=decay) + reference = TrainingStrategy( + **{ + **_ema_restart_strategy_kwargs(), + "num_epochs": None, + "num_steps": 3, + "hooks": [reference_ema], + } + ) + reference.run(full_dataset) + expected_params = _model_parameter_vector(reference) + expected_ema_state = _ema_state_dict(reference_ema) + + checkpoint_ema = EMAHook(model_key="main", decay=decay) + checkpointed = TrainingStrategy( + **{ + **_ema_restart_strategy_kwargs(), + "num_epochs": None, + "num_steps": 2, + "hooks": [ + checkpoint_ema, + CheckpointHook(tmp_path, step_interval=2, async_save=False), + ], + } + ) + checkpointed.run(full_dataset) + + assert checkpoint_ema.num_updates == 2 + assert (tmp_path / "hooks" / "checkpoints" / "0.pt").is_file() + + restored_ema = EMAHook(model_key="main", decay=decay) + restored = TrainingStrategy.load_checkpoint( + tmp_path, + checkpoint_index=0, + hooks=[restored_ema], + ) + assert restored.step_count == 2 + assert restored_ema._averaged_model is None + assert restored_ema._pending_averaged_state is not None + + restored.num_steps = 3 + restored.run(full_dataset) + + assert restored.step_count == reference.step_count + assert restored_ema.num_updates == reference_ema.num_updates + torch.testing.assert_close(_model_parameter_vector(restored), expected_params) + _assert_state_dict_close(_ema_state_dict(restored_ema), expected_ema_state) + + @pytest.mark.skipif(not dist.is_gloo_available(), reason="gloo backend required") + def test_ddp_wrapped_strategy_saves_unwrapped_model_state( + self, + tmp_path: Path, + baseline_strategy_kwargs: dict[str, Any], + dataset: list[Any], + ) -> None: + """DDP checkpoints save the underlying model, not ``module.`` keys.""" + if dist.is_initialized(): + pytest.skip("test requires ownership of the process group") + checkpoint_dir = tmp_path / "checkpoints" + _init_single_process_group(tmp_path) + try: + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "num_epochs": None, + "num_steps": 1, + "hooks": [ + CheckpointHook( + checkpoint_dir, + step_interval=1, + async_save=False, + ), + ], + } + ) + strategy.models["main"] = torch.nn.parallel.DistributedDataParallel( + strategy.models["main"] + ) + + strategy.run([dataset[0]]) + + weights = torch.load( + checkpoint_dir / "models" / "main" / "checkpoints" / "0.pt", + weights_only=True, + ) + assert all(not key.startswith("module.") for key in weights) + restored = load_checkpoint(checkpoint_dir)["strategy"] + torch.testing.assert_close( + _model_parameter_vector(restored), + _model_parameter_vector(strategy), + ) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + def test_native_checkpoint_rejects_fsdp_wrapped_model( + self, + monkeypatch: pytest.MonkeyPatch, + demo_model: torch.nn.Module, + ) -> None: + """FSDP/FSDP2 models fail clearly until DCP support is implemented.""" + from nvalchemi.training import _checkpoint + + monkeypatch.setattr(_checkpoint, "_is_fsdp_wrapped", lambda module: True) + + with pytest.raises( + NotImplementedError, + match="torch.distributed.checkpoint", + ): + _checkpoint._checkpoint_model(demo_model) diff --git a/test/training/test_ddp_hook.py b/test/training/test_ddp_hook.py new file mode 100644 index 00000000..be69669f --- /dev/null +++ b/test/training/test_ddp_hook.py @@ -0,0 +1,649 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DDPHook and distributed manager integration.""" + +from __future__ import annotations + +import os +import queue +import socket +from enum import Enum +from typing import Any, Sequence + +import pytest +import torch +from torch import distributed as dist +from torch.utils.data import ( + BatchSampler, + DataLoader, + DistributedSampler, + Sampler, + SequentialSampler, +) + +from nvalchemi.data.atomic_data import AtomicData +from nvalchemi.hooks._context import HookContext, TrainContext +from nvalchemi.training import TrainingStage +from nvalchemi.training.hooks import DDPHook +from nvalchemi.training.strategy import TrainingStrategy +from test.training.conftest import ( + _build_baseline_strategy_kwargs, + _build_batch, + _build_dataset, +) + + +class _FakeManager: + """Structural distributed manager used by hook tests.""" + + def __init__(self, *, world_size: int = 2, rank: int = 0) -> None: + self.world_size = world_size + self.rank = rank + self.global_rank = rank + self.local_rank = rank + self.initialized = world_size > 1 + self.device = torch.device("cpu") + self.broadcast_buffers = False + self.find_unused_parameters = False + + def is_initialized(self) -> bool: + return self.initialized + + +class _FakeDDP(torch.nn.Module): + """Small DDP stand-in that records constructor kwargs.""" + + calls: list[dict[str, Any]] = [] + + def __init__(self, module: torch.nn.Module, **kwargs: Any) -> None: + super().__init__() + self.module = module + self.kwargs = kwargs + type(self).calls.append(kwargs) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.module(*args, **kwargs) + + +class _CustomDistributedSampler(Sampler[int]): + """Sampler with non-DistributedSampler constructor argument names.""" + + def __init__( + self, + data_source: Any, + *, + shards: int, + position: int, + token: object, + ) -> None: + self.data_source = data_source + self.shards = shards + self.position = position + self.token = token + + def __iter__(self) -> Any: + return iter(range(self.position, len(self.data_source), self.shards)) + + def __len__(self) -> int: + return len(range(self.position, len(self.data_source), self.shards)) + + +class _TorchKeywordDistributedSampler(Sampler[int]): + """Sampler that follows PyTorch DistributedSampler constructor keywords.""" + + def __init__( + self, + data_source: Any, + *, + num_replicas: int, + rank: int, + shuffle: bool = False, + seed: int = 0, + drop_last: bool = False, + ) -> None: + self.data_source = data_source + self.num_replicas = num_replicas + self.rank = rank + self.shuffle = shuffle + self.seed = seed + self.drop_last = drop_last + self.epoch = 0 + + def __iter__(self) -> Any: + return iter(range(self.rank, len(self.data_source), self.num_replicas)) + + def __len__(self) -> int: + return len(range(self.rank, len(self.data_source), self.num_replicas)) + + def set_epoch(self, epoch: int) -> None: + """Record the sampler epoch.""" + self.epoch = epoch + + +class _MutableSamplerDataloader: + """Minimal dataloader-like object with a mutable sampler attribute.""" + + def __init__( + self, + dataset: Any, + *, + sampler: Any | None = None, + drop_last: bool = False, + ) -> None: + self.dataset = dataset + self.sampler = sampler + self.drop_last = drop_last + + +class _ContextCaptureHook: + """Capture contexts observed at a given stage.""" + + frequency = 1 + + def __init__(self, stage: TrainingStage) -> None: + self.stage = stage + self.contexts: list[TrainContext] = [] + + def __call__(self, ctx: HookContext, stage: Enum) -> None: + assert isinstance(ctx, TrainContext) + self.contexts.append(ctx) + + +class _OptimizerParamHook: + """Assert optimizers are constructed after DDP wrapping.""" + + frequency = 1 + stage = TrainingStage.BEFORE_TRAINING + + def __init__(self) -> None: + self.saw_wrapped_model = False + + def __call__(self, ctx: HookContext, stage: Enum) -> None: + assert isinstance(ctx, TrainContext) + assert ctx.models is not None + model = ctx.models["main"] + self.saw_wrapped_model = isinstance(model, _FakeDDP) + model_param_ids = {id(param) for param in model.parameters()} + optimizer_param_ids = { + id(param) + for optimizer in ctx.optimizers + for group in optimizer.param_groups + for param in group["params"] + } + assert optimizer_param_ids <= model_param_ids + + +class _Reader: + """Minimal datapipe reader for sampler mutation tests.""" + + def __len__(self) -> int: + return 4 + + def _load_sample(self, index: int) -> dict[str, torch.Tensor]: + return { + "positions": torch.zeros(1, 3), + "atomic_numbers": torch.ones(1, dtype=torch.long), + "atomic_masses": torch.ones(1), + } + + def _get_sample_metadata(self, index: int) -> dict[str, Any]: + return {} + + def read_many( + self, indices: Sequence[int] + ) -> list[tuple[dict[str, torch.Tensor], dict[str, Any]]]: + """Load multiple samples and metadata records.""" + return [ + (self._load_sample(index), self._get_sample_metadata(index)) + for index in indices + ] + + def close(self) -> None: + pass + + +def _make_strategy(**overrides: Any) -> TrainingStrategy: + """Build a baseline TrainingStrategy with local overrides.""" + kwargs = _build_baseline_strategy_kwargs() + if "num_steps" in overrides and "num_epochs" not in overrides: + kwargs["num_epochs"] = None + kwargs.update(overrides) + return TrainingStrategy(**kwargs) + + +def _free_port() -> int: + """Return an available localhost TCP port for process-group setup.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _state_dict_cpu(strategy: TrainingStrategy) -> dict[str, torch.Tensor]: + """Return the main model state dict detached on CPU.""" + return { + key: value.detach().cpu().clone() + for key, value in strategy.models["main"].state_dict().items() + } + + +def _run_ddp_worker( + rank: int, + world_size: int, + port: int, + result_queue: Any, +) -> None: + """Run one CPU DDP training step and send final parameters to the parent.""" + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(port), + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank), + } + ) + strategy = _make_strategy( + hooks=[DDPHook(backend="gloo", find_unused_parameters=True)], + num_steps=1, + ) + strategy.run([_build_batch(n_systems=1, seed=5)]) + result_queue.put( + ( + rank, + {key: value.tolist() for key, value in _state_dict_cpu(strategy).items()}, + ) + ) + + +@pytest.fixture(autouse=True) +def _reset_fake_ddp() -> None: + """Reset fake DDP call history before every test.""" + _FakeDDP.calls.clear() + + +class TestDistributedManagerField: + def test_nvalchemi_distributed_reexports_physicsnemo_manager(self) -> None: + from physicsnemo.distributed import DistributedManager as PhysicsNeMoManager + + from nvalchemi.distributed import DistributedManager + + assert DistributedManager is PhysicsNeMoManager + + def test_resolves_rank_and_world_size_from_environment( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from nvalchemi import distributed + + class _UninitializedManager: + @classmethod + def is_initialized(cls) -> bool: + return False + + monkeypatch.setattr(distributed, "DistributedManager", _UninitializedManager) + monkeypatch.setenv("RANK", "3") + monkeypatch.setenv("WORLD_SIZE", "8") + + assert distributed.resolve_global_rank() == 3 + assert distributed.resolve_world_size() == 8 + + def test_explicit_rank_overrides_runtime_state(self) -> None: + from nvalchemi.distributed import resolve_global_rank + + assert resolve_global_rank(5) == 5 + + def test_manager_is_runtime_only_and_visible_to_context(self) -> None: + manager = _FakeManager(world_size=1) + capture = _ContextCaptureHook(TrainingStage.BEFORE_BATCH) + strategy = _make_strategy( + distributed_manager=manager, + hooks=[capture], + num_steps=1, + ) + + assert "distributed_manager" not in strategy.to_spec_dict() + strategy.run([_build_batch()]) + + assert capture.contexts + assert capture.contexts[0].workflow.distributed_manager is manager + assert capture.contexts[0].global_rank == manager.rank + + +class TestDDPHookWrapping: + def test_wraps_before_optimizer_construction_and_restores( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr(torch.nn.parallel, "DistributedDataParallel", _FakeDDP) + ddp = DDPHook(find_unused_parameters=True, broadcast_buffers=False) + recorder = _OptimizerParamHook() + strategy = _make_strategy( + distributed_manager=_FakeManager(), + hooks=[ddp, recorder], + num_steps=1, + ) + original = strategy.models["main"] + + strategy.run([_build_batch()]) + + assert recorder.saw_wrapped_model + assert strategy.models["main"] is original + assert _FakeDDP.calls == [ + { + "find_unused_parameters": True, + "broadcast_buffers": False, + "static_graph": False, + } + ] + + def test_defaults_to_manager_ddp_flags( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr(torch.nn.parallel, "DistributedDataParallel", _FakeDDP) + manager = _FakeManager() + manager.find_unused_parameters = True + manager.broadcast_buffers = True + strategy = _make_strategy( + distributed_manager=manager, + hooks=[DDPHook()], + num_steps=1, + ) + + strategy.run([_build_batch()]) + + assert _FakeDDP.calls == [ + { + "find_unused_parameters": True, + "broadcast_buffers": True, + "static_graph": False, + } + ] + + def test_unknown_model_key_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(torch.nn.parallel, "DistributedDataParallel", _FakeDDP) + strategy = _make_strategy( + distributed_manager=_FakeManager(), + hooks=[DDPHook(model_keys=("missing",))], + num_steps=1, + ) + + with pytest.raises(KeyError, match="unknown model"): + strategy.run([_build_batch()]) + + +class TestDDPHookDataloaderMutation: + def test_sets_sampler_on_generic_dataloader_with_sampler_attribute(self) -> None: + hook = DDPHook() + hook._manager = _FakeManager(rank=1) + loader = _MutableSamplerDataloader(list(range(8)), drop_last=True) + + prepared = hook.prepare_dataloader(loader) + + assert prepared is loader + assert isinstance(loader.sampler, DistributedSampler) + assert loader.sampler.rank == 1 + assert loader.sampler.num_replicas == 2 + assert loader.sampler.drop_last is True + + def test_strategy_setup_uses_workflow_dataloader( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setattr(torch.nn.parallel, "DistributedDataParallel", _FakeDDP) + manager = _FakeManager(rank=1) + loader = DataLoader( + _build_dataset(n_batches=4), + batch_size=1, + shuffle=True, + collate_fn=lambda x: x[0], + ) + strategy = _make_strategy( + distributed_manager=manager, + hooks=[DDPHook()], + num_steps=1, + ) + + strategy.run(loader) + + assert strategy.active_dataloader is not loader + assert isinstance(strategy.active_dataloader.sampler, DistributedSampler) + assert strategy.active_dataloader.sampler.rank == manager.rank + + def test_replaces_torch_dataloader_sampler(self) -> None: + hook = DDPHook() + hook._manager = _FakeManager(rank=1) + dataset = list(range(8)) + loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0) + + prepared = hook.prepare_dataloader(loader) + + assert prepared is not loader + assert isinstance(prepared, DataLoader) + assert isinstance(prepared.sampler, DistributedSampler) + assert prepared.sampler.rank == 1 + assert prepared.sampler.num_replicas == 2 + assert prepared.sampler.shuffle is False + + def test_sampler_kwargs_override_default_sampler_args(self) -> None: + hook = DDPHook( + sampler_kwargs={ + "shuffle": True, + "seed": 17, + "drop_last": True, + } + ) + hook._manager = _FakeManager(rank=1) + dataset = list(range(8)) + loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0) + + prepared = hook.prepare_dataloader(loader) + + assert isinstance(prepared, DataLoader) + assert isinstance(prepared.sampler, DistributedSampler) + assert prepared.sampler.shuffle is True + assert prepared.sampler.seed == 17 + assert prepared.sampler.drop_last is True + + def test_uses_custom_sampler_cls_and_kwargs(self) -> None: + token = object() + hook = DDPHook( + sampler_cls=_CustomDistributedSampler, + sampler_kwargs={ + "shards": 4, + "position": 2, + "token": token, + }, + ) + hook._manager = _FakeManager(rank=1) + dataset = list(range(8)) + loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0) + + prepared = hook.prepare_dataloader(loader) + + assert isinstance(prepared, DataLoader) + assert isinstance(prepared.sampler, _CustomDistributedSampler) + assert prepared.sampler.shards == 4 + assert prepared.sampler.position == 2 + assert prepared.sampler.token is token + + def test_keeps_existing_custom_sampler(self) -> None: + token = object() + hook = DDPHook( + sampler_cls=_CustomDistributedSampler, + sampler_kwargs={ + "shards": 2, + "position": 1, + "token": token, + }, + ) + hook._manager = _FakeManager() + dataset = list(range(8)) + sampler = _CustomDistributedSampler( + dataset, + shards=2, + position=1, + token=token, + ) + loader = DataLoader(dataset, batch_size=2, sampler=sampler) + + prepared = hook.prepare_dataloader(loader) + + assert prepared is loader + assert prepared.sampler is sampler + + def test_keeps_existing_protocol_distributed_sampler(self) -> None: + hook = DDPHook() + hook._manager = _FakeManager(rank=1) + dataset = list(range(8)) + sampler = _TorchKeywordDistributedSampler( + dataset, + num_replicas=2, + rank=1, + ) + loader = DataLoader(dataset, batch_size=2, sampler=sampler) + + prepared = hook.prepare_dataloader(loader) + + assert prepared is loader + assert prepared.sampler is sampler + + def test_injects_defaults_into_protocol_compatible_sampler_cls(self) -> None: + hook = DDPHook( + sampler_cls=_TorchKeywordDistributedSampler, + sampler_kwargs={"seed": 23}, + ) + hook._manager = _FakeManager(rank=1) + dataset = list(range(8)) + loader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0) + + prepared = hook.prepare_dataloader(loader) + + assert isinstance(prepared.sampler, _TorchKeywordDistributedSampler) + assert prepared.sampler.num_replicas == 2 + assert prepared.sampler.rank == 1 + assert prepared.sampler.shuffle is False + assert prepared.sampler.seed == 23 + assert prepared.sampler.drop_last is False + + def test_keeps_existing_distributed_sampler(self) -> None: + hook = DDPHook() + hook._manager = _FakeManager() + dataset = list(range(8)) + sampler = DistributedSampler(dataset, num_replicas=2, rank=0) + loader = DataLoader(dataset, batch_size=2, sampler=sampler) + + prepared = hook.prepare_dataloader(loader) + + assert prepared is loader + assert prepared.sampler is sampler + + def test_rejects_custom_batch_sampler(self) -> None: + hook = DDPHook() + hook._manager = _FakeManager() + dataset = list(range(8)) + batch_sampler = BatchSampler( + SequentialSampler(dataset), + batch_size=2, + drop_last=False, + ) + loader = DataLoader(dataset, batch_sampler=batch_sampler) + + with pytest.raises(ValueError, match="batch_sampler"): + hook.prepare_dataloader(loader) + + def test_mutates_nvalchemi_datapipe_sampler(self) -> None: + from nvalchemi.data.datapipes.dataloader import DataLoader as NVCDataLoader + from nvalchemi.data.datapipes.dataset import Dataset + + hook = DDPHook() + hook._manager = _FakeManager(rank=1) + dataset = Dataset(_Reader(), device="cpu") + loader = NVCDataLoader(dataset, batch_size=2, use_streams=False) + + prepared = hook.prepare_dataloader(loader) + + assert prepared is loader + assert isinstance(loader.sampler, DistributedSampler) + assert loader.sampler.rank == 1 + + +def test_single_process_ddp_hook_is_noop(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(torch.nn.parallel, "DistributedDataParallel", _FakeDDP) + strategy = _make_strategy( + distributed_manager=_FakeManager(world_size=1), + hooks=[DDPHook()], + num_steps=1, + ) + + strategy.run([_build_batch()]) + + assert _FakeDDP.calls == [] + + +def test_torch_distributed_sampler_epoch_is_preserved() -> None: + hook = DDPHook() + hook._manager = _FakeManager() + dataset = _build_dataset(n_batches=4) + loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x[0]) + strategy = _make_strategy(num_steps=1) + + prepared = hook.prepare_dataloader(loader) + assert isinstance(prepared.sampler, DistributedSampler) + strategy._set_sampler_epoch(prepared) + + assert prepared.sampler.epoch == 0 + + +def test_reader_protocol_builds_atomic_data() -> None: + reader = _Reader() + sample = AtomicData(**reader._load_sample(0)) + assert sample.positions.shape == (1, 3) + + +@pytest.mark.skipif(not dist.is_gloo_available(), reason="gloo backend required") +def test_two_process_cpu_ddp_matches_single_process_baseline() -> None: + baseline = _make_strategy(num_steps=1) + baseline.run([_build_batch(n_systems=1, seed=5)]) + expected = _state_dict_cpu(baseline) + + ctx = torch.multiprocessing.get_context("spawn") + result_queue = ctx.Queue() + port = _free_port() + procs = [ + ctx.Process( + target=_run_ddp_worker, + args=(rank, 2, port, result_queue), + ) + for rank in range(2) + ] + for proc in procs: + proc.start() + for proc in procs: + proc.join(timeout=30) + for proc in procs: + assert proc.exitcode == 0 + + results: dict[int, dict[str, Any]] = {} + for _ in range(2): + rank, state = result_queue.get(timeout=5) + results[rank] = state + + assert set(results) == {0, 1} + for state in results.values(): + for key, expected_value in expected.items(): + actual = torch.as_tensor(state[key], dtype=expected_value.dtype) + assert torch.allclose(actual, expected_value, atol=1e-6, rtol=1e-6) + + try: + result_queue.close() + except (AttributeError, OSError, queue.Empty): + pass diff --git a/test/training/test_ema_hook.py b/test/training/test_ema_hook.py new file mode 100644 index 00000000..19edeaac --- /dev/null +++ b/test/training/test_ema_hook.py @@ -0,0 +1,1004 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for :class:`nvalchemi.training.hooks.EMAHook`.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, Mock + +import pytest +import torch +from pydantic import ValidationError +from torch import nn +from torch.optim.swa_utils import AveragedModel + +from nvalchemi.hooks._context import TrainContext +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training._validation import ValidationConfig +from nvalchemi.training.hooks import EMAHook, TrainingUpdateHook +from nvalchemi.training.strategy import TrainingStrategy +from test.training.conftest import _build_baseline_strategy_kwargs, _build_batch + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_linear( + in_f: int = 4, out_f: int = 4, *, seed: int | None = None +) -> nn.Linear: + if seed is not None: + torch.manual_seed(seed) + return nn.Linear(in_f, out_f) + + +def _make_ctx( + models: dict[str, nn.Module], + step_count: int, + *, + optimizers: list[Any] | None = None, +) -> Mock: + return Mock( + spec=TrainContext, + models=models, + step_count=step_count, + optimizers=optimizers if optimizers is not None else [], + loss=None, + ) + + +def _params_equal(a: nn.Module, b: nn.Module) -> bool: + pa = list(a.parameters()) + pb = list(b.parameters()) + if len(pa) != len(pb): + return False + return all(torch.equal(x, y) for x, y in zip(pa, pb, strict=True)) + + +def _clone_state(model: nn.Module) -> dict[str, torch.Tensor]: + return {k: v.detach().clone() for k, v in model.state_dict().items()} + + +def _drive( + hook: EMAHook, + source: nn.Module, + *, + n_calls: int, + start_step_count: int = 0, +) -> None: + """Call ``hook`` ``n_calls`` times on ``AFTER_OPTIMIZER_STEP``. + + ``ctx.step_count`` runs from ``start_step_count`` to + ``start_step_count + n_calls - 1`` inclusive. + """ + for s in range(start_step_count, start_step_count + n_calls): + ctx = _make_ctx({"main": source}, step_count=s) + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + +def _initialized_hook_and_state( + *, + seed: int = 0, + decay: float = 0.5, +) -> tuple[nn.Module, EMAHook, dict[str, Any]]: + source = _make_linear(seed=seed) + hook = EMAHook(model_key="main", decay=decay) + ctx = _make_ctx({"main": source}, step_count=0) + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + return source, hook, hook.state_dict() + + +class _VetoFirstOptimizerStepHook(TrainingUpdateHook): + """Veto the first optimizer step, then allow later steps.""" + + priority = 10 + + def __init__(self) -> None: + self.optimizer_step_calls = 0 + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor | None]: + if stage is TrainingStage.DO_OPTIMIZER_STEP: + self.optimizer_step_calls += 1 + return self.optimizer_step_calls > 1, ctx.loss + return True, ctx.loss + + +class _CudaBufferResetOnDeepcopy(nn.Module): + """Exercise EMA repair for modules whose deepcopy loses buffer placement. + + ``AveragedModel`` constructs EMA state by deep-copying the source + ``nn.Module``. Some generated or monkey-patched modules can reconstruct + registered buffers on CPU during that copy even when the live training + module is on CUDA. This fixture creates that failure mode directly so + EMA tests verify device repair against a real module copy, not a bare + tensor dictionary. + """ + + def __init__( + self, + parameter_device: torch.device, + buffer_device: torch.device, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones((), device=parameter_device)) + self.register_buffer("constant", torch.ones((), device=buffer_device)) + + def __deepcopy__(self, memo: dict[int, Any]) -> _CudaBufferResetOnDeepcopy: + clone = type(self)(self.weight.device, torch.device("cpu")) + with torch.no_grad(): + clone.weight.copy_(self.weight) + memo[id(self)] = clone + return clone + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return a tensor that requires parameter and buffer devices to match.""" + return x * self.weight + self.constant + + +class _CudaBufferOnlyResetOnDeepcopy(nn.Module): + """Exercise EMA repair for modules with only registered buffers. + + Not every valid ``nn.Module`` has trainable parameters; some wrappers, + lookup tables, normalizers, or generated helper modules carry their + device-sensitive state entirely in buffers. This fixture makes the + deepcopy path reset that buffer to CPU so tests verify EMA device repair + does not depend on finding a parameter first. + """ + + def __init__(self, buffer_device: torch.device) -> None: + super().__init__() + self.register_buffer("constant", torch.ones((), device=buffer_device)) + + def __deepcopy__(self, memo: dict[int, Any]) -> _CudaBufferOnlyResetOnDeepcopy: + clone = type(self)(torch.device("cpu")) + memo[id(self)] = clone + return clone + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return a tensor that requires the buffer to follow the input device.""" + return x + self.constant + + +class _MixedDeviceBufferOnDeepcopy(nn.Module): + """Exercise EMA preservation of intentional mixed-device placement. + + The EMA copy should follow each corresponding source tensor, not collapse + the whole module onto the first parameter's device. This fixture keeps a + CUDA parameter beside a CPU buffer to guard monkey-patched or third-party + modules that intentionally store side tables on host while computing with + device parameters. + """ + + def __init__( + self, + parameter_device: torch.device, + buffer_device: torch.device, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones((), device=parameter_device)) + self.register_buffer("cpu_table", torch.ones((), device=buffer_device)) + + def __deepcopy__(self, memo: dict[int, Any]) -> _MixedDeviceBufferOnDeepcopy: + clone = type(self)(self.weight.device, torch.device("cpu")) + with torch.no_grad(): + clone.weight.copy_(self.weight) + memo[id(self)] = clone + return clone + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return a CUDA result while leaving the CPU table as side state.""" + return x * self.weight + + +class _Float64ResetOnDeepcopy(nn.Module): + """Exercise EMA repair for modules whose deepcopy resets floating dtype.""" + + def __init__(self, dtype: torch.dtype) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones((), dtype=dtype)) + self.register_buffer("constant", torch.ones((), dtype=dtype)) + + def __deepcopy__(self, memo: dict[int, Any]) -> _Float64ResetOnDeepcopy: + clone = type(self)(torch.float64) + with torch.no_grad(): + clone.weight.copy_(self.weight) + clone.constant.copy_(self.constant) + memo[id(self)] = clone + return clone + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return a tensor that requires parameter and buffer dtypes to match.""" + return x * self.weight + self.constant + + +def _cpu_averaged_state(state: dict[str, Any]) -> dict[str, Any]: + averaged_state = { + key: value.cpu() if torch.is_tensor(value) else value + for key, value in state["averaged_model_state"].items() + } + return {**state, "averaged_model_state": averaged_state} + + +# --------------------------------------------------------------------------- +# Construction & validation +# --------------------------------------------------------------------------- + + +class TestEMAHookConstruction: + def test_defaults(self) -> None: + hook = EMAHook() + assert hook.model_key == "main" + assert hook.decay == pytest.approx(0.999) + assert hook.update_every == 1 + assert hook.start_step == 0 + assert hook.use_buffers is True + assert hook.num_updates == 0 + assert EMAHook.priority == 50 + assert isinstance(hook, TrainingUpdateHook) + assert hook._averaged_model is None + assert hook._pending_averaged_state is None + + @pytest.mark.parametrize( + ("kwargs", "field"), + [ + pytest.param({"decay": 1.0}, "decay", id="decay_eq_1_rejected"), + pytest.param({"decay": -0.1}, "decay", id="decay_negative_rejected"), + pytest.param( + {"update_every": 0}, "update_every", id="update_every_zero_rejected" + ), + pytest.param( + {"update_every": -1}, + "update_every", + id="update_every_negative_rejected", + ), + pytest.param( + {"start_step": -1}, "start_step", id="start_step_negative_rejected" + ), + pytest.param({"model_key": ""}, "model_key", id="model_key_empty_rejected"), + pytest.param( + {"model_key": " "}, "model_key", id="model_key_whitespace_rejected" + ), + pytest.param( + {"num_updates": -1}, "num_updates", id="num_updates_negative_rejected" + ), + ], + ) + def test_invalid_field_values_raise( + self, kwargs: dict[str, Any], field: str + ) -> None: + with pytest.raises(ValidationError) as excinfo: + EMAHook(**kwargs) + # Confirm the error points at the offending field. + assert any(field in err["loc"] for err in excinfo.value.errors()) + + def test_extra_kwargs_rejected(self) -> None: + with pytest.raises(ValidationError): + EMAHook(decya=0.9) + + +class TestEMAHookBuildOverride: + """The ``_build_averaged_model`` seam lets a caller inject a copy.""" + + def test_default_build_deepcopies_source(self) -> None: + source = _make_linear(seed=0) + hook = EMAHook(model_key="main", decay=0.5) + hook(_make_ctx({"main": source}, step_count=0), TrainingStage.SETUP) + averaged = hook.get_averaged_model() + # A fresh deepcopy: distinct object, weights mirrored from source. + assert averaged.module is not source + assert _params_equal(averaged.module, source) + + def test_override_adopts_prebuilt_without_deepcopy(self) -> None: + source = _make_linear(seed=0) + # A pre-built averaged model with deliberately different weights so + # an accidental deepcopy of ``source`` would be detectable. + prebuilt = AveragedModel( + _make_linear(seed=1), multi_avg_fn=None, use_buffers=True + ) + + class _InjectedEMAHook(EMAHook): + def _build_averaged_model(self, src: nn.Module) -> AveragedModel: + return prebuilt + + hook = _InjectedEMAHook(model_key="main", decay=0.5) + hook(_make_ctx({"main": source}, step_count=0), TrainingStage.SETUP) + # The hook adopted the injected model verbatim — no deepcopy. + assert hook.get_averaged_model() is prebuilt + + +# --------------------------------------------------------------------------- +# Single-model update behavior +# --------------------------------------------------------------------------- + + +class TestEMAHookSingleModelUpdate: + def setup_method(self) -> None: + self.source = _make_linear(seed=0) + self.source_snapshot = _clone_state(self.source) + + def test_single_call_initializes_and_increments(self) -> None: + hook = EMAHook(model_key="main", decay=0.5) + ctx = _make_ctx({"main": self.source}, step_count=0) + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + assert hook.num_updates == 1 + assert hook._averaged_model is not None + # Source model untouched (hook is observer-only). + for k, v in self.source.state_dict().items(): + assert torch.equal(v, self.source_snapshot[k]) + + def test_decay_zero_matches_source_after_one_update(self) -> None: + hook = EMAHook(model_key="main", decay=0.0) + ctx = _make_ctx({"main": self.source}, step_count=0) + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + averaged = hook.get_averaged_model().module + for (n, p_src), p_avg in zip( + self.source.named_parameters(), + averaged.parameters(), + strict=True, + ): + torch.testing.assert_close(p_src, p_avg, msg=f"param {n} differs") + + def test_no_storage_sharing_with_source(self) -> None: + """Mutating source after init must not change averaged params.""" + hook = EMAHook(model_key="main", decay=0.0) + ctx = _make_ctx({"main": self.source}, step_count=0) + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + averaged = hook.get_averaged_model().module + averaged_snapshot = _clone_state(averaged) + for p_src, p_avg in zip( + self.source.parameters(), averaged.parameters(), strict=True + ): + assert id(p_src) != id(p_avg) + assert p_src.data_ptr() != p_avg.data_ptr() + with torch.no_grad(): + for p in self.source.parameters(): + p.add_(100.0) + for k, v in averaged.state_dict().items(): + assert torch.equal(v, averaged_snapshot[k]) + + def test_setup_initializes_without_update(self) -> None: + hook = EMAHook(model_key="main") + ctx = _make_ctx({"main": self.source}, step_count=0) + + hook(ctx, TrainingStage.SETUP) + + assert hook.num_updates == 0 + assert hook._averaged_model is not None + + def test_non_update_stages_after_setup_do_not_update(self) -> None: + hook = EMAHook(model_key="main") + ctx = _make_ctx({"main": self.source}, step_count=0) + hook(ctx, TrainingStage.SETUP) + for stage in TrainingStage: + if stage in (TrainingStage.SETUP, TrainingStage.AFTER_OPTIMIZER_STEP): + continue + hook(ctx, stage) + assert hook.num_updates == 0 + + def test_get_averaged_model_before_init_raises(self) -> None: + hook = EMAHook(model_key="main") + with pytest.raises(RuntimeError, match="has not initialized"): + hook.get_averaged_model() + + +# --------------------------------------------------------------------------- +# model_key selection across multiple models +# --------------------------------------------------------------------------- + + +class TestEMAHookModelKeySelection: + def setup_method(self) -> None: + # Different shapes so we can assert structural identity. + self.model_a = _make_linear(in_f=4, out_f=4, seed=0) + self.model_b = _make_linear(in_f=4, out_f=8, seed=1) + self.snapshot_a = _clone_state(self.model_a) + self.snapshot_b = _clone_state(self.model_b) + + def test_selects_only_intended_model(self) -> None: + hook = EMAHook(model_key="ema_target", decay=0.0) + ctx = _make_ctx( + {"main": self.model_a, "ema_target": self.model_b}, + step_count=0, + ) + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + for k, v in self.model_a.state_dict().items(): + assert torch.equal(v, self.snapshot_a[k]) + for k, v in self.model_b.state_dict().items(): + assert torch.equal(v, self.snapshot_b[k]) + + averaged = hook.get_averaged_model().module + assert averaged.weight.shape == self.model_b.weight.shape + torch.testing.assert_close(averaged.weight, self.model_b.weight) + torch.testing.assert_close(averaged.bias, self.model_b.bias) + + def test_unmatched_models_untouched(self) -> None: + hook = EMAHook(model_key="ema_target", decay=0.5) + ctx = _make_ctx( + {"main": self.model_a, "ema_target": self.model_b}, + step_count=0, + ) + a_param_ids_before = {id(p) for p in self.model_a.parameters()} + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + a_param_ids_after = {id(p) for p in self.model_a.parameters()} + assert a_param_ids_before == a_param_ids_after + for k, v in self.model_a.state_dict().items(): + assert torch.equal(v, self.snapshot_a[k]) + + def test_two_hooks_average_independently(self) -> None: + hook1 = EMAHook(model_key="m1", decay=0.0) + hook2 = EMAHook(model_key="m2", decay=0.0) + ctx = _make_ctx({"m1": self.model_a, "m2": self.model_b}, step_count=0) + hook1(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + hook2(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + avg1 = hook1.get_averaged_model() + avg2 = hook2.get_averaged_model() + assert avg1 is not avg2 + assert avg1.module.weight.shape == self.model_a.weight.shape + assert avg2.module.weight.shape == self.model_b.weight.shape + ids1 = {p.data_ptr() for p in avg1.parameters()} + ids2 = {p.data_ptr() for p in avg2.parameters()} + assert ids1.isdisjoint(ids2) + assert hook1.num_updates == 1 + assert hook2.num_updates == 1 + + def test_missing_model_key_raises_keyerror(self) -> None: + hook = EMAHook(model_key="ghost") + ctx = _make_ctx({"main": self.model_a}, step_count=0) + with pytest.raises(KeyError) as excinfo: + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + msg = str(excinfo.value) + assert "'ghost'" in msg + assert "['main']" in msg + + +# --------------------------------------------------------------------------- +# Step filtering: update_every and start_step +# --------------------------------------------------------------------------- + + +class TestEMAHookStepFiltering: + def setup_method(self) -> None: + self.source = _make_linear(seed=0) + + def test_update_every_skips_intermediate_steps(self) -> None: + # step_count=0..6 => completed=1..7; multiples of 3 are 3, 6 => 2 updates. + hook = EMAHook(model_key="main", update_every=3) + _drive(hook, self.source, n_calls=7) + assert hook.num_updates == 2 + + def test_update_every_one_fires_every_step(self) -> None: + hook = EMAHook(model_key="main", update_every=1) + _drive(hook, self.source, n_calls=5) + assert hook.num_updates == 5 + + def test_start_step_delays_first_update(self) -> None: + hook = EMAHook(model_key="main", start_step=5, update_every=1) + # step_count=0..3 => completed=1..4 < 5: no-op. + _drive(hook, self.source, n_calls=4) + assert hook.num_updates == 0 + assert hook._averaged_model is None + # step_count=4 => completed=5: first update fires. + _drive(hook, self.source, n_calls=1, start_step_count=4) + assert hook.num_updates == 1 + assert hook._averaged_model is not None + # step_count=5..9 => completed=6..10: 5 more updates, total 6. + _drive(hook, self.source, n_calls=5, start_step_count=5) + assert hook.num_updates == 6 + + def test_global_modulo_with_start_step_and_update_every(self) -> None: + """``update_every`` is a *global* modulo on completed_step, not relative to start_step.""" + hook = EMAHook(model_key="main", start_step=5, update_every=10) + # completed=1..15: only completed=10 is eligible. + _drive(hook, self.source, n_calls=15) + assert hook.num_updates == 1 + # completed=16..20: completed=20 is the next eligible step. + _drive(hook, self.source, n_calls=5, start_step_count=15) + assert hook.num_updates == 2 + + +# --------------------------------------------------------------------------- +# No mutation of grads / optimizer / scaler +# --------------------------------------------------------------------------- + + +class TestEMAHookSideEffects: + def test_gradients_and_optimizer_state_untouched(self) -> None: + source = _make_linear(seed=0) + x = torch.randn(2, 4) + target = torch.randn(2, 4) + loss = ((source(x) - target) ** 2).mean() + loss.backward() + grad_snapshots = { + n: p.grad.detach().clone() for n, p in source.named_parameters() + } + + optimizer_mock = MagicMock(spec=torch.optim.Optimizer) + hook = EMAHook(model_key="main") + ctx = _make_ctx({"main": source}, step_count=0, optimizers=[optimizer_mock]) + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + for n, p in source.named_parameters(): + torch.testing.assert_close(p.grad, grad_snapshots[n]) + # Optimizer mock was never called or method-accessed in any way. + assert optimizer_mock.method_calls == [] + assert optimizer_mock.mock_calls == [] + + def test_amp_autocast_smoke(self) -> None: + """EMAHook runs without error under torch.amp.autocast (no AMP-API coupling).""" + source = _make_linear(seed=0) + hook = EMAHook(model_key="main", decay=0.5) + ctx = _make_ctx({"main": source}, step_count=0) + + with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16): + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + assert hook.num_updates == 1 + + def test_skipped_optimizer_step_does_not_update_ema(self) -> None: + source = _make_linear(seed=0) + hook = EMAHook(model_key="main", decay=0.5) + ctx = _make_ctx({"main": source}, step_count=0) + + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP, will_skip=True) + + assert hook.num_updates == 0 + assert hook._averaged_model is None + + def test_averaged_copy_follows_source_floating_dtypes(self) -> None: + source = _Float64ResetOnDeepcopy(torch.float32) + hook = EMAHook(model_key="main", decay=0.5) + + hook( + _make_ctx({"main": source}, step_count=0), + TrainingStage.AFTER_OPTIMIZER_STEP, + ) + with torch.no_grad(): + source.weight.fill_(3.0) + source.constant.fill_(5.0) + hook( + _make_ctx({"main": source}, step_count=1), + TrainingStage.AFTER_OPTIMIZER_STEP, + ) + + averaged = hook.get_averaged_model().module + assert averaged.weight.dtype is torch.float32 + assert averaged.constant.dtype is torch.float32 + out = averaged(torch.ones((), dtype=torch.float32)) + assert out.dtype is torch.float32 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_averaged_copy_and_state_restore_follow_source_tensor_devices( + self, + ) -> None: + device = torch.device("cuda:0") + source = _CudaBufferResetOnDeepcopy(device, device) + + # First prove the lazy AveragedModel construction path repairs the + # deepcopy artifact: the source buffer is CUDA, but this test module's + # __deepcopy__ reconstructs the averaged buffer on CPU. + hook = EMAHook(model_key="main", decay=0.0) + ctx = _make_ctx({"main": source}, step_count=0) + ctx.workflow = object() + + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + averaged = hook.get_averaged_model().module + assert averaged.constant.device == device + out = averaged(torch.ones((), device=device)) + torch.testing.assert_close(out, torch.tensor(2.0, device=device)) + + # Simulate a checkpoint loaded on CPU before EMA has seen the live + # training model. load_state_dict must stash this as pending state, + # then first EMA update must build the averaged model and reapply the + # source tensor devices after loading that CPU state. + cpu_state = _cpu_averaged_state(hook.state_dict()) + restored = EMAHook(model_key="main", decay=0.0) + restored.load_state_dict(cpu_state) + restored_ctx = _make_ctx({"main": source}, step_count=1) + restored_ctx.workflow = object() + + restored(restored_ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + restored_averaged = restored.get_averaged_model().module + assert restored_averaged.constant.device == device + restored_out = restored_averaged(torch.ones((), device=device)) + torch.testing.assert_close(restored_out, torch.tensor(2.0, device=device)) + + # Also cover the already-initialized restore path. This is the branch + # used when an EMA hook has a live AveragedModel and then receives a + # checkpoint state whose tensors were materialized on CPU. + initialized = EMAHook(model_key="main", decay=0.0) + initialized_ctx = _make_ctx({"main": source}, step_count=0) + initialized_ctx.workflow = object() + initialized(initialized_ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + initialized.load_state_dict(cpu_state) + + initialized_averaged = initialized.get_averaged_model().module + assert initialized_averaged.constant.device == device + initialized_out = initialized_averaged(torch.ones((), device=device)) + torch.testing.assert_close(initialized_out, torch.tensor(2.0, device=device)) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_buffer_only_averaged_copy_follows_source_device(self) -> None: + device = torch.device("cuda:0") + source = _CudaBufferOnlyResetOnDeepcopy(device) + hook = EMAHook(model_key="main", decay=0.0) + ctx = _make_ctx({"main": source}, step_count=0) + ctx.workflow = object() + + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + averaged = hook.get_averaged_model().module + assert averaged.constant.device == device + out = averaged(torch.ones((), device=device)) + torch.testing.assert_close(out, torch.tensor(2.0, device=device)) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_mixed_device_averaged_copy_preserves_source_buffer_device( + self, + ) -> None: + device = torch.device("cuda:0") + source = _MixedDeviceBufferOnDeepcopy(device, torch.device("cpu")) + hook = EMAHook(model_key="main", decay=0.0) + ctx = _make_ctx({"main": source}, step_count=0) + ctx.workflow = object() + + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + averaged = hook.get_averaged_model().module + assert averaged.weight.device == device + assert averaged.cpu_table.device == torch.device("cpu") + out = averaged(torch.ones((), device=device)) + torch.testing.assert_close(out, torch.tensor(1.0, device=device)) + + +class TestEMAHookStrategyIntegration: + def test_strategy_autowrap_updates_after_successful_optimizer_steps(self) -> None: + ema = EMAHook(model_key="main", decay=0.0) + veto_first = _VetoFirstOptimizerStepHook() + strategy = TrainingStrategy( + **{ + **_build_baseline_strategy_kwargs(), + "num_epochs": None, + "num_steps": 1, + "hooks": [veto_first, ema], + } + ) + + strategy.run([_build_batch(seed=0), _build_batch(seed=10)]) + + assert strategy.batch_count == 2 + assert strategy.step_count == 1 + assert veto_first.optimizer_step_calls == 2 + assert ema.num_updates == 1 + averaged = ema.get_averaged_model().module + for source_param, averaged_param in zip( + strategy.models["main"].parameters(), + averaged.parameters(), + strict=True, + ): + torch.testing.assert_close(averaged_param, source_param) + + +# --------------------------------------------------------------------------- +# Checkpointing: state_dict / load_state_dict +# --------------------------------------------------------------------------- + + +class TestEMAHookCheckpoint: + def test_state_dict_contains_config_and_averaged_state(self) -> None: + source = _make_linear(seed=0) + hook = EMAHook(model_key="main", decay=0.5, update_every=2, start_step=1) + ctx = _make_ctx({"main": source}, step_count=1) # completed=2 + hook(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + state = hook.state_dict() + assert { + "model_key", + "decay", + "update_every", + "start_step", + "use_buffers", + "num_updates", + } <= state.keys() + assert state["num_updates"] == 1 + assert "averaged_model_state" in state + assert isinstance(state["averaged_model_state"], dict) + + def test_round_trip_num_updates_and_weights(self) -> None: + source_a, hook_a, _ = _initialized_hook_and_state(seed=0, decay=0.5) + # Run a second update with perturbed source so EMA != source. + with torch.no_grad(): + for p in source_a.parameters(): + p.add_(0.5) + _drive(hook_a, source_a, n_calls=1, start_step_count=1) + assert hook_a.num_updates == 2 + state = hook_a.state_dict() + + # Build B with same config, init via a call on a different source, then load. + source_b, hook_b, _ = _initialized_hook_and_state(seed=99, decay=0.5) + avg_a = hook_a.get_averaged_model().module + avg_b = hook_b.get_averaged_model().module + assert not _params_equal(avg_a, avg_b) + + hook_b.load_state_dict(state) + assert hook_b.num_updates == hook_a.num_updates + avg_b = hook_b.get_averaged_model().module + for k in avg_a.state_dict(): + torch.testing.assert_close(avg_b.state_dict()[k], avg_a.state_dict()[k]) + assert hook_b._pending_averaged_state is None + + def test_pending_state_applied_on_first_call(self) -> None: + """Pending weights must be loaded BEFORE the first update, not after.""" + decay = 0.5 + source_a, hook_a, state_a = _initialized_hook_and_state(seed=0, decay=decay) + loaded_pending = { + k: v.detach().clone() + for k, v in hook_a.get_averaged_model().module.state_dict().items() + } + + hook_b = EMAHook(model_key="main", decay=decay) + hook_b.load_state_dict(state_a) + assert hook_b._averaged_model is None + assert hook_b._pending_averaged_state is not None + + source_b = _make_linear(seed=99) + source_b_snapshot = _clone_state(source_b) + ctx_b = _make_ctx({"main": source_b}, step_count=10_000) + hook_b(ctx_b, TrainingStage.AFTER_OPTIMIZER_STEP) + + assert hook_b._averaged_model is not None + assert hook_b._pending_averaged_state is None + assert hook_b.num_updates == hook_a.num_updates + 1 + + # Verify avg = decay * pending + (1 - decay) * source_b on parameters. + # Buffers may use a different averaging rule when use_buffers=True. + averaged = hook_b.get_averaged_model().module + param_keys = {n for n, _ in averaged.named_parameters()} + avg_state = averaged.state_dict() + for key in param_keys: + expected = ( + decay * loaded_pending[key] + (1.0 - decay) * source_b_snapshot[key] + ) + torch.testing.assert_close( + avg_state[key], expected, msg=f"EMA formula mismatch on {key!r}" + ) + # If pending were ignored, the first AveragedModel update would copy + # source_b verbatim regardless of multi_avg_fn. + for key in param_keys: + assert not torch.equal(avg_state[key], source_b_snapshot[key]) + + def test_save_before_init_emits_pending_state(self) -> None: + _, hook_a, state_a = _initialized_hook_and_state(seed=0, decay=0.5) + + hook_b = EMAHook(model_key="main", decay=0.5) + hook_b.load_state_dict(state_a) + state_b = hook_b.state_dict() + assert "averaged_model_state" in state_b + # Verify by content, not identity. + emitted = state_b["averaged_model_state"] + original = state_a["averaged_model_state"] + assert emitted.keys() == original.keys() + for k in emitted: + torch.testing.assert_close(emitted[k], original[k]) + + def test_partial_load_preserves_num_updates(self) -> None: + hook = EMAHook(model_key="main", decay=0.999) + hook.num_updates = 5 + hook.load_state_dict({"decay": 0.999}) + assert hook.num_updates == 5 + + def test_load_clears_averaged_state_when_absent(self) -> None: + _, hook_a, state_a = _initialized_hook_and_state(seed=0, decay=0.5) + + pending_hook = EMAHook(model_key="main", decay=0.5) + pending_hook.load_state_dict(state_a) + assert pending_hook._pending_averaged_state is not None + + # Subsequent load that omits averaged_model_state should clear pending state. + pending_hook.load_state_dict({"decay": 0.5}) + assert pending_hook._averaged_model is None + assert pending_hook._pending_averaged_state is None + + _, initialized_hook, _ = _initialized_hook_and_state(seed=99, decay=0.5) + assert initialized_hook._averaged_model is not None + + initialized_hook.load_state_dict({"decay": 0.5}) + assert initialized_hook._averaged_model is None + assert initialized_hook._pending_averaged_state is None + + def test_config_conflict_raises_value_error_with_format(self) -> None: + hook = EMAHook(model_key="main", decay=0.999) + with pytest.raises(ValueError) as excinfo: + hook.load_state_dict({"decay": 0.9}) + msg = str(excinfo.value) + assert "EMAHook checkpoint conflict:" in msg + assert "decay=0.9" in msg + assert "constructor decay=0.999" in msg + assert "construct the hook with matching config" in msg + + def test_config_conflict_on_model_key(self) -> None: + hook = EMAHook(model_key="main") + with pytest.raises(ValueError, match="EMAHook checkpoint conflict: model_key="): + hook.load_state_dict({"model_key": "ema"}) + + def test_load_after_live_init_overwrites_weights(self) -> None: + _, hook_a, state_a = _initialized_hook_and_state(seed=0, decay=0.5) + _, hook_b, _ = _initialized_hook_and_state(seed=99, decay=0.5) + + avg_a = hook_a.get_averaged_model().module + avg_b = hook_b.get_averaged_model().module + assert not _params_equal(avg_a, avg_b) + + hook_b.load_state_dict(state_a) + avg_b = hook_b.get_averaged_model().module + for k in avg_a.state_dict(): + torch.testing.assert_close(avg_b.state_dict()[k], avg_a.state_dict()[k]) + assert hook_b._pending_averaged_state is None + + +# --------------------------------------------------------------------------- +# Inference-model write via set_inference_model (Phase C) +# --------------------------------------------------------------------------- + + +class TestInferenceModelWrite: + """EMAHook publishes averaged weights into the strategy inference_model slot.""" + + def test_single_model_publishes_bare_module(self) -> None: + """After eligible AFTER_OPTIMIZER_STEP, strategy.inference_model is a bare Module.""" + ema = EMAHook(model_key="main", decay=0.0) + strategy = TrainingStrategy( + **{ + **_build_baseline_strategy_kwargs(), + "num_epochs": None, + "num_steps": 1, + "hooks": [ema], + } + ) + assert strategy.inference_model is None + strategy.run([_build_batch(seed=0)]) + assert strategy.inference_model is not None + assert isinstance(strategy.inference_model, nn.Module) + assert not isinstance(strategy.inference_model, nn.ModuleDict) + averaged_module = ema.get_averaged_model().module + assert strategy.inference_model is averaged_module + + def test_two_hooks_produce_moduledict(self) -> None: + """Two EMA hooks with distinct model_keys produce an nn.ModuleDict.""" + model_a = _make_linear(in_f=4, out_f=4, seed=0) + model_b = _make_linear(in_f=4, out_f=4, seed=1) + + ema_a = EMAHook(model_key="m1", decay=0.0) + ema_b = EMAHook(model_key="m2", decay=0.0) + + # Use a lightweight workflow stub that has set_inference_model + # and single_model_input=False, avoiding full strategy construction. + class _WorkflowStub: + single_model_input = False + inference_model: nn.Module | nn.ModuleDict | None = None + + def set_inference_model( + self, module: nn.Module, *, model_key: str | None = None + ) -> None: + if model_key is None or self.single_model_input: + self.inference_model = module + return + if not isinstance(self.inference_model, nn.ModuleDict): + self.inference_model = nn.ModuleDict() + self.inference_model[model_key] = module + + workflow = _WorkflowStub() + ctx = Mock( + spec=TrainContext, + models={"m1": model_a, "m2": model_b}, + step_count=0, + optimizers=[], + loss=None, + workflow=workflow, + ) + ema_a(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + ema_b(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + assert isinstance(workflow.inference_model, nn.ModuleDict) + assert "m1" in workflow.inference_model + assert "m2" in workflow.inference_model + assert workflow.inference_model["m1"] is ema_a.get_averaged_model().module + assert workflow.inference_model["m2"] is ema_b.get_averaged_model().module + + def test_setup_publishes_before_start_step_without_update(self) -> None: + """SETUP publishes an initial EMA model while start_step still gates updates.""" + ema = EMAHook(model_key="main", decay=0.0, start_step=100) + strategy = TrainingStrategy( + **{ + **_build_baseline_strategy_kwargs(), + "num_epochs": None, + "num_steps": 1, + "hooks": [ema], + } + ) + assert strategy.inference_model is None + strategy.run([_build_batch(seed=0)]) + # SETUP initializes the inference model; start_step still prevents updates. + assert ema.num_updates == 0 + assert strategy.inference_model is ema.get_averaged_model().module + + def test_setup_materializes_pending_checkpoint_state(self) -> None: + """SETUP can publish restored EMA state before another train step.""" + source = _build_baseline_strategy_kwargs()["models"] + initialized = EMAHook(model_key="main", decay=0.0) + initialized( + _make_ctx({"main": source}, step_count=0), + TrainingStage.AFTER_OPTIMIZER_STEP, + ) + state = initialized.state_dict() + restored = EMAHook(model_key="main", decay=0.0) + restored.load_state_dict(state) + + strategy = TrainingStrategy( + **{ + **_build_baseline_strategy_kwargs(), + "num_epochs": None, + "num_steps": 1, + "hooks": [restored], + } + ) + strategy.validation_config = ValidationConfig( + validation_data=[_build_batch(seed=1)], + use_ema="always", + ) + + assert strategy.inference_model is None + assert restored._averaged_model is None + assert restored._pending_averaged_state is not None + + # Simulate a restored strategy that has already reached its target; + # run() still executes SETUP hooks, then returns before another train step. + strategy.step_count = 1 + strategy.run([_build_batch(seed=1)]) + + assert strategy.inference_model is restored.get_averaged_model().module + assert restored._pending_averaged_state is None + + summary = strategy.validate() + + assert summary is not None + assert summary["model_source"] == "ema" + assert restored.num_updates == initialized.num_updates + assert strategy.inference_model is restored.get_averaged_model().module + + def test_no_crash_without_set_inference_model(self) -> None: + """EMAHook works when workflow lacks set_inference_model (defensive guard).""" + ema = EMAHook(model_key="main", decay=0.0) + source = _make_linear(seed=0) + ctx = Mock( + spec=TrainContext, + models={"main": source}, + step_count=0, + optimizers=[], + loss=None, + ) + # workflow with no set_inference_model attribute + ctx.workflow = object() + ema(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + assert ema.num_updates == 1 diff --git a/test/training/test_loss_schedules.py b/test/training/test_loss_schedules.py new file mode 100644 index 00000000..2993eae3 --- /dev/null +++ b/test/training/test_loss_schedules.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for :mod:`nvalchemi.training.losses.schedules`.""" + +from __future__ import annotations + +import json +from typing import Any + +import pytest +from pydantic import ValidationError + +from nvalchemi.training import ( + ConstantWeight, + CosineWeight, + LinearWeight, + LossWeightSchedule, + PiecewiseWeight, + create_model_spec, + create_model_spec_from_json, +) + + +class TestSchedules: + """Tests for the weight schedules, protocol, and BaseSpec round-trip.""" + + def test_protocol_runtime_check(self) -> None: + w = ConstantWeight(value=1.0) + assert isinstance(w, LossWeightSchedule) + assert w.per_epoch is False + + def test_constant_weight(self) -> None: + w = ConstantWeight(value=2.5) + assert w(0, 0) == 2.5 + assert w(100, 3) == 2.5 + assert w(100_000, 99) == 2.5 + + @pytest.mark.parametrize("cls", [LinearWeight, CosineWeight]) + def test_ramp_endpoints_and_clamp( + self, cls: type[LinearWeight | CosineWeight] + ) -> None: + w = cls(start=0.0, end=1.0, num_steps=10) + assert w(0, 0) == 0.0 + assert abs(w(10, 0) - 1.0) < 1e-6 + assert w(100, 0) == 1.0 + assert w(-5, 0) == 0.0 + + def test_linear_midpoint(self) -> None: + w = LinearWeight(start=0.0, end=1.0, num_steps=10) + assert abs(w(5, 0) - 0.5) < 1e-6 + + def test_cosine_midpoint(self) -> None: + w = CosineWeight(start=0.0, end=1.0, num_steps=10) + assert abs(w(5, 0) - 0.5) < 1e-6 + + @pytest.mark.parametrize("cls", [LinearWeight, CosineWeight]) + def test_per_epoch_ramps_use_epoch_counter( + self, cls: type[LinearWeight | CosineWeight] + ) -> None: + w = cls(start=0.0, end=1.0, num_steps=10, per_epoch=True) + assert w.per_epoch is True + assert w(step=10, epoch=0) == 0.0 + assert abs(w(step=0, epoch=5) - 0.5) < 1e-6 + assert w(step=0, epoch=10) == 1.0 + + @pytest.mark.parametrize( + "boundaries,values,step,expected", + [ + ((100,), (0.1, 0.9), 0, 0.1), + ((100,), (0.1, 0.9), 99, 0.1), + ((100,), (0.1, 0.9), 100, 0.9), + ((100,), (0.1, 0.9), 500, 0.9), + ((10, 20, 30), (0.0, 0.25, 0.5, 1.0), 5, 0.0), + ((10, 20, 30), (0.0, 0.25, 0.5, 1.0), 10, 0.25), + ((10, 20, 30), (0.0, 0.25, 0.5, 1.0), 20, 0.5), + ((10, 20, 30), (0.0, 0.25, 0.5, 1.0), 30, 1.0), + ], + ) + def test_piecewise_weight( + self, + boundaries: tuple[int, ...], + values: tuple[float, ...], + step: int, + expected: float, + ) -> None: + w = PiecewiseWeight(boundaries=boundaries, values=values) + assert w(step, 0) == expected + + def test_per_epoch_piecewise_uses_epoch_counter(self) -> None: + w = PiecewiseWeight( + boundaries=(2,), + values=(0.0, 1.0), + per_epoch=True, + ) + assert w(step=100, epoch=1) == 0.0 + assert w(step=0, epoch=2) == 1.0 + + @pytest.mark.parametrize( + "cls,kwargs", + [ + (LinearWeight, {"start": 0.0, "end": 1.0, "num_steps": 0}), + (LinearWeight, {"start": 0.0, "end": 1.0, "num_steps": -3}), + (CosineWeight, {"start": 0.0, "end": 1.0, "num_steps": 0}), + ( + PiecewiseWeight, + {"boundaries": (10, 20), "values": (0.1, 0.5)}, + ), + ( + PiecewiseWeight, + {"boundaries": (10, 5), "values": (0.1, 0.5, 0.9)}, + ), + (PiecewiseWeight, {"boundaries": (-1,), "values": (0.1, 0.5)}), + ], + ) + def test_schedule_validators_reject_bad_input( + self, cls: type, kwargs: dict[str, Any] + ) -> None: + with pytest.raises(ValidationError): + cls(**kwargs) + + def test_schedule_frozen(self) -> None: + w = ConstantWeight(value=1.0) + with pytest.raises(ValidationError): + w.value = 2.0 # type: ignore[misc] + + def test_piecewise_hashable(self) -> None: + w = PiecewiseWeight(boundaries=(10, 20), values=(0.1, 0.5, 0.9)) + assert hash(w) == hash(w) + + @pytest.mark.parametrize( + "cls,kwargs", + [ + (ConstantWeight, {"value": 0.5}), + ( + LinearWeight, + {"start": 0.1, "end": 0.9, "num_steps": 100, "per_epoch": True}, + ), + (CosineWeight, {"start": 1.0, "end": 0.0, "num_steps": 50}), + ( + PiecewiseWeight, + {"boundaries": (10, 20), "values": (0.1, 0.5, 0.9)}, + ), + ], + ) + def test_schedule_basespec_roundtrip( + self, cls: type, kwargs: dict[str, Any] + ) -> None: + spec = create_model_spec(cls, **kwargs) + dumped = spec.model_dump_json() + rebuilt_spec = create_model_spec_from_json(json.loads(dumped)) + built = rebuilt_spec.build() + assert isinstance(built, cls) + for k, v in kwargs.items(): + assert getattr(built, k) == v + assert isinstance(built(5, 0), float) diff --git a/test/training/test_losses.py b/test/training/test_losses.py new file mode 100644 index 00000000..57837690 --- /dev/null +++ b/test/training/test_losses.py @@ -0,0 +1,2615 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +from torch import nn + +from nvalchemi.training import ( + BaseLossFunction, + ComposedLossFunction, + ComposedLossOutput, + ConstantWeight, + EnergyHuberLoss, + EnergyMAELoss, + EnergyMSELoss, + ForceHuberLoss, + ForceL2NormLoss, + ForceMSELoss, + LinearWeight, + StressHuberLoss, + StressMSELoss, + loss_component_to_spec, +) +from nvalchemi.training._spec import create_model_spec, create_model_spec_from_json +from nvalchemi.training.losses import ( + assert_same_shape, + frobenius_mse, + per_graph_mean, + per_graph_sum, +) +from nvalchemi.training.losses.terms import _huber_loss + + +class _ToyLoss(BaseLossFunction): + # Concrete subclass returning a constant tensor — used in composition tests. + + def __init__(self, value: float = 1.0) -> None: + super().__init__() + self.value = float(value) + self.prediction_key = "prediction" + self.target_key = "target" + + def forward( + self, + pred: torch.Tensor, # noqa: ARG002 + target: torch.Tensor, # noqa: ARG002 + **kwargs: Any, # noqa: ARG002 + ) -> torch.Tensor: + return torch.tensor(self.value) + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, # noqa: ARG002 + ) -> torch.Tensor: + return pred - target + + +class _PositionsLoss(BaseLossFunction): + # Toy loss whose ``forward`` sums ``pred`` (gradient-bearing). + + def __init__(self, scale: float = 1.0) -> None: + super().__init__() + self.scale = float(scale) + self.prediction_key = "positions" + self.target_key = "positions" + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, # noqa: ARG002 + **kwargs: Any, # noqa: ARG002 + ) -> torch.Tensor: + return self.scale * pred.sum() + + def compute_residual( + self, + pred: torch.Tensor, + target: torch.Tensor, + valid: torch.Tensor, # noqa: ARG002 + ) -> torch.Tensor: + return pred - target + + +class _ReturnSchedule: + # Schedule whose ``__call__`` returns a configurable value. + + per_epoch: bool = False + + def __init__(self, value: Any) -> None: + self.value = value + + def __call__(self, step: int, epoch: int) -> Any: # noqa: ARG002 + return self.value + + +def _dummy_loss_tensors() -> tuple[torch.Tensor, torch.Tensor]: + return torch.tensor(0.0), torch.tensor(0.0) + + +def _dummy_loss_mappings() -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + pred, target = _dummy_loss_tensors() + return {"prediction": pred}, {"target": target} + + +def _full_loss_batch() -> SimpleNamespace: + # Standard 3-graph layout covering energy + forces + stress. + num_graphs = 3 + num_nodes = 6 + return SimpleNamespace( + batch_idx=torch.tensor([0, 0, 1, 1, 1, 2], dtype=torch.int32), + num_graphs=num_graphs, + num_nodes_per_graph=torch.tensor([2, 3, 1], dtype=torch.long), + energy=torch.tensor([[1.0], [2.0], [3.0]]), + predicted_energy=torch.tensor([[1.5], [2.5], [3.5]]), + forces=torch.zeros(num_nodes, 3), + predicted_forces=torch.ones(num_nodes, 3), + stress=torch.zeros(num_graphs, 3, 3), + predicted_stress=torch.ones(num_graphs, 3, 3), + ) + + +def _loss_metadata(batch: SimpleNamespace) -> dict[str, Any]: + return { + name: getattr(batch, name) + for name in ("batch_idx", "num_graphs", "num_nodes_per_graph") + if hasattr(batch, name) + } + + +def _tensor_mapping(batch: SimpleNamespace) -> dict[str, torch.Tensor]: + return { + name: value + for name, value in vars(batch).items() + if isinstance(value, torch.Tensor) + } + + +def _call_from_batch( + loss: BaseLossFunction | ComposedLossFunction, + batch: SimpleNamespace, + **metadata: Any, +) -> ComposedLossOutput: + composed = ( + loss + if isinstance(loss, ComposedLossFunction) + else ComposedLossFunction(components=(loss,)) + ) + tensors = _tensor_mapping(batch) + return composed(tensors, tensors, **(_loss_metadata(batch) | metadata)) + + +class TestReductions: + def setup_method(self) -> None: + # 3 graphs with 2, 3, 1 atoms respectively. + self.batch_idx = torch.tensor([0, 0, 1, 1, 1, 2], dtype=torch.int32) + + def test_per_graph_sum_matches_manual(self) -> None: + vals = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + got = per_graph_sum(vals, self.batch_idx) + assert torch.allclose(got, torch.tensor([3.0, 12.0, 6.0])) + + def test_per_graph_sum_preserves_shape(self) -> None: + vals = torch.randn(6, 3, requires_grad=True) + got = per_graph_sum(vals, self.batch_idx) + assert got.shape == (3, 3) + assert got.grad_fn is not None + + def test_per_graph_sum_explicit_num_graphs_keeps_trailing_empty(self) -> None: + vals = torch.tensor([1.0, 2.0]) + batch_idx = torch.tensor([0, 0], dtype=torch.int32) + got = per_graph_sum(vals, batch_idx, num_graphs=3) + assert torch.allclose(got, torch.tensor([3.0, 0.0, 0.0])) + + def test_per_graph_mean_matches_manual(self) -> None: + vals = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + got = per_graph_mean(vals, self.batch_idx) + assert torch.allclose(got, torch.tensor([1.5, 4.0, 6.0])) + + def test_frobenius_mse_matches_manual(self) -> None: + pred = torch.tensor( + [ + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + [[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]], + ] + ) + target = torch.zeros(2, 3, 3) + # Identity * k contributes 3*k^2 nonzero entries; mean over 9. + got = frobenius_mse(pred, target) + expected = torch.tensor([3.0 / 9.0, 12.0 / 9.0]) + assert torch.allclose(got, expected) + + def test_frobenius_mse_preserves_grad(self) -> None: + pred = torch.randn(2, 3, 3, requires_grad=True) + target = torch.randn(2, 3, 3) + got = frobenius_mse(pred, target) + assert got.grad_fn is not None + got.sum().backward() + assert pred.grad is not None + + def test_per_graph_sum_bad_num_graphs(self) -> None: + with pytest.raises(ValueError, match="num_graphs must be positive"): + per_graph_sum(torch.zeros(3), torch.zeros(3, dtype=torch.int32), 0) + + +class TestReductionsCompile: + @staticmethod + def _compile_kwargs(device: str) -> dict[str, Any]: + kwargs: dict[str, Any] = {"fullgraph": True} + if device == "cuda": + kwargs["backend"] = "cudagraphs" + return kwargs + + @staticmethod + def _batch_idx(device: str) -> torch.Tensor: + return torch.tensor([0, 0, 1, 1, 1, 2], dtype=torch.int32, device=device) + + @pytest.mark.parametrize( + ("fn", "args_factory"), + [ + pytest.param( + per_graph_sum, + lambda device, batch_idx: ( + torch.arange(18, dtype=torch.float32, device=device).reshape(6, 3), + batch_idx, + 3, + ), + id="per_graph_sum", + ), + pytest.param( + per_graph_mean, + lambda device, batch_idx: ( + torch.arange(18, dtype=torch.float32, device=device).reshape(6, 3), + batch_idx, + 3, + ), + id="per_graph_mean", + ), + pytest.param( + frobenius_mse, + lambda device, batch_idx: ( + torch.arange(18, dtype=torch.float32, device=device).reshape( + 2, 3, 3 + ), + torch.arange(18, dtype=torch.float32, device=device) + .reshape(2, 3, 3) + .flip(0), + ), + id="frobenius_mse", + ), + ], + ) + def test_reduction_compiles( + self, + fn: Any, + args_factory: Any, + device: str, + ) -> None: + batch_idx = self._batch_idx(device) + args = args_factory(device, batch_idx) + compiled = torch.compile(fn, **self._compile_kwargs(device)) + + got = compiled(*args) + expected = fn(*args) + + assert torch.allclose(got, expected) + + +class TestConcreteLossesCompile: + @staticmethod + def _compile_kwargs(device: str) -> dict[str, Any]: + kwargs: dict[str, Any] = {"fullgraph": True} + if device == "cuda": + kwargs["backend"] = "cudagraphs" + return kwargs + + def test_energy_mae_loss_compiles_fullgraph(self, device: str) -> None: + loss = EnergyMAELoss() + pred = torch.tensor([[6.0], [15.0], [8.0]], device=device) + target = torch.tensor([[3.0], [10.0], [4.0]], device=device) + counts = torch.tensor([3, 5, 2], dtype=torch.long, device=device) + + def fn( + pred: torch.Tensor, target: torch.Tensor, counts: torch.Tensor + ) -> torch.Tensor: + return loss(pred, target, num_nodes_per_graph=counts) + + compiled = torch.compile(fn, **self._compile_kwargs(device)) + torch.testing.assert_close( + compiled(pred, target, counts), fn(pred, target, counts) + ) + + def test_force_l2_norm_loss_dense_compiles_fullgraph(self, device: str) -> None: + loss = ForceL2NormLoss() + pred = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + [1.0, 1.0, 1.0], + [2.0, 0.0, 0.0], + ], + device=device, + ) + target = torch.zeros_like(pred) + batch_idx = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=device) + + def fn( + pred: torch.Tensor, target: torch.Tensor, batch_idx: torch.Tensor + ) -> torch.Tensor: + return loss(pred, target, batch_idx=batch_idx, num_graphs=2) + + compiled = torch.compile(fn, **self._compile_kwargs(device)) + torch.testing.assert_close( + compiled(pred, target, batch_idx), fn(pred, target, batch_idx) + ) + + def test_force_l2_norm_loss_padded_compiles_fullgraph(self, device: str) -> None: + loss = ForceL2NormLoss() + pred = torch.tensor( + [ + [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]], + [[1.0, 1.0, 1.0], [2.0, 0.0, 0.0], [99.0, 99.0, 99.0]], + ], + device=device, + ) + target = torch.zeros_like(pred) + counts = torch.tensor([3, 2], dtype=torch.long, device=device) + + def fn( + pred: torch.Tensor, target: torch.Tensor, counts: torch.Tensor + ) -> torch.Tensor: + return loss(pred, target, num_nodes_per_graph=counts) + + compiled = torch.compile(fn, **self._compile_kwargs(device)) + torch.testing.assert_close( + compiled(pred, target, counts), fn(pred, target, counts) + ) + + +class TestBaseLossFunction: + # ``compute_residual(pred, target, valid)`` is the sole abstract method. + # ``forward`` is a concrete template orchestrating validate → normalize → + # mask → compute_residual → reduce. Weighting lives on + # :class:`ComposedLossFunction`. + + def test_baseloss_abstract_cannot_instantiate(self) -> None: + with pytest.raises(TypeError, match="abstract"): + BaseLossFunction() + + def test_baseloss_is_nn_module(self) -> None: + loss = _ToyLoss(value=1.0) + assert isinstance(loss, nn.Module) + + def test_baseloss_has_no_weight_attribute(self) -> None: + loss = _ToyLoss(value=3.0) + assert not hasattr(loss, "weight") + + def test_baseloss_forward_returns_raw_unweighted_tensor(self) -> None: + # Calling the module must return exactly what ``forward`` returns, + # with no weighting applied at the leaf. + loss = _ToyLoss(value=2.5) + assert torch.allclose(loss(*_dummy_loss_tensors()), torch.tensor(2.5)) + + def test_baseloss_forward_accepts_extra_kwargs(self) -> None: + # Leaves accept arbitrary kwargs so composed losses can forward + # shared graph metadata without forcing every component to consume it. + loss = _ToyLoss(value=4.0) + assert torch.allclose( + loss(*_dummy_loss_tensors(), unused_metadata=7), torch.tensor(4.0) + ) + + def test_baseloss_to_device_smoke(self) -> None: + # Stateless loss still supports ``.to()`` via nn.Module. + loss = EnergyMSELoss() + moved = loss.to("meta") + assert isinstance(moved, nn.Module) + assert moved is loss # .to() is in-place for nn.Module + + def test_baseloss_state_dict_empty(self) -> None: + loss = EnergyMSELoss() + assert len(loss.state_dict()) == 0 + assert list(loss.parameters()) == [] + assert list(loss.buffers()) == [] + + +class TestLossRepr: + @pytest.mark.parametrize( + ("loss_factory", "class_name", "substrings"), + [ + pytest.param( + lambda: EnergyMSELoss(per_atom=True), + "EnergyMSELoss", + ( + "target_key='energy'", + "prediction_key='predicted_energy'", + "per_atom=True", + ), + id="energy", + ), + pytest.param( + lambda: EnergyHuberLoss(delta=0.5), + "EnergyHuberLoss", + ("target_key='energy'", "per_atom=True", "delta=0.5"), + id="energy_huber", + ), + pytest.param( + lambda: ForceMSELoss(normalize_by_atom_count=False), + "ForceMSELoss", + ("normalize_by_atom_count=False",), + id="force", + ), + pytest.param( + lambda: ForceHuberLoss(delta=0.5), + "ForceHuberLoss", + ( + "normalize_by_atom_count=False", + "delta=0.5", + ), + id="force_huber", + ), + pytest.param( + lambda: StressMSELoss(ignore_nonfinite=True), + "StressMSELoss", + ("target_key='stress'", "ignore_nonfinite=True"), + id="stress", + ), + pytest.param( + lambda: StressHuberLoss(delta=0.5), + "StressHuberLoss", + ("target_key='stress'", "ignore_nonfinite=True", "delta=0.5"), + id="stress_huber", + ), + ], + ) + def test_concrete_loss_repr_contains_hyperparameters( + self, + loss_factory: Any, + class_name: str, + substrings: tuple[str, ...], + ) -> None: + text = repr(loss_factory()) + assert class_name in text + for substring in substrings: + assert substring in text, (substring, text) + + def test_concrete_loss_repr_has_no_weight_attribute(self) -> None: + # Weight lives on the composition, not on leaves. + for text in ( + repr(EnergyMSELoss()), + repr(EnergyHuberLoss()), + repr(ForceMSELoss()), + repr(ForceHuberLoss()), + repr(StressMSELoss()), + repr(StressHuberLoss()), + ): + assert "weight" not in text + + def test_composed_repr_shows_nested_components(self) -> None: + composed = EnergyMSELoss() + ForceMSELoss() + text = repr(composed) + assert "ComposedLossFunction" in text + assert "EnergyMSELoss" in text + assert "ForceMSELoss" in text + # nn.ModuleList numbers its children; "(0):" is the first entry. + assert "(0)" in text + + def test_composed_repr_includes_normalize_weights_flag(self) -> None: + text = repr(EnergyMSELoss() + ForceMSELoss()) + assert "normalize_weights=True" in text + text_off = repr( + ComposedLossFunction( + (EnergyMSELoss(), ForceMSELoss()), normalize_weights=False + ) + ) + assert "normalize_weights=False" in text_off + + def test_extra_repr_non_empty_on_concrete(self) -> None: + for loss in (EnergyMSELoss(), ForceMSELoss(), StressMSELoss()): + assert loss.extra_repr() != "" + + +class TestComposedLossFunction: + def setup_method(self) -> None: + self.loss_a = _ToyLoss(value=1.0) + self.loss_b = _ToyLoss(value=1.0) + self.loss_c = _ToyLoss(value=1.0) + + def test_add_two_losses(self) -> None: + composed = self.loss_a + self.loss_b + assert isinstance(composed, ComposedLossFunction) + assert tuple(composed.components) == (self.loss_a, self.loss_b) + + def test_composed_defaults_to_normalize_weights_true(self) -> None: + composed = ComposedLossFunction((EnergyMSELoss(), ForceMSELoss())) + assert composed.normalize_weights is True + # Defaults to all-1.0 weights → normalized to 1/N each. + assert composed.current_weight() == [0.5, 0.5] + + def test_composed_default_weights_are_all_one(self) -> None: + composed = ComposedLossFunction( + (EnergyMSELoss(), ForceMSELoss()), normalize_weights=False + ) + assert composed._weights == [1.0, 1.0] + assert composed.current_weight() == [1.0, 1.0] + + def test_composed_is_nn_module(self) -> None: + composed = self.loss_a + self.loss_b + assert isinstance(composed, nn.Module) + modules = list(composed.modules()) + assert self.loss_a in modules + assert self.loss_b in modules + + def test_composed_components_stored_as_module_list(self) -> None: + composed = self.loss_a + self.loss_b + assert isinstance(composed.components, nn.ModuleList) + + @pytest.mark.parametrize( + "build", + [ + lambda a, b, c: (a + b) + c, + lambda a, b, c: a + (b + c), + ], + ids=["left_assoc", "right_assoc"], + ) + def test_nested_addition_flattens(self, build: Any) -> None: + composed = build(self.loss_a, self.loss_b, self.loss_c) + assert isinstance(composed, ComposedLossFunction) + assert len(composed.components) == 3 + assert all(not isinstance(c, ComposedLossFunction) for c in composed.components) + + def test_sum_over_list(self) -> None: + # sum() seeds with 0 → exercises __radd__. + composed = sum([self.loss_a, self.loss_b, self.loss_c]) + assert isinstance(composed, ComposedLossFunction) + assert len(composed.components) == 3 + + def test_weights_length_must_match_components(self) -> None: + with pytest.raises(ValueError, match="weights has length"): + ComposedLossFunction((self.loss_a, self.loss_b), weights=[1.0, 2.0, 3.0]) + + def test_weights_reject_non_numeric(self) -> None: + with pytest.raises(TypeError, match="weights\\[0\\] must be"): + ComposedLossFunction( + (self.loss_a,), + weights=["not-a-weight"], # type: ignore[list-item] + ) + + def test_weights_none_entry_coerced_to_one(self) -> None: + composed = ComposedLossFunction( + (self.loss_a, self.loss_b), + weights=[None, 3.0], + normalize_weights=False, + ) + assert composed._weights == [1.0, 3.0] + + def test_normalize_weights_zero_sum_raises(self) -> None: + composed = ComposedLossFunction((self.loss_a, self.loss_b), weights=[0.0, 0.0]) + with pytest.raises( + ValueError, + match=( + r"sum is not strictly positive \(sum=0\.0\)\. " + r"Resolved weights at step=0, epoch=None: " + r"\{'_ToyLoss_0': 0\.0, '_ToyLoss_1': 0\.0\}" + ), + ): + composed.current_weight() + + def test_normalize_weights_nonzero_but_zero_sum_raises(self) -> None: + # [1, -1] sums to zero; the error message must reflect that the + # individual resolved weights were non-zero. + composed = ComposedLossFunction((self.loss_a, self.loss_b), weights=[1.0, -1.0]) + with pytest.raises( + ValueError, + match=( + r"sum is not strictly positive \(sum=0\.0\)\. " + r"Resolved weights at step=7, epoch=2: " + r"\{'_ToyLoss_0': 1\.0, '_ToyLoss_1': -1\.0\}" + ), + ): + composed.current_weight(step=7, epoch=2) + + def test_normalize_weights_negative_sum_raises(self) -> None: + # A negative raw sum would flip every effective weight's sign + # after normalization; reject it with the same "not strictly + # positive" error used for zero sums. + composed = ComposedLossFunction((self.loss_a, self.loss_b), weights=[1.0, -3.0]) + with pytest.raises( + ValueError, match=r"sum is not strictly positive \(sum=-2\.0\)" + ): + composed.current_weight() + + def test_normalize_weights_false_returns_raw(self) -> None: + composed = ComposedLossFunction( + (self.loss_a, self.loss_b), + weights=[3.0, 2.0], + normalize_weights=False, + ) + assert composed.current_weight() == [3.0, 2.0] + + def test_weighted_sum_unnormalized_is_pure_weighted_sum(self) -> None: + a, b, v1, v2 = 2.0, 3.0, 5.0, 7.0 + comp1 = _ToyLoss(value=v1) + comp2 = _ToyLoss(value=v2) + composed = ComposedLossFunction( + (comp1, comp2), weights=[a, b], normalize_weights=False + ) + out = composed(*_dummy_loss_mappings(), step=0, epoch=0) + expected = a * v1 + b * v2 + assert torch.allclose(out["total_loss"], torch.tensor(expected), atol=1e-6) + + def test_per_component_unweighted_and_weight_populated(self) -> None: + comp1 = _ToyLoss(value=2.0) + comp2 = _ToyLoss(value=4.0) + composed = ComposedLossFunction( + (comp1, comp2), weights=[3.0, 2.0], normalize_weights=False + ) + out = composed(*_dummy_loss_mappings()) + assert set(out) == { + "total_loss", + "per_component_unweighted", + "per_component_weight", + "per_component_raw_weight", + "per_component_sample", + } + assert torch.allclose( + out["per_component_unweighted"]["_ToyLoss_0"], torch.tensor(2.0) + ) + assert torch.allclose( + out["per_component_unweighted"]["_ToyLoss_1"], torch.tensor(4.0) + ) + assert out["per_component_weight"] == { + "_ToyLoss_0": 3.0, + "_ToyLoss_1": 2.0, + } + # Without normalization raw and effective weights match. + assert out["per_component_raw_weight"] == out["per_component_weight"] + assert torch.allclose(out["total_loss"], torch.tensor(14.0)) + + def test_per_component_weight_reflects_normalization(self) -> None: + composed = ComposedLossFunction( + (_ToyLoss(value=1.0), _ToyLoss(value=1.0)), + weights=[3.0, 2.0], + ) + out = composed(*_dummy_loss_mappings()) + assert out["per_component_weight"] == { + "_ToyLoss_0": 0.6, + "_ToyLoss_1": 0.4, + } + # Raw weights expose the pre-normalization values so a user + # logging a scheduled loss can observe the underlying ramp. + assert out["per_component_raw_weight"] == { + "_ToyLoss_0": 3.0, + "_ToyLoss_1": 2.0, + } + assert torch.allclose( + out["per_component_unweighted"]["_ToyLoss_0"], torch.tensor(1.0) + ) + assert torch.allclose( + out["per_component_unweighted"]["_ToyLoss_1"], torch.tensor(1.0) + ) + + def test_per_component_raw_weight_tracks_schedule_on_single_leaf(self) -> None: + # Single-component normalized composition: effective weight is + # always 1.0, so raw_weight is the only way to observe the + # underlying schedule ramp. + schedule = LinearWeight(start=0.0, end=1.0, num_steps=10) + composed = schedule * _ToyLoss(value=1.0) + out_mid = composed(*_dummy_loss_mappings(), step=5) + assert out_mid["per_component_weight"] == {"_ToyLoss": 1.0} + assert out_mid["per_component_raw_weight"] == {"_ToyLoss": 0.5} + out_end = composed(*_dummy_loss_mappings(), step=10) + assert out_end["per_component_raw_weight"] == {"_ToyLoss": 1.0} + + def test_empty_components_raises(self) -> None: + with pytest.raises(ValueError, match="at least one"): + ComposedLossFunction(components=()) + + def test_non_loss_component_rejected(self) -> None: + with pytest.raises( + TypeError, + match="components\\[0\\] must be a BaseLossFunction or ComposedLossFunction", + ): + ComposedLossFunction( + components=("not-a-loss",), # type: ignore[arg-type] + ) + + def test_requires_eval_grad_true_when_any_component_requires(self) -> None: + class _GradLoss(_ToyLoss): + requires_eval_grad = True + + class _NoGradLoss(_ToyLoss): + requires_eval_grad = False + + composed = ComposedLossFunction((_NoGradLoss(), _GradLoss())) + assert composed.requires_eval_grad() is True + + def test_requires_eval_grad_false_when_all_components_disclaim(self) -> None: + class _NoGradLoss(_ToyLoss): + requires_eval_grad = False + + composed = ComposedLossFunction((_NoGradLoss(), _NoGradLoss())) + assert composed.requires_eval_grad() is False + + def test_requires_eval_grad_raises_on_undeclared_component(self) -> None: + class _UndeclaredLoss(_ToyLoss): + requires_eval_grad = None + + composed = ComposedLossFunction((_UndeclaredLoss(),)) + with pytest.raises(ValueError, match="infer whether"): + composed.requires_eval_grad() + + def test_gradient_flows_through_all_components(self) -> None: + positions = torch.randn(4, 3, requires_grad=True) + loss_a = _PositionsLoss(scale=2.0) + loss_b = _PositionsLoss(scale=3.0) + composed = ComposedLossFunction( + (loss_a, loss_b), weights=[1.0, 1.0], normalize_weights=False + ) + out = composed( + {"positions": positions}, + {"positions": torch.zeros_like(positions)}, + step=0, + epoch=0, + ) + out["total_loss"].backward() + # d/dx sum(x) = 1 per element; composed multiplier = 2 + 3 = 5. + expected_grad = torch.full_like(positions, 5.0) + assert positions.grad is not None + assert torch.allclose(positions.grad, expected_grad, atol=1e-6) + + def test_schedule_applied_inside_composition(self) -> None: + # A schedule attached to a component's slot in the composition is + # resolved once per call. + leaf = _ToyLoss(value=4.0) + composed = ComposedLossFunction( + (leaf,), weights=[ConstantWeight(value=2.5)], normalize_weights=False + ) + out = composed(*_dummy_loss_mappings(), step=0, epoch=0) + # 2.5 (schedule) * 4.0 (forward) = 10.0 + assert torch.allclose(out["total_loss"], torch.tensor(10.0), atol=1e-6) + + def test_schedule_counters_are_not_forwarded_to_leaf(self) -> None: + class _StrictLoss(BaseLossFunction): + prediction_key = "prediction" + target_key = "target" + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return pred + target + + def compute_residual( + self, pred: torch.Tensor, target: torch.Tensor, valid: torch.Tensor + ) -> torch.Tensor: + return pred - target + + composed = ComposedLossFunction( + (_StrictLoss(),), + weights=[LinearWeight(start=0.0, end=1.0, num_steps=10)], + normalize_weights=False, + ) + out = composed( + {"prediction": torch.tensor(2.0)}, + {"target": torch.tensor(3.0)}, + step=5, + epoch=7, + ) + assert torch.allclose(out["total_loss"], torch.tensor(2.5), atol=1e-6) + + def test_linear_schedule_on_component_in_composition(self) -> None: + leaf = _ToyLoss(value=1.0) + composed = ComposedLossFunction( + (leaf,), + weights=[LinearWeight(start=0.0, end=1.0, num_steps=10)], + normalize_weights=False, + ) + assert torch.allclose( + composed(*_dummy_loss_mappings(), step=0)["total_loss"], + torch.tensor(0.0), + atol=1e-6, + ) + assert torch.allclose( + composed(*_dummy_loss_mappings(), step=10)["total_loss"], + torch.tensor(1.0), + atol=1e-6, + ) + assert torch.allclose( + composed(*_dummy_loss_mappings(), step=5)["total_loss"], + torch.tensor(0.5), + atol=1e-6, + ) + + def test_per_epoch_schedule_with_none_epoch_raises_in_composition(self) -> None: + leaf = _ToyLoss(value=1.0) + composed = ComposedLossFunction( + (leaf,), + weights=[LinearWeight(start=0.0, end=1.0, num_steps=10, per_epoch=True)], + ) + with pytest.raises(ValueError, match="per_epoch=True"): + composed(*_dummy_loss_mappings(), step=3, epoch=None) + + @pytest.mark.parametrize( + ("bad_value", "match"), + [ + pytest.param(float("nan"), r"non-finite weight nan", id="nan"), + pytest.param(float("inf"), r"non-finite weight inf", id="inf"), + pytest.param(float("-inf"), r"non-finite weight -inf", id="neg_inf"), + ], + ) + def test_schedule_non_finite_weight_raises( + self, bad_value: float, match: str + ) -> None: + composed = ComposedLossFunction( + (_ToyLoss(value=1.0),), + weights=[_ReturnSchedule(bad_value)], + normalize_weights=False, + ) + with pytest.raises(ValueError, match=match): + composed.current_weight(step=0, epoch=0) + + def test_schedule_non_numeric_weight_raises(self) -> None: + composed = ComposedLossFunction( + (_ToyLoss(value=1.0),), + weights=[_ReturnSchedule("oops")], + normalize_weights=False, + ) + with pytest.raises( + TypeError, + match=r"_ReturnSchedule returned str; " + r"LossWeightSchedule\.__call__ must return float", + ): + composed.current_weight(step=0, epoch=0) + + def test_nested_composition_applies_each_weight_exactly_once(self) -> None: + # Nesting must not cause duplicate weight application. + leaf = _ToyLoss(value=2.0) + inner = ComposedLossFunction((leaf,), weights=[3.0], normalize_weights=False) + outer = ComposedLossFunction((inner,), weights=[1.0], normalize_weights=False) + out = outer(*_dummy_loss_mappings(), step=0, epoch=0) + # 1 * 3 * 2 = 6 + assert torch.allclose(out["total_loss"], torch.tensor(6.0), atol=1e-6) + assert torch.allclose( + out["per_component_unweighted"]["_ToyLoss"], torch.tensor(2.0) + ) + + def test_nested_composition_multiplies_weights_elementwise(self) -> None: + leaf1 = _ToyLoss(value=1.0) + leaf2 = _ToyLoss(value=1.0) + inner = ComposedLossFunction( + (leaf1, leaf2), weights=[3.0, 2.0], normalize_weights=False + ) + # Outer wraps the inner composition with weight 5.0. + outer = ComposedLossFunction((inner,), weights=[5.0], normalize_weights=False) + # After flattening the effective per-leaf weights are 5*3=15, 5*2=10. + assert outer.current_weight() == [15.0, 10.0] + + @pytest.mark.parametrize("op", ["add"], ids=["add"]) + def test_not_implemented_for_bad_type(self, op: str) -> None: + if op == "add": + with pytest.raises(TypeError): + _ = self.loss_a + "hello" # type: ignore[operator] + + +class TestOperatorSugar: + """Tests for ``scalar * loss``, ``schedule * loss``, and operator composition.""" + + @pytest.mark.parametrize( + ("side", "weight_kind"), + [ + ("left", "float"), + ("right", "float"), + ("left", "schedule"), + ("right", "schedule"), + ], + ) + def test_multiplication_wraps_leaf_in_composition( + self, side: str, weight_kind: str + ) -> None: + leaf = _ToyLoss(value=1.0) + weight: float | ConstantWeight = ( + 3.0 if weight_kind == "float" else ConstantWeight(value=2.5) + ) + composed = weight * leaf if side == "left" else leaf * weight + assert isinstance(composed, ComposedLossFunction) + assert len(composed.components) == 1 + assert composed._weights == [weight] + + def test_scaled_leaf_plus_scaled_leaf_flattens_and_normalizes(self) -> None: + composed = 3.0 * EnergyMSELoss() + 2.0 * ForceMSELoss() + assert isinstance(composed, ComposedLossFunction) + assert len(composed.components) == 2 + # Raw weights preserved on construction; normalization is applied + # only at call time. + assert composed._weights == [3.0, 2.0] + # Default normalize_weights=True: 3/(3+2), 2/(3+2). + assert composed.current_weight() == [0.6, 0.4] + + def test_single_scaled_leaf_normalizes_to_one(self) -> None: + composed = 3.0 * _ToyLoss(value=5.0) + # One component → sum(raw)=3 → effective 1.0. + assert composed.current_weight() == [1.0] + out = composed(*_dummy_loss_mappings()) + assert torch.allclose(out["total_loss"], torch.tensor(5.0), atol=1e-6) + + def test_schedule_times_leaf_participates_in_current_weight(self) -> None: + # Operator-attached schedule is stored on the composition and + # resolved at call time. Step-interpolation detail is covered by + # ``test_linear_schedule_on_component_in_composition``. + schedule = LinearWeight(start=0.0, end=1.0, num_steps=10) + composed = schedule * _ToyLoss(value=1.0) + _ToyLoss(value=1.0) + assert composed._weights[0] is schedule + assert composed.current_weight(step=10) == [0.5, 0.5] + + def test_float_mul_on_composition_scales_every_weight(self) -> None: + base = ComposedLossFunction( + (_ToyLoss(value=1.0), _ToyLoss(value=1.0)), + weights=[3.0, 2.0], + normalize_weights=False, + ) + scaled = 4.0 * base + assert scaled._weights == [12.0, 8.0] + # Normalization flag is inherited. + assert scaled.normalize_weights is False + + def test_schedule_mul_on_composition_raises(self) -> None: + base = ComposedLossFunction((_ToyLoss(value=1.0),)) + schedule = ConstantWeight(value=2.0) + with pytest.raises(TypeError, match="LossWeightSchedule"): + _ = schedule * base + + def test_add_mismatched_normalize_raises(self) -> None: + normalized = ComposedLossFunction( + (_ToyLoss(value=1.0),), normalize_weights=True + ) + unnormalized = ComposedLossFunction( + (_ToyLoss(value=1.0),), normalize_weights=False + ) + with pytest.raises( + ValueError, + match=r"mismatched normalize_weights \(self=True, other=False\)", + ): + _ = normalized + unnormalized + with pytest.raises( + ValueError, + match=r"mismatched normalize_weights \(self=False, other=True\)", + ): + _ = unnormalized + normalized + + def test_bool_multiplication_rejected(self) -> None: + # ``True * loss`` could silently mean "1.0 * loss", which hides + # user bugs. Reject bools explicitly. + with pytest.raises(TypeError): + _ = True * _ToyLoss(value=1.0) # type: ignore[operator] + + def test_radd_bare_leaf_plus_composition(self) -> None: + composition = 2.0 * _ToyLoss(value=1.0) + result = _ToyLoss(value=1.0) + composition + assert isinstance(result, ComposedLossFunction) + assert len(result.components) == 2 + assert result._weights == [1.0, 2.0] + + +class TestWeightFactors: + @pytest.mark.parametrize( + ("factory", "expected"), + [ + pytest.param( + lambda: ComposedLossFunction( + (EnergyMSELoss(), ForceMSELoss()), + normalize_weights=False, + ), + {"EnergyMSELoss": 1.0, "ForceMSELoss": 1.0}, + id="default_weights_unnormalized", + ), + pytest.param( + lambda: ComposedLossFunction( + (EnergyMSELoss(), ForceMSELoss()), + weights=[ConstantWeight(value=2.0), ConstantWeight(value=3.0)], + normalize_weights=False, + ), + {"EnergyMSELoss": 2.0, "ForceMSELoss": 3.0}, + id="schedule_weights_unnormalized", + ), + pytest.param( + lambda: ComposedLossFunction( + (EnergyMSELoss(), ForceMSELoss()), + weights=[3.0, 2.0], + ), + {"EnergyMSELoss": 0.6, "ForceMSELoss": 0.4}, + id="float_weights_normalized", + ), + ], + ) + def test_weight_factors_simple_cases( + self, factory: Any, expected: dict[str, float] + ) -> None: + assert factory().weight_factors(step=0, epoch=0) == expected + + def test_weight_factors_no_args_smoke(self) -> None: + # ``weight_factors`` takes default ``step=0, epoch=None`` so + # introspection helpers don't demand args. + composed = ComposedLossFunction( + (EnergyMSELoss(),), weights=[ConstantWeight(value=2.0)] + ) + # Single component + normalization → effective weight is 1.0. + assert composed.weight_factors() == {"EnergyMSELoss": 1.0} + + def test_weight_factors_class_name_collision_gets_indexed_suffix(self) -> None: + composed = ComposedLossFunction( + components=(StressMSELoss(), StressMSELoss()), + ) + got = composed.weight_factors(step=0, epoch=0) + assert set(got) == {"StressMSELoss_0", "StressMSELoss_1"} + # Normalized to 0.5 each. + assert got["StressMSELoss_0"] == 0.5 + assert got["StressMSELoss_1"] == 0.5 + + def test_weight_factors_three_way_collision_across_nested_composition(self) -> None: + # Inner composition contains two ``StressMSELoss`` instances; wrapping in + # an outer composition with another ``StressMSELoss`` must collapse to + # three collision-suffixed keys — NOT to a mix like + # ``{"StressMSELoss_0", "StressMSELoss_1", "StressMSELoss"}`` from per-level + # suffixing. + inner = ComposedLossFunction(components=(StressMSELoss(), StressMSELoss())) + outer = ComposedLossFunction( + components=(inner, StressMSELoss()), normalize_weights=False + ) + got = outer.weight_factors(step=0, epoch=0) + assert set(got) == {"StressMSELoss_0", "StressMSELoss_1", "StressMSELoss_2"} + assert all(v == 1.0 for v in got.values()) + + def test_weight_factors_nested_composition_flattens(self) -> None: + inner = ComposedLossFunction( + (EnergyMSELoss(),), + weights=[ConstantWeight(value=0.5)], + normalize_weights=False, + ) + outer = ComposedLossFunction( + (inner, ForceMSELoss()), + weights=[1.0, ConstantWeight(value=4.0)], + normalize_weights=False, + ) + assert outer.weight_factors(step=0, epoch=0) == { + "EnergyMSELoss": 0.5, + "ForceMSELoss": 4.0, + } + + +class TestConcreteLosses: + def setup_method(self) -> None: + # Mixed-size batch: 3 graphs with 3, 5, 2 atoms respectively. + self.nodes_per_graph = [3, 5, 2] + self.num_graphs = 3 + self.num_nodes = sum(self.nodes_per_graph) + self.batch_idx = torch.tensor([0, 0, 0, 1, 1, 1, 1, 1, 2, 2], dtype=torch.int32) + self.num_nodes_per_graph = torch.tensor(self.nodes_per_graph, dtype=torch.long) + + def _batch(self, **extra: torch.Tensor) -> SimpleNamespace: + return SimpleNamespace( + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + num_nodes_per_graph=self.num_nodes_per_graph, + **extra, + ) + + @staticmethod + def _force_l2_dense_case() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_idx = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32) + target = torch.zeros(5, 3) + pred = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + [1.0, 1.0, 1.0], + [2.0, 0.0, 0.0], + ] + ) + return pred, target, batch_idx + + @staticmethod + def _force_l2_padded_case() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + target = torch.zeros(2, 3, 3) + pred = torch.tensor( + [ + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + ], + [ + [1.0, 1.0, 1.0], + [2.0, 0.0, 0.0], + [99.0, 99.0, 99.0], + ], + ] + ) + counts = torch.tensor([3, 2], dtype=torch.long) + return pred, target, counts + + def test_energy_loss_gradient_matches_analytic( + self, fixed_torch_seed: None + ) -> None: + target = torch.randn(self.num_graphs, 1) + pred = (target + torch.randn_like(target) * 0.1).detach().requires_grad_() + EnergyMSELoss()(pred, target).backward() + # MSE over (B, 1): d/d pred = 2*(pred - target) / B. + expected_grad = 2.0 * (pred.detach() - target) / self.num_graphs + assert pred.grad is not None + assert torch.allclose(pred.grad, expected_grad, atol=1e-6) + + def test_energy_loss_per_atom_weights_by_atom_count(self) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]]) # per-graph energies + pred = torch.tensor([[6.0], [15.0], [8.0]]) + got = EnergyMSELoss(per_atom=True)( + pred, target, num_nodes_per_graph=self.num_nodes_per_graph + ) + # Per-atom pred: [2, 3, 4]; target: [1, 2, 2]; diffs: [1, 1, 2]. + # Atom-count weighted MSE: (3*1 + 5*1 + 2*4) / (3 + 5 + 2) = 1.6. + assert torch.allclose(got, torch.tensor(1.6), atol=1e-6) + + def test_energy_loss_per_atom_accepts_padded_node_mask(self) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]]) # per-graph energies + pred = torch.tensor([[6.0], [15.0], [8.0]]) + node_mask = torch.tensor( + [ + [True, True, True, False, False], + [True, True, True, True, True], + [True, True, False, False, False], + ] + ) + got = EnergyMSELoss(per_atom=True)(pred, target, num_nodes_per_graph=node_mask) + # The padded mask has row counts [3, 5, 2], matching the dense-count test. + assert torch.allclose(got, torch.tensor(1.6), atol=1e-6) + + def test_energy_loss_per_atom_accepts_cpu_counts_on_cuda( + self, gpu_device: str + ) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]], device=gpu_device) + pred = torch.tensor([[6.0], [15.0], [8.0]], device=gpu_device) + got = EnergyMSELoss(per_atom=True)( + pred, target, num_nodes_per_graph=self.num_nodes_per_graph + ) + + assert got.device.type == "cuda" + assert torch.allclose(got, torch.tensor(1.6, device=gpu_device), atol=1e-6) + + def test_energy_mae_loss_matches_manual_per_atom_mean(self) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]]) + pred = torch.tensor([[6.0], [15.0], [8.0]]) + got = EnergyMAELoss()( + pred, target, num_nodes_per_graph=self.num_nodes_per_graph + ) + counts = self.num_nodes_per_graph.to(pred).unsqueeze(-1) + abs_residual = (pred / counts - target / counts).abs() + expected = (abs_residual * counts).sum() / counts.sum() + assert torch.allclose(got, expected, atol=1e-6) + + def test_energy_mae_loss_ignores_nan_and_inf_targets(self) -> None: + target = torch.tensor([[3.0], [float("nan")], [float("inf")], [8.0]]) + pred = torch.tensor([[6.0], [20.0], [30.0], [4.0]]) + counts = torch.tensor([3, 5, 2, 2], dtype=torch.long) + got = EnergyMAELoss()(pred, target, num_nodes_per_graph=counts) + # Valid entries: index 0 (count=3) and index 3 (count=2). + # Per-atom abs residuals: |6/3 - 3/3| = 1.0, |4/2 - 8/2| = 2.0 + # Atom-count weighted: (3*1.0 + 2*2.0) / (3+2) = 7/5 = 1.4 + expected = torch.tensor(7.0 / 5.0) + assert torch.allclose(got, expected, atol=1e-6) + + def test_energy_mae_loss_gradient_flows(self) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]]) + pred = torch.tensor([[6.0], [15.0], [8.0]], requires_grad=True) + EnergyMAELoss()( + pred, target, num_nodes_per_graph=self.num_nodes_per_graph + ).backward() + assert pred.grad is not None + assert pred.grad.shape == pred.shape + + def test_energy_huber_loss_matches_mace_style_graph_mean(self) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]]) + pred = torch.tensor([[6.0], [15.0], [8.0]]) + got = EnergyHuberLoss(delta=1.5)( + pred, target, num_nodes_per_graph=self.num_nodes_per_graph + ) + counts = self.num_nodes_per_graph.to(pred).unsqueeze(-1) + per_graph = _huber_loss(pred / counts - target / counts, 1.5).reshape(-1) + expected = per_graph.mean() + assert torch.allclose(got, expected, atol=1e-6) + + def test_energy_huber_loss_ignores_nan_and_inf_targets(self) -> None: + target = torch.tensor([[3.0], [float("nan")], [float("inf")], [8.0]]) + pred = torch.tensor([[6.0], [20.0], [30.0], [4.0]]) + counts = torch.tensor([3, 5, 2, 2], dtype=torch.long) + got = EnergyHuberLoss(delta=1.5)(pred, target, num_nodes_per_graph=counts) + expected = _huber_loss(torch.tensor([1.0, -2.0]), 1.5).mean() + assert torch.allclose(got, expected, atol=1e-6) + + def test_energy_mae_loss_accepts_vector_shape(self) -> None: + target = torch.tensor([3.0, 10.0, 4.0]) + pred = torch.tensor([6.0, 15.0, 8.0]) + got = EnergyMAELoss()( + pred, target, num_nodes_per_graph=self.num_nodes_per_graph + ) + counts = self.num_nodes_per_graph.to(pred) + abs_residual = (pred / counts - target / counts).abs() + expected = (abs_residual * counts).sum() / counts.sum() + assert torch.allclose(got, expected, atol=1e-6) + + @pytest.mark.parametrize( + "bad_counts", + [ + torch.tensor([3, 5], dtype=torch.long), + torch.ones(2, 5, dtype=torch.bool), + ], + ids=["count_length", "mask_batch"], + ) + def test_energy_mae_loss_rejects_count_batch_mismatch( + self, bad_counts: torch.Tensor + ) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]]) + pred = torch.tensor([[6.0], [15.0], [8.0]]) + with pytest.raises(ValueError, match="must match energy batch size"): + EnergyMAELoss()(pred, target, num_nodes_per_graph=bad_counts) + + def test_force_loss_matches_hand_computed(self) -> None: + # 2 graphs with 3 and 2 atoms for a small hand-traceable case. + batch_idx = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32) + target = torch.zeros(5, 3) + pred = torch.tensor( + [ + [1.0, 0.0, 0.0], # graph 0 atom 0: |f|^2 = 1 + [0.0, 2.0, 0.0], # graph 0 atom 1: |f|^2 = 4 + [0.0, 0.0, 3.0], # graph 0 atom 2: |f|^2 = 9 + [1.0, 1.0, 1.0], # graph 1 atom 0: |f|^2 = 3 + [2.0, 0.0, 0.0], # graph 1 atom 1: |f|^2 = 4 + ] + ) + # normalize_by_atom_count=True: per-graph mean of |f|^2 then mean + # over graphs, then / 3 for per-component. + # graph 0 mean |f|^2 = (1+4+9)/3 = 14/3 + # graph 1 mean |f|^2 = (3+4)/2 = 7/2 + # mean over graphs = (14/3 + 7/2) / 2 = (28/6 + 21/6) / 2 = 49/12 + # divided by 3 components = 49/36 + got_norm = ForceMSELoss(normalize_by_atom_count=True)( + pred, + target, + batch_idx=batch_idx, + num_graphs=2, + ) + assert torch.allclose(got_norm, torch.tensor(49.0 / 36.0), atol=1e-6) + + # normalize=False: elementwise mean over the (V, 3) tensor. + # sum of squares = 1+4+9+3+4 = 21 across 5*3 = 15 entries -> 21/15 = 1.4. + got_global = ForceMSELoss(normalize_by_atom_count=False)(pred, target) + assert torch.allclose(got_global, torch.tensor(21.0 / 15.0), atol=1e-6) + + def test_force_huber_loss_matches_componentwise_huber(self) -> None: + target = torch.zeros(2, 3) + pred = torch.tensor( + [ + [0.5, 2.0, -3.0], + [1.0, -0.25, 0.0], + ] + ) + got = ForceHuberLoss(normalize_by_atom_count=False, delta=1.0)(pred, target) + expected = _huber_loss(pred - target, 1.0).mean() + assert torch.allclose(got, expected, atol=1e-6) + + def test_force_l2_norm_loss_dense_matches_manual_per_graph_reduction(self) -> None: + pred, target, batch_idx = self._force_l2_dense_case() + got = ForceL2NormLoss()(pred, target, batch_idx=batch_idx, num_graphs=2) + per_atom = torch.linalg.vector_norm(pred - target, ord=2, dim=-1) + graph0 = per_atom[:3].mean() + graph1 = per_atom[3:].mean() + expected = torch.stack((graph0, graph1)).mean() + assert torch.allclose(got, expected, atol=1e-6) + + def test_force_l2_norm_loss_dense_global_mean_when_not_normalized(self) -> None: + pred, target, _ = self._force_l2_dense_case() + got = ForceL2NormLoss(normalize_by_atom_count=False)(pred, target) + expected = torch.linalg.vector_norm(pred - target, ord=2, dim=-1).mean() + assert torch.allclose(got, expected, atol=1e-6) + + def test_force_loss_padded_layout_matches_flat_hand_computed(self) -> None: + target = torch.zeros(2, 3, 3) + pred = torch.tensor( + [ + [ + [1.0, 0.0, 0.0], # graph 0 atom 0: |f|^2 = 1 + [0.0, 2.0, 0.0], # graph 0 atom 1: |f|^2 = 4 + [0.0, 0.0, 3.0], # graph 0 atom 2: |f|^2 = 9 + ], + [ + [1.0, 1.0, 1.0], # graph 1 atom 0: |f|^2 = 3 + [2.0, 0.0, 0.0], # graph 1 atom 1: |f|^2 = 4 + [99.0, 99.0, 99.0], # padding; must be ignored + ], + ] + ) + target[1, 2] = float("nan") + num_nodes_per_graph = torch.tensor([3, 2], dtype=torch.long) + + got_norm = ForceMSELoss(normalize_by_atom_count=True)( + pred, target, num_nodes_per_graph=num_nodes_per_graph + ) + assert torch.allclose(got_norm, torch.tensor(49.0 / 36.0), atol=1e-6) + + got_global = ForceMSELoss(normalize_by_atom_count=False)( + pred, target, num_nodes_per_graph=num_nodes_per_graph + ) + assert torch.allclose(got_global, torch.tensor(21.0 / 15.0), atol=1e-6) + + def test_force_huber_loss_padded_layout_ignores_padding(self) -> None: + target = torch.zeros(2, 3, 3) + pred = torch.ones(2, 3, 3) + pred[1, 2] = 100.0 + target[1, 2] = float("nan") + counts = torch.tensor([3, 2], dtype=torch.long) + + got = ForceHuberLoss( + normalize_by_atom_count=False, + delta=1.0, + )(pred, target, num_nodes_per_graph=counts) + + assert torch.allclose(got, torch.tensor(0.5), atol=1e-6) + + def test_force_l2_norm_loss_padded_ignores_padding_and_nonfinite_targets( + self, + ) -> None: + pred, target, num_nodes_per_graph = self._force_l2_padded_case() + target[0, 2, 0] = float("inf") + target[1, 2] = float("nan") + + got = ForceL2NormLoss()(pred, target, num_nodes_per_graph=num_nodes_per_graph) + expected = torch.stack( + ( + torch.tensor([1.0, 2.0]).mean(), + torch.tensor([3.0**0.5, 2.0]).mean(), + ) + ).mean() + assert torch.allclose(got, expected, atol=1e-6) + + got_global = ForceL2NormLoss(normalize_by_atom_count=False)( + pred, target, num_nodes_per_graph=num_nodes_per_graph + ) + expected_global = torch.tensor([1.0, 2.0, 3.0**0.5, 2.0]).mean() + assert torch.allclose(got_global, expected_global, atol=1e-6) + + @pytest.mark.parametrize( + "bad_counts", + [ + torch.tensor([3], dtype=torch.long), + torch.ones(1, 3, dtype=torch.bool), + ], + ids=["count_length", "mask_batch"], + ) + def test_force_l2_norm_loss_padded_rejects_count_batch_mismatch( + self, bad_counts: torch.Tensor + ) -> None: + pred, target, _ = self._force_l2_padded_case() + with pytest.raises(ValueError, match="must match force batch size"): + ForceL2NormLoss()(pred, target, num_nodes_per_graph=bad_counts) + + def test_force_loss_padded_layout_accepts_node_mask(self) -> None: + target = torch.zeros(2, 3, 3) + pred = torch.ones(2, 3, 3) + pred[1, 2] = 100.0 + target[1, 2] = float("nan") + node_mask = torch.tensor( + [ + [True, True, True], + [True, True, False], + ] + ) + + got = ForceMSELoss(normalize_by_atom_count=True)( + pred, target, num_nodes_per_graph=node_mask + ) + + assert torch.allclose(got, torch.tensor(1.0), atol=1e-6) + + def test_force_loss_gradient_flows(self) -> None: + pred = torch.randn(self.num_nodes, 3, requires_grad=True) + target = torch.randn(self.num_nodes, 3) + ForceMSELoss()( + pred, + target, + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + ).backward() + assert pred.grad is not None + assert pred.grad.shape == pred.shape + + def test_force_l2_norm_loss_gradient_flows(self) -> None: + pred = torch.randn(self.num_nodes, 3, requires_grad=True) + target = torch.randn(self.num_nodes, 3) + ForceL2NormLoss()( + pred, + target, + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + ).backward() + assert pred.grad is not None + assert pred.grad.shape == pred.shape + + def test_stress_loss_matches_elementwise_mse(self, fixed_torch_seed: None) -> None: + pred = torch.randn(self.num_graphs, 3, 3, requires_grad=True) + target = torch.randn(self.num_graphs, 3, 3) + got = StressMSELoss()(pred, target) + # Frobenius MSE averaged over graphs == elementwise MSE. + expected = torch.nn.functional.mse_loss(pred, target) + assert torch.allclose(got, expected, atol=1e-6) + got.backward() + assert pred.grad is not None + + def test_stress_huber_loss_matches_componentwise_huber(self) -> None: + pred = torch.tensor( + [ + [[1.0, 2.0, 0.0], [0.5, -0.5, 0.0], [0.0, 0.0, 0.0]], + [[3.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 0.0]], + ], + requires_grad=True, + ) + target = torch.zeros_like(pred) + got = StressHuberLoss(delta=1.0)(pred, target) + expected = _huber_loss(pred - target, 1.0).reshape(2, -1).mean(dim=-1).mean() + assert torch.allclose(got, expected, atol=1e-6) + got.backward() + assert pred.grad is not None + + @pytest.mark.parametrize( + ("loss_factory", "batch_kwargs", "missing_attr"), + [ + pytest.param( + lambda: EnergyMSELoss(), + {"energy": torch.zeros(3, 1)}, # predicted_energy omitted + "predicted_energy", + id="energy_missing_prediction", + ), + pytest.param( + lambda: ForceMSELoss(), + {"predicted_forces": torch.zeros(10, 3)}, # forces omitted + "forces", + id="force_missing_target", + ), + pytest.param( + lambda: StressMSELoss(), + {"stress": torch.zeros(3, 3, 3)}, # predicted_stress omitted + "predicted_stress", + id="stress_missing_prediction", + ), + ], + ) + def test_missing_mapping_key_raises_key_error( + self, + loss_factory: Any, + batch_kwargs: dict[str, torch.Tensor], + missing_attr: str, + ) -> None: + loss = loss_factory() + batch = self._batch(**batch_kwargs) + with pytest.raises(KeyError, match=missing_attr): + _call_from_batch(loss, batch) + + def test_mapping_key_resolving_to_none_raises_type_error(self) -> None: + loss = ComposedLossFunction(components=(EnergyMSELoss(),)) + predictions = {"predicted_energy": None} # type: ignore[dict-item] + targets = {"energy": torch.zeros(3, 1)} + with pytest.raises(TypeError, match="predicted_energy"): + loss(predictions, targets) + + @pytest.mark.parametrize( + ("loss_factory", "batch_kwargs", "loss_name"), + [ + pytest.param( + lambda: EnergyMSELoss(), + { + "energy": torch.zeros(3, 2), # unequal trailing dim (strict) + "predicted_energy": torch.zeros(3, 3), + }, + "EnergyMSELoss", + id="energy_trailing_mismatch", + ), + pytest.param( + lambda: ForceMSELoss(), + { + "forces": torch.zeros(10, 2), # unequal trailing dim (strict) + "predicted_forces": torch.zeros(10, 3), + }, + "ForceMSELoss", + id="force_component_mismatch", + ), + pytest.param( + lambda: StressMSELoss(), + { + "stress": torch.zeros(3, 2), # rank/shape mismatch (strict) + "predicted_stress": torch.zeros(3, 3, 3), + }, + "StressMSELoss", + id="stress_rank_mismatch", + ), + ], + ) + def test_prediction_target_shape_mismatch_raises( + self, + loss_factory: Any, + batch_kwargs: dict[str, torch.Tensor], + loss_name: str, + ) -> None: + loss = loss_factory() + batch = self._batch(**batch_kwargs) + with pytest.raises( + ValueError, + match=rf"{loss_name}: prediction and target shape must match exactly", + ): + _call_from_batch(loss, batch) + + @pytest.mark.parametrize( + ("loss_factory", "tensor_kwargs", "missing_attr"), + [ + pytest.param( + lambda: EnergyMSELoss(per_atom=True), + { + "energy": torch.zeros(3, 1), + "predicted_energy": torch.zeros(3, 1), + }, + "num_nodes_per_graph", + id="energy_per_atom_missing_num_nodes", + ), + pytest.param( + lambda: ForceMSELoss(), + { + "forces": torch.zeros(10, 3), + "predicted_forces": torch.zeros(10, 3), + }, + "batch_idx", + id="force_missing_batch_idx", + ), + pytest.param( + lambda: ForceMSELoss(), + { + "forces": torch.zeros(10, 3), + "predicted_forces": torch.zeros(10, 3), + }, + "num_graphs", + id="force_missing_num_graphs", + ), + pytest.param( + lambda: ForceMSELoss(), + { + "forces": torch.zeros(3, 5, 3), + "predicted_forces": torch.zeros(3, 5, 3), + }, + "num_nodes_per_graph", + id="force_missing_num_nodes", + ), + ], + ) + def test_missing_loss_metadata_raises_value_error( + self, + loss_factory: Any, + tensor_kwargs: dict[str, torch.Tensor], + missing_attr: str, + ) -> None: + loss = loss_factory() + batch = self._batch(**tensor_kwargs) + # _batch sets all graph metadata fields; drop the one under test + # to exercise the tensor-first metadata requirement path. + delattr(batch, missing_attr) + with pytest.raises(ValueError, match=missing_attr): + _call_from_batch(loss, batch) + + def test_composed_losses_backprop_to_all_inputs(self) -> None: + batch = _full_loss_batch() + for name in ("energy", "forces", "stress"): + setattr(batch, name, torch.randn_like(getattr(batch, name))) + for name in ("predicted_energy", "predicted_forces", "predicted_stress"): + setattr( + batch, name, torch.randn_like(getattr(batch, name)).requires_grad_() + ) + + composed = ( + EnergyMSELoss() + + ConstantWeight(value=10.0) * ForceMSELoss() + + ConstantWeight(value=0.1) * StressMSELoss() + ) + assert isinstance(composed, ComposedLossFunction) + assert len(composed.components) == 3 + out = _call_from_batch(composed, batch) + assert set(out) == { + "total_loss", + "per_component_unweighted", + "per_component_weight", + "per_component_raw_weight", + "per_component_sample", + } + assert set(out["per_component_unweighted"]) == { + "EnergyMSELoss", + "ForceMSELoss", + "StressMSELoss", + } + out["total_loss"].backward() + for grad in ( + batch.predicted_energy.grad, + batch.predicted_forces.grad, + batch.predicted_stress.grad, + ): + assert grad is not None + assert not torch.all(grad == 0) + + def test_force_loss_reads_from_configured_prediction_key(self) -> None: + target = torch.zeros(self.num_nodes, 3) + renamed_pred = torch.ones(self.num_nodes, 3) + predictions = {"my_model_forces": renamed_pred} + targets = {"forces": target} + loss = ComposedLossFunction( + components=(ForceMSELoss(prediction_key="my_model_forces"),) + ) + got = loss( + predictions, + targets, + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + ) + # |pred - target|^2 sum over 3 components = 3 per atom. + # per-graph mean = 3; mean over graphs = 3; / 3 = 1.0. + # Single-component composition normalizes effective weight to 1.0. + assert torch.allclose(got["total_loss"], torch.tensor(1.0), atol=1e-6) + assert torch.allclose( + got["per_component_unweighted"]["ForceMSELoss"], torch.tensor(1.0) + ) + + def test_force_loss_resolves_from_batch_dense(self) -> None: + pred = torch.randn(self.num_nodes, 3) + target = torch.randn(self.num_nodes, 3) + mini_batch = self._batch() + + got_batch = ForceMSELoss()(pred, target, batch=mini_batch) + got_explicit = ForceMSELoss()( + pred, + target, + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + ) + assert torch.allclose(got_batch, got_explicit, atol=1e-6) + + def test_energy_loss_per_atom_resolves_from_batch(self) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]]) + pred = torch.tensor([[6.0], [15.0], [8.0]]) + mini_batch = self._batch() + + got_batch = EnergyMSELoss(per_atom=True)(pred, target, batch=mini_batch) + got_explicit = EnergyMSELoss(per_atom=True)( + pred, target, num_nodes_per_graph=self.num_nodes_per_graph + ) + assert torch.allclose(got_batch, got_explicit, atol=1e-6) + + def test_force_loss_explicit_kwarg_overrides_batch(self) -> None: + pred = torch.randn(self.num_nodes, 3) + target = torch.randn(self.num_nodes, 3) + mini_batch = self._batch() + + # Collapse to a single graph so the override path produces a + # measurably different loss from the batch-derived grouping. + override_batch_idx = torch.zeros(self.num_nodes, dtype=torch.int32) + override_num_graphs = 1 + + got_override = ForceMSELoss()( + pred, + target, + batch=mini_batch, + batch_idx=override_batch_idx, + num_graphs=override_num_graphs, + ) + got_direct = ForceMSELoss()( + pred, + target, + batch_idx=override_batch_idx, + num_graphs=override_num_graphs, + ) + got_batch_only = ForceMSELoss()(pred, target, batch=mini_batch) + + assert torch.allclose(got_override, got_direct, atol=1e-6) + assert not torch.allclose(got_override, got_batch_only, atol=1e-6) + + def test_energy_loss_per_atom_explicit_override_wins(self) -> None: + target = torch.tensor([[3.0], [10.0], [4.0]]) + pred = torch.tensor([[6.0], [15.0], [8.0]]) + mini_batch = self._batch() + + # Flat 1-atom counts produce a different per-atom scale than the + # batch-derived [3, 5, 2] counts, making the override observable. + override_counts = torch.tensor([1, 1, 1], dtype=torch.long) + + got_override = EnergyMSELoss(per_atom=True)( + pred, target, batch=mini_batch, num_nodes_per_graph=override_counts + ) + got_direct = EnergyMSELoss(per_atom=True)( + pred, target, num_nodes_per_graph=override_counts + ) + got_batch_only = EnergyMSELoss(per_atom=True)(pred, target, batch=mini_batch) + + assert torch.allclose(got_override, got_direct, atol=1e-6) + assert not torch.allclose(got_override, got_batch_only, atol=1e-6) + + +class TestPerSampleLoss: + # ---- Shared helpers ---------------------------------------------- + + @staticmethod + def _assert_per_sample( + loss: BaseLossFunction, expected_shape: tuple[int, ...] + ) -> torch.Tensor: + ps = loss.per_sample_loss + assert ps is not None + assert ps.shape == expected_shape + assert ps.requires_grad is False + return ps + + @pytest.mark.parametrize( + ("kwargs", "extra", "expected_fn"), + [ + pytest.param( + {}, + {}, + lambda pred, target, counts: (pred - target).pow(2).squeeze(-1), + id="default", + ), + pytest.param( + {"per_atom": True}, + {"num_nodes_per_graph": torch.tensor([2, 3, 1], dtype=torch.long)}, + lambda pred, target, counts: ( + ((pred - target) / counts.to(pred).unsqueeze(-1)).pow(2).squeeze(-1) + ), + id="per_atom_residuals", + ), + ], + ) + def test_energy_loss_per_sample_populated_detached_shape_and_value( + self, + kwargs: dict[str, Any], + extra: dict[str, torch.Tensor], + expected_fn: Any, + ) -> None: + torch.manual_seed(0) + b = 3 + pred = torch.randn(b, 1, requires_grad=True) + target = torch.randn(b, 1) + counts = extra.get("num_nodes_per_graph") + loss = EnergyMSELoss(**kwargs) + scalar = loss(pred, target, **extra) + ps = self._assert_per_sample(loss, (b,)) + torch.testing.assert_close(ps, expected_fn(pred, target, counts)) + if counts is None: + # Default energy MSE is graph-balanced. + torch.testing.assert_close(ps.mean(), scalar) + else: + # ``per_atom=True`` stores per-graph squared per-atom residuals + # for diagnostics, while the scalar is atom-count weighted. + weights = counts.to(ps) + torch.testing.assert_close(ps.mul(weights).sum() / weights.sum(), scalar) + + def test_energy_loss_per_sample_ignore_nonfinite_populates(self) -> None: + """``ignore_nonfinite`` populates ``(B,)`` with zero on all-NaN rows. + + Kept as a distinct case: ``per_sample_loss.mean()`` does NOT equal + the scalar return here because the scalar divides by the global + valid-entry count while the per-sample view is per-row residual. + """ + pred = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) + target = torch.tensor([[0.0], [float("nan")], [2.5], [float("nan")]]) + loss = EnergyMSELoss(ignore_nonfinite=True) + loss(pred, target) + ps = self._assert_per_sample(loss, (4,)) + assert ps[1].item() == 0.0 + assert ps[3].item() == 0.0 + torch.testing.assert_close(ps[0], torch.tensor(1.0)) + torch.testing.assert_close(ps[2], torch.tensor(0.25)) + + @pytest.mark.parametrize( + "ignore_nonfinite", [False, True], ids=["default", "ignore_nonfinite"] + ) + def test_stress_loss_per_sample_populated_detached_shape_and_mean( + self, ignore_nonfinite: bool + ) -> None: + torch.manual_seed(0) + b = 3 + pred = torch.randn(b, 3, 3, requires_grad=True) + target = torch.randn(b, 3, 3) + loss = StressMSELoss(ignore_nonfinite=ignore_nonfinite) + scalar = loss(pred, target) + ps = self._assert_per_sample(loss, (b,)) + expected = (pred - target).pow(2).mean(dim=(-2, -1)) + torch.testing.assert_close(ps, expected) + torch.testing.assert_close(ps.mean(), scalar) + + def test_stress_loss_ignore_nonfinite_all_nan_row_is_zero(self) -> None: + torch.manual_seed(0) + pred = torch.randn(3, 3, 3) + target = torch.randn(3, 3, 3) + target[1] = float("nan") + loss = StressMSELoss(ignore_nonfinite=True) + loss(pred, target) + ps = self._assert_per_sample(loss, (3,)) + assert ps[1].item() == 0.0 + for g in (0, 2): + expected = (pred[g] - target[g]).pow(2).mean() + torch.testing.assert_close(ps[g], expected) + + @pytest.mark.parametrize( + ("normalize", "layout"), + [ + pytest.param(True, "dense", id="dense_normalize"), + pytest.param(True, "padded", id="padded_normalize"), + pytest.param(False, "padded", id="padded_no_normalize"), + ], + ) + def test_force_loss_per_sample_populated_detached_shape_and_value( + self, normalize: bool, layout: str + ) -> None: + torch.manual_seed(0) + loss = ForceMSELoss(normalize_by_atom_count=normalize) + if layout == "dense": + v = 5 + batch_idx = torch.tensor([0, 0, 1, 2, 2], dtype=torch.int32) + num_graphs = 3 + pred = torch.randn(v, 3, requires_grad=True) + target = torch.randn(v, 3) + loss(pred, target, batch_idx=batch_idx, num_graphs=num_graphs) + ps = self._assert_per_sample(loss, (num_graphs,)) + per_atom_se = (pred - target).pow(2).sum(dim=-1) + per_atom_valid = torch.ones(v, dtype=pred.dtype) * 3 + per_graph_num = per_graph_sum(per_atom_se, batch_idx, num_graphs=num_graphs) + per_graph_den = per_graph_sum( + per_atom_valid, batch_idx, num_graphs=num_graphs + ).clamp_min(1.0) + expected = per_graph_num / per_graph_den + torch.testing.assert_close(ps, expected) + return + # padded layout shared by both normalize=True and normalize=False. + b = 3 + v_max = 4 + pred = torch.randn(b, v_max, 3, requires_grad=True) + target = torch.randn(b, v_max, 3) + num_nodes_per_graph = torch.tensor([2, 1, 4], dtype=torch.long) + scalar = loss(pred, target, num_nodes_per_graph=num_nodes_per_graph) + ps = self._assert_per_sample(loss, (b,)) + node_mask = torch.arange(v_max).unsqueeze(0) < num_nodes_per_graph.unsqueeze(-1) + valid = node_mask.unsqueeze(-1).expand_as(pred).to(dtype=pred.dtype) + squared_error = ((pred - target) * valid).pow(2) + per_graph_num = squared_error.sum(dim=(-2, -1)) + per_graph_den = valid.sum(dim=(-2, -1)).clamp_min(1.0) + expected = per_graph_num / per_graph_den + torch.testing.assert_close(ps, expected) + if normalize: + # Normalized path: per-sample mean equals the scalar return. + torch.testing.assert_close(ps.mean(), scalar) + + def test_force_loss_dense_no_normalize_per_sample_is_none(self) -> None: + torch.manual_seed(0) + pred = torch.randn(5, 3) + target = torch.randn(5, 3) + loss = ForceMSELoss(normalize_by_atom_count=False) + loss(pred, target) + assert loss.per_sample_loss is None + + def test_per_sample_loss_cleared_on_each_forward_call(self) -> None: + torch.manual_seed(0) + loss = ForceMSELoss(normalize_by_atom_count=False) + padded_pred = torch.randn(3, 4, 3) + padded_target = torch.randn(3, 4, 3) + num_nodes_per_graph = torch.tensor([2, 1, 4], dtype=torch.long) + loss(padded_pred, padded_target, num_nodes_per_graph=num_nodes_per_graph) + assert loss.per_sample_loss is not None + loss(torch.randn(5, 3), torch.randn(5, 3)) + assert loss.per_sample_loss is None + + def test_per_sample_loss_cleared_on_exception(self) -> None: + torch.manual_seed(0) + loss = EnergyMSELoss() + loss(torch.randn(3, 1), torch.randn(3, 1)) + assert loss.per_sample_loss is not None + pred = torch.randn(3, 1, dtype=torch.float32) + target = torch.randn(3, 1, dtype=torch.float64) + with pytest.raises(ValueError): + loss(pred, target) + assert loss.per_sample_loss is None + + @staticmethod + def _energy_stress_inputs( + b: int, requires_grad: bool = False + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + predictions = { + "predicted_energy": torch.randn(b, 1, requires_grad=requires_grad), + "predicted_stress": torch.randn(b, 3, 3, requires_grad=requires_grad), + } + targets = { + "energy": torch.randn(b, 1), + "stress": torch.randn(b, 3, 3), + } + return predictions, targets + + def test_composed_output_has_per_component_sample_field(self) -> None: + torch.manual_seed(0) + b = 3 + composed = EnergyMSELoss() + StressMSELoss() + out = composed(*self._energy_stress_inputs(b)) + assert set(out["per_component_sample"]) == set(out["per_component_unweighted"]) + for value in out["per_component_sample"].values(): + assert value.shape == (b,) + assert value.requires_grad is False + + def test_composed_per_component_sample_is_weighted_by_effective_weight( + self, + ) -> None: + torch.manual_seed(0) + b = 3 + energy = EnergyMSELoss() + stress = StressMSELoss() + composed = ComposedLossFunction( + (energy, stress), weights=[3.0, 1.0], normalize_weights=False + ) + pred_e = torch.randn(b, 1) + tgt_e = torch.randn(b, 1) + pred_s = torch.randn(b, 3, 3) + tgt_s = torch.randn(b, 3, 3) + out = composed( + {"predicted_energy": pred_e, "predicted_stress": pred_s}, + {"energy": tgt_e, "stress": tgt_s}, + ) + assert energy.per_sample_loss is not None + expected_energy = 3.0 * energy.per_sample_loss + torch.testing.assert_close( + out["per_component_sample"]["EnergyMSELoss"], expected_energy + ) + + def test_composed_component_without_per_sample_is_absent(self) -> None: + torch.manual_seed(0) + v = 5 + b = 3 + composed = ComposedLossFunction( + (EnergyMSELoss(), ForceMSELoss(normalize_by_atom_count=False)) + ) + predictions = { + "predicted_energy": torch.randn(b, 1), + "predicted_forces": torch.randn(v, 3), + } + targets = { + "energy": torch.randn(b, 1), + "forces": torch.randn(v, 3), + } + out = composed(predictions, targets) + assert "EnergyMSELoss" in out["per_component_sample"] + assert "ForceMSELoss" not in out["per_component_sample"] + + def test_composed_per_component_sample_sum_matches_total_loss(self) -> None: + torch.manual_seed(0) + b = 4 + composed = EnergyMSELoss() + StressMSELoss() + predictions, targets = self._energy_stress_inputs(b) + out = composed(predictions, targets) + per_sample_sum = sum(out["per_component_sample"].values()) + torch.testing.assert_close(per_sample_sum.mean(), out["total_loss"]) + + @pytest.mark.parametrize( + ("bad_value", "expected_exc", "expected_msg_fragment"), + [ + (1.0, TypeError, "must be a torch.Tensor or None"), + (torch.zeros(2, 3), ValueError, "must be a 1-D tensor"), + ], + ids=["non_tensor_raises_type_error", "non_1d_tensor_raises_value_error"], + ) + def test_composed_rejects_invalid_custom_per_sample_loss( + self, + bad_value: Any, + expected_exc: type[BaseException], + expected_msg_fragment: str, + ) -> None: + class _BadPerSampleLoss(_ToyLoss): + def __init__(self, value: Any) -> None: + super().__init__() + self._value = value + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + out = super().forward(pred, target, **kwargs) + self.per_sample_loss = self._value + return out + + composed = ComposedLossFunction((_BadPerSampleLoss(bad_value),)) + with pytest.raises(expected_exc, match=expected_msg_fragment): + composed({"prediction": torch.tensor(0.0)}, {"target": torch.tensor(0.0)}) + + +class TestIgnoreNaN: + """Tests for the opt-in ``ignore_nonfinite`` masking in concrete losses. + + Targets with ``NaN`` represent missing labels and must not contribute + to loss value or gradient. Predictions are assumed finite. The + implementation uses branch-free tensor ops, so behavior is the same + as the eager path these tests assert. + """ + + def setup_method(self) -> None: + # Reuse the 3,5,2-atom layout from ``TestConcreteLosses`` so per-atom + # normalization and multi-graph masking paths are all exercised. + self.nodes_per_graph = [3, 5, 2] + self.num_graphs = 3 + self.num_nodes = sum(self.nodes_per_graph) + self.batch_idx = torch.tensor([0, 0, 0, 1, 1, 1, 1, 1, 2, 2], dtype=torch.int32) + self.num_nodes_per_graph = torch.tensor(self.nodes_per_graph, dtype=torch.long) + + def _batch(self, **extra: torch.Tensor) -> SimpleNamespace: + return SimpleNamespace( + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + num_nodes_per_graph=self.num_nodes_per_graph, + **extra, + ) + + # ---- EnergyMSELoss --------------------------------------------------- + + def test_energy_loss_default_propagates_nan(self) -> None: + target = torch.tensor([[1.0], [float("nan")], [3.0]]) + pred = torch.tensor([[1.5], [2.5], [3.5]]) + got = EnergyMSELoss()(pred, target) + assert torch.isnan(got) + + def test_energy_loss_ignore_nonfinite_masks_missing_targets(self) -> None: + target = torch.tensor([[1.0], [float("nan")], [3.0]]) + pred = torch.tensor([[1.5], [2.5], [3.5]]) + got = EnergyMSELoss(ignore_nonfinite=True)(pred, target) + # Valid entries contribute (0.5)^2 and (0.5)^2; two valid entries. + expected = torch.tensor((0.25 + 0.25) / 2.0) + assert torch.allclose(got, expected, atol=1e-6) + + def test_energy_loss_ignore_nonfinite_zero_gradient_at_nan_positions(self) -> None: + target = torch.tensor([[1.0], [float("nan")], [3.0]]) + pred = torch.tensor([[1.5], [10.0], [3.5]], requires_grad=True) + EnergyMSELoss(ignore_nonfinite=True)(pred, target).backward() + assert pred.grad is not None + # The NaN-target entry must receive exactly zero gradient. + assert pred.grad[1].item() == 0.0 + # Other entries must receive finite, non-zero gradient. + assert torch.isfinite(pred.grad).all() + assert pred.grad[0].item() != 0.0 + assert pred.grad[2].item() != 0.0 + + def test_energy_loss_ignore_nonfinite_all_nan_gives_zero(self) -> None: + target = torch.full((self.num_graphs, 1), float("nan")) + pred = torch.randn(self.num_graphs, 1, requires_grad=True) + got = EnergyMSELoss(ignore_nonfinite=True)(pred, target) + assert torch.allclose(got, torch.tensor(0.0)) + got.backward() + assert pred.grad is not None + assert torch.all(pred.grad == 0.0) + + def test_energy_loss_ignore_nonfinite_per_atom_weights_valid_counts(self) -> None: + # Per-atom normalization must be applied before masking so the + # valid-entry MSE is computed on per-atom values, not raw energies. + target = torch.tensor([[3.0], [float("nan")], [4.0]]) # per-atom: 1, -, 2 + pred = torch.tensor([[6.0], [15.0], [8.0]]) # per-atom: 2, 3, 4 + got = EnergyMSELoss(per_atom=True, ignore_nonfinite=True)( + pred, target, num_nodes_per_graph=self.num_nodes_per_graph + ) + # Valid per-atom diffs: (2-1)=1 and (4-2)=2. Only valid graph + # counts enter the denominator: (3*1 + 2*4) / (3 + 2). + expected = torch.tensor(11.0 / 5.0) + assert torch.allclose(got, expected, atol=1e-6) + + def test_energy_loss_ignore_nonfinite_off_matches_baseline(self) -> None: + target = torch.randn(self.num_graphs, 1) + pred = torch.randn(self.num_graphs, 1) + baseline = EnergyMSELoss()(pred, target) + opt_in = EnergyMSELoss(ignore_nonfinite=True)(pred, target) + assert torch.allclose(baseline, opt_in, atol=1e-6) + + # ---- ForceMSELoss ---------------------------------------------------- + + def test_force_loss_default_propagates_nan(self) -> None: + target = torch.zeros(self.num_nodes, 3) + target[4, 1] = float("nan") + pred = torch.ones(self.num_nodes, 3) + assert torch.isnan( + ForceMSELoss(normalize_by_atom_count=True)( + pred, + target, + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + ) + ) + assert torch.isnan(ForceMSELoss(normalize_by_atom_count=False)(pred, target)) + + def test_force_loss_ignore_nonfinite_global_masks_missing_components(self) -> None: + target = torch.zeros(self.num_nodes, 3) + target[4, 1] = float("nan") # one component missing + pred = torch.ones(self.num_nodes, 3) + got = ForceMSELoss(normalize_by_atom_count=False, ignore_nonfinite=True)( + pred, target + ) + # V*3 - 1 = 29 valid entries, each contributing (1 - 0)^2 = 1. + expected = torch.tensor(29.0 / 29.0) + assert torch.allclose(got, expected, atol=1e-6) + + def test_force_loss_ignore_nonfinite_per_graph_all_nan_graph_zero_contribution( + self, + ) -> None: + # Graph 1 (atoms 3..7) has fully-NaN force labels; graphs 0 and 2 + # are fully labeled. The all-NaN graph must contribute zero to the + # mean over graphs. + target = torch.zeros(self.num_nodes, 3) + target[3:8] = float("nan") + pred = torch.ones(self.num_nodes, 3) + got = ForceMSELoss(normalize_by_atom_count=True, ignore_nonfinite=True)( + pred, + target, + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + ) + # Graph 0: 3 atoms * 3 components all valid, each (1-0)^2 = 1, + # per-graph loss = 9/9 = 1. Graph 2: 2 atoms * 3 components all + # valid, per-graph loss = 6/6 = 1. Graph 1: all NaN, loss = 0. + # Mean over 3 graphs = (1 + 0 + 1) / 3. + expected = torch.tensor(2.0 / 3.0) + assert torch.allclose(got, expected, atol=1e-6) + + def test_force_loss_ignore_nonfinite_per_graph_partial_mask(self) -> None: + # Single component missing on one atom; check the per-graph + # denominator reflects 3*n_atoms - missing, not 3*n_atoms. + target = torch.zeros(self.num_nodes, 3) + target[0, 0] = float("nan") # graph 0, atom 0, x component + pred = torch.ones(self.num_nodes, 3) + got = ForceMSELoss(normalize_by_atom_count=True, ignore_nonfinite=True)( + pred, + target, + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + ) + # Graph 0: 8 valid components (out of 9), each contributes 1; loss = 8/8 = 1. + # Graph 1: 15/15 = 1. Graph 2: 6/6 = 1. Mean = 1.0. + expected = torch.tensor(1.0) + assert torch.allclose(got, expected, atol=1e-6) + + def test_force_loss_ignore_nonfinite_zero_gradient_at_nan_positions(self) -> None: + target = torch.zeros(self.num_nodes, 3) + target[0, 0] = float("nan") + pred = torch.randn(self.num_nodes, 3, requires_grad=True) + ForceMSELoss(normalize_by_atom_count=True, ignore_nonfinite=True)( + pred, + target, + batch_idx=self.batch_idx, + num_graphs=self.num_graphs, + ).backward() + assert pred.grad is not None + assert pred.grad[0, 0].item() == 0.0 + # At least one other component receives non-zero gradient. + assert torch.isfinite(pred.grad).all() + assert (pred.grad != 0.0).any() + + def test_force_loss_ignore_nonfinite_off_matches_baseline( + self, fixed_torch_seed: None + ) -> None: + target = torch.randn(self.num_nodes, 3) + pred = torch.randn(self.num_nodes, 3) + for norm in (True, False): + metadata = ( + {"batch_idx": self.batch_idx, "num_graphs": self.num_graphs} + if norm + else {} + ) + baseline = ForceMSELoss(normalize_by_atom_count=norm)( + pred, target, **metadata + ) + opt_in = ForceMSELoss(normalize_by_atom_count=norm, ignore_nonfinite=True)( + pred, target, **metadata + ) + assert torch.allclose(baseline, opt_in, atol=1e-6) + + # ---- StressMSELoss --------------------------------------------------- + + def test_stress_loss_default_propagates_nan(self) -> None: + target = torch.zeros(self.num_graphs, 3, 3) + target[1, 2, 2] = float("nan") + pred = torch.ones(self.num_graphs, 3, 3) + assert torch.isnan(StressMSELoss()(pred, target)) + + def test_stress_loss_ignore_nonfinite_all_nan_graph_zero_contribution(self) -> None: + target = torch.zeros(self.num_graphs, 3, 3) + target[1] = float("nan") # full graph 1 unlabeled + pred = torch.ones(self.num_graphs, 3, 3) + got = StressMSELoss(ignore_nonfinite=True)(pred, target) + # Graph 0: 9 valid entries each (1-0)^2 = 1, per-graph loss = 9/9 = 1. + # Graph 1: all NaN, loss = 0. Graph 2: loss = 1. + # Mean = (1 + 0 + 1) / 3. + expected = torch.tensor(2.0 / 3.0) + assert torch.allclose(got, expected, atol=1e-6) + + def test_stress_loss_ignore_nonfinite_partial_mask(self) -> None: + target = torch.zeros(self.num_graphs, 3, 3) + target[0, 0, 0] = float("nan") # one entry missing in graph 0 + pred = torch.ones(self.num_graphs, 3, 3) + got = StressMSELoss(ignore_nonfinite=True)(pred, target) + # Graph 0: 8 valid entries of (1-0)^2 = 1 -> loss = 8/8 = 1. + # Graphs 1, 2: loss = 1. Mean = 1.0. + expected = torch.tensor(1.0) + assert torch.allclose(got, expected, atol=1e-6) + + def test_stress_loss_ignore_nonfinite_zero_gradient_at_nan_positions(self) -> None: + target = torch.zeros(self.num_graphs, 3, 3) + target[0, 0, 0] = float("nan") + pred = torch.randn(self.num_graphs, 3, 3, requires_grad=True) + StressMSELoss(ignore_nonfinite=True)(pred, target).backward() + assert pred.grad is not None + assert pred.grad[0, 0, 0].item() == 0.0 + assert torch.isfinite(pred.grad).all() + + def test_stress_loss_ignore_nonfinite_off_matches_baseline( + self, fixed_torch_seed: None + ) -> None: + target = torch.randn(self.num_graphs, 3, 3) + pred = torch.randn(self.num_graphs, 3, 3) + baseline = StressMSELoss()(pred, target) + opt_in = StressMSELoss(ignore_nonfinite=True)(pred, target) + assert torch.allclose(baseline, opt_in, atol=1e-6) + + # ---- Repr --------------------------------------------------------- + + def test_ignore_nonfinite_appears_in_extra_repr(self) -> None: + for loss in ( + EnergyMSELoss(ignore_nonfinite=True), + ForceMSELoss(ignore_nonfinite=True), + StressMSELoss(ignore_nonfinite=True), + ): + assert "ignore_nonfinite=True" in repr(loss) + + +class TestLossModelSpec: + """Tests for :func:`create_model_spec` round-trip on concrete losses. + + Since leaf losses no longer carry a ``weight`` kwarg (weighting lives + on :class:`ComposedLossFunction`), these tests exercise the generic + spec workflow for plain concrete-loss kwargs only: + ``create_model_spec(cls, **kwargs)`` → ``model_dump_json`` → + ``json.loads`` → :func:`create_model_spec_from_json` → + ``spec.build()``. The rebuilt instance must preserve ``__init__`` + kwargs and stay functionally equivalent on tensor inputs. + """ + + def _roundtrip(self, spec: Any) -> Any: + dumped = json.loads(spec.model_dump_json()) + return create_model_spec_from_json(dumped) + + @pytest.mark.parametrize( + ("cls", "kwargs"), + [ + pytest.param(EnergyMSELoss, {}, id="energy_defaults"), + pytest.param( + EnergyMSELoss, + {"per_atom": True, "ignore_nonfinite": True}, + id="energy_per_atom_ignore_nonfinite", + ), + pytest.param( + EnergyMSELoss, + {"target_key": "u_ref", "prediction_key": "u_hat"}, + id="energy_renamed_keys", + ), + pytest.param( + EnergyHuberLoss, + {"per_atom": True, "delta": 0.5, "ignore_nonfinite": True}, + id="energy_huber", + ), + pytest.param(ForceMSELoss, {}, id="force_defaults"), + pytest.param( + ForceMSELoss, + {"normalize_by_atom_count": False, "ignore_nonfinite": True}, + id="force_global_ignore_nonfinite", + ), + pytest.param( + ForceHuberLoss, + { + "normalize_by_atom_count": False, + "delta": 0.5, + "ignore_nonfinite": True, + }, + id="force_huber", + ), + pytest.param(StressMSELoss, {}, id="stress_defaults"), + pytest.param( + StressMSELoss, {"ignore_nonfinite": True}, id="stress_ignore_nonfinite" + ), + pytest.param( + StressHuberLoss, + {"delta": 0.5, "ignore_nonfinite": True}, + id="stress_huber", + ), + ], + ) + def test_loss_basespec_roundtrip( + self, cls: type[BaseLossFunction], kwargs: dict[str, Any] + ) -> None: + """JSON round-trip rebuilds a loss with matching kwargs.""" + spec = create_model_spec(cls, **kwargs) + rebuilt = self._roundtrip(spec) + built = rebuilt.build() + assert isinstance(built, cls) + for k, v in kwargs.items(): + assert getattr(built, k) == v + # Leaves no longer carry a ``weight`` attribute — weighting lives + # on :class:`ComposedLossFunction`. + assert not hasattr(built, "weight") + + def test_loss_spec_preserves_timestamp(self) -> None: + """Rehydrated spec keeps the original timestamp byte-for-byte.""" + spec = create_model_spec(EnergyMSELoss, per_atom=True) + rebuilt = self._roundtrip(spec) + assert rebuilt.timestamp == spec.timestamp + + def test_rebuilt_loss_is_functionally_equivalent(self) -> None: + """A round-tripped loss produces the same value as the original.""" + pred = torch.randn(3, 1) + target = torch.randn(3, 1) + original = EnergyMSELoss(ignore_nonfinite=True) + spec = create_model_spec(EnergyMSELoss, ignore_nonfinite=True) + rebuilt = self._roundtrip(spec).build() + + assert torch.allclose(original(pred, target), rebuilt(pred, target), atol=1e-6) + + def test_loss_component_to_spec_roundtrip(self) -> None: + """Public loss component spec helper round-trips leaf loss config.""" + spec = loss_component_to_spec( + EnergyMSELoss(per_atom=True, ignore_nonfinite=True) + ) + rebuilt = self._roundtrip(spec).build() + assert isinstance(rebuilt, EnergyMSELoss) + assert rebuilt.per_atom is True + assert rebuilt.ignore_nonfinite is True + + def test_huber_loss_component_to_spec_roundtrip(self) -> None: + """Public loss component spec helper round-trips Huber loss config.""" + spec = loss_component_to_spec(ForceHuberLoss(delta=0.5)) + rebuilt = self._roundtrip(spec).build() + assert isinstance(rebuilt, ForceHuberLoss) + assert rebuilt.delta == 0.5 + + def test_loss_component_to_spec_rejects_composed_loss(self) -> None: + """Public loss component spec helper rejects non-leaf compositions.""" + with pytest.raises( + TypeError, + match="use ComposedLossFunction spec serialization for composed losses", + ): + loss_component_to_spec(ComposedLossFunction([EnergyMSELoss()])) + + def test_loss_component_to_spec_rejects_non_loss(self) -> None: + """Public loss component spec helper rejects non-loss objects clearly.""" + with pytest.raises( + TypeError, + match="loss_component_to_spec accepts only leaf BaseLossFunction objects", + ): + loss_component_to_spec(object()) + + +class TestShapeValidationOptIn: + def test_bare_subclass_does_not_shape_check(self) -> None: + loss = _ToyLoss(value=1.0) + pred = torch.randn(3, 1) + target = torch.randn(4, 5, 7) # deliberately mismatched + got = loss(pred, target) + assert torch.allclose(got, torch.tensor(1.0)) + + def test_energy_loss_raises_on_shape_mismatch(self) -> None: + loss = EnergyMSELoss() + pred = torch.zeros(3, 2) + target = torch.zeros(3, 3) # unequal trailing dim (strict) + with pytest.raises( + ValueError, + match="EnergyMSELoss: prediction and target shape must match exactly", + ): + loss(pred, target) + + def test_assert_same_shape_public_helper(self) -> None: + pred = torch.zeros(3, 2) + target = torch.zeros(3, 2) + assert_same_shape( + pred, + target, + name="MyLoss", + prediction_key="p", + target_key="t", + ) + with pytest.raises( + ValueError, + match=r"MyLoss: prediction and target shape mismatch; " + r"prediction_key='p' has shape \(3, 2\), " + r"target_key='t' has shape \(3, 3\)", + ): + assert_same_shape( + pred, + torch.zeros(3, 3), + name="MyLoss", + prediction_key="p", + target_key="t", + ) + + def test_assert_same_shape_omits_key_fragments_when_none(self) -> None: + with pytest.raises( + ValueError, + match=r"MyLoss: prediction and target shape mismatch; " + r"prediction has shape \(3, 2\), target has shape \(3, 3\)", + ): + assert_same_shape( + torch.zeros(3, 2), + torch.zeros(3, 3), + name="MyLoss", + ) + + def test_assert_same_shape_accepts_broadcastable(self) -> None: + # (B, 1) vs (B, 3) is broadcast-compatible; must not raise. + assert_same_shape( + torch.zeros(4, 1), + torch.zeros(4, 3), + name="MyLoss", + prediction_key="p", + target_key="t", + ) + + def test_assert_same_shape_rejects_dtype_mismatch(self) -> None: + pred = torch.zeros(3, 2, dtype=torch.float32) + target = torch.zeros(3, 2, dtype=torch.float64) + with pytest.raises( + ValueError, + match=r"MyLoss: prediction and target dtype mismatch; " + r"prediction_key='p' has dtype torch\.float32, " + r"target_key='t' has dtype torch\.float64", + ): + assert_same_shape( + pred, + target, + name="MyLoss", + prediction_key="p", + target_key="t", + ) + + def test_assert_same_shape_dtype_check_runs_before_shape_check(self) -> None: + # Both dtype AND shape mismatch — must surface dtype error, not shape. + pred = torch.zeros(3, 2, dtype=torch.float32) + target = torch.zeros(3, 3, dtype=torch.float64) + with pytest.raises(ValueError, match="dtype mismatch"): + assert_same_shape( + pred, + target, + name="MyLoss", + ) + + +class TestLeafShapeEqualityGuard: + # The module-level ``assert_same_shape`` helper has two shape + # policies: its default broadcast-compatible mode accepts shapes + # like ``(B, 1)`` vs ``(B, 3)``, and ``strict=True`` requires exact + # equality. All built-in leaf losses opt into ``strict=True`` so + # the broadcast trap cannot silently corrupt their elementwise + # arithmetic. These tests lock that in. + + def test_energy_loss_rejects_broadcast_trap(self) -> None: + # ``(B, 1)`` vs ``(B, 3)`` is broadcast-compatible but broadcasts + # into a ``(B, 3)`` residual — silently triples the loss. + loss = EnergyMSELoss() + pred = torch.zeros(4, 1) + target = torch.zeros(4, 3) + with pytest.raises( + ValueError, + match=( + r"EnergyMSELoss: prediction and target shape must match exactly " + r"for elementwise loss; prediction_key='predicted_energy' has " + r"shape \(4, 1\), target_key='energy' has shape \(4, 3\)" + ), + ): + loss(pred, target) + + def test_energy_loss_rejects_squeezed_vs_unsqueezed(self) -> None: + # ``(B, 1)`` vs ``(B,)`` broadcasts to a ``(B, B)`` outer product. + loss = EnergyMSELoss() + pred = torch.zeros(4, 1) + target = torch.zeros(4) + with pytest.raises(ValueError, match="shape must match exactly"): + loss(pred, target) + + def test_energy_loss_happy_path(self) -> None: + # Regression: canonical ``(B, 1)`` vs ``(B, 1)`` still works. + loss = EnergyMSELoss() + pred = torch.tensor([[1.0], [2.0], [3.0]]) + target = torch.tensor([[1.5], [2.5], [3.5]]) + scalar = loss(pred, target) + # Each residual squared is 0.25; mean is 0.25. + assert torch.allclose(scalar, torch.tensor(0.25)) + + def test_force_loss_dense_rejects_component_broadcast(self) -> None: + # ``(V, 1)`` vs ``(V, 3)`` is broadcast-compatible but semantically + # wrong — a per-atom scalar compared against a 3-component force. + loss = ForceMSELoss(normalize_by_atom_count=False) + pred = torch.zeros(5, 1) + target = torch.zeros(5, 3) + with pytest.raises( + ValueError, + match=( + r"ForceMSELoss: prediction and target shape must match exactly " + r"for elementwise loss; prediction_key='predicted_forces' has " + r"shape \(5, 1\), target_key='forces' has shape \(5, 3\)" + ), + ): + loss(pred, target) + + def test_force_loss_padded_rejects_component_broadcast(self) -> None: + # Padded layout ``(B, V_max, 1)`` vs ``(B, V_max, 3)``. + loss = ForceMSELoss(normalize_by_atom_count=False) + pred = torch.zeros(2, 4, 1) + target = torch.zeros(2, 4, 3) + with pytest.raises(ValueError, match="shape must match exactly"): + loss( + pred, + target, + num_nodes_per_graph=torch.tensor([4, 4]), + ) + + def test_force_loss_dense_happy_path(self) -> None: + # Regression: canonical dense ``(V, 3)`` vs ``(V, 3)`` still works. + loss = ForceMSELoss(normalize_by_atom_count=False) + pred = torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]) + target = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + scalar = loss(pred, target) + # Sum of squares 1 + 4 = 5 over 6 components → 5/6. + assert torch.allclose(scalar, torch.tensor(5.0 / 6.0)) + + def test_stress_loss_rejects_component_broadcast(self) -> None: + # ``(B, 1, 3)`` vs ``(B, 3, 3)`` is broadcast-compatible. + loss = StressMSELoss() + pred = torch.zeros(2, 1, 3) + target = torch.zeros(2, 3, 3) + with pytest.raises( + ValueError, + match=( + r"StressMSELoss: prediction and target shape must match exactly " + r"for elementwise loss; prediction_key='predicted_stress' has " + r"shape \(2, 1, 3\), target_key='stress' has shape \(2, 3, 3\)" + ), + ): + loss(pred, target) + + def test_stress_loss_happy_path(self) -> None: + # Regression: canonical ``(B, 3, 3)`` vs ``(B, 3, 3)`` still works. + loss = StressMSELoss() + pred = torch.zeros(2, 3, 3) + target = torch.zeros(2, 3, 3) + scalar = loss(pred, target) + assert torch.allclose(scalar, torch.tensor(0.0)) + + def test_assert_same_shape_strict_rejects_broadcast_compatible(self) -> None: + # Direct test of the public helper's strict policy: shapes that + # pass the default broadcast policy must be rejected. + with pytest.raises( + ValueError, + match=( + r"MyLoss: prediction and target shape must match exactly " + r"for elementwise loss; prediction_key='p' has shape " + r"\(4, 1\), target_key='t' has shape \(4, 3\)" + ), + ): + assert_same_shape( + torch.zeros(4, 1), + torch.zeros(4, 3), + name="MyLoss", + prediction_key="p", + target_key="t", + strict=True, + ) + + def test_assert_same_shape_strict_accepts_equal(self) -> None: + # Strict policy must accept exact-equal shapes. + assert_same_shape( + torch.zeros(4, 3), + torch.zeros(4, 3), + name="MyLoss", + strict=True, + ) + + def test_assert_same_shape_strict_still_checks_dtype_first(self) -> None: + # Strict policy shares the dtype-before-shape ordering. + with pytest.raises(ValueError, match="dtype mismatch"): + assert_same_shape( + torch.zeros(4, 1, dtype=torch.float32), + torch.zeros(4, 3, dtype=torch.float64), + name="MyLoss", + strict=True, + ) diff --git a/test/training/test_mixed_precision.py b/test/training/test_mixed_precision.py new file mode 100644 index 00000000..6b7c44b9 --- /dev/null +++ b/test/training/test_mixed_precision.py @@ -0,0 +1,675 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for :class:`nvalchemi.training.hooks.MixedPrecisionHook`.""" + +from __future__ import annotations + +from collections.abc import Callable +from enum import Enum +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +import pytest +import torch +from pydantic import ValidationError + +from nvalchemi.data import Batch +from nvalchemi.hooks._context import HookContext, TrainContext +from nvalchemi.hooks._protocol import Hook +from nvalchemi.models.base import BaseModelMixin +from nvalchemi.training import EnergyMSELoss, ForceMSELoss +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training.hooks import ( + MixedPrecisionHook, + TrainingUpdateHook, + TrainingUpdateOrchestrator, +) +from nvalchemi.training.hooks.mixed_precision import MixedPrecisionHook as _MP +from nvalchemi.training.optimizers import OptimizerConfig +from nvalchemi.training.strategy import TrainingStrategy, default_training_fn +from test.training.conftest import _build_demo_model + +# --------------------------------------------------------------------------- +# Shared fixtures / helpers +# --------------------------------------------------------------------------- + +ALL_PRECISIONS: list[torch.dtype] = [torch.float32, torch.bfloat16, torch.float16] + + +def _cast_back_training_fn( + model: BaseModelMixin, batch: Batch +) -> dict[str, torch.Tensor]: + """Forward the model and cast predictions back to fp32. + + Autocast casts eligible ops to lower precision inside its region, but + the project's loss terms enforce dtype parity with the ``batch.*`` + targets (which remain fp32). Casting predictions back restores that + parity while still exercising the autocast-covered forward pass. + """ + preds = default_training_fn(model, batch) + return {k: v.to(torch.float32) for k, v in preds.items()} + + +@pytest.fixture(params=ALL_PRECISIONS, ids=lambda p: str(p).replace("torch.", "")) +def precision(request: pytest.FixtureRequest) -> torch.dtype: + """Parametrize over the three supported AMP precisions.""" + return request.param + + +@pytest.fixture +def strategy_factory( + baseline_strategy_kwargs: dict[str, Any], +) -> Callable[..., TrainingStrategy]: + """Return a factory that builds a strategy with the cast-back training_fn. + + The factory is kept because eight tests call it with varying ``hooks`` + and ``devices`` kwargs; eliminating it would repeat the same merge + pattern at each site. + """ + + def _factory(**overrides: Any) -> TrainingStrategy: + kwargs = { + **baseline_strategy_kwargs, + "training_fn": _cast_back_training_fn, + **overrides, + } + return TrainingStrategy(**kwargs) + + return _factory + + +@pytest.fixture +def mocked_scaler() -> Any: + """Patch ``torch.amp.GradScaler`` and yield the mock scaler instance. + + The scaler reports a healthy step (no inf, constant scale) and returns + a ``MagicMock`` from ``scale()`` so ``backward`` on the scaled loss is + observable by the tests. + """ + with patch("torch.amp.GradScaler", autospec=True) as scaler_cls: + scaler = scaler_cls.return_value + scaler.get_scale.return_value = 65536.0 + scaler._found_inf_per_device.return_value = { + torch.device("cpu"): torch.tensor(0.0) + } + scaler.scale.return_value = MagicMock(name="scaled_loss") + yield scaler + + +class _ObserverHook: + """Observer hook that forwards ``(ctx, stage)`` to ``callback`` at ``stage``. + + Attributes + ---------- + stage : TrainingStage + The stage on which the registry should dispatch this hook. + frequency : int + Fixed to ``1``. + """ + + def __init__( + self, + stage: TrainingStage, + callback: Callable[[HookContext, Enum], None], + ) -> None: + self.stage = stage + self.frequency = 1 + self._callback = callback + + def __call__(self, ctx: HookContext, stage: Enum) -> None: + self._callback(ctx, stage) + + +class _ClaimsStagesHook: + """Observer hook that opts into one or more stages via ``_runs_on_stage``. + + Used by exclusivity tests to collide with :class:`MixedPrecisionHook` + on the ``DO_BACKWARD`` stage. + + Attributes + ---------- + stage : None + Explicitly ``None`` — stage selection is delegated to ``_runs_on_stage``. + frequency : int + Fixed to ``1``. + """ + + def __init__(self, claimed: set[TrainingStage]) -> None: + self._claimed = set(claimed) + self.stage = None + self.frequency = 1 + + def _runs_on_stage(self, stage: Enum) -> bool: + return stage in self._claimed + + def __call__(self, ctx: HookContext, stage: Enum) -> None: # noqa: ARG002 + pass + + +class _OptimizerStepVetoHook(TrainingUpdateHook): + """Update hook that vetoes optimizer stepping before AMP unscales grads.""" + + priority = 10 + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, # noqa: ARG002 + ) -> tuple[bool, torch.Tensor]: + return stage is not TrainingStage.DO_OPTIMIZER_STEP, ctx.loss + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + """Constructor validation and Hook-Protocol attribute defaults.""" + + def test_invalid_precision_rejected(self) -> None: + with pytest.raises(ValidationError): + MixedPrecisionHook(precision="fp8") # type: ignore[arg-type] + + @pytest.mark.parametrize( + "bad", [torch.float64, "float64"], ids=["float64_dtype", "float64_str"] + ) + def test_unsupported_dtype_rejected(self, bad: Any) -> None: + with pytest.raises(ValidationError): + MixedPrecisionHook(precision=bad) + + def test_precision_accepts_dtype_object(self) -> None: + assert MixedPrecisionHook(precision=torch.float16).precision == torch.float16 + + def test_precision_accepts_canonical_string(self) -> None: + assert MixedPrecisionHook(precision="bfloat16").precision == torch.bfloat16 + + @pytest.mark.parametrize( + ("alias", "expected"), + [ + ("fp32", torch.float32), + ("bf16", torch.bfloat16), + ("fp16", torch.float16), + ], + ) + def test_precision_accepts_common_aliases( + self, alias: str, expected: torch.dtype + ) -> None: + assert MixedPrecisionHook(precision=alias).precision == expected + + def test_unknown_config_key_rejected(self) -> None: + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + MixedPrecisionHook(precision="float16", precison="float32") # type: ignore[call-arg] + + def test_invalid_precision_message_lists_supported_values(self) -> None: + with pytest.raises( + ValidationError, + match="MixedPrecisionHook.precision must be one of " + r"\(float32, bfloat16, float16\)", + ): + MixedPrecisionHook(precision="fp64") # type: ignore[arg-type] + + def test_is_training_update_hook(self, precision: torch.dtype) -> None: + hook = MixedPrecisionHook(precision=precision) + assert isinstance(hook, TrainingUpdateHook) + assert not isinstance(hook, Hook) + assert hook.priority == 20 + + def test_class_identity_across_module_paths(self) -> None: + # Both import paths must resolve to the same class. + assert MixedPrecisionHook is _MP + + def test_strategy_autowraps_into_update_orchestrator( + self, strategy_factory: Callable[..., TrainingStrategy] + ) -> None: + hook = MixedPrecisionHook(precision=torch.float32) + strategy = strategy_factory(hooks=[hook]) + assert len(strategy.hooks) == 1 + assert isinstance(strategy.hooks[0], TrainingUpdateOrchestrator) + assert strategy.hooks[0]._hooks == [hook] + + +# --------------------------------------------------------------------------- +# Stage dispatch +# --------------------------------------------------------------------------- + + +class TestStageDispatch: + """Update-hook fallback keeps unhandled stages a silent no-op.""" + + def test_unclaimed_stage_is_noop(self) -> None: + hook = MixedPrecisionHook(precision=torch.float32) + ctx = Mock(spec=TrainContext) + ctx.loss = Mock(name="loss") + proceed, loss = hook(ctx, TrainingStage.AFTER_BATCH, will_skip=False) + assert proceed is True + assert loss is ctx.loss + assert hook._scaler is None + assert hook._autocast_ctx is None + assert hook._active is False + + +# --------------------------------------------------------------------------- +# Core training (precision × device) +# --------------------------------------------------------------------------- + + +class TestCoreTraining: + """One-step training with the hook enabled under every precision / device. + + Covers autocast state visibility at ``BEFORE_FORWARD``, clean completion + on CPU, and the absence of ``MixedPrecisionHook``-originated warnings. + """ + + def test_one_step_completes_cleanly( + self, + precision: torch.dtype, + device: str, + strategy_factory: Callable[..., TrainingStrategy], + batch: Batch, + recwarn: pytest.WarningsRecorder, + ) -> None: + mp = MixedPrecisionHook(precision=precision) + strategy = strategy_factory(hooks=[mp], devices=[torch.device(device)]) + strategy.run([batch]) + assert strategy.step_count == 1 + assert all("MixedPrecisionHook" not in str(w.message) for w in recwarn.list), [ + str(w.message) for w in recwarn.list + ] + + def test_autocast_state_during_forward( + self, + precision: torch.dtype, + device: str, + strategy_factory: Callable[..., TrainingStrategy], + batch: Batch, + ) -> None: + records: dict[str, Any] = {} + + def _observe(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + records["enabled"] = torch.is_autocast_enabled(device) + records["dtype"] = torch.get_autocast_dtype(device) + + mp = MixedPrecisionHook(precision=precision) + observer = _ObserverHook(TrainingStage.BEFORE_FORWARD, _observe) + strategy = strategy_factory( + hooks=[mp, observer], devices=[torch.device(device)] + ) + strategy.run([batch]) + # fp32 bypasses autocast entirely; low-precision modes enable autocast + # with the matching dtype. + expected_enabled = precision != torch.float32 + assert records["enabled"] is expected_enabled + if expected_enabled: + assert records["dtype"] == precision + + def test_autocast_disabled_after_backward( + self, + precision: torch.dtype, + device: str, + strategy_factory: Callable[..., TrainingStrategy], + batch: Batch, + ) -> None: + records: dict[str, bool] = {} + + def _observe(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + records["enabled"] = torch.is_autocast_enabled(device) + + mp = MixedPrecisionHook(precision=precision) + observer = _ObserverHook(TrainingStage.AFTER_BACKWARD, _observe) + strategy = strategy_factory( + hooks=[mp, observer], devices=[torch.device(device)] + ) + strategy.run([batch]) + assert records["enabled"] is False + + def test_fp32_precision_does_not_create_amp_state( + self, + device: str, + strategy_factory: Callable[..., TrainingStrategy], + batch: Batch, + ) -> None: + records: dict[str, Any] = {} + + mp = MixedPrecisionHook(precision=torch.float32) + + def _observe(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + records["scaler"] = mp._scaler + records["autocast_ctx"] = mp._autocast_ctx + records["active"] = mp._active + + observer = _ObserverHook(TrainingStage.BEFORE_FORWARD, _observe) + strategy = strategy_factory( + hooks=[mp, observer], devices=[torch.device(device)] + ) + strategy.run([batch]) + + assert records == {"scaler": None, "autocast_ctx": None, "active": False} + + +# --------------------------------------------------------------------------- +# GradScaler behavior (mocked): call order + multi-optimizer + scheduler gating +# --------------------------------------------------------------------------- + + +class TestGradScalerBehavior: + """fp16 drives ``GradScaler`` in canonical order and gates schedulers (reqs 10, 25-27).""" + + def test_scaler_call_order( + self, + mocked_scaler: Any, + strategy_factory: Callable[..., TrainingStrategy], + batch: Batch, + ) -> None: + scaled_loss = mocked_scaler.scale.return_value + mp = MixedPrecisionHook(precision=torch.float16) + strategy = strategy_factory(hooks=[mp]) + strategy.run([batch]) + + names = [name for name, _, _ in mocked_scaler.method_calls] + assert scaled_loss.backward.called + assert names.index("scale") < names.index("unscale_") + assert names.index("unscale_") < names.index("step") + assert names.index("step") < names.index("update") + assert names.count("scale") == 1 + assert names.count("unscale_") == 1 + assert names.count("step") == 1 + assert names.count("update") == 1 + + def test_multi_optimizer_unscale_and_step( + self, mocked_scaler: Any, batch: Batch + ) -> None: + model = _build_demo_model() + params = list(model.parameters()) + half = len(params) // 2 + group_a, group_b = params[:half], params[half:] + mp_hook = MixedPrecisionHook(precision=torch.float16) + strategy = TrainingStrategy( + models=model, + optimizer_configs=[ + OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3, "foreach": False}, + ), + ], + num_epochs=1, + training_fn=_cast_back_training_fn, + loss_fn=EnergyMSELoss() + ForceMSELoss(normalize_by_atom_count=True), + hooks=[mp_hook], + ) + # Replace the built optimizer list with two optimizers over disjoint + # params — more direct than threading multiple configs/models. + opt_a = torch.optim.Adam(group_a, lr=1e-3) + opt_b = torch.optim.Adam(group_b, lr=1e-3) + with strategy: + strategy._train_batch_with_optimizers(batch, [opt_a, opt_b], [None, None]) + + names = [name for name, _, _ in mocked_scaler.method_calls] + assert names.count("unscale_") == 2 + assert names.count("step") == 2 + assert names.count("update") == 1 + first_step_idx = names.index("step") + last_unscale_idx = max(i for i, n in enumerate(names) if n == "unscale_") + assert last_unscale_idx < first_step_idx + + def test_vetoed_optimizer_step_does_not_unscale_or_update_scaler( + self, + mocked_scaler: Any, + batch: Batch, + ) -> None: + scaled_loss = mocked_scaler.scale.return_value + param = torch.nn.Parameter(torch.ones(())) + opt = torch.optim.SGD([param], lr=1.0) + workflow = Mock() + workflow.devices = [torch.device("cpu")] + ctx = TrainContext( + batch=batch, + workflow=workflow, + loss=param.square(), + optimizers=[opt], + lr_schedulers=[], + ) + mp = MixedPrecisionHook(precision=torch.float16) + orch = TrainingUpdateOrchestrator(_OptimizerStepVetoHook(), mp) + with orch: + orch(ctx, TrainingStage.DO_BACKWARD) + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + + assert scaled_loss.backward.called + mocked_scaler.unscale_.assert_not_called() + mocked_scaler.step.assert_not_called() + mocked_scaler.update.assert_not_called() + + def test_grad_scaler_no_scheduler_fast_path_skips_found_inf_query( + self, mocked_scaler: Any + ) -> None: + param = torch.nn.Parameter(torch.ones(())) + opt = torch.optim.SGD([param], lr=1.0) + ctx = TrainContext( + batch=Mock(spec=Batch), + optimizers=[opt], + lr_schedulers=[], + grad_scaler=mocked_scaler, + ) + orch = TrainingUpdateOrchestrator() + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + + mocked_scaler.step.assert_called_once_with(opt) + mocked_scaler.update.assert_called_once() + mocked_scaler._found_inf_per_device.assert_not_called() + + @pytest.mark.parametrize( + ("found_inf", "expected_step_called"), + [(0.0, True), (1.0, False)], + ids=["no_inf_steps_sched", "found_inf_skips_sched"], + ) + def test_scheduler_gating( + self, + found_inf: float, + expected_step_called: bool, + strategy_factory: Callable[..., TrainingStrategy], + batch: Batch, + ) -> None: + with patch("torch.amp.GradScaler", autospec=True) as scaler_cls: + scaler = scaler_cls.return_value + scaler.get_scale.return_value = 65536.0 + scaler._found_inf_per_device.return_value = { + torch.device("cpu"): torch.tensor(found_inf) + } + scaler.scale.return_value = MagicMock(name="scaled_loss") + + sched = MagicMock(name="sched") + mp = MixedPrecisionHook(precision=torch.float16) + strategy = strategy_factory(hooks=[mp]) + opt = torch.optim.Adam(strategy.models["main"].parameters(), lr=1e-3) + with strategy: + strategy._train_batch_with_optimizers(batch, [opt], [sched]) + + if expected_step_called: + sched.step.assert_called_once() + else: + sched.step.assert_not_called() + + +# --------------------------------------------------------------------------- +# Real CUDA end-to-end (no mock) — bf16 / fp16 +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +class TestCUDAEndToEnd: + """Real autocast + real ``GradScaler`` drive a full step without error.""" + + @pytest.mark.parametrize( + "cuda_precision", + [torch.bfloat16, torch.float16], + ids=["bf16", "fp16"], + ) + def test_single_step_runs_cleanly( + self, cuda_precision: torch.dtype, batch: Batch + ) -> None: + device = torch.device("cuda:0") + model = _build_demo_model() + mp = MixedPrecisionHook(precision=cuda_precision) + observed: dict[str, Any] = {} + + def _capture_forward(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + observed["autocast_enabled"] = torch.is_autocast_enabled("cuda") + observed["autocast_dtype"] = torch.get_autocast_dtype("cuda") + + def _capture_after_step(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + # Capture after the update step; the scaler persists across batches + # until ``__exit__`` resets it. + if mp._scaler is not None: + observed["scale"] = mp._scaler.get_scale() + + forward_hook = _ObserverHook(TrainingStage.BEFORE_FORWARD, _capture_forward) + after_hook = _ObserverHook( + TrainingStage.AFTER_OPTIMIZER_STEP, _capture_after_step + ) + strategy = TrainingStrategy( + models=model, + optimizer_configs=OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + ), + num_epochs=1, + training_fn=_cast_back_training_fn, + loss_fn=EnergyMSELoss() + ForceMSELoss(normalize_by_atom_count=True), + devices=[device], + hooks=[mp, forward_hook, after_hook], + ) + strategy.run([batch]) + + assert strategy.step_count == 1 + assert observed["autocast_enabled"] is True + assert observed["autocast_dtype"] == cuda_precision + if cuda_precision == torch.float16: + assert "scale" in observed + assert torch.isfinite(torch.tensor(observed["scale"])) + + def test_real_fp16_overflow_skips_optimizer_and_scheduler(self) -> None: + device = torch.device("cuda:0") + param = torch.nn.Parameter(torch.ones((), device=device)) + opt = torch.optim.SGD([param], lr=1.0) + sched = MagicMock(name="sched") + mp = MixedPrecisionHook(precision=torch.float16) + orch = TrainingUpdateOrchestrator(mp) + workflow = Mock() + workflow.devices = [device] + ctx = TrainContext( + batch=Mock(spec=Batch), + workflow=workflow, + loss=param * torch.tensor(float("inf"), device=device), + optimizers=[opt], + lr_schedulers=[sched], + ) + + try: + param_before = param.detach().clone() + orch(ctx, TrainingStage.DO_BACKWARD) + assert mp._scaler is not None + scale_before = mp._scaler.get_scale() + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + torch.testing.assert_close(param.detach(), param_before) + sched.step.assert_not_called() + assert mp._scaler.get_scale() < scale_before + finally: + orch.__exit__(None, None, None) + + +# --------------------------------------------------------------------------- +# DO_ stage exclusivity (integration) +# --------------------------------------------------------------------------- + + +class TestDOStageExclusivity: + """MixedPrecisionHook composes through the update orchestrator.""" + + def test_two_mp_hooks_rejected_to_prevent_double_scaling( + self, strategy_factory: Callable[..., TrainingStrategy] + ) -> None: + first = MixedPrecisionHook(precision=torch.float32) + second = MixedPrecisionHook(precision=torch.bfloat16) + with pytest.raises(ValueError, match="MixedPrecisionHook"): + strategy_factory(hooks=[first, second]) + + def test_mp_plus_other_do_backward_claimant_rejected( + self, strategy_factory: Callable[..., TrainingStrategy] + ) -> None: + with pytest.raises(ValueError, match="DO_BACKWARD"): + strategy_factory( + hooks=[ + MixedPrecisionHook(precision=torch.float32), + _ClaimsStagesHook({TrainingStage.DO_BACKWARD}), + ] + ) + + +# --------------------------------------------------------------------------- +# Live-vs-detached loss contract +# --------------------------------------------------------------------------- + + +class TestLiveDetachedLossContract: + """The live-before-backward / detached-after-backward invariant holds (req 22).""" + + def test_loss_graph_state_around_backward( + self, + strategy_factory: Callable[..., TrainingStrategy], + batch: Batch, + ) -> None: + records: dict[TrainingStage, bool] = {} + + def _record(ctx: HookContext, stage: TrainingStage) -> None: + records[stage] = ctx.loss.grad_fn is not None + + hooks = [ + MixedPrecisionHook(precision=torch.float32), + _ObserverHook(TrainingStage.BEFORE_BACKWARD, _record), + _ObserverHook(TrainingStage.AFTER_BACKWARD, _record), + ] + strategy = strategy_factory(hooks=hooks) + strategy.run([batch]) + assert records[TrainingStage.BEFORE_BACKWARD] is True + assert records[TrainingStage.AFTER_BACKWARD] is False + + +# --------------------------------------------------------------------------- +# zero_grad(set_to_none=True) regression +# --------------------------------------------------------------------------- + + +class TestZeroGradSetToNone: + """Regression: optimizers are zeroed with ``set_to_none=True`` (req 28).""" + + def test_zero_grad_called_with_set_to_none_true( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + captured_kwargs: list[dict[str, Any]] = [] + original = torch.optim.Adam.zero_grad + + def _spy(self: torch.optim.Adam, **kwargs: Any) -> None: + captured_kwargs.append(dict(kwargs)) + original(self, **kwargs) + + mp = MixedPrecisionHook(precision=torch.float32) + strategy = TrainingStrategy(**{**baseline_strategy_kwargs, "hooks": [mp]}) + with patch.object(torch.optim.Adam, "zero_grad", _spy): + strategy.run([batch]) + assert captured_kwargs, "zero_grad was never called" + for kw in captured_kwargs: + assert kw.get("set_to_none") is True diff --git a/test/training/test_optimizers.py b/test/training/test_optimizers.py new file mode 100644 index 00000000..bd123e3c --- /dev/null +++ b/test/training/test_optimizers.py @@ -0,0 +1,340 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for optimizer configuration and stepping helpers.""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import patch + +import pytest +import torch +from torch import nn + +from nvalchemi.training import register_type_serializer +from nvalchemi.training._spec import create_model_spec_from_json +from nvalchemi.training.optimizers import ( + OptimizerConfig, + _extract_scheduler_metric, + setup_optimizers, + step_lr_schedulers, + step_metric_schedulers, + step_optimizers, + zero_gradients, +) + + +class _CustomPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): + pass + + +_OPTIMIZER_CONFIG_REJECTION_CASES: list[tuple[str, dict[str, Any]]] = [ + ( + "Invalid optimizer kwargs", + { + "optimizer_cls": torch.optim.Adam, + "optimizer_kwargs": {"bogus_kwarg": 0.1}, + }, + ), + ( + "scheduler_kwargs", + { + "optimizer_cls": torch.optim.Adam, + "optimizer_kwargs": {"lr": 1e-3}, + "scheduler_cls": None, + "scheduler_kwargs": {"step_size": 10}, + }, + ), +] + + +class TestOptimizerConfig: + def test_public_type_serializer_export_available(self) -> None: + assert callable(register_type_serializer) + + def test_build_adam_no_scheduler(self) -> None: + layer = nn.Linear(4, 2) + cfg = OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + ) + optimizer, scheduler = cfg.build(layer.parameters()) + assert isinstance(optimizer, torch.optim.Adam) + assert scheduler is None + + def test_build_with_step_lr(self) -> None: + layer = nn.Linear(4, 2) + cfg = OptimizerConfig( + optimizer_cls=torch.optim.SGD, + optimizer_kwargs={"lr": 0.1}, + scheduler_cls=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 10, "gamma": 0.5}, + ) + optimizer, scheduler = cfg.build(layer.parameters()) + assert isinstance(optimizer, torch.optim.SGD) + assert isinstance(scheduler, torch.optim.lr_scheduler.StepLR) + + def test_class_fields_accept_dotted_paths(self) -> None: + cfg = OptimizerConfig( + optimizer_cls="torch.optim.sgd.SGD", + scheduler_cls="torch.optim.lr_scheduler.StepLR", + scheduler_kwargs={"step_size": 2}, + ) + assert cfg.optimizer_cls is torch.optim.SGD + assert cfg.scheduler_cls is torch.optim.lr_scheduler.StepLR + + @pytest.mark.parametrize( + "kwargs", + [ + {"optimizer_cls": "not.a.real.Optimizer"}, + { + "optimizer_cls": torch.optim.Adam, + "scheduler_cls": "not.a.real.Scheduler", + }, + ], + ids=["bad_optimizer_cls", "bad_scheduler_cls"], + ) + def test_class_fields_reject_bad_dotted_paths(self, kwargs: dict[str, Any]) -> None: + with pytest.raises(ValueError, match="must resolve to an importable class"): + OptimizerConfig(**kwargs) + + @pytest.mark.parametrize( + ("match", "kwargs"), + _OPTIMIZER_CONFIG_REJECTION_CASES, + ids=[ + "invalid_optimizer_kwarg", + "orphan_scheduler_kwargs", + ], + ) + def test_invalid_config_rejected(self, match: str, kwargs: dict[str, Any]) -> None: + with pytest.raises(ValueError, match=match): + OptimizerConfig(**kwargs) + + def test_to_spec_from_spec_roundtrip(self) -> None: + cfg = OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3, "betas": (0.9, 0.95)}, + scheduler_cls=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 5, "gamma": 0.1}, + ) + spec = cfg.to_spec() + restored = OptimizerConfig.from_spec(spec) + assert restored.optimizer_cls is torch.optim.Adam + assert restored.optimizer_kwargs["lr"] == pytest.approx(1e-3) + assert restored.scheduler_cls is torch.optim.lr_scheduler.StepLR + assert restored.scheduler_kwargs == {"step_size": 5, "gamma": 0.1} + + def test_json_roundtrip_via_spec(self) -> None: + cfg = OptimizerConfig( + optimizer_cls=torch.optim.SGD, + optimizer_kwargs={"lr": 0.01, "momentum": 0.9}, + ) + spec = cfg.to_spec() + spec_json = spec.model_dump_json() + spec_back = create_model_spec_from_json(json.loads(spec_json)) + restored = OptimizerConfig.from_spec(spec_back) + assert restored.optimizer_cls is torch.optim.SGD + assert restored.optimizer_kwargs == {"lr": 0.01, "momentum": 0.9} + assert restored.scheduler_cls is None + + +class TestOptimizerHelpers: + def test_setup_optimizers_returns_opt_sched_pairs(self) -> None: + model = nn.Linear(4, 2) + pairs = setup_optimizers( + model, + OptimizerConfig(optimizer_cls=torch.optim.Adam), + ) + assert set(pairs.keys()) == {"main"} + assert len(pairs["main"]) == 1 + optimizer, scheduler = pairs["main"][0] + assert isinstance(optimizer, torch.optim.Adam) + assert scheduler is None + + def test_setup_optimizers_subset_of_models(self) -> None: + student = nn.Linear(4, 2) + teacher = nn.Linear(4, 2) + pairs = setup_optimizers( + {"student": student, "teacher": teacher}, + {"student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)]}, + ) + assert set(pairs) == {"student"} + + def test_setup_optimizers_accepts_moduledict_models(self) -> None: + models = nn.ModuleDict( + { + "student": nn.Linear(4, 2), + "teacher": nn.Linear(4, 2), + } + ) + pairs = setup_optimizers( + models, + {"student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)]}, + ) + + assert set(pairs) == {"student"} + + def test_setup_optimizers_invalid_key_raises(self) -> None: + with pytest.raises(ValueError, match="not present in models"): + setup_optimizers( + {"student": nn.Linear(4, 2)}, + {"teacher": [OptimizerConfig(optimizer_cls=torch.optim.Adam)]}, + ) + + def test_setup_optimizers_frozen_model_raises(self) -> None: + model = nn.Linear(4, 2) + for param in model.parameters(): + param.requires_grad_(False) + with pytest.raises(ValueError, match="no trainable parameters"): + setup_optimizers(model, OptimizerConfig(optimizer_cls=torch.optim.Adam)) + + def test_zero_gradients_zeroes_all_optimizers(self) -> None: + layer_a = nn.Linear(2, 2) + layer_b = nn.Linear(3, 3) + opt_a = torch.optim.SGD(layer_a.parameters(), lr=0.1) + opt_b = torch.optim.SGD(layer_b.parameters(), lr=0.1) + layer_a.weight.grad = torch.ones_like(layer_a.weight) + layer_b.weight.grad = torch.ones_like(layer_b.weight) + zero_gradients([opt_a, opt_b]) + assert layer_a.weight.grad is None + assert layer_b.weight.grad is None + + def test_step_optimizers_advances_params(self) -> None: + torch.manual_seed(0) + layer = nn.Linear(2, 1) + opt = torch.optim.SGD(layer.parameters(), lr=0.1) + before = layer.weight.detach().clone() + layer.weight.grad = torch.ones_like(layer.weight) + step_optimizers([opt]) + assert not torch.equal(before, layer.weight.detach()) + + def test_step_lr_schedulers_skips_none(self) -> None: + layer = nn.Linear(2, 1) + opt = torch.optim.SGD(layer.parameters(), lr=1.0) + sched = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.5) + before_lr = sched.get_last_lr()[0] + step_lr_schedulers([None, sched, None]) + after_lr = sched.get_last_lr()[0] + assert after_lr == pytest.approx(before_lr * 0.5) + + +class TestMetricDrivenSchedulers: + """Phase D: metric-driven (ReduceLROnPlateau) scheduler support.""" + + @staticmethod + def _make_plateau( + lr: float = 0.1, patience: int = 1, factor: float = 0.5 + ) -> tuple[ + nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler.ReduceLROnPlateau + ]: + """Return a (layer, optimizer, ReduceLROnPlateau) triple.""" + layer = nn.Linear(2, 1) + opt = torch.optim.SGD(layer.parameters(), lr=lr) + plateau = torch.optim.lr_scheduler.ReduceLROnPlateau( + opt, patience=patience, factor=factor + ) + return layer, opt, plateau + + def test_reduce_lr_on_plateau_now_accepted(self) -> None: + """OptimizerConfig no longer rejects ReduceLROnPlateau.""" + cfg = OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau, + scheduler_kwargs={"patience": 5}, + ) + assert cfg.scheduler_cls is torch.optim.lr_scheduler.ReduceLROnPlateau + + def test_reduce_lr_on_plateau_subclass_accepted(self) -> None: + """OptimizerConfig also accepts ReduceLROnPlateau subclasses.""" + cfg = OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + scheduler_cls=_CustomPlateau, + ) + assert cfg.scheduler_cls is _CustomPlateau + + def test_scheduler_metric_adapter_requires_scheduler_cls(self) -> None: + """scheduler_metric_adapter without scheduler_cls raises ValueError.""" + with pytest.raises(ValueError, match="scheduler_metric_adapter provided"): + OptimizerConfig( + optimizer_cls=torch.optim.Adam, + scheduler_metric_adapter="total_loss", + ) + + def test_step_lr_schedulers_skips_metric_driven(self) -> None: + """step_lr_schedulers does not call step() on ReduceLROnPlateau.""" + _, _, plateau = self._make_plateau() + layer2 = nn.Linear(2, 1) + opt2 = torch.optim.SGD(layer2.parameters(), lr=1.0) + steplr = torch.optim.lr_scheduler.StepLR(opt2, step_size=1, gamma=0.5) + + steplr_epoch_before = steplr.last_epoch + with patch.object(plateau, "step", wraps=plateau.step) as mock_plateau_step: + step_lr_schedulers([plateau, steplr]) + + mock_plateau_step.assert_not_called() + assert steplr.last_epoch == steplr_epoch_before + 1 + + def test_step_metric_schedulers_str_adapter(self) -> None: + """step_metric_schedulers with a str adapter passes the right value.""" + _, opt, plateau = self._make_plateau() + summary = {"my_loss": torch.tensor(0.42), "other": 99} + with patch.object(plateau, "step", wraps=plateau.step) as mock_step: + step_metric_schedulers([plateau], ["my_loss"], summary) + mock_step.assert_called_once() + arg = mock_step.call_args[0][0] + assert arg == pytest.approx(0.42) + + def test_step_metric_schedulers_callable_adapter(self) -> None: + """step_metric_schedulers with a callable adapter.""" + _, opt, plateau = self._make_plateau() + summary = {"nested": {"val": 1.23}} + adapter = lambda s: s["nested"]["val"] # noqa: E731 + with patch.object(plateau, "step", wraps=plateau.step) as mock_step: + step_metric_schedulers([plateau], [adapter], summary) + mock_step.assert_called_once() + assert mock_step.call_args[0][0] == pytest.approx(1.23) + + def test_step_metric_schedulers_default_adapter(self) -> None: + """step_metric_schedulers with adapter=None uses 'total_loss' key.""" + _, opt, plateau = self._make_plateau() + summary = { + "name": "validation", + "total_loss": torch.tensor(0.55), + "per_component_unweighted": {}, + } + with patch.object(plateau, "step", wraps=plateau.step) as mock_step: + step_metric_schedulers([plateau], [None], summary) + mock_step.assert_called_once() + assert mock_step.call_args[0][0] == pytest.approx(0.55) + + def test_step_metric_schedulers_skips_non_metric(self) -> None: + """step_metric_schedulers skips time-based schedulers and None.""" + layer = nn.Linear(2, 1) + opt = torch.optim.SGD(layer.parameters(), lr=1.0) + steplr = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.5) + epoch_before = steplr.last_epoch + summary = {"total_loss": torch.tensor(0.5)} + step_metric_schedulers([None, steplr], [None, None], summary) + # StepLR should NOT have been stepped (it's not metric-driven) + assert steplr.last_epoch == epoch_before + + def test_extract_scheduler_metric_missing_key_raises(self) -> None: + """_extract_scheduler_metric raises KeyError for absent str key.""" + summary = {"a": 1.0, "b": 2.0} + with pytest.raises(KeyError, match="not_here"): + _extract_scheduler_metric(summary, "not_here") diff --git a/test/training/test_runtime.py b/test/training/test_runtime.py new file mode 100644 index 00000000..c78b698a --- /dev/null +++ b/test/training/test_runtime.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for training runtime helpers.""" + +from __future__ import annotations + +import pytest +import torch +from torch import nn +from torch.utils.data import SequentialSampler + +from nvalchemi.training.runtime import ( + configure_dataloader, + freeze_unconfigured_models, + move_to_devices, +) + + +class TestRuntimeHelpers: + @pytest.mark.parametrize("n_models", [1, 2], ids=["single_model", "two_models"]) + def test_move_to_devices_cpu(self, n_models: int) -> None: + models = {str(i): nn.Linear(4, 2) for i in range(n_models)} + devices = [torch.device("cpu")] + out = move_to_devices(models, devices) + assert len(out) == n_models + for m in out.values(): + assert next(m.parameters()).device.type == "cpu" + + def test_move_to_devices_moduledict_preserves_input_shape(self) -> None: + models = nn.ModuleDict({"a": nn.Linear(4, 2), "b": nn.Linear(4, 2)}) + out = move_to_devices(models, [torch.device("cpu")]) + assert out is models + assert list(out.keys()) == ["a", "b"] + for model in out.values(): + assert next(model.parameters()).device.type == "cpu" + + def test_configure_dataloader_supports_sampler(self) -> None: + dataset = [0, 1, 2] + loader = configure_dataloader( + dataset, + batch_size=1, + sampler=SequentialSampler(dataset), + ) + assert [int(batch.item()) for batch in loader] == dataset + + def test_configure_dataloader_sampler_shuffle_conflict(self) -> None: + dataset = [0, 1, 2] + with pytest.raises(ValueError, match="shuffle=True is incompatible"): + configure_dataloader( + dataset, + batch_size=1, + shuffle=True, + sampler=SequentialSampler(dataset), + ) + + def test_freeze_unconfigured_models_restores_state(self) -> None: + trained = nn.Linear(2, 1) + omitted = nn.Linear(2, 1) + omitted.eval() + params = list(omitted.parameters()) + params[0].requires_grad_(False) + initial_training = omitted.training + initial_requires_grad = [param.requires_grad for param in params] + with freeze_unconfigured_models( + {"trained": trained, "omitted": omitted}, {"trained": object()} + ): + assert omitted.training is False + assert [param.requires_grad for param in params] == [False] * len(params) + assert omitted.training is initial_training + assert [param.requires_grad for param in params] == initial_requires_grad + + def test_freeze_unconfigured_models_accepts_moduledict(self) -> None: + models = nn.ModuleDict({"trained": nn.Linear(2, 1), "omitted": nn.Linear(2, 1)}) + omitted = models["omitted"] + params = list(omitted.parameters()) + with freeze_unconfigured_models(models, {"trained": object()}): + assert omitted.training is False + assert [param.requires_grad for param in params] == [False] * len(params) + assert omitted.training is True + assert [param.requires_grad for param in params] == [True] * len(params) diff --git a/test/training/test_spec.py b/test/training/test_spec.py new file mode 100644 index 00000000..1d8053d2 --- /dev/null +++ b/test/training/test_spec.py @@ -0,0 +1,746 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for :mod:`nvalchemi.training._spec`.""" + +from __future__ import annotations + +import ast +import json +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from nvalchemi.training._spec import ( + _TYPE_SERIALIZERS, + BaseSpec, + _check_no_positional_only, + _dtype_deserialize, + _import_cls, + create_model_spec, + create_model_spec_from_json, + register_type_serializer, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures used across test classes +# --------------------------------------------------------------------------- + + +class _KwOnly: + """Class with ``**kwargs`` in its signature so unknown kwargs are allowed.""" + + def __init__(self, a: int = 1, **kwargs: Any) -> None: + self.a = a + self.kwargs = kwargs + + +class _AnnotatedDtype: + """Class whose ``dtype`` parameter is annotated, exercising the registry path.""" + + def __init__(self, dtype: torch.dtype = torch.float32) -> None: + self.dtype = dtype + + +class _WrapsModule(nn.Module): + """Class whose ``child`` param is annotated as ``nn.Module`` for nested-build.""" + + def __init__(self, child: nn.Module, scale: float = 1.0) -> None: + super().__init__() + self.child = child + self.scale = scale + + +class _WithDevice: + """Class whose ``device`` parameter is annotated.""" + + def __init__(self, device: torch.device = torch.device("cpu")) -> None: + self.device = device + + +class _WithTensor: + """Class holding a single ``torch.Tensor`` parameter.""" + + def __init__(self, weights: torch.Tensor) -> None: + self.weights = weights + + +class _WithTensorLinspace: + """Class holding a ``torch.Tensor`` named ``buf`` (distinct from _WithTensor).""" + + def __init__(self, buf: torch.Tensor) -> None: + self.buf = buf + + +class _WithStringPath: + """Class holding a string that may look like an importable class path.""" + + def __init__(self, label: str) -> None: + self.label = label + + +class _WithUnannotatedStringPath: + """Class holding an unannotated string-like parameter.""" + + def __init__(self, label="") -> None: + self.label = label + + +class _WithUnannotatedClassField: + """Class holding an unannotated class-valued parameter.""" + + def __init__(self, plugin_cls=None) -> None: + self.plugin_cls = plugin_cls + + +class _WithClassFields: + """Class holding parametrized and optional class annotations.""" + + def __init__( + self, + optimizer_cls: type[torch.optim.Optimizer], + scheduler_cls: type | None = None, + ) -> None: + self.optimizer_cls = optimizer_cls + self.scheduler_cls = scheduler_cls + + +class _FactoryBuilt: + """Object reconstructed through an importable classmethod factory.""" + + def __init__(self, value: int, device: torch.device, dtype: torch.dtype) -> None: + self.value = value + self.device = device + self.dtype = dtype + + @classmethod + def from_config( + cls, + value: int, + *, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype | None = None, + ) -> "_FactoryBuilt": + return cls(value=value, device=device, dtype=dtype or torch.float32) + + +class _TupleWrapsModules(nn.Module): + """Class whose tuple field is populated from nested specs.""" + + def __init__(self, children: tuple[nn.Module, ...]) -> None: + super().__init__() + self.children_tuple = children + + +def _make_positional_only_cls() -> type: + """Return a class with a positional-only ``__init__`` parameter. + + Using ``exec`` keeps the ``/`` syntax out of the top-level file body + while still producing a real class with positional-only params. + """ + ns: dict[str, Any] = {} + src = ( + "class _PosOnly:\n" + " def __init__(self, x, /, y=0):\n" + " self.x = x\n" + " self.y = y\n" + ) + exec(src, ns) # noqa: S102 — deliberately constructs positional-only signature + return ns["_PosOnly"] + + +class Outer: + """Module-level host class for nested-class qualname resolution tests.""" + + class Inner: + """Nested class used to verify ``_import_cls`` handles nested qualnames.""" + + +# Prototype `main()` fixtures — used by TestFullPrototypeScenario. Defined at +# module level (not inside a test class) so their __qualname__ stays clean. + + +class MyBlock(nn.Module): + """Residual MLP block mirroring the example_spec.py prototype.""" + + def __init__( + self, + hidden_dim: int, + projection_dims: list[int], + dropout_p: float, + activation: nn.Module, + dtype: torch.dtype = torch.float32, + ) -> None: + super().__init__() + self.activation = activation + self.dropout = nn.Dropout(dropout_p) + dims = [hidden_dim, *projection_dims, hidden_dim] + self.projections = nn.ModuleList( + [nn.Linear(dims[i], dims[i + 1], dtype=dtype) for i in range(len(dims) - 1)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply dropout+activation+projection residually.""" + residual = x + h = x + for proj in self.projections: + h = self.dropout(self.activation(proj(h))) + return h + residual + + +class MyMLIP(nn.Module): + """Toy MLIP with a residual block, feature-scale buffer, and output head.""" + + def __init__( + self, + hidden_dim: int, + cutoff: float, + block: MyBlock, + input_activation: nn.Module, + feature_scale: torch.Tensor, + dtype: torch.dtype = torch.float32, + ) -> None: + super().__init__() + if feature_scale.ndim != 1 or feature_scale.shape[0] != hidden_dim: + raise ValueError( + f"feature_scale must be 1D of length {hidden_dim}, " + f"got shape {tuple(feature_scale.shape)}" + ) + self.hidden_dim = hidden_dim + self.cutoff = cutoff + self.input_proj = nn.Linear(hidden_dim, hidden_dim, dtype=dtype) + self.input_activation = input_activation + self.block = block + self.output_proj = nn.Linear(hidden_dim, 1, dtype=dtype) + self.register_buffer("feature_scale", feature_scale.to(dtype)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run the prototype forward graph.""" + h = self.input_activation(self.input_proj(x)) + h = h * self.feature_scale + h = self.block(h) + return self.output_proj(h) + + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + + +class TestClsPathResolution: + """Resolution of dotted ``cls_path`` strings back to class objects.""" + + def test_resolves_simple(self) -> None: + assert _import_cls("torch.nn.Linear") is nn.Linear + + def test_resolves_deep_module_path(self) -> None: + # Multi-level module path: greedy prefix matching must import + # the longest valid module and walk the remainder as attributes. + assert ( + _import_cls("torch.nn.modules.activation.SiLU") + is nn.modules.activation.SiLU + ) + + def test_resolves_nested_qualname(self) -> None: + """Resolve a nested-class qualname via greedy module-prefix matching.""" + # ``Outer`` is defined at module scope in this test file, so the + # importable prefix is the test module and ``Outer.Inner`` is an + # attribute chain on it. + cls = _import_cls(f"{Outer.__module__}.Outer.Inner") + assert cls is Outer.Inner + + def test_raises_on_invalid_module(self) -> None: + with pytest.raises(ModuleNotFoundError): + _import_cls("definitely_not_a_real_module.SomeCls") + + def test_raises_on_invalid_attr(self) -> None: + with pytest.raises(AttributeError): + _import_cls("torch.nn.ThisAttrDoesNotExist") + + def test_raises_on_non_class_target(self) -> None: + # torch.tensor is a function, not a class. + with pytest.raises(TypeError, match="non-class"): + _import_cls("torch.tensor") + + +class TestTypeSerializerRegistry: + """Round-trip behavior and security of the type-serializer registry.""" + + def test_register_replaces_existing(self) -> None: + original = _TYPE_SERIALIZERS[torch.dtype] + sentinel_ser = lambda d: "sentinel" # noqa: E731 + sentinel_deser = lambda s: torch.float32 # noqa: E731 + try: + register_type_serializer(torch.dtype, sentinel_ser, sentinel_deser) + ser, deser = _TYPE_SERIALIZERS[torch.dtype] + assert ser is sentinel_ser + assert deser is sentinel_deser + finally: + register_type_serializer(torch.dtype, original[0], original[1]) + # Restored: + assert _TYPE_SERIALIZERS[torch.dtype] == original + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.int64]) + def test_dtype_roundtrip(self, dtype: torch.dtype) -> None: + spec = create_model_spec(_AnnotatedDtype, dtype=dtype) + dumped = json.loads(spec.model_dump_json()) + rebuilt = create_model_spec_from_json(dumped) + assert rebuilt.dtype is dtype + + def test_dtype_raises_on_non_dtype_attr(self) -> None: + # torch.nn exists but is a module, not a dtype. The isinstance guard + # in _dtype_deserialize must reject it to block attr-smuggling. + with pytest.raises(ValueError, match="does not resolve to a torch.dtype"): + _dtype_deserialize("nn") + + def test_dtype_raises_on_non_string(self) -> None: + with pytest.raises(TypeError, match="expected str"): + _dtype_deserialize(42) + + def test_device_roundtrip(self) -> None: + spec = create_model_spec(_WithDevice, device=torch.device("cpu")) + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + assert rebuilt.device == torch.device("cpu") + + def test_tensor_roundtrip(self) -> None: + t = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + spec = create_model_spec(_WithTensor, weights=t) + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + assert rebuilt.weights.shape == t.shape + assert rebuilt.weights.dtype == t.dtype + assert torch.equal(rebuilt.weights, t) + + def test_class_annotations_roundtrip(self) -> None: + spec = create_model_spec( + _WithClassFields, + optimizer_cls=torch.optim.Adam, + scheduler_cls=torch.optim.lr_scheduler.StepLR, + ) + dumped = json.loads(spec.model_dump_json()) + assert dumped["optimizer_cls"] == "torch.optim.adam.Adam" + assert dumped["scheduler_cls"] == "torch.optim.lr_scheduler.StepLR" + + rebuilt = create_model_spec_from_json(dumped) + assert rebuilt.optimizer_cls is torch.optim.Adam + assert rebuilt.scheduler_cls is torch.optim.lr_scheduler.StepLR + + def test_optional_class_annotation_roundtrip_none(self) -> None: + spec = create_model_spec( + _WithClassFields, + optimizer_cls=torch.optim.SGD, + scheduler_cls=None, + ) + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + assert rebuilt.optimizer_cls is torch.optim.SGD + assert rebuilt.scheduler_cls is None + + def test_unannotated_class_field_roundtrip(self) -> None: + spec = create_model_spec( + _WithUnannotatedClassField, + plugin_cls=torch.optim.Adam, + ) + dumped = json.loads(spec.model_dump_json()) + assert dumped["plugin_cls"] == {"__type__": "torch.optim.adam.Adam"} + + rebuilt = create_model_spec_from_json(dumped) + assert rebuilt.plugin_cls is torch.optim.Adam + assert rebuilt.build().plugin_cls is torch.optim.Adam + + +class TestSignatureIntrospection: + """Signature-level validation helpers.""" + + def test_positional_only_rejected(self) -> None: + cls_ = _make_positional_only_cls() + with pytest.raises(TypeError, match="positional-only"): + _check_no_positional_only(cls_) + + +class TestCreateModelSpec: + """Construction of a :class:`BaseSpec` via :func:`create_model_spec`.""" + + def test_creates_spec_with_cls_path_and_timestamp(self) -> None: + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + assert spec.cls_path == "torch.nn.modules.linear.Linear" + # timestamp: ISO-8601, parses + from datetime import datetime + + parsed = datetime.fromisoformat(spec.timestamp) + assert parsed.tzinfo is not None + + def test_rejects_unknown_kwarg(self) -> None: + with pytest.raises(TypeError, match="Unknown kwargs"): + create_model_spec(nn.Linear, in_features=4, out_features=2, bogus=1) + + def test_accepts_arbitrary_kwargs_with_var_keyword(self) -> None: + # _KwOnly has **kwargs, so `extra_foo` should be accepted. + spec = create_model_spec(_KwOnly, a=2, extra_foo="hello") + assert spec.a == 2 + assert spec.extra_foo == "hello" + + def test_nested_spec_composition(self) -> None: + act_spec = create_model_spec(nn.SiLU) + spec = create_model_spec(_WrapsModule, child=act_spec, scale=0.5) + # The nested field should be a BaseSpec. + assert isinstance(spec.child, BaseSpec) + assert spec.child.cls_path.endswith(".SiLU") + + def test_tensor_field(self) -> None: + t = torch.linspace(0.0, 1.0, 5) + spec = create_model_spec(_WithTensorLinspace, buf=t) + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + assert torch.equal(rebuilt.buf, t) + assert rebuilt.buf.dtype == t.dtype + assert tuple(rebuilt.buf.shape) == tuple(t.shape) + + def test_importable_classmethod_factory(self) -> None: + spec = create_model_spec( + _FactoryBuilt.from_config, + value=7, + dtype=torch.float64, + ) + + assert spec.cls_path.endswith("._FactoryBuilt.from_config") + assert spec.accepts_kwarg("device") + assert not spec.accepts_kwarg("bogus") + + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + built = rebuilt.build(device=torch.device("cuda:0")) + + assert isinstance(built, _FactoryBuilt) + assert built.value == 7 + assert built.device == torch.device("cuda:0") + assert built.dtype == torch.float64 + + +class TestCreateModelSpecFromJson: + """JSON-dict rehydration via :func:`create_model_spec_from_json`.""" + + def test_roundtrip_preserves_timestamp(self) -> None: + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + assert rebuilt.timestamp == spec.timestamp + + def test_recursive_nested_spec_rehydrated(self) -> None: + act_spec = create_model_spec(nn.SiLU) + spec = create_model_spec(_WrapsModule, child=act_spec, scale=0.5) + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + assert isinstance(rebuilt.child, BaseSpec) + assert rebuilt.child.cls_path.endswith(".SiLU") + assert rebuilt.child.timestamp == act_spec.timestamp + + def test_tuple_nested_spec_sequence_rehydrated(self) -> None: + specs = ( + create_model_spec(nn.SiLU), + create_model_spec(nn.GELU), + ) + spec = create_model_spec(_TupleWrapsModules, children=specs) + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + + assert isinstance(rebuilt.children, tuple) + assert [child.cls_path for child in rebuilt.children] == [ + "torch.nn.modules.activation.SiLU", + "torch.nn.modules.activation.GELU", + ] + built = rebuilt.build() + assert isinstance(built.children_tuple, tuple) + assert [type(child) for child in built.children_tuple] == [nn.SiLU, nn.GELU] + + @pytest.mark.parametrize("missing", ["cls_path", "timestamp"]) + def test_missing_required_field_raises_valueerror(self, missing: str) -> None: + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + dumped = json.loads(spec.model_dump_json()) + dumped.pop(missing) + with pytest.raises(ValueError, match=f"missing required field '{missing}'"): + create_model_spec_from_json(dumped) + + def test_bad_cls_path_raises_valueerror(self) -> None: + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + dumped = json.loads(spec.model_dump_json()) + dumped["cls_path"] = "definitely.not.a.real.Class" + with pytest.raises(ValueError, match="Could not resolve cls_path"): + create_model_spec_from_json(dumped) + + def test_string_field_preserves_importable_class_path(self) -> None: + spec = create_model_spec(_WithStringPath, label="torch.nn.Linear") + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + assert rebuilt.label == "torch.nn.Linear" + assert rebuilt.build().label == "torch.nn.Linear" + + def test_unannotated_string_field_preserves_importable_class_path(self) -> None: + spec = create_model_spec(_WithUnannotatedStringPath, label="torch.nn.Linear") + rebuilt = create_model_spec_from_json(json.loads(spec.model_dump_json())) + assert rebuilt.label == "torch.nn.Linear" + assert rebuilt.build().label == "torch.nn.Linear" + + def test_unannotated_param_dtype_rehydrates_via_deserializer_probe( + self, + ) -> None: + # nn.Linear.__init__'s dtype parameter is unannotated. Without the + # eager deserializer probe in create_model_spec_from_json, the str + # "torch.float32" would pass through untyped and build() would hand + # a str to torch.empty. The probe rehydrates the string into a + # torch.dtype before create_model_spec sees it. + spec = create_model_spec( + nn.Linear, in_features=4, out_features=2, dtype=torch.float32 + ) + dumped = json.loads(spec.model_dump_json()) + rebuilt = create_model_spec_from_json(dumped) + assert rebuilt.dtype == torch.float32 # type: ignore[attr-defined] + model = rebuilt.build() + assert model.weight.dtype == torch.float32 + + +class TestBaseSpecBuild: + """Behavior of :meth:`BaseSpec.build`.""" + + def test_build_basic_nn_linear(self) -> None: + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + m = spec.build() + assert isinstance(m, nn.Linear) + out = m(torch.randn(3, 4)) + assert out.shape == (3, 2) + + def test_build_strict_is_noop(self) -> None: + spec = create_model_spec(nn.Linear, in_features=4, out_features=2) + m = spec.build(strict=True) + assert isinstance(m, nn.Linear) + + def test_build_accepts_runtime_args_and_kwargs(self) -> None: + spec = create_model_spec(nn.Linear, in_features=4) + m = spec.build(out_features=8) + assert m.in_features == 4 + assert m.out_features == 8 + + def test_build_nested_spec_composition(self) -> None: + # Prototype-style: child module built recursively. + act_spec = create_model_spec(nn.SiLU) + spec = create_model_spec(_WrapsModule, child=act_spec, scale=0.5) + obj = spec.build() + assert isinstance(obj, _WrapsModule) + assert isinstance(obj.child, nn.SiLU) + assert obj.scale == 0.5 + + +class TestFullPrototypeScenario: + """End-to-end scenario port of ``example_spec.py::main``.""" + + def test_prototype_main_roundtrip(self, tmp_path: Path) -> None: + from nvalchemi.training._checkpoint import load_checkpoint, save_checkpoint + + hidden_dim = 16 # smaller than prototype for test speed + feature_scale = torch.linspace(0.5, 1.5, hidden_dim) + + spec = create_model_spec( + MyMLIP, + hidden_dim=hidden_dim, + cutoff=5.0, + block=create_model_spec( + MyBlock, + hidden_dim=hidden_dim, + projection_dims=[24, 24], + dropout_p=0.1, + activation=create_model_spec(nn.SiLU), + dtype=torch.float32, + ), + input_activation=create_model_spec(nn.GELU), + feature_scale=feature_scale, + dtype=torch.float32, + ) + + torch.manual_seed(0) + model = spec.build() + assert isinstance(model, MyMLIP) + assert isinstance(model.block, MyBlock) + assert isinstance(model.block.activation, nn.SiLU) + assert isinstance(model.input_activation, nn.GELU) + assert torch.equal(model.feature_scale, feature_scale) + + # --- Save + reload via checkpoint I/O ----------------------------- + save_checkpoint(tmp_path, models={"main": (model, spec)}) + save_checkpoint(tmp_path, models={"main": (model, spec)}) + model_dir = tmp_path / "models" / "main" + assert (model_dir / "spec.json").is_file() + assert (model_dir / "checkpoints" / "0.pt").is_file() + assert (model_dir / "checkpoints" / "1.pt").is_file() + + torch.manual_seed(999) # different seed to prove weights came from ckpt + result = load_checkpoint(tmp_path) + reloaded_model, reloaded_spec = result.models["main"] + + sd_orig = model.state_dict() + sd_new = reloaded_model.state_dict() + assert sd_orig.keys() == sd_new.keys() + for k in sd_orig: + assert torch.equal(sd_orig[k], sd_new[k]) + assert reloaded_spec.timestamp == spec.timestamp + assert torch.equal(reloaded_spec.feature_scale, feature_scale) + + # Forward pass under eval() must be bit-identical. + model.eval() + reloaded_model.eval() + x = torch.randn(3, hidden_dim) + with torch.no_grad(): + y_ref = model(x) + y_new = reloaded_model(x) + assert torch.equal(y_ref, y_new) + + # --- Optimizer + scheduler round-trip ----------------------------- + opt_spec = create_model_spec( + torch.optim.AdamW, + lr=1e-3, + betas=(0.9, 0.95), + weight_decay=1e-4, + eps=1e-8, + ) + sched_spec = create_model_spec( + torch.optim.lr_scheduler.CosineAnnealingLR, + T_max=100, + eta_min=0.0, + ) + optimizer = opt_spec.build(model.parameters()) + scheduler = sched_spec.build(optimizer) + + opt_reloaded_spec = create_model_spec_from_json( + json.loads(opt_spec.model_dump_json()) + ) + sched_reloaded_spec = create_model_spec_from_json( + json.loads(sched_spec.model_dump_json()) + ) + reloaded_optimizer = opt_reloaded_spec.build(reloaded_model.parameters()) + reloaded_scheduler = sched_reloaded_spec.build(reloaded_optimizer) + + for k in ("lr", "betas", "weight_decay", "eps"): + assert optimizer.param_groups[0][k] == reloaded_optimizer.param_groups[0][k] + assert scheduler.T_max == reloaded_scheduler.T_max == 100 + assert scheduler.eta_min == reloaded_scheduler.eta_min == 0.0 + + # LR trajectory equivalence over 10 steps. + for p in model.parameters(): + p.grad = torch.zeros_like(p) + for p in reloaded_model.parameters(): + p.grad = torch.zeros_like(p) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + optimizer.step() + reloaded_optimizer.step() + traj_orig, traj_new = [], [] + for _ in range(10): + scheduler.step() + reloaded_scheduler.step() + traj_orig.append(scheduler.get_last_lr()[0]) + traj_new.append(reloaded_scheduler.get_last_lr()[0]) + assert traj_orig == traj_new + + +class TestSecurityNoPickle: + """AST-level security invariants for no-pickle serialization modules.""" + + _TARGETS = ( + Path(__file__).resolve().parents[2] / "nvalchemi" / "_serialization.py", + Path(__file__).resolve().parents[2] / "nvalchemi" / "training" / "_spec.py", + Path(__file__).resolve().parents[2] + / "nvalchemi" + / "training" + / "_checkpoint.py", + ) + _FORBIDDEN_MODULES = frozenset({"pickle", "cloudpickle", "dill", "marshal"}) + + def _tree(self, path: Path) -> ast.AST: + return ast.parse(path.read_text()) + + def test_no_pickle_imports(self) -> None: + for path in self._TARGETS: + tree = self._tree(path) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + root = alias.name.split(".")[0] + assert root not in self._FORBIDDEN_MODULES, ( + f"{path.name}:{node.lineno} imports forbidden " + f"module {alias.name!r}" + ) + elif isinstance(node, ast.ImportFrom): + if node.module is None: + continue + root = node.module.split(".")[0] + assert root not in self._FORBIDDEN_MODULES, ( + f"{path.name}:{node.lineno} imports from forbidden " + f"module {node.module!r}" + ) + + def test_torch_load_always_weights_only_true(self) -> None: + for path in self._TARGETS: + tree = self._tree(path) + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if not ( + isinstance(func, ast.Attribute) + and func.attr == "load" + and isinstance(func.value, ast.Name) + and func.value.id == "torch" + ): + continue + kw = {k.arg: k.value for k in node.keywords if k.arg is not None} + assert "weights_only" in kw, ( + f"{path.name}:{node.lineno} torch.load() missing " + f"weights_only= kwarg" + ) + val = kw["weights_only"] + assert isinstance(val, ast.Constant) and val.value is True, ( + f"{path.name}:{node.lineno} torch.load(weights_only=...) " + f"must be literal True, got {ast.dump(val)}" + ) + + def test_torch_save_always_state_dict(self) -> None: + for path in self._TARGETS: + tree = self._tree(path) + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if not ( + isinstance(func, ast.Attribute) + and func.attr == "save" + and isinstance(func.value, ast.Name) + and func.value.id == "torch" + ): + continue + assert node.args, ( + f"{path.name}:{node.lineno} torch.save() called with no args" + ) + first = node.args[0] + # Must be either an x.state_dict() call or a variable + # named "state_dict" (accepted when the call site already + # resolved the state_dict, e.g. in _save_component). + is_state_dict_call = ( + isinstance(first, ast.Call) + and isinstance(first.func, ast.Attribute) + and first.func.attr == "state_dict" + ) + is_state_dict_var = ( + isinstance(first, ast.Name) and first.id == "state_dict" + ) + assert is_state_dict_call or is_state_dict_var, ( + f"{path.name}:{node.lineno} torch.save() first arg must be " + f"an x.state_dict() call or a 'state_dict' variable, " + f"got {ast.dump(first)}" + ) diff --git a/test/training/test_stages.py b/test/training/test_stages.py new file mode 100644 index 00000000..09a034e5 --- /dev/null +++ b/test/training/test_stages.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from nvalchemi.dynamics.base import DynamicsStage +from nvalchemi.hooks import HookRegistryMixin +from nvalchemi.training import TrainingStage + +# Canonical name/order snapshot. Must be edited by hand if TrainingStage members +# change — that is the point: an accidental reorder or rename fails this test. +_EXPECTED_MEMBERS: tuple[str, ...] = ( + "SETUP", + "BEFORE_TRAINING", + "BEFORE_EPOCH", + "BEFORE_BATCH", + "BEFORE_FORWARD", + "AFTER_FORWARD", + "BEFORE_LOSS", + "AFTER_LOSS", + "BEFORE_BACKWARD", + "DO_BACKWARD", + "AFTER_BACKWARD", + "BEFORE_OPTIMIZER_STEP", + "DO_OPTIMIZER_STEP", + "AFTER_OPTIMIZER_STEP", + "AFTER_BATCH", + "AFTER_EPOCH", + "AFTER_TRAINING", + "AFTER_VALIDATION", +) + + +class _TrainingHost(HookRegistryMixin): + _stage_type = TrainingStage + + def __init__(self): + self.step_count = 0 + self._init_hooks() + + +class _DynamicsHost(HookRegistryMixin): + _stage_type = DynamicsStage + + def __init__(self): + self.step_count = 0 + self._init_hooks() + + +class TestTrainingStageEnum: + def test_members_in_declared_order(self): + assert tuple(s.name for s in TrainingStage) == _EXPECTED_MEMBERS + + def test_values_are_unique(self): + assert len({s.value for s in TrainingStage}) == len(TrainingStage) + + def test_members_count(self): + assert len(TrainingStage) == 18 + + def test_all_members_are_before_or_after_or_do(self): + for member in TrainingStage: + assert member is TrainingStage.SETUP or member.name.startswith( + ("BEFORE_", "AFTER_", "DO_") + ) + + def test_do_backward_between_before_and_after(self): + members = list(TrainingStage) + assert ( + members.index(TrainingStage.DO_BACKWARD) + == members.index(TrainingStage.BEFORE_BACKWARD) + 1 + ) + assert ( + members.index(TrainingStage.DO_BACKWARD) + == members.index(TrainingStage.AFTER_BACKWARD) - 1 + ) + + def test_do_optimizer_step_between_before_and_after(self): + members = list(TrainingStage) + assert ( + members.index(TrainingStage.DO_OPTIMIZER_STEP) + == members.index(TrainingStage.BEFORE_OPTIMIZER_STEP) + 1 + ) + assert ( + members.index(TrainingStage.DO_OPTIMIZER_STEP) + == members.index(TrainingStage.AFTER_OPTIMIZER_STEP) - 1 + ) + + def test_public_optimizer_boundaries_remain_distinct(self): + members = list(TrainingStage) + assert members.index(TrainingStage.AFTER_BACKWARD) + 1 == members.index( + TrainingStage.BEFORE_OPTIMIZER_STEP + ) + assert members.index(TrainingStage.AFTER_OPTIMIZER_STEP) + 1 == members.index( + TrainingStage.AFTER_BATCH + ) + + +class TestTrainingStageRegistration: + def test_register_training_hook_succeeds(self): + host = _TrainingHost() + + class TrainingHook: + frequency = 1 + stage = TrainingStage.BEFORE_BATCH + + def __call__(self, ctx, stage): + pass + + hook = TrainingHook() + host.register_hook(hook) + + assert len(host.hooks) == 1 + assert host.hooks[0] is hook + + def test_call_hooks_dispatches_by_stage(self): + host = _TrainingHost() + host.step_count = 1 + call_log: list[TrainingStage] = [] + + class BeforeBatchHook: + frequency = 1 + stage = TrainingStage.BEFORE_BATCH + + def __call__(self, ctx, stage): + call_log.append(stage) + + class AfterBatchHook: + frequency = 1 + stage = TrainingStage.AFTER_BATCH + + def __call__(self, ctx, stage): + call_log.append(stage) + + host.register_hook(BeforeBatchHook()) + host.register_hook(AfterBatchHook()) + + host._call_hooks(TrainingStage.BEFORE_BATCH, MagicMock()) + + assert call_log == [TrainingStage.BEFORE_BATCH] + + def test_dynamics_stage_rejected_on_training_host(self): + host = _TrainingHost() + + class DynamicsHook: + frequency = 1 + stage = DynamicsStage.BEFORE_STEP + + def __call__(self, ctx, stage): + pass + + with pytest.raises( + TypeError, match=r"type DynamicsStage.*only accepts TrainingStage" + ): + host.register_hook(DynamicsHook()) + + def test_training_host_requires_stage(self): + """Pins that a TrainingStage-typed host inherits the generic "stage required" contract.""" + host = _TrainingHost() + + class NoStageHook: + frequency = 1 + stage = None + + def __call__(self, ctx, stage): + pass + + with pytest.raises(TypeError, match="no stage assigned"): + host.register_hook(NoStageHook()) + + +class TestStageIsolation: + def test_training_stage_rejected_on_dynamics_host(self): + host = _DynamicsHost() + + class TrainingHook: + frequency = 1 + stage = TrainingStage.BEFORE_BATCH + + def __call__(self, ctx, stage): + pass + + with pytest.raises( + TypeError, match=r"type TrainingStage.*only accepts DynamicsStage" + ): + host.register_hook(TrainingHook()) + + def test_runs_on_stage_bypass_allows_cross_category(self): + """`_runs_on_stage` bypasses the stage-type check at registration.""" + host = _TrainingHost() + + class CrossCategoryHook: + frequency = 1 + stage = DynamicsStage.BEFORE_STEP # foreign enum; normally rejected + + def _runs_on_stage(self, stage): + return True + + def __call__(self, ctx, stage): + pass + + host.register_hook(CrossCategoryHook()) + assert len(host.hooks) == 1 diff --git a/test/training/test_strategy.py b/test/training/test_strategy.py new file mode 100644 index 00000000..1bb15caf --- /dev/null +++ b/test/training/test_strategy.py @@ -0,0 +1,1821 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for TrainingStrategy, OptimizerConfig, and loop helpers.""" + +from __future__ import annotations + +import json +import operator +from collections.abc import Callable, Iterator, Mapping +from enum import Enum +from typing import Any +from unittest.mock import patch + +import pytest +import torch + +from nvalchemi.data import Batch +from nvalchemi.hooks._context import HookContext, TrainContext +from nvalchemi.models.base import BaseModelMixin +from nvalchemi.training import ( + ComposedLossFunction, + EnergyMSELoss, + ForceMSELoss, + LinearWeight, + TrainingStage, +) +from nvalchemi.training.hooks import TrainingUpdateHook +from nvalchemi.training.optimizers import OptimizerConfig +from nvalchemi.training.strategy import TrainingStrategy, default_training_fn +from test.training.conftest import ( + _build_adam_optimizer_configs, + _build_baseline_strategy_kwargs, + _build_batch, + _build_dataset, + _build_demo_model, +) + + +def demo_training_fn(model: BaseModelMixin, batch: Batch) -> dict[str, torch.Tensor]: + """Training step: forward pass producing ``predicted_energy`` + ``predicted_forces``. + + Module-level so it can round-trip through + :meth:`TrainingStrategy.to_spec_dict` (lambdas and nested functions are + rejected by the serializer). + """ + return default_training_fn(model, batch) + + +def dict_demo_training_fn( + models: dict[str, BaseModelMixin], batch: Batch +) -> dict[str, torch.Tensor]: + """Distillation-style dict-model training function using all named models.""" + student = demo_training_fn(models["student"], batch) + teacher = demo_training_fn(models["teacher"], batch) + assert set(models) == {"student", "teacher"} + return { + "predicted_energy": student["predicted_energy"], + "predicted_forces": teacher["predicted_forces"], + } + + +def mapping_annotated_training_fn( + models: Mapping[str, BaseModelMixin], batch: Batch +) -> dict[str, torch.Tensor]: + """Mapping-annotated training function for validation tests.""" + return demo_training_fn(models["main"], batch) + + +def moduledict_annotated_training_fn( + models: torch.nn.ModuleDict, batch: Batch +) -> dict[str, torch.Tensor]: + """ModuleDict-annotated training function for validation tests.""" + return demo_training_fn(models["main"], batch) + + +def single_model_training_fn( + model: BaseModelMixin, batch: Batch +) -> dict[str, torch.Tensor]: + """Single-model training function for validation tests.""" + return demo_training_fn(model, batch) + + +def _make_demo_model() -> Any: + """Return a freshly seeded demo model for local strategy tests.""" + return _build_demo_model() + + +def _make_batch(n_systems: int = 2, n_atoms_each: int = 3, seed: int = 0) -> Batch: + """Return a deterministic batch for local strategy tests.""" + return _build_batch(n_systems=n_systems, n_atoms_each=n_atoms_each, seed=seed) + + +def _make_dataset( + n_batches: int = 3, + n_systems: int = 2, + n_atoms_each: int = 3, + base_seed: int = 100, +) -> list[Batch]: + """Return a deterministic dataset for local strategy tests.""" + return _build_dataset( + n_batches=n_batches, + n_systems=n_systems, + n_atoms_each=n_atoms_each, + base_seed=base_seed, + ) + + +def _adam_optimizer_configs() -> dict[str, list[OptimizerConfig]]: + """Return the default Adam optimizer config mapping.""" + return _build_adam_optimizer_configs() + + +def _make_strategy(**overrides: Any) -> TrainingStrategy: + """Build a strategy with baseline kwargs plus local overrides.""" + models = overrides.pop("models") if "models" in overrides else None + kwargs = _build_baseline_strategy_kwargs(models=models) + kwargs.update(overrides) + return TrainingStrategy(**kwargs) + + +class _RecordingLinear(torch.nn.Linear): + """Linear module that records device-placement calls.""" + + def __init__(self) -> None: + super().__init__(4, 4) + self.to_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def to(self, *args: Any, **kwargs: Any) -> torch.nn.Module: + """Record and forward :meth:`torch.nn.Module.to` calls.""" + self.to_calls.append((args, kwargs)) + return super().to(*args, **kwargs) + + +class _RecordingHook: + """Hook object tagged with ``stage``; forwards ``(ctx, stage)`` to ``callback``. + + Stage filtering is done by the hook runner via ``self.stage``; this + helper just forwards. Recording runs on CPU — callbacks that convert + tensors via ``float(...)`` are not safe for GPU tensors without an + explicit ``.cpu()``. + """ + + def __init__( + self, + stage: Enum, + callback: Callable[[HookContext, Enum], None], + ) -> None: + self.stage = stage + self.frequency = 1 + self._callback = callback + + def __call__(self, ctx: HookContext, stage: Enum) -> None: + self._callback(ctx, stage) + + +class _EveryOtherOptimizerStepHook(TrainingUpdateHook): + """Veto optimizer steps on alternating batches.""" + + priority = 10 + + def __init__(self) -> None: + self.calls = 0 + self.batch_counts: list[int] = [] + self.step_counts: list[int] = [] + self.step_decisions: list[bool] = [] + self.after_skip: list[bool] = [] + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor | None]: + if stage is TrainingStage.DO_OPTIMIZER_STEP: + should_step = self.calls % 2 == 1 + self.batch_counts.append(ctx.batch_count) + self.step_counts.append(ctx.step_count) + self.step_decisions.append(should_step) + self.calls += 1 + return should_step, ctx.loss + if stage is TrainingStage.AFTER_OPTIMIZER_STEP: + self.after_skip.append(will_skip) + return True, ctx.loss + + +class _EpochSampler: + """Sampler stub that records epochs passed to ``set_epoch``.""" + + def __init__(self) -> None: + self.epochs: list[int] = [] + + def set_epoch(self, epoch: int) -> None: + self.epochs.append(epoch) + + +class _RestartableLoader: + """Re-iterable sized loader with a sampler for restart tests.""" + + def __init__(self, batches: list[Batch]) -> None: + self._batches = batches + self.sampler = _EpochSampler() + + def __iter__(self) -> Iterator[Batch]: + return iter(self._batches) + + def __len__(self) -> int: + return len(self._batches) + + +_VALIDATOR_REJECTION_CASES: list[tuple[str, dict[str, Any]]] = [ + ( + "models must contain at least one BaseModelMixin", + {"models": {}, "optimizer_configs": {}}, + ), + ( + "optimizer_configs must configure at least one model", + {"optimizer_configs": {}}, + ), + ( + r"optimizer_configs\['main'\] must contain", + {"optimizer_configs": {"main": []}}, + ), + ( + "models must map names", + {"models": {"main": torch.nn.Linear(1, 1)}, "optimizer_configs": {}}, + ), + ( + "not present in models", + { + "optimizer_configs": { + "missing": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + } + }, + ), + ( + "devices must have length", + {"devices": [torch.device("cpu"), torch.device("cpu")]}, + ), + ( + "devices must contain at least one torch.device", + {"devices": []}, + ), + ( + "Exactly one of num_epochs or num_steps", + {"num_epochs": 1, "num_steps": 1}, + ), + ( + "Exactly one of num_epochs or num_steps", + {"num_epochs": None, "num_steps": None}, + ), + ("greater than or equal to 1", {"num_epochs": -1}), + ("greater than or equal to 1", {"num_steps": -1, "num_epochs": None}), + ("greater than 0", {"epoch_step_modifier": 0}), + ( + "no attribute", + {"training_fn": "nvalchemi.training.strategy.not_a_real_fn"}, + ), +] + +_DELETE = object() + +_FROM_SPEC_REJECTION_CASES: list[tuple[str, Any, str]] = [ + ("optimizer_configs", [], "optimizer_configs"), + ("optimizer_configs", {"main": [1]}, "optimizer_configs"), + ("devices", "cpu", "devices"), + ("loss_fn_spec", [], "loss_fn_spec"), + ("model_specs", [], "model_specs"), + ("training_fn", _DELETE, "no training_fn"), + ("training_fn", 123, "training_fn"), + ("single_model_input", "yes", "single_model_input"), +] + + +class TestTrainingStrategyValidators: + @pytest.mark.parametrize( + ("match", "overrides"), + _VALIDATOR_REJECTION_CASES, + ids=[ + "empty_models", + "empty_optimizer_configs", + "empty_per_model_list", + "invalid_model_value", + "optimizer_key_missing", + "devices_wrong_length", + "devices_empty", + "both_num_epochs_and_num_steps", + "neither_num_epochs_nor_num_steps", + "negative_num_epochs", + "negative_num_steps", + "nonpositive_epoch_step_modifier", + "training_fn_bad_dotted_path", + ], + ) + def test_construction_rejected( + self, + match: str, + overrides: dict[str, Any], + baseline_strategy_kwargs: dict[str, Any], + ) -> None: + kwargs = {**baseline_strategy_kwargs, **overrides} + with pytest.raises(ValueError, match=match): + TrainingStrategy(**kwargs) + + def test_training_fn_dotted_string_resolved( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + strat = TrainingStrategy( + **{**baseline_strategy_kwargs, "training_fn": "operator.add"} + ) + assert strat.training_fn is operator.add + + def test_training_fn_required_message_suggests_default( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + kwargs = dict(baseline_strategy_kwargs) + del kwargs["training_fn"] + with pytest.raises(ValueError, match="default_training_fn"): + TrainingStrategy(**kwargs) + + def test_leaf_loss_fn_normalized_to_composed_loss( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + strategy = TrainingStrategy( + **{**baseline_strategy_kwargs, "loss_fn": EnergyMSELoss()} + ) + assert isinstance(strategy.loss_fn, ComposedLossFunction) + assert len(strategy.loss_fn.components) == 1 + assert isinstance(strategy.loss_fn.components[0], EnergyMSELoss) + + def test_single_model_rejects_mapping_annotation( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + with pytest.raises(ValueError, match="single-model"): + TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "training_fn": mapping_annotated_training_fn, + } + ) + + def test_single_model_rejects_moduledict_annotation( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + with pytest.raises(ValueError, match="single-model"): + TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "training_fn": moduledict_annotated_training_fn, + } + ) + + def test_dict_models_reject_single_model_annotation( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + with pytest.raises(ValueError, match="models=model"): + TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "models": { + "student": _build_demo_model(), + "teacher": _build_demo_model(), + }, + "optimizer_configs": { + "student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + "training_fn": single_model_training_fn, + } + ) + + def test_duplicate_hook_instances_rejected( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + hook = _RecordingHook(TrainingStage.BEFORE_BATCH, lambda ctx, stage: None) + with pytest.raises(ValueError, match="duplicate hook"): + TrainingStrategy(**{**baseline_strategy_kwargs, "hooks": [hook, hook]}) + + def test_epoch_constructor_alias_populates_epoch_count(self) -> None: + strategy = _make_strategy(epoch=3) + assert strategy.epoch_count == 3 + assert strategy.epoch == 3 + + +class TestTrainingStrategyRun: + def test_single_model_training_fn_receives_model_only( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + seen: list[BaseModelMixin] = [] + + def _training_fn( + model: BaseModelMixin, batch: Batch + ) -> dict[str, torch.Tensor]: + seen.append(model) + return demo_training_fn(model, batch) + + strategy = TrainingStrategy( + **{**baseline_strategy_kwargs, "training_fn": _training_fn} + ) + strategy.run([batch]) + assert seen == [strategy.models["main"]] + + def test_dict_model_training_fn_receives_all_models( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "models": { + "student": _build_demo_model(), + "teacher": _build_demo_model(), + }, + "optimizer_configs": { + "student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + "training_fn": dict_demo_training_fn, + } + ) + strategy.run([batch]) + assert strategy.step_count == 1 + + def test_dict_model_multi_device_run_raises( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "models": { + "student": _build_demo_model(), + "teacher": _build_demo_model(), + }, + "optimizer_configs": { + "student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + "training_fn": dict_demo_training_fn, + "devices": [torch.device("cpu"), torch.device("cpu")], + } + ) + with pytest.raises( + ValueError, match="Named-model training with multiple devices" + ): + strategy.run([batch]) + + def test_moduledict_models_are_accepted_as_named_models( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "models": torch.nn.ModuleDict( + {"student": _build_demo_model(), "teacher": _build_demo_model()} + ), + "optimizer_configs": { + "student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + "training_fn": dict_demo_training_fn, + } + ) + assert isinstance(strategy.models, dict) + assert set(strategy.models) == {"student", "teacher"} + strategy.run([batch]) + assert strategy.step_count == 1 + + def test_omitted_model_is_temporarily_frozen_and_eval( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + teacher = _build_demo_model() + teacher.eval() + params = list(teacher.parameters()) + params[0].requires_grad_(False) + initial_training = teacher.training + initial_requires_grad = [param.requires_grad for param in params] + seen_during_run: list[tuple[bool, list[bool]]] = [] + + def _training_fn( + models: dict[str, BaseModelMixin], batch: Batch + ) -> dict[str, torch.Tensor]: + seen_during_run.append( + ( + models["teacher"].training, + [param.requires_grad for param in models["teacher"].parameters()], + ) + ) + return dict_demo_training_fn(models, batch) + + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "models": {"student": _build_demo_model(), "teacher": teacher}, + "optimizer_configs": { + "student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + "training_fn": _training_fn, + } + ) + strategy.run([batch]) + assert strategy.models["student"].training is True + assert any( + param.requires_grad for param in strategy.models["student"].parameters() + ) + assert seen_during_run == [(False, [False] * len(params))] + assert strategy.models["teacher"].training is initial_training + assert [param.requires_grad for param in params] == initial_requires_grad + + def test_default_training_fn_opt_in_runs_single_model( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + strategy = TrainingStrategy( + **{**baseline_strategy_kwargs, "training_fn": default_training_fn} + ) + strategy.run([batch]) + assert strategy.step_count == 1 + + def test_train_batch_public_api_runs_per_batch_flow_only( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + seen: list[TrainingStage] = [] + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "hooks": [ + _RecordingHook( + TrainingStage.BEFORE_TRAINING, + lambda _ctx, stage: seen.append(stage), + ), + _RecordingHook( + TrainingStage.BEFORE_BATCH, + lambda _ctx, stage: seen.append(stage), + ), + ], + } + ) + + strategy.train_batch(batch) + + assert seen == [TrainingStage.BEFORE_BATCH] + assert strategy.step_count == 1 + assert strategy.batch_count == 1 + assert strategy._last_batch is not None + + def test_train_batch_reuses_runtime_optimizer_state( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + strategy = TrainingStrategy(**baseline_strategy_kwargs) + strategy.train_batch(batch) + optimizers = strategy._optimizers + schedulers = strategy._lr_schedulers + + strategy.train_batch(_build_batch(seed=10)) + + assert strategy.step_count == 2 + assert strategy.batch_count == 2 + assert strategy._optimizers is optimizers + assert strategy._lr_schedulers is schedulers + + def test_two_epoch_loop_updates_counters_and_loss_hooks( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + after_loss_calls: list[int] = [] + + def _record(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + assert ctx.loss is not None + after_loss_calls.append(ctx.step_count) + + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "num_epochs": 2, + "hooks": [_RecordingHook(TrainingStage.AFTER_LOSS, _record)], + } + ) + dataset = _build_dataset(n_batches=3) + strategy.run(dataset) + + assert strategy.step_count == 2 * len(dataset) + assert strategy.batch_count == 2 * len(dataset) + assert strategy.epoch_count == 2 + assert strategy.epoch_step_count == 0 + assert strategy.epoch == strategy.epoch_count + assert after_loss_calls == list(range(2 * len(dataset))) + + def test_num_steps_recycles_dataloader_until_target(self) -> None: + torch.manual_seed(0) + after_loss_calls: list[int] = [] + + def _record(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + after_loss_calls.append(ctx.step_count) + + strategy = _make_strategy( + num_epochs=None, + num_steps=5, + hooks=[_RecordingHook(TrainingStage.AFTER_LOSS, _record)], + ) + strategy.run(_make_dataset(n_batches=2)) + + assert strategy.step_count == 5 + assert strategy.batch_count == 5 + assert after_loss_calls == list(range(5)) + + def test_num_steps_run_at_target_is_noop(self) -> None: + calls = 0 + + def _training_fn( + model: BaseModelMixin, batch: Batch + ) -> dict[str, torch.Tensor]: + nonlocal calls + calls += 1 + return demo_training_fn(model, batch) + + strategy = _make_strategy( + num_epochs=None, + num_steps=1, + training_fn=_training_fn, + ) + dataset = _make_dataset(n_batches=2) + + strategy.run(dataset) + strategy.run(dataset) + + assert calls == 1 + assert strategy.step_count == 1 + assert strategy.batch_count == 1 + + def test_num_epochs_run_at_converted_target_is_noop(self) -> None: + calls = 0 + + def _training_fn( + model: BaseModelMixin, batch: Batch + ) -> dict[str, torch.Tensor]: + nonlocal calls + calls += 1 + return demo_training_fn(model, batch) + + strategy = _make_strategy(num_epochs=1, training_fn=_training_fn) + dataset = _make_dataset(n_batches=2) + + strategy.run(dataset) + strategy.run(dataset) + + assert calls == len(dataset) + assert strategy.step_count == len(dataset) + assert strategy.batch_count == len(dataset) + + def test_num_epochs_target_uses_epoch_step_modifier(self) -> None: + strategy = _make_strategy(num_epochs=2, epoch_step_modifier=0.5) + strategy.run(_make_dataset(n_batches=3)) + + assert strategy.step_count == 3 + assert strategy.batch_count == 3 + + def test_num_epochs_target_counts_executed_optimizer_steps(self) -> None: + hook = _EveryOtherOptimizerStepHook() + strategy = _make_strategy( + num_epochs=1, + epoch_step_modifier=0.5, + hooks=[hook], + ) + + strategy.run(_make_dataset(n_batches=4)) + + assert strategy.step_count == 2 + assert strategy.batch_count == 4 + assert strategy.epoch_count == 1 + assert strategy.epoch_step_count == 0 + assert hook.batch_counts == [0, 1, 2, 3] + assert hook.step_counts == [0, 0, 1, 1] + assert hook.step_decisions == [False, True, False, True] + assert hook.after_skip == [True, False, True, False] + + def test_num_epochs_requires_sized_dataloader(self) -> None: + strategy = _make_strategy(num_epochs=1) + + with pytest.raises(ValueError, match="num_epochs requires a sized dataloader"): + strategy.run(iter(_make_dataset(n_batches=1))) + + def test_run_resumes_from_epoch_and_step_count(self) -> None: + dataset = _make_dataset(n_batches=3) + loader = _RestartableLoader(dataset) + batch_index = { + float(batch.energy.detach().cpu().flatten()[0]): i + for i, batch in enumerate(dataset) + } + seen_batches: list[int] = [] + + def _training_fn( + model: BaseModelMixin, batch: Batch + ) -> dict[str, torch.Tensor]: + key = float(batch.energy.detach().cpu().flatten()[0]) + seen_batches.append(batch_index[key]) + return demo_training_fn(model, batch) + + strategy = _make_strategy( + num_epochs=None, + num_steps=7, + step_count=4, + epoch_count=1, + training_fn=_training_fn, + ) + + strategy.run(loader) + + assert loader.sampler.epochs == [1, 2] + assert seen_batches == [1, 2, 0] + assert strategy.step_count == 7 + assert strategy.batch_count == 7 + assert strategy.epoch_count == 2 + assert strategy.epoch_step_count == 1 + + def test_run_resumes_from_explicit_epoch_step_count(self) -> None: + dataset = _make_dataset(n_batches=3) + loader = _RestartableLoader(dataset) + batch_index = { + float(batch.energy.detach().cpu().flatten()[0]): i + for i, batch in enumerate(dataset) + } + seen_batches: list[int] = [] + + def _training_fn( + model: BaseModelMixin, batch: Batch + ) -> dict[str, torch.Tensor]: + key = float(batch.energy.detach().cpu().flatten()[0]) + seen_batches.append(batch_index[key]) + return demo_training_fn(model, batch) + + strategy = _make_strategy( + num_epochs=None, + num_steps=7, + step_count=4, + epoch_count=1, + epoch_step_count=1, + training_fn=_training_fn, + ) + + strategy.run(loader) + + assert loader.sampler.epochs == [1, 2] + assert seen_batches == [1, 2, 0] + assert strategy.step_count == 7 + assert strategy.batch_count == 7 + assert strategy.epoch_count == 2 + assert strategy.epoch_step_count == 1 + + def test_run_rejects_inconsistent_explicit_epoch_step_count(self) -> None: + strategy = _make_strategy( + num_epochs=None, + num_steps=7, + step_count=4, + epoch_count=1, + epoch_step_count=2, + ) + + with pytest.raises(ValueError, match="restart counters are inconsistent"): + strategy.run(_make_dataset(n_batches=3)) + + +_EXPECTED_STAGE_ORDER: tuple[TrainingStage, ...] = ( + TrainingStage.BEFORE_TRAINING, + TrainingStage.BEFORE_EPOCH, + TrainingStage.BEFORE_BATCH, + TrainingStage.BEFORE_FORWARD, + TrainingStage.AFTER_FORWARD, + TrainingStage.BEFORE_LOSS, + TrainingStage.AFTER_LOSS, + TrainingStage.BEFORE_BACKWARD, + TrainingStage.AFTER_BACKWARD, + TrainingStage.BEFORE_OPTIMIZER_STEP, + TrainingStage.AFTER_OPTIMIZER_STEP, + TrainingStage.AFTER_BATCH, + TrainingStage.AFTER_EPOCH, + TrainingStage.AFTER_TRAINING, +) + + +# Snapshot shape: (loss_populated, losses_populated, requires_grad). +_LossSnapshot = tuple[bool, bool, bool] + + +def _snapshot_ctx(ctx: HookContext) -> _LossSnapshot: + return ( + ctx.loss is not None, + ctx.losses is not None, + bool(ctx.loss.requires_grad) if ctx.loss is not None else False, + ) + + +class TestTrainingStrategyHookOrder: + def test_strategy_context_manager_nests_without_reentry( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + events: list[str] = [] + + class _ContextHook: + stage = TrainingStage.BEFORE_BATCH + frequency = 1 + + def __enter__(self) -> None: + events.append("enter") + + def __exit__(self, exc_type: object, exc: object, tb: object) -> None: + events.append("exit") + + def __call__(self, ctx: HookContext, stage: Enum) -> None: + pass + + hook = _ContextHook() + strategy = TrainingStrategy(**{**baseline_strategy_kwargs, "hooks": [hook]}) + with strategy: + with strategy: + assert events == ["enter"] + assert events == ["enter"] + assert events == ["enter", "exit"] + + def test_entered_strategy_run_reuses_hook_context( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + events: list[str] = [] + + class _ContextHook: + stage = TrainingStage.BEFORE_BATCH + frequency = 1 + + def __enter__(self) -> None: + events.append("enter") + + def __exit__(self, exc_type: object, exc: object, tb: object) -> None: + events.append("exit") + + def __call__(self, ctx: HookContext, stage: Enum) -> None: # noqa: ARG002 + events.append("call") + + hook = _ContextHook() + strategy = TrainingStrategy(**{**baseline_strategy_kwargs, "hooks": [hook]}) + with strategy: + strategy.run([batch]) + assert events == ["enter", "call", "exit"] + + def test_strategy_context_exposes_named_models( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + seen_keys: list[set[str]] = [] + + def _record(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + assert isinstance(ctx, TrainContext) + seen_keys.append(set(ctx.models)) + assert ctx.model is ctx.models["main"] + + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "hooks": [_RecordingHook(TrainingStage.BEFORE_BATCH, _record)], + } + ) + strategy.run([batch]) + assert seen_keys == [{"main"}] + + def test_stage_order_one_batch( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + log: list[Enum] = [] + hooks = [ + _RecordingHook(stage, lambda ctx, s, _log=log: _log.append(s)) # noqa: ARG005 + for stage in _EXPECTED_STAGE_ORDER + ] + strategy = TrainingStrategy(**{**baseline_strategy_kwargs, "hooks": hooks}) + strategy.run([batch]) + assert tuple(log) == _EXPECTED_STAGE_ORDER + + def test_hook_context_loss_lifecycle( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + tracked_stages = ( + TrainingStage.BEFORE_LOSS, + TrainingStage.AFTER_LOSS, + TrainingStage.BEFORE_BACKWARD, + TrainingStage.AFTER_BACKWARD, + TrainingStage.BEFORE_OPTIMIZER_STEP, + TrainingStage.AFTER_BATCH, + ) + snapshots: dict[TrainingStage, list[_LossSnapshot]] = { + stage: [] for stage in tracked_stages + } + + def _record_snapshot(ctx: HookContext, stage: TrainingStage) -> None: + snapshots[stage].append(_snapshot_ctx(ctx)) + + hooks = [_RecordingHook(stage, _record_snapshot) for stage in tracked_stages] + strategy = TrainingStrategy(**{**baseline_strategy_kwargs, "hooks": hooks}) + strategy.run([batch]) + + # Before the loss is computed, loss + losses are both absent. + assert snapshots[TrainingStage.BEFORE_LOSS] == [(False, False, False)] + + # AFTER_LOSS + BEFORE_BACKWARD: loss is live and requires grad. + for stage in (TrainingStage.AFTER_LOSS, TrainingStage.BEFORE_BACKWARD): + assert snapshots[stage] == [(True, True, True)] + + # From AFTER_BACKWARD onward, loss is detached. + for stage in ( + TrainingStage.AFTER_BACKWARD, + TrainingStage.BEFORE_OPTIMIZER_STEP, + TrainingStage.AFTER_BATCH, + ): + assert snapshots[stage] == [(True, True, False)] + + +class TestTrainingStrategySpecRoundTrip: + def test_roundtrip_preserves_declarative_fields( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + loss_fn = EnergyMSELoss(per_atom=True) + ForceMSELoss( + normalize_by_atom_count=False + ) + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "optimizer_configs": { + "main": [ + OptimizerConfig( + optimizer_cls=torch.optim.Adam, + optimizer_kwargs={"lr": 1e-3}, + scheduler_cls=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 3, "gamma": 0.5}, + ) + ] + }, + "num_epochs": 2, + "epoch_step_modifier": 0.5, + "loss_fn": loss_fn, + "devices": [torch.device("cpu")], + } + ) + spec = strategy.to_spec_dict() + spec_back = json.loads(json.dumps(spec)) + + fresh_model = _build_demo_model() + restored = TrainingStrategy.from_spec_dict( + spec_back, models=fresh_model, hooks=[] + ) + assert restored.num_epochs == 2 + assert restored.num_steps is None + assert restored.epoch_step_modifier == pytest.approx(0.5) + assert restored.devices == [torch.device("cpu")] + assert restored.training_fn is demo_training_fn + assert "main" in spec["model_specs"] + assert spec["single_model_input"] is True + restored_cfg = restored.optimizer_configs["main"][0] + assert restored_cfg.optimizer_cls is torch.optim.Adam + assert restored_cfg.optimizer_kwargs["lr"] == pytest.approx(1e-3) + assert restored_cfg.scheduler_cls is torch.optim.lr_scheduler.StepLR + assert restored_cfg.scheduler_kwargs == {"step_size": 3, "gamma": 0.5} + assert isinstance(restored.loss_fn, ComposedLossFunction) + leaves = list(restored.loss_fn.components) + assert len(leaves) == 2 + assert isinstance(leaves[0], EnergyMSELoss) + assert isinstance(leaves[1], ForceMSELoss) + assert leaves[0].per_atom is True + assert leaves[1].normalize_by_atom_count is False + + def test_roundtrip_preserves_loss_weights_and_normalization( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + loss_fn = ComposedLossFunction( + [ + EnergyMSELoss(), + ForceMSELoss(normalize_by_atom_count=False), + ], + weights=[0.25, LinearWeight(start=0.1, end=0.5, num_steps=10)], + normalize_weights=False, + ) + strategy = TrainingStrategy(**{**baseline_strategy_kwargs, "loss_fn": loss_fn}) + + spec = json.loads(json.dumps(strategy.to_spec_dict())) + restored = TrainingStrategy.from_spec_dict( + spec, models=_build_demo_model(), hooks=[] + ) + + assert restored.loss_fn.normalize_weights is False + assert restored.loss_fn._weights[0] == pytest.approx(0.25) + assert isinstance(restored.loss_fn._weights[1], LinearWeight) + schedule = restored.loss_fn._weights[1] + assert schedule.start == pytest.approx(0.1) + assert schedule.end == pytest.approx(0.5) + assert schedule.num_steps == 10 + + def test_roundtrip_preserves_scaled_loss_weight_schedule(self) -> None: + schedule = LinearWeight(start=0.2, end=1.0, num_steps=10) + loss_fn = 0.25 * ComposedLossFunction([EnergyMSELoss()], weights=[schedule]) + strategy = _make_strategy(loss_fn=loss_fn) + + spec = json.loads(json.dumps(strategy.to_spec_dict())) + restored = TrainingStrategy.from_spec_dict( + spec, models=_make_demo_model(), hooks=[] + ) + + weight = restored.loss_fn._weights[0] + assert weight(0, 0) == pytest.approx(0.25 * schedule(0, 0)) + assert weight(5, 0) == pytest.approx(0.25 * schedule(5, 0)) + + def test_missing_optimizer_configs_key_raises( + self, strategy: TrainingStrategy + ) -> None: + spec = strategy.to_spec_dict() + del spec["optimizer_configs"] + with pytest.raises(ValueError, match="optimizer_configs"): + TrainingStrategy.from_spec_dict(spec, models=_build_demo_model(), hooks=[]) + + @pytest.mark.parametrize( + ("key", "value", "match"), + _FROM_SPEC_REJECTION_CASES, + ids=[ + "optimizer_configs_not_mapping", + "optimizer_config_entries_not_specs", + "devices_not_list", + "loss_fn_spec_not_mapping", + "model_specs_not_mapping", + "missing_training_fn", + "training_fn_not_string", + "single_model_input_not_bool", + ], + ) + def test_from_spec_rejects_malformed_fields( + self, key: str, value: Any, match: str, strategy: TrainingStrategy + ) -> None: + spec = strategy.to_spec_dict() + if value is _DELETE: + del spec[key] + else: + spec[key] = value + + with pytest.raises(ValueError, match=match): + TrainingStrategy.from_spec_dict(spec, models=_build_demo_model(), hooks=[]) + + def test_integer_optimizer_key_migrates_to_main( + self, strategy: TrainingStrategy + ) -> None: + spec = strategy.to_spec_dict() + original = spec["optimizer_configs"]["main"] + spec["optimizer_configs"] = {"0": original} + restored = TrainingStrategy.from_spec_dict( + spec, models=_build_demo_model(), hooks=[] + ) + assert set(restored.optimizer_configs) == {"main"} + + def test_single_model_spec_without_runtime_model_restores_single_call_mode( + self, strategy: TrainingStrategy, batch: Batch + ) -> None: + seen_args: list[BaseModelMixin | dict[str, BaseModelMixin]] = [] + + def _record_training_fn( + model: BaseModelMixin, batch: Batch + ) -> dict[str, torch.Tensor]: + seen_args.append(model) + return default_training_fn(strategy.models["main"], batch) + + restored = TrainingStrategy.from_spec_dict( + strategy.to_spec_dict(), hooks=[], training_fn=_record_training_fn + ) + restored.train_batch(batch) + assert seen_args == [restored.models["main"]] + + def test_single_main_named_spec_restores_named_call_mode( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + strategy = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "models": {"main": _build_demo_model()}, + "optimizer_configs": _build_adam_optimizer_configs(), + "training_fn": mapping_annotated_training_fn, + } + ) + + spec = strategy.to_spec_dict() + restored = TrainingStrategy.from_spec_dict(spec, hooks=[]) + + assert spec["single_model_input"] is False + assert restored.single_model_input is False + restored.run([batch]) + assert restored.step_count == 1 + + def test_model_spec_roundtrip_restores_runnable_demo_model( + self, baseline_strategy_kwargs: dict[str, Any], batch: Batch + ) -> None: + strategy = TrainingStrategy( + **{**baseline_strategy_kwargs, "training_fn": default_training_fn} + ) + restored = TrainingStrategy.from_spec_dict(strategy.to_spec_dict(), hooks=[]) + + assert restored.models["main"] is not strategy.models["main"] + restored.run([batch]) + + assert restored.step_count == 1 + + def test_runtime_model_override_merges_over_spec_models( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + spec = TrainingStrategy( + **{ + **baseline_strategy_kwargs, + "models": { + "main": _build_demo_model(), + "teacher": _build_demo_model(), + }, + "optimizer_configs": { + "main": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + "training_fn": dict_demo_training_fn, + } + ).to_spec_dict() + replacement = _build_demo_model() + restored = TrainingStrategy.from_spec_dict(spec, models=replacement, hooks=[]) + assert restored.models["main"] is replacement + assert "teacher" in restored.models + assert restored.single_model_input is False + + @pytest.mark.parametrize("drop_training_fn", [False, True]) + def test_runtime_training_fn_override( + self, drop_training_fn: bool, strategy: TrainingStrategy + ) -> None: + spec = strategy.to_spec_dict() + if drop_training_fn: + del spec["training_fn"] + restored = TrainingStrategy.from_spec_dict( + spec, + models=_build_demo_model(), + hooks=[], + training_fn=default_training_fn, + ) + assert restored.training_fn is default_training_fn + + def test_non_importable_training_fn_warns_and_is_omitted( + self, baseline_strategy_kwargs: dict[str, Any] + ) -> None: + strategy = TrainingStrategy( + **{**baseline_strategy_kwargs, "training_fn": lambda model, batch: {}} + ) + with pytest.warns(UserWarning, match="Omitting non-importable training_fn"): + spec = strategy.to_spec_dict() + assert "training_fn" not in spec + + +class TestValidationCapabilities: + """Phase A introspection methods on TrainingStrategy.""" + + def _make_validation_strategy(self, **overrides: Any) -> TrainingStrategy: + """Build a strategy with a ValidationConfig attached.""" + from nvalchemi.training._validation import ValidationConfig + + batch = _make_batch() + vc_kwargs = overrides.pop("validation_config_kwargs", {}) + vc = ValidationConfig(validation_data=[batch], **vc_kwargs) + return _make_strategy(validation_config=vc, **overrides) + + # -- model resolution (via ValidationLoop.from_training_strategy) -- + + def test_model_arg_returns_live_model_when_slot_none(self) -> None: + """No inference_model slot -> live training model.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy() + assert strategy.inference_model is None + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._model_arg is strategy.models["main"] + assert loop._modules == (strategy.models["main"],) + assert loop._ema_model_keys == () + + def test_model_arg_returns_slot_when_set_single_model(self) -> None: + """Setting inference_model returns the slot model.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy() + replacement = torch.nn.Linear(4, 4) + strategy.inference_model = replacement + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._model_arg is replacement + assert loop._modules == (replacement,) + assert loop._ema_model_keys == ("main",) + + def test_set_inference_model_moves_module_to_primary_device(self) -> None: + """Publishing inference_model preserves identity and aligns device.""" + strategy = self._make_validation_strategy(devices=[torch.device("cpu")]) + replacement = _RecordingLinear() + + strategy.set_inference_model(replacement) + + assert strategy.inference_model is replacement + assert replacement.to_calls == [ + ((torch.device("cpu"),), {"non_blocking": True}) + ] + + def test_model_arg_moduledict_slot_named_model(self) -> None: + """ModuleDict slot overrides matching keys; missing keys fall back.""" + from nvalchemi.training._validation import ValidationConfig, ValidationLoop + + teacher = _build_demo_model() + student = _build_demo_model() + strategy = _make_strategy( + models={"student": student, "teacher": teacher}, + optimizer_configs={ + "student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + training_fn=dict_demo_training_fn, + ) + strategy.validation_config = ValidationConfig(validation_data=[_make_batch()]) + ema_student = torch.nn.Linear(4, 4) + strategy.inference_model = torch.nn.ModuleDict({"student": ema_student}) + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._model_arg["student"] is ema_student + assert loop._model_arg["teacher"] is teacher + assert loop._ema_model_keys == ("student",) + assert ema_student in loop._modules + + def test_model_arg_moduledict_missing_key_falls_back(self) -> None: + """ModuleDict slot missing 'teacher' key -> live teacher model used.""" + from nvalchemi.training._validation import ValidationConfig, ValidationLoop + + teacher = _build_demo_model() + student = _build_demo_model() + strategy = _make_strategy( + models={"student": student, "teacher": teacher}, + optimizer_configs={ + "student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + training_fn=dict_demo_training_fn, + ) + strategy.validation_config = ValidationConfig(validation_data=[_make_batch()]) + ema_student = torch.nn.Linear(4, 4) + strategy.inference_model = torch.nn.ModuleDict({"student": ema_student}) + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._model_arg["teacher"] is teacher + assert "teacher" not in loop._ema_model_keys + + def test_model_arg_use_ema_always_empty_slot_raises(self) -> None: + """use_ema='always' with no inference_model slot raises.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy( + validation_config_kwargs={"use_ema": "always"}, + ) + assert strategy.inference_model is None + with pytest.raises(RuntimeError, match="inference_model slot"): + ValidationLoop.from_training_strategy(strategy) + + def test_model_arg_use_ema_never_ignores_slot(self) -> None: + """use_ema='never' ignores the slot even if populated.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy( + validation_config_kwargs={"use_ema": "never"}, + ) + replacement = torch.nn.Linear(4, 4) + strategy.inference_model = replacement + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._model_arg is strategy.models["main"] + assert loop._modules == (strategy.models["main"],) + assert loop._ema_model_keys == () + + def test_model_arg_use_ema_always_named_missing_raises(self) -> None: + """use_ema='always' with named models where slot misses a key raises.""" + from nvalchemi.training._validation import ValidationConfig, ValidationLoop + + teacher = _build_demo_model() + student = _build_demo_model() + strategy = _make_strategy( + models={"student": student, "teacher": teacher}, + optimizer_configs={ + "student": [OptimizerConfig(optimizer_cls=torch.optim.Adam)] + }, + training_fn=dict_demo_training_fn, + ) + strategy.validation_config = ValidationConfig( + validation_data=[_make_batch()], use_ema="always" + ) + strategy.inference_model = torch.nn.ModuleDict( + {"student": torch.nn.Linear(4, 4)} + ) + with pytest.raises(RuntimeError, match="missing"): + ValidationLoop.from_training_strategy(strategy) + + # -- _inference_autocast -- + + def test_inference_autocast_no_hook_returns_float32(self) -> None: + """No MixedPrecisionHook -> (nullcontext, 'float32').""" + from contextlib import nullcontext + + strategy = self._make_validation_strategy() + factory, precision = strategy._inference_autocast(torch.device("cpu")) + assert factory is nullcontext + assert precision == "float32" + + def test_inference_autocast_with_mixed_precision_hook(self) -> None: + """MixedPrecisionHook registered -> its autocast + precision label.""" + from nvalchemi.training.hooks.mixed_precision import MixedPrecisionHook + + mp = MixedPrecisionHook(precision=torch.bfloat16) + strategy = self._make_validation_strategy(hooks=[mp]) + factory, precision = strategy._inference_autocast(torch.device("cpu")) + assert precision == "bfloat16" + ctx = factory() + assert ctx is not None + + def test_inference_autocast_never_ignores_hook(self) -> None: + """use_mixed_precision='never' ignores registered MixedPrecisionHook.""" + from contextlib import nullcontext + + from nvalchemi.training.hooks.mixed_precision import MixedPrecisionHook + + mp = MixedPrecisionHook(precision=torch.bfloat16) + strategy = self._make_validation_strategy( + hooks=[mp], + validation_config_kwargs={"use_mixed_precision": "never"}, + ) + factory, precision = strategy._inference_autocast(torch.device("cpu")) + assert factory is nullcontext + assert precision == "float32" + + def test_inference_autocast_always_no_hook_raises(self) -> None: + """use_mixed_precision='always' without hook raises.""" + strategy = self._make_validation_strategy( + validation_config_kwargs={"use_mixed_precision": "always"}, + ) + with pytest.raises(RuntimeError, match="MixedPrecisionHook"): + strategy._inference_autocast(torch.device("cpu")) + + # -- grad resolution (via ValidationLoop.from_training_strategy) -- + + def test_resolve_grad_enabled(self) -> None: + """grad_mode='enabled' returns True.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy( + validation_config_kwargs={"grad_mode": "enabled"}, + ) + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._grad_enabled is True + + def test_resolve_grad_disabled(self) -> None: + """grad_mode='disabled' returns False.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy( + validation_config_kwargs={"grad_mode": "disabled"}, + ) + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._grad_enabled is False + + def test_resolve_grad_auto_with_force_loss(self) -> None: + """grad_mode='auto' with ForceMSELoss (requires_eval_grad=True) returns True.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy( + loss_fn=ForceMSELoss(), + ) + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._grad_enabled is True + + def test_resolve_grad_auto_with_energy_loss(self) -> None: + """grad_mode='auto' with EnergyMSELoss (requires_eval_grad=False) returns False.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy( + loss_fn=EnergyMSELoss(), + ) + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._grad_enabled is False + + def test_resolve_grad_auto_unknown_component_raises(self) -> None: + """grad_mode='auto' with requires_eval_grad=None raises ValueError.""" + from nvalchemi.training._validation import ValidationLoop + from nvalchemi.training.losses.composition import BaseLossFunction + + class _AmbiguousLoss(BaseLossFunction): + requires_eval_grad = None + + def compute_residual( + self, pred: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + return pred - target + + strategy = self._make_validation_strategy( + loss_fn=_AmbiguousLoss(), + ) + with pytest.raises(ValueError, match="infer whether"): + ValidationLoop.from_training_strategy(strategy) + + # -- loss resolution (via ValidationLoop.from_training_strategy) -- + + def test_resolve_loss_fn_uses_config_loss(self) -> None: + """When validation_config.loss_fn is set, use it.""" + from nvalchemi.training._validation import ValidationLoop + + val_loss = EnergyMSELoss() + strategy = self._make_validation_strategy( + validation_config_kwargs={"loss_fn": val_loss}, + ) + loop = ValidationLoop.from_training_strategy(strategy) + assert isinstance(loop._loss_fn, ComposedLossFunction) + assert isinstance(loop._loss_fn.components[0], EnergyMSELoss) + + def test_resolve_loss_fn_falls_back_to_strategy(self) -> None: + """When validation_config.loss_fn is None, use strategy.loss_fn.""" + from nvalchemi.training._validation import ValidationLoop + + strategy = self._make_validation_strategy() + loop = ValidationLoop.from_training_strategy(strategy) + assert loop._loss_fn is not None + assert len(loop._loss_fn.components) == len(strategy.loss_fn.components) + + # -- last_validation field -- + + def test_last_validation_roundtrips(self) -> None: + """last_validation is None by default and stores assigned values.""" + strategy = _make_strategy() + assert strategy.last_validation is None + strategy.last_validation = {"test": 1} + assert strategy.last_validation == {"test": 1} + + +class TestValidationSchedule: + """Phase C: validation checkpoint wiring into run().""" + + @staticmethod + def _make_schedule_strategy( + *, + every_n_epochs: int | None = None, + every_n_steps: int | None = None, + num_epochs: int | None = None, + num_steps: int | None = None, + hooks: list[Any] | None = None, + ) -> TrainingStrategy: + """Build a strategy with a ValidationConfig attached for schedule tests.""" + from nvalchemi.training._validation import ValidationConfig + + overrides: dict[str, Any] = {} + if num_epochs is not None: + overrides["num_epochs"] = num_epochs + if num_steps is not None: + overrides["num_epochs"] = None + overrides["num_steps"] = num_steps + if hooks is not None: + overrides["hooks"] = hooks + val_data = [_make_batch()] + vc = ValidationConfig( + validation_data=val_data, + every_n_epochs=every_n_epochs, + every_n_steps=every_n_steps, + ) + return _make_strategy(validation_config=vc, **overrides) + + # -- every_n_epochs -- + + def test_every_n_epochs_fires_at_correct_boundaries(self) -> None: + """Validation fires after epochs 1 and 2, plus the end-of-run pass.""" + strategy = self._make_schedule_strategy(every_n_epochs=1, num_epochs=2) + validate_epochs: list[int] = [] + orig_validate = TrainingStrategy.validate + + def _recording_validate(self_: Any) -> Any: + validate_epochs.append(self_.epoch_count) + return orig_validate(self_) + + dataset = _make_dataset(n_batches=2) + with patch.object(TrainingStrategy, "validate", _recording_validate): + strategy.run(dataset) + # Scheduled epochs 1 and 2, then the unconditional end-of-run pass. + assert validate_epochs == [1, 2, 2] + assert strategy.last_validation is not None + + def test_every_n_epochs_skips_intermediate(self) -> None: + """every_n_epochs=2: fires after epoch 2 (not 1), plus end-of-run.""" + strategy = self._make_schedule_strategy(every_n_epochs=2, num_epochs=3) + validate_epochs: list[int] = [] + orig_validate = TrainingStrategy.validate + + def _recording_validate(self_: Any) -> Any: + validate_epochs.append(self_.epoch_count) + return orig_validate(self_) + + dataset = _make_dataset(n_batches=2) + with patch.object(TrainingStrategy, "validate", _recording_validate): + strategy.run(dataset) + # Scheduled at epoch 2 only, then the unconditional end-of-run pass + # at the final epoch (3). + assert validate_epochs == [2, 3] + + def test_every_n_epochs_freshness_flag(self) -> None: + """_validation_checkpoint returns True only after validation-firing epochs.""" + strategy = self._make_schedule_strategy(every_n_epochs=2, num_epochs=2) + checkpoint_results: list[tuple[int, bool]] = [] + orig_checkpoint = TrainingStrategy._validation_checkpoint + + def _recording_checkpoint(self_: Any, stage: Any) -> Any: + result = orig_checkpoint(self_, stage) + if stage is TrainingStage.AFTER_EPOCH: + checkpoint_results.append((self_.epoch_count, result)) + return result + + dataset = _make_dataset(n_batches=2) + with patch.object( + TrainingStrategy, "_validation_checkpoint", _recording_checkpoint + ): + strategy.run(dataset) + # Epoch 1: no validation (2%2!=0), False; epoch 2: validation, True + assert checkpoint_results == [(1, False), (2, True)] + + # -- every_n_steps -- + + def test_every_n_steps_fires_at_correct_steps(self) -> None: + """every_n_steps=2 fires at step_count 2 and 4.""" + strategy = self._make_schedule_strategy(every_n_steps=2, num_steps=5) + validate_steps: list[int] = [] + orig_validate = TrainingStrategy.validate + + def _recording_validate(self_: Any) -> Any: + validate_steps.append(self_.step_count) + return orig_validate(self_) + + dataset = _make_dataset(n_batches=10) + with patch.object(TrainingStrategy, "validate", _recording_validate): + strategy.run(dataset) + # Scheduled at steps 2 and 4, then the unconditional end-of-run pass + # at the final step (5). + assert validate_steps == [2, 4, 5] + + def test_every_n_steps_freshness_toggles(self) -> None: + """_validation_checkpoint returns True only on step boundaries.""" + strategy = self._make_schedule_strategy(every_n_steps=3, num_steps=4) + checkpoint_results: list[tuple[int, bool]] = [] + orig_checkpoint = TrainingStrategy._validation_checkpoint + + def _recording_checkpoint(self_: Any, stage: Any) -> Any: + result = orig_checkpoint(self_, stage) + if stage is TrainingStage.AFTER_OPTIMIZER_STEP: + checkpoint_results.append((self_.step_count, result)) + return result + + dataset = _make_dataset(n_batches=10) + with patch.object( + TrainingStrategy, "_validation_checkpoint", _recording_checkpoint + ): + strategy.run(dataset) + # step 3 fires validation (True); steps 1, 2, 4 are False + assert checkpoint_results == [(1, False), (2, False), (3, True), (4, False)] + + # -- ordering: validate() runs AFTER EMA publishes -- + + def test_step_cadence_validate_runs_after_ema_publish(self) -> None: + """validate() reads inference_model AFTER EMA hook publishes at AFTER_OPTIMIZER_STEP.""" + from nvalchemi.training.hooks import EMAHook + + ema = EMAHook(model_key="main", decay=0.0, start_step=0) + strategy = self._make_schedule_strategy( + every_n_steps=1, num_steps=1, hooks=[ema] + ) + inference_model_at_validate: list[torch.nn.Module | None] = [] + orig_validate = TrainingStrategy.validate + + def _recording_validate(self_: Any) -> Any: + inference_model_at_validate.append(self_.inference_model) + return orig_validate(self_) + + dataset = _make_dataset(n_batches=2) + with patch.object(TrainingStrategy, "validate", _recording_validate): + strategy.run(dataset) + # Scheduled at step 1, then the unconditional end-of-run pass. + assert len(inference_model_at_validate) == 2 + # EMA should have published a module before validate was called. + assert all(model is not None for model in inference_model_at_validate) + + # -- no validation_config -- + + def test_no_validation_config_does_nothing(self) -> None: + """No validation_config: _validation_checkpoint returns False, run() works.""" + strategy = _make_strategy() + assert strategy.validation_config is None + dataset = _make_dataset(n_batches=2) + strategy.run(dataset) + assert ( + strategy._validation_checkpoint(TrainingStage.AFTER_OPTIMIZER_STEP) is False + ) + assert strategy.last_validation is None + + # -- last_validation populated -- + + def test_last_validation_populated_after_schedule(self) -> None: + """After scheduled validation, last_validation has data.""" + strategy = self._make_schedule_strategy(every_n_steps=1, num_steps=1) + dataset = _make_dataset(n_batches=2) + strategy.run(dataset) + assert strategy.last_validation is not None + assert isinstance(strategy.last_validation, dict) + + # -- unconditional end-of-run validation -- + + def test_validation_always_runs_at_end_off_boundary(self) -> None: + """A validation_config always validates at end-of-run, even off boundary.""" + strategy = self._make_schedule_strategy(every_n_steps=3, num_steps=2) + validate_steps: list[int] = [] + orig_validate = TrainingStrategy.validate + + def _recording_validate(self_: Any) -> Any: + validate_steps.append(self_.step_count) + return orig_validate(self_) + + dataset = _make_dataset(n_batches=10) + with patch.object(TrainingStrategy, "validate", _recording_validate): + strategy.run(dataset) + # No in-loop checkpoint fires (step 2 is not a multiple of 3); only the + # unconditional end-of-run pass at the final step (2) runs. + assert validate_steps == [2] + assert strategy.last_validation is not None + + # -- AFTER_VALIDATION hook -- + + def test_after_validation_hook_fires_with_live_summary(self) -> None: + """AFTER_VALIDATION hooks observe the live summary before it is consumed.""" + strategy = self._make_schedule_strategy(every_n_steps=1, num_steps=1) + observed: list[dict[str, Any] | None] = [] + + def _record(ctx: HookContext, stage: Enum) -> None: # noqa: ARG001 + observed.append(ctx.validation) + + strategy.register_hook(_RecordingHook(TrainingStage.AFTER_VALIDATION, _record)) + dataset = _make_dataset(n_batches=2) + strategy.run(dataset) + + # Fired at the step-1 checkpoint and the unconditional end-of-run pass. + assert len(observed) == 2 + assert all(summary is not None for summary in observed) + assert all("total_loss" in summary for summary in observed) + + +class TestMetricSchedulerStepping: + """Phase D: ReduceLROnPlateau steps only at validation checkpoints.""" + + @staticmethod + def _make_metric_strategy( + *, + every_n_steps: int | None = None, + every_n_epochs: int | None = None, + num_steps: int | None = None, + num_epochs: int | None = None, + plateau_patience: int = 1, + plateau_factor: float = 0.5, + plateau_lr: float = 0.1, + add_time_based: bool = False, + ) -> TrainingStrategy: + """Build a strategy with a ReduceLROnPlateau scheduler and ValidationConfig.""" + from nvalchemi.training._validation import ValidationConfig + + opt_cfgs: list[OptimizerConfig] = [ + OptimizerConfig( + optimizer_cls=torch.optim.SGD, + optimizer_kwargs={"lr": plateau_lr}, + scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau, + scheduler_kwargs={ + "patience": plateau_patience, + "factor": plateau_factor, + "threshold": 0.0, + }, + ), + ] + if add_time_based: + opt_cfgs.append( + OptimizerConfig( + optimizer_cls=torch.optim.SGD, + optimizer_kwargs={"lr": 0.5}, + scheduler_cls=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 1, "gamma": 0.9}, + ), + ) + overrides: dict[str, Any] = { + "optimizer_configs": {"main": opt_cfgs}, + } + if num_epochs is not None: + overrides["num_epochs"] = num_epochs + if num_steps is not None: + overrides["num_epochs"] = None + overrides["num_steps"] = num_steps + val_data = [_make_batch()] + vc = ValidationConfig( + validation_data=val_data, + every_n_steps=every_n_steps, + every_n_epochs=every_n_epochs, + ) + return _make_strategy(validation_config=vc, **overrides) + + def test_plateau_steps_at_validation_checkpoints(self) -> None: + """ReduceLROnPlateau.step() is called at validation checkpoints only.""" + # every_n_steps=1 with 3 steps: checkpoint at steps 1, 2, 3 + end-of-run + strategy = self._make_metric_strategy( + every_n_steps=1, + num_steps=3, + plateau_patience=0, + plateau_factor=0.5, + ) + dataset = _make_dataset(n_batches=5) + + plateau_step_calls: list[int] = [] + orig_checkpoint = TrainingStrategy._validation_checkpoint + + def _recording_checkpoint(self_: Any, stage: Any) -> Any: + result = orig_checkpoint(self_, stage) + if result: + plateau_step_calls.append(self_.step_count) + return result + + with patch.object( + TrainingStrategy, "_validation_checkpoint", _recording_checkpoint + ): + strategy.run(dataset) + + # Validation checkpoints fire at steps 1, 2, 3 + assert plateau_step_calls == [1, 2, 3] + # With patience=0 the LR drops on every validation checkpoint + # where metric doesn't improve. The plateau scheduler was stepped + # at each checkpoint, so LR should have dropped. + final_lr = strategy._optimizers[0].param_groups[0]["lr"] + assert final_lr < 0.1 + + def test_plateau_not_stepped_between_checkpoints(self) -> None: + """Between validation checkpoints, ReduceLROnPlateau is NOT stepped.""" + strategy = self._make_metric_strategy( + every_n_steps=3, + num_steps=4, + plateau_patience=10, + ) + dataset = _make_dataset(n_batches=10) + lr_at_each_step: list[float] = [] + + orig_train = TrainingStrategy._train_batch_with_optimizers + + def _recording_train(self_: Any, batch: Any, opts: Any, scheds: Any) -> Any: + result = orig_train(self_, batch, opts, scheds) + lr_at_each_step.append(opts[0].param_groups[0]["lr"]) + return result + + with patch.object( + TrainingStrategy, "_train_batch_with_optimizers", _recording_train + ): + strategy.run(dataset) + + # LR should be constant at all steps (patience=10 means no drop) + assert all(lr == pytest.approx(lr_at_each_step[0]) for lr in lr_at_each_step) + + def test_last_validation_consumed_after_checkpoint(self) -> None: + """last_validation is None after a checkpoint consumes it.""" + strategy = self._make_metric_strategy( + every_n_steps=1, + num_steps=2, + plateau_patience=10, + ) + dataset = _make_dataset(n_batches=5) + + post_checkpoint_states: list[bool] = [] + orig_checkpoint = TrainingStrategy._validation_checkpoint + + def _recording_checkpoint(self_: Any, stage: Any) -> Any: + result = orig_checkpoint(self_, stage) + if result: + # After _validation_checkpoint, last_validation should be consumed + post_checkpoint_states.append(self_.last_validation is None) + return result + + with patch.object( + TrainingStrategy, "_validation_checkpoint", _recording_checkpoint + ): + strategy.run(dataset) + + # Each checkpoint should have consumed last_validation + assert len(post_checkpoint_states) >= 1 + assert all(post_checkpoint_states) + + def test_time_based_scheduler_step_count_unchanged(self) -> None: + """A time-based StepLR scheduler steps every optimizer step, unchanged by metric support.""" + strategy = self._make_metric_strategy( + every_n_steps=2, + num_steps=4, + plateau_patience=10, + add_time_based=True, + ) + dataset = _make_dataset(n_batches=10) + strategy.run(dataset) + + # The second optimizer (StepLR with gamma=0.9, step_size=1) should + # have stepped every optimizer step. After 4 steps: lr = 0.5 * 0.9^4 + steplr_opt = strategy._optimizers[1] + expected_lr = 0.5 * (0.9**4) + actual_lr = steplr_opt.param_groups[0]["lr"] + assert actual_lr == pytest.approx(expected_lr, rel=1e-5) + + def test_plateau_lr_drops_with_constant_loss(self) -> None: + """E2E: plateau scheduler drops LR when validation loss plateaus.""" + # patience=1, factor=0.5: after 2 consecutive non-improving + # validations, LR drops. With every_n_steps=1 and num_steps=4, + # validation fires at steps 1,2,3,4 + end-of-run. The loss is + # deterministic (same val data, same model), so it plateaus. + strategy = self._make_metric_strategy( + every_n_steps=1, + num_steps=4, + plateau_patience=1, + plateau_factor=0.5, + plateau_lr=0.01, + ) + dataset = _make_dataset(n_batches=8) + initial_lr = 0.01 + strategy.run(dataset) + + final_lr = strategy._optimizers[0].param_groups[0]["lr"] + # With patience=1, the LR should have dropped at least once + assert final_lr < initial_lr diff --git a/test/training/test_strategy_validate.py b/test/training/test_strategy_validate.py new file mode 100644 index 00000000..43639480 --- /dev/null +++ b/test/training/test_strategy_validate.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for :meth:`TrainingStrategy.validate` (Phase B).""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import patch + +import pytest +import torch + +from nvalchemi.data import Batch +from nvalchemi.hooks._context import TrainContext +from nvalchemi.models.base import BaseModelMixin +from nvalchemi.training import EnergyMSELoss +from nvalchemi.training._stages import TrainingStage +from nvalchemi.training._validation import ValidationConfig +from nvalchemi.training.hooks import TrainingUpdateHook, TrainingUpdateOrchestrator +from nvalchemi.training.strategy import TrainingStrategy, default_training_fn +from test.training.conftest import ( + _build_baseline_strategy_kwargs, + _build_batch, +) + + +def _energy_only_training_fn( + model: BaseModelMixin, batch: Batch +) -> dict[str, torch.Tensor]: + """Run the demo model with only energy active.""" + active_outputs = set(model.model_config.active_outputs) + model.set_config("active_outputs", {"energy"}) + try: + return default_training_fn(model, batch) + finally: + model.set_config("active_outputs", active_outputs) + + +def _make_validation_strategy(**overrides: Any) -> TrainingStrategy: + """Build a strategy with a ValidationConfig attached.""" + batch = _build_batch() + vc_kwargs = overrides.pop("validation_config_kwargs", {}) + vc = ValidationConfig(validation_data=[batch], **vc_kwargs) + kwargs = _build_baseline_strategy_kwargs() + kwargs["validation_config"] = vc + kwargs.update(overrides) + return TrainingStrategy(**kwargs) + + +class _GradAccumulationVetoHook(TrainingUpdateHook): + """Allow only every ``accumulate_every``-th optimizer step. + + Mirrors gradient accumulation: vetoed ``DO_OPTIMIZER_STEP`` batches keep + ``step_count`` stalled at its last value while ``batch_count`` advances. + """ + + priority = 10 + + def __init__(self, accumulate_every: int) -> None: + self.accumulate_every = accumulate_every + self.optimizer_step_calls = 0 + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor | None]: + if stage is TrainingStage.DO_OPTIMIZER_STEP: + self.optimizer_step_calls += 1 + return self.optimizer_step_calls % self.accumulate_every == 0, ctx.loss + return True, ctx.loss + + +class TestStrategyValidateLiveWeights: + """validate() with default (live) model weights.""" + + def test_returns_summary_dict_with_expected_keys(self) -> None: + """validate() returns a summary dict with the canonical key set.""" + strategy = _make_validation_strategy() + summary = strategy.validate() + + assert summary is not None + assert summary["name"] == "validation" + assert summary["model_source"] == "live" + assert summary["precision"] == "float32" + assert "total_loss" in summary + assert "per_component_unweighted" in summary + assert "EnergyMSELoss" in summary["per_component_unweighted"] + assert "ForceMSELoss" in summary["per_component_unweighted"] + assert summary["num_batches"] == 1 + + def test_summary_stored_on_last_validation(self) -> None: + """validate() sets last_validation / validation property.""" + strategy = _make_validation_strategy() + summary = strategy.validate() + + assert strategy.last_validation is summary + + +class TestStrategyValidateInferenceModel: + """validate() with inference_model (EMA) slot populated.""" + + def test_single_module_slot_reports_ema_source(self) -> None: + """Setting inference_model (single module) -> model_source='ema'.""" + strategy = _make_validation_strategy( + loss_fn=EnergyMSELoss(), + training_fn=_energy_only_training_fn, + validation_config_kwargs={"grad_mode": "disabled"}, + ) + # Populate the inference_model slot with a copy of the live model + live = strategy.models["main"] + import copy + + ema_model = copy.deepcopy(live) + strategy.inference_model = ema_model + + summary = strategy.validate() + + assert summary is not None + assert summary["model_source"] == "ema" + assert summary["ema_model_keys"] == ["main"] + + +class TestStrategyValidateGradIsolation: + """validate() with grad_mode='enabled' preserves training gradients.""" + + def test_grad_enabled_restores_pre_existing_grads(self) -> None: + """Pre-existing param.grad is identical after a grad-enabled validate().""" + strategy = _make_validation_strategy( + validation_config_kwargs={"grad_mode": "enabled"}, + ) + model = strategy.models["main"] + # Set a fake gradient on every parameter + original_grads: dict[str, torch.Tensor] = {} + for name, param in model.named_parameters(): + fake_grad = torch.randn_like(param) + param.grad = fake_grad.clone() + original_grads[name] = fake_grad + + strategy.validate() + + for name, param in model.named_parameters(): + assert param.grad is not None, f"grad lost for {name}" + assert torch.equal(param.grad, original_grads[name]), ( + f"grad changed for {name}" + ) + + +class TestStrategyValidateTrainingModeRestoration: + """validate() restores module training modes when set_eval=True.""" + + def test_train_mode_restored_after_validate(self) -> None: + """Modules in train() mode before validate() are restored to train().""" + strategy = _make_validation_strategy( + validation_config_kwargs={"set_eval": True}, + ) + model = strategy.models["main"] + model.train() + + strategy.validate() + + assert model.training is True + + +class TestStrategyValidateErrorHandling: + """validate() error paths.""" + + def test_raises_when_validation_config_is_none(self) -> None: + """validate() raises RuntimeError when validation_config is not set.""" + kwargs = _build_baseline_strategy_kwargs() + strategy = TrainingStrategy(**kwargs) + assert strategy.validation_config is None + + with pytest.raises(RuntimeError, match="requires a validation_config"): + strategy.validate() + + def test_raises_when_mixed_precision_always_without_hook(self) -> None: + """use_mixed_precision='always' without MixedPrecisionHook raises RuntimeError.""" + strategy = _make_validation_strategy( + validation_config_kwargs={"use_mixed_precision": "always"}, + ) + with pytest.raises(RuntimeError, match="MixedPrecisionHook"): + strategy.validate() + + +def _find_orchestrator(strategy: TrainingStrategy) -> TrainingUpdateOrchestrator: + """Return the orchestrator the strategy coalesced its update hooks into.""" + return next( + hook for hook in strategy.hooks if isinstance(hook, TrainingUpdateOrchestrator) + ) + + +class TestStrategyValidateStepCadenceGate: + """Step cadence fires only on batches whose optimizer step ran.""" + + def test_stalled_step_count_validates_once_per_multiple(self) -> None: + """Batches stalled on an eval multiple must not re-fire validation. + + With 3-batch gradient accumulation, ``step_count`` only advances on + every third batch and sits on the eval multiple for the two vetoed + batches that follow it. Without the step-ran gate each stalled batch + re-ran a full validation pass; with it, every multiple fires exactly + once. + """ + accum = _GradAccumulationVetoHook(accumulate_every=3) + strategy = _make_validation_strategy( + validation_config_kwargs={"every_n_steps": 2}, + num_epochs=None, + num_steps=4, + hooks=[accum], + ) + validate_steps: list[int] = [] + orig_validate = TrainingStrategy.validate + + def _recording_validate(self_: Any) -> Any: + validate_steps.append(self_.step_count) + return orig_validate(self_) + + dataset = [_build_batch(seed=i * 10) for i in range(12)] + with patch.object(TrainingStrategy, "validate", _recording_validate): + strategy.run(dataset) + + assert strategy.batch_count == 12 + assert strategy.step_count == 4 + # Steps 2 and 4 each fire exactly once — the stalled batches after + # step 2 (vetoed optimizer steps) do not re-fire — followed by the + # unconditional end-of-run pass at the final step (4). + assert validate_steps == [2, 4, 4] + + def test_vetoed_step_does_not_fire_on_parked_multiple(self) -> None: + """The gate mirrors the orchestrator's per-batch step-skipped flag.""" + strategy = _make_validation_strategy( + validation_config_kwargs={"every_n_steps": 2}, + hooks=[_GradAccumulationVetoHook(accumulate_every=3)], + ) + orchestrator = _find_orchestrator(strategy) + strategy.step_count = 4 + + # The batch that advanced step_count onto the multiple fires. + orchestrator._optimizer_step_skipped = False + assert strategy._should_validate(TrainingStage.AFTER_OPTIMIZER_STEP) is True + assert ( + strategy._validation_checkpoint(TrainingStage.AFTER_OPTIMIZER_STEP) is True + ) + assert strategy.last_validation is not None + + # Vetoed batches leave step_count parked; must not re-fire. + orchestrator._optimizer_step_skipped = True + assert strategy._should_validate(TrainingStage.AFTER_OPTIMIZER_STEP) is False + assert ( + strategy._validation_checkpoint(TrainingStage.AFTER_OPTIMIZER_STEP) is False + ) + + def test_plain_strategy_without_orchestrator_always_fires(self) -> None: + """Without an update orchestrator every step counts as ran.""" + strategy = _make_validation_strategy( + validation_config_kwargs={"every_n_steps": 2}, + ) + strategy.step_count = 4 + + assert strategy._should_validate(TrainingStage.AFTER_OPTIMIZER_STEP) is True + assert ( + strategy._validation_checkpoint(TrainingStage.AFTER_OPTIMIZER_STEP) is True + ) + # Off-multiple steps stay gated by the cadence itself. + strategy.step_count = 5 + assert strategy._should_validate(TrainingStage.AFTER_OPTIMIZER_STEP) is False + + def test_epoch_cadence_ignores_step_ran_signal(self) -> None: + """every_n_epochs fires even when the last optimizer step was vetoed.""" + strategy = _make_validation_strategy( + validation_config_kwargs={"every_n_epochs": 1}, + hooks=[_GradAccumulationVetoHook(accumulate_every=3)], + ) + orchestrator = _find_orchestrator(strategy) + orchestrator._optimizer_step_skipped = True + strategy.step_count = 7 + strategy.epoch_count = 1 + + # Epoch cadence ignores the step-ran signal. + assert strategy._should_validate(TrainingStage.AFTER_EPOCH) is True + assert strategy._validation_checkpoint(TrainingStage.AFTER_EPOCH) is True diff --git a/test/training/test_training_update_orchestrator.py b/test/training/test_training_update_orchestrator.py new file mode 100644 index 00000000..7ce0d7a9 --- /dev/null +++ b/test/training/test_training_update_orchestrator.py @@ -0,0 +1,1361 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ``TrainingUpdateHook`` and ``TrainingUpdateOrchestrator``. + +Covers the hook framework defined in ``nvalchemi.training.hooks.update`` and +its integration with :class:`TrainingStrategy` (auto-wrap, conflict +detection, dispatch-driven training-loop suppression). The strategy-level +helpers (``demo_training_fn``, ``_make_demo_model`` etc.) are duplicated +locally rather than imported from ``test_strategy`` to keep these tests +self-contained and immune to pytest collection ordering. +""" + +from __future__ import annotations + +import contextlib +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pydantic +import pytest +import torch + +from nvalchemi.data import AtomicData, Batch +from nvalchemi.hooks._context import TrainContext +from nvalchemi.hooks._protocol import Hook +from nvalchemi.models.base import BaseModelMixin +from nvalchemi.training import ( + EnergyMSELoss, + ForceMSELoss, + TrainingStage, +) +from nvalchemi.training.hooks import ( + TrainingUpdateHook, + TrainingUpdateOrchestrator, +) +from nvalchemi.training.hooks.update import ( + _MULTIPLE_ORCHESTRATOR_MSG, + _check_veto, + _fold_training_update_hooks, + _hook_claims_stage, +) +from nvalchemi.training.optimizers import OptimizerConfig +from nvalchemi.training.strategy import TrainingStrategy, default_training_fn + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + +_UPDATE_STAGES: tuple[TrainingStage, ...] = ( + TrainingStage.BEFORE_BATCH, + TrainingStage.DO_BACKWARD, + TrainingStage.DO_OPTIMIZER_STEP, + TrainingStage.AFTER_OPTIMIZER_STEP, +) + +_NON_UPDATE_STAGES: tuple[TrainingStage, ...] = tuple( + s for s in TrainingStage if s not in _UPDATE_STAGES +) + +_GATED_STAGES: tuple[TrainingStage, ...] = ( + TrainingStage.BEFORE_BATCH, + TrainingStage.DO_OPTIMIZER_STEP, +) + +_DO_STAGES: tuple[TrainingStage, ...] = ( + TrainingStage.DO_BACKWARD, + TrainingStage.DO_OPTIMIZER_STEP, +) + + +def _demo_training_fn(model: BaseModelMixin, batch: Batch) -> dict[str, torch.Tensor]: + return default_training_fn(model, batch) + + +def _make_atomic_data(n_atoms: int = 3, seed: int = 0) -> AtomicData: + g = torch.Generator().manual_seed(seed) + positions = torch.randn(n_atoms, 3, generator=g) + atomic_numbers = torch.randint(1, 10, (n_atoms,), dtype=torch.long, generator=g) + energy = torch.randn(1, 1, generator=g) + forces = torch.randn(n_atoms, 3, generator=g) + return AtomicData( + positions=positions, + atomic_numbers=atomic_numbers, + atomic_masses=torch.ones(n_atoms), + energy=energy, + forces=forces, + ) + + +def _make_batch(n_systems: int = 2, n_atoms_each: int = 3, seed: int = 0) -> Batch: + return Batch.from_data_list( + [_make_atomic_data(n_atoms_each, seed=seed + i) for i in range(n_systems)] + ) + + +def _make_demo_model() -> Any: + from nvalchemi.models.demo import DemoModel, DemoModelWrapper + + torch.manual_seed(0) + return DemoModelWrapper(DemoModel(num_atom_types=20, hidden_dim=8)) + + +def _baseline_strategy_kwargs() -> dict[str, Any]: + return { + "models": _make_demo_model(), + "optimizer_configs": OptimizerConfig(optimizer_cls=torch.optim.Adam), + "num_epochs": 1, + "training_fn": _demo_training_fn, + "loss_fn": EnergyMSELoss() + ForceMSELoss(normalize_by_atom_count=True), + } + + +def _make_strategy(**overrides: Any) -> TrainingStrategy: + kwargs = _baseline_strategy_kwargs() + kwargs.update(overrides) + return TrainingStrategy(**kwargs) + + +def _make_ctx(loss: torch.Tensor | None = None) -> TrainContext: + if loss is None: + loss = torch.tensor(1.0) + batch = _make_batch() + return TrainContext(batch=batch, loss=loss) + + +def _single_orchestrator(strategy: TrainingStrategy) -> TrainingUpdateOrchestrator: + orchs = [h for h in strategy.hooks if isinstance(h, TrainingUpdateOrchestrator)] + assert len(orchs) == 1, f"Expected exactly one orchestrator, found {len(orchs)}" + return orchs[0] + + +@contextlib.contextmanager +def _patched_update_helpers(): # type: ignore[no-untyped-def] + """Patch the orchestrator-side and strategy-side training helpers. + + Yields a ``SimpleNamespace`` with attributes: + ``orch_zero``, ``orch_step``, ``orch_sched`` for the orchestrator path + (``nvalchemi.training.hooks.update.*``) and ``strategy_zero``, + ``strategy_step``, ``strategy_sched`` for the strategy default path + (``nvalchemi.training.strategy.*``). + """ + with ( + patch("nvalchemi.training.hooks.update.zero_gradients") as orch_zero, + patch("nvalchemi.training.hooks.update.step_optimizers") as orch_step, + patch("nvalchemi.training.hooks.update.step_lr_schedulers") as orch_sched, + patch("nvalchemi.training.strategy.zero_gradients") as strategy_zero, + patch("nvalchemi.training.strategy.step_optimizers") as strategy_step, + patch("nvalchemi.training.strategy.step_lr_schedulers") as strategy_sched, + ): + yield SimpleNamespace( + orch_zero=orch_zero, + orch_step=orch_step, + orch_sched=orch_sched, + strategy_zero=strategy_zero, + strategy_step=strategy_step, + strategy_sched=strategy_sched, + ) + + +@contextlib.contextmanager +def _run_strategy_with_patched_helpers(hooks: list[Any]): # type: ignore[no-untyped-def] + """Build a strategy from ``hooks``, train one batch, and yield the mock namespace. + + The helper patches both strategy-side and orchestrator-side update helpers + while ``train_batch`` runs synchronously, then yields the mocks for + inspection. + """ + strategy = _make_strategy(hooks=hooks) + with _patched_update_helpers() as m: + strategy.train_batch(_make_batch()) + yield m + + +# --------------------------------------------------------------------------- +# Hook subclasses for tests +# --------------------------------------------------------------------------- + + +class _RecordingUpdateHook(TrainingUpdateHook): + def __init__(self, priority: int = 50) -> None: + self.priority = priority + self.calls: list[tuple[TrainingStage, bool]] = [] + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor]: + self.calls.append((stage, will_skip)) + return True, ctx.loss + + +class _VetoHook(TrainingUpdateHook): + def __init__(self, veto_stage: TrainingStage, priority: int = 50) -> None: + self.priority = priority + self.veto_stage = veto_stage + self.calls: list[tuple[TrainingStage, bool]] = [] + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor]: + self.calls.append((stage, will_skip)) + return stage is not self.veto_stage, ctx.loss + + +class _BadProceedHook(TrainingUpdateHook): + def __init__(self, proceed: object, priority: int = 50) -> None: + self.priority = priority + self._bad_proceed = proceed + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor]: + return self._bad_proceed, ctx.loss # type: ignore[return-value] + + +class _LossTransformHook(TrainingUpdateHook): + def __init__(self, factor: float, priority: int = 50) -> None: + self.priority = priority + self.factor = factor + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor]: + if stage == TrainingStage.DO_BACKWARD: + return True, ctx.loss * self.factor + return True, ctx.loss + + +class _NoneLossHook(TrainingUpdateHook): + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, None]: + return True, None + + +class _GradScalerSetHook(TrainingUpdateHook): + """Update hook that writes ``ctx.grad_scaler`` on ``DO_BACKWARD``.""" + + priority = 10 + + def __init__(self, scaler: object) -> None: + self._scaler = scaler + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor]: + if stage == TrainingStage.DO_BACKWARD: + ctx.grad_scaler = self._scaler + return True, ctx.loss + + +class _GradScalerReadHook(TrainingUpdateHook): + """Update hook that records ``ctx.grad_scaler`` on ``DO_BACKWARD``.""" + + priority = 20 + + def __init__(self) -> None: + self.observed: object | None = None + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor]: + if stage == TrainingStage.DO_BACKWARD: + self.observed = ctx.grad_scaler + return True, ctx.loss + + +class _LifecycleUpdateHook(TrainingUpdateHook): + """Update hook that records lifecycle method calls.""" + + def __init__(self, name: str, events: list[str], priority: int = 50) -> None: + self.name = name + self.events = events + self.priority = priority + + def __enter__(self) -> None: + self.events.append(f"enter:{self.name}") + + def __exit__(self, exc_type: object, exc: object, tb: object) -> None: + self.events.append(f"exit:{self.name}") + + +class _CloseOnlyUpdateHook(TrainingUpdateHook): + """Update hook that records ``close`` calls.""" + + def __init__(self, name: str, events: list[str], priority: int = 50) -> None: + self.name = name + self.events = events + self.priority = priority + + def close(self) -> None: + self.events.append(f"close:{self.name}") + + +class _AfterOptimizerStepHook(TrainingUpdateHook): + """Update hook that records ``will_skip`` on ``AFTER_OPTIMIZER_STEP``.""" + + def __init__(self, priority: int = 50) -> None: + self.priority = priority + self.will_skip_values: list[bool] = [] + + def __call__( + self, + ctx: TrainContext, + stage: TrainingStage, + will_skip: bool, + ) -> tuple[bool, torch.Tensor]: + if stage == TrainingStage.AFTER_OPTIMIZER_STEP: + self.will_skip_values.append(will_skip) + return True, ctx.loss + + +class _FakeEqHook: + """Hook-like object whose ``__eq__`` always returns ``True``. + + Used to verify that ``_validate_single_do_claimants`` uses identity + (``is``) rather than equality when checking whether the candidate hook + is already in the existing hook list. + """ + + def __init__(self, stage: TrainingStage | None = None) -> None: + self.stage = stage + self.frequency = 1 + + def __eq__(self, other: object) -> bool: + return True + + def __hash__(self) -> int: + return id(self) + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + return None + + +class _StageOnlyHook: + def __init__(self, stage: TrainingStage) -> None: + self.stage = stage + self.frequency = 1 + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + return None + + +class _ContextCaptureHook(_StageOnlyHook): + def __init__(self, stage: TrainingStage) -> None: + super().__init__(stage) + self.contexts: list[TrainContext] = [] + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + self.contexts.append(ctx) + + +class _RaisingStageHook(_StageOnlyHook): + def __init__(self, stage: TrainingStage) -> None: + super().__init__(stage) + self.enabled = True + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + if self.enabled: + raise RuntimeError("forced hook failure") + + +class _DoBackwardOwnerHook(_StageOnlyHook): + def __init__(self) -> None: + super().__init__(TrainingStage.DO_BACKWARD) + self.calls = 0 + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + self.calls += 1 + assert ctx.loss is not None + ctx.loss.backward() + + +class _DoOptimizerStepOwnerHook(_StageOnlyHook): + def __init__(self) -> None: + super().__init__(TrainingStage.DO_OPTIMIZER_STEP) + self.calls = 0 + self.contexts: list[TrainContext] = [] + + def __call__(self, ctx: TrainContext, stage: TrainingStage) -> None: + self.calls += 1 + self.contexts.append(ctx) + for optimizer in ctx.optimizers: + optimizer.step() + for scheduler in ctx.lr_schedulers: + if scheduler is not None: + scheduler.step() + + +class _HybridStageRunsOnHook(_StageOnlyHook): + def _runs_on_stage(self, stage: TrainingStage) -> bool: + return False + + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + + +class TestTrainingUpdateHookDefaults: + def test_default_priority_is_fifty(self) -> None: + assert TrainingUpdateHook.priority == 50 + + def test_runs_on_stage_true_for_update_stages(self) -> None: + hook = TrainingUpdateHook() + for stage in _UPDATE_STAGES: + assert hook._runs_on_stage(stage) is True + + def test_runs_on_stage_false_for_non_update_stages(self) -> None: + hook = TrainingUpdateHook() + for stage in _NON_UPDATE_STAGES: + assert hook._runs_on_stage(stage) is False, ( + f"Expected False for {stage.name}, got True." + ) + + def test_default_call_returns_true_and_ctx_loss(self) -> None: + loss = torch.tensor(3.14) + ctx = _make_ctx(loss=loss) + hook = TrainingUpdateHook() + for stage in _UPDATE_STAGES: + proceed, returned_loss = hook(ctx, stage, will_skip=False) + assert proceed is True + assert returned_loss is loss + + +class TestAddAlgebra: + def test_hook_plus_hook_yields_orchestrator(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + result = a + b + assert isinstance(result, TrainingUpdateOrchestrator) + assert result._hooks == [a, b] + + def test_hook_plus_orchestrator_yields_flat_orchestrator(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + c = _RecordingUpdateHook(priority=30) + orch = TrainingUpdateOrchestrator(b, c) + result = a + orch + assert isinstance(result, TrainingUpdateOrchestrator) + assert result._hooks == [a, b, c] + for inner in result._hooks: + assert not isinstance(inner, TrainingUpdateOrchestrator) + + def test_orchestrator_plus_hook_yields_flat_orchestrator(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=30) + c = _RecordingUpdateHook(priority=20) + orch = TrainingUpdateOrchestrator(a, b) + result = orch + c + assert isinstance(result, TrainingUpdateOrchestrator) + assert result._hooks == [a, c, b] + + def test_orchestrator_plus_orchestrator_yields_flat_orchestrator(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=40) + c = _RecordingUpdateHook(priority=20) + d = _RecordingUpdateHook(priority=30) + left = TrainingUpdateOrchestrator(a, b) + right = TrainingUpdateOrchestrator(c, d) + result = left + right + assert isinstance(result, TrainingUpdateOrchestrator) + assert result._hooks == [a, c, d, b] + + def test_constituents_preserve_identity(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + result = a + b + assert result._hooks[0] is a + assert result._hooks[1] is b + + def test_hook_plus_int_raises_type_error(self) -> None: + hook = _RecordingUpdateHook() + with pytest.raises(TypeError): + _ = hook + 42 # type: ignore[operator] + + def test_orchestrator_plus_int_raises_type_error(self) -> None: + orch = TrainingUpdateOrchestrator(_RecordingUpdateHook()) + with pytest.raises(TypeError): + _ = orch + 42 # type: ignore[operator] + + def test_addition_never_returns_bare_hook(self) -> None: + a = _RecordingUpdateHook() + b = _RecordingUpdateHook() + assert isinstance(a + b, TrainingUpdateOrchestrator) + + +class TestPriorityOrdering: + def test_three_hooks_sorted_ascending(self) -> None: + h_high = _RecordingUpdateHook(priority=30) + h_low = _RecordingUpdateHook(priority=10) + h_mid = _RecordingUpdateHook(priority=20) + orch = TrainingUpdateOrchestrator(h_high, h_low, h_mid) + ctx = _make_ctx() + orch(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + assert [h.priority for h in orch._hooks] == [10, 20, 30] + assert orch._hooks == [h_low, h_mid, h_high] + assert h_low.calls and h_mid.calls and h_high.calls + + def test_stable_sort_preserves_insertion_order_on_ties(self) -> None: + first = _RecordingUpdateHook(priority=20) + second = _RecordingUpdateHook(priority=20) + third = _RecordingUpdateHook(priority=20) + orch = TrainingUpdateOrchestrator(first, second, third) + assert orch._hooks == [first, second, third] + + +class TestUpdateStageDispatch: + def test_before_batch_calls_zero_gradients_when_proceed(self) -> None: + hook = _RecordingUpdateHook(priority=10) + orch = TrainingUpdateOrchestrator(hook) + ctx = _make_ctx() + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.BEFORE_BATCH) + m.orch_zero.assert_called_once_with(ctx.optimizers) + + def test_before_batch_skips_zero_gradients_on_veto(self) -> None: + hook = _VetoHook(veto_stage=TrainingStage.BEFORE_BATCH, priority=10) + orch = TrainingUpdateOrchestrator(hook) + ctx = _make_ctx() + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.BEFORE_BATCH) + m.orch_zero.assert_not_called() + + def test_do_backward_calls_backward_and_assigns_loss(self) -> None: + param = torch.nn.Parameter(torch.tensor([1.0])) + loss = (param * 3.0).sum() # dL/dparam = 3 prior to chain + hook = _LossTransformHook(factor=2.0, priority=10) + orch = TrainingUpdateOrchestrator(hook) + ctx = _make_ctx(loss=loss) + orch(ctx, TrainingStage.DO_BACKWARD) + assert param.grad is not None + # Original grad (3.0) scaled by 2.0 = 6.0 + assert param.grad.item() == pytest.approx(6.0) + # ctx.loss is replaced with the transformed scalar tensor. + assert ctx.loss is not loss + + def test_do_optimizer_step_calls_step_helpers_when_proceed(self) -> None: + hook = _RecordingUpdateHook(priority=10) + orch = TrainingUpdateOrchestrator(hook) + ctx = _make_ctx() + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + m.orch_step.assert_called_once_with(ctx.optimizers) + m.orch_sched.assert_called_once_with(ctx.lr_schedulers) + + def test_do_optimizer_step_uses_grad_scaler_when_present(self) -> None: + hook = _RecordingUpdateHook(priority=10) + orch = TrainingUpdateOrchestrator(hook) + optimizer = Mock(spec=torch.optim.Optimizer) + scheduler = Mock() + scaler = Mock(spec=torch.amp.GradScaler) + scaler.get_scale.side_effect = [128.0, 128.0] + ctx = _make_ctx() + ctx.optimizers = [optimizer] + ctx.lr_schedulers = [scheduler] + ctx.grad_scaler = scaler + + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + + m.orch_step.assert_not_called() + scaler.step.assert_called_once_with(optimizer) + scaler.update.assert_called_once_with() + m.orch_sched.assert_called_once_with(ctx.lr_schedulers) + assert orch.optimizer_step_skipped is False + + def test_do_optimizer_step_skips_schedulers_when_grad_scaler_skips(self) -> None: + observer = _AfterOptimizerStepHook(priority=20) + orch = TrainingUpdateOrchestrator(_RecordingUpdateHook(priority=10), observer) + optimizer = Mock(spec=torch.optim.Optimizer) + scheduler = Mock() + scaler = Mock(spec=torch.amp.GradScaler) + scaler.get_scale.side_effect = [128.0, 64.0] + ctx = _make_ctx() + ctx.optimizers = [optimizer] + ctx.lr_schedulers = [scheduler] + ctx.grad_scaler = scaler + + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + orch(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + + m.orch_step.assert_not_called() + scaler.step.assert_called_once_with(optimizer) + scaler.update.assert_called_once_with() + m.orch_sched.assert_not_called() + assert orch.optimizer_step_skipped is True + assert observer.will_skip_values == [True] + + def test_do_optimizer_step_skips_step_helpers_on_veto(self) -> None: + hook = _VetoHook(veto_stage=TrainingStage.DO_OPTIMIZER_STEP, priority=10) + orch = TrainingUpdateOrchestrator(hook) + ctx = _make_ctx() + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + m.orch_step.assert_not_called() + m.orch_sched.assert_not_called() + + def test_after_optimizer_step_iterates_with_will_skip_false(self) -> None: + h1 = _RecordingUpdateHook(priority=10) + h2 = _RecordingUpdateHook(priority=20) + orch = TrainingUpdateOrchestrator(h1, h2) + ctx = _make_ctx() + orch(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + assert h1.calls == [(TrainingStage.AFTER_OPTIMIZER_STEP, False)] + assert h2.calls == [(TrainingStage.AFTER_OPTIMIZER_STEP, False)] + + def test_after_optimizer_step_receives_will_skip_false_after_step(self) -> None: + observer = _AfterOptimizerStepHook(priority=10) + orch = TrainingUpdateOrchestrator(observer) + ctx = _make_ctx() + with _patched_update_helpers(): + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + orch(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + assert observer.will_skip_values == [False] + + def test_after_optimizer_step_receives_will_skip_true_after_veto(self) -> None: + veto = _VetoHook(veto_stage=TrainingStage.DO_OPTIMIZER_STEP, priority=10) + observer = _AfterOptimizerStepHook(priority=20) + orch = TrainingUpdateOrchestrator(veto, observer) + ctx = _make_ctx() + with _patched_update_helpers(): + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + orch(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) + assert observer.will_skip_values == [True] + + @pytest.mark.parametrize( + "stage", + [TrainingStage.BEFORE_FORWARD, TrainingStage.BEFORE_BACKWARD], + ids=lambda s: s.name, + ) + def test_non_update_stage_is_noop(self, stage: TrainingStage) -> None: + h1 = _RecordingUpdateHook(priority=10) + h2 = _RecordingUpdateHook(priority=20) + orch = TrainingUpdateOrchestrator(h1, h2) + ctx = _make_ctx() + orch(ctx, stage) + assert h1.calls == [] + assert h2.calls == [] + + +class TestVetoComposition: + def test_before_batch_no_short_circuit_all_hooks_called(self) -> None: + h1 = _RecordingUpdateHook(priority=10) + h2 = _VetoHook(veto_stage=TrainingStage.BEFORE_BATCH, priority=20) + h3 = _RecordingUpdateHook(priority=30) + h4 = _RecordingUpdateHook(priority=40) + orch = TrainingUpdateOrchestrator(h1, h2, h3, h4) + ctx = _make_ctx() + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.BEFORE_BATCH) + assert len(h1.calls) == 1 + assert len(h2.calls) == 1 + assert len(h3.calls) == 1 + assert len(h4.calls) == 1 + # Hooks BEFORE the vetoing hook saw will_skip=False. + assert h1.calls[0] == (TrainingStage.BEFORE_BATCH, False) + assert h2.calls[0] == (TrainingStage.BEFORE_BATCH, False) + # Hooks AFTER the vetoing hook saw will_skip=True. + assert h3.calls[0] == (TrainingStage.BEFORE_BATCH, True) + assert h4.calls[0] == (TrainingStage.BEFORE_BATCH, True) + m.orch_zero.assert_not_called() + + def test_do_optimizer_step_veto_suppresses_both_helpers(self) -> None: + h1 = _RecordingUpdateHook(priority=10) + h2 = _VetoHook(veto_stage=TrainingStage.DO_OPTIMIZER_STEP, priority=20) + orch = TrainingUpdateOrchestrator(h1, h2) + ctx = _make_ctx() + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.DO_OPTIMIZER_STEP) + m.orch_step.assert_not_called() + m.orch_sched.assert_not_called() + + def test_any_false_among_trues_wins(self) -> None: + # Five priority-buckets; only the priority-30 hook vetoes BEFORE_BATCH. + hooks: list[TrainingUpdateHook] = [ + _RecordingUpdateHook(priority=10), + _RecordingUpdateHook(priority=20), + _VetoHook(veto_stage=TrainingStage.BEFORE_BATCH, priority=30), + _RecordingUpdateHook(priority=40), + _RecordingUpdateHook(priority=50), + ] + orch = TrainingUpdateOrchestrator(*hooks) + ctx = _make_ctx() + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.BEFORE_BATCH) + m.orch_zero.assert_not_called() + + def test_all_true_path_runs_gated_operation(self) -> None: + hooks = [_RecordingUpdateHook(priority=p) for p in (10, 20, 30)] + orch = TrainingUpdateOrchestrator(*hooks) + ctx = _make_ctx() + with _patched_update_helpers() as m: + orch(ctx, TrainingStage.BEFORE_BATCH) + m.orch_zero.assert_called_once_with(ctx.optimizers) + + +class TestLossChain: + def test_two_hook_chain_multiplies_loss(self) -> None: + param = torch.nn.Parameter(torch.tensor([1.0])) + x = 5.0 + loss = (param * x).sum() # base dL/dparam = 5 + hook_lo = _LossTransformHook(factor=0.5, priority=10) + hook_hi = _LossTransformHook(factor=4.0, priority=20) + orch = TrainingUpdateOrchestrator(hook_lo, hook_hi) + ctx = _make_ctx(loss=loss) + orch(ctx, TrainingStage.DO_BACKWARD) + assert param.grad is not None + assert param.grad.item() == pytest.approx(2.0 * x) + + def test_passthrough_hook_preserves_chain(self) -> None: + param = torch.nn.Parameter(torch.tensor([1.0])) + x = 2.0 + loss = (param * x).sum() + hook_lo = _LossTransformHook(factor=3.0, priority=10) + # Default __call__ returns (True, ctx.loss) — pass-through. + hook_passthrough = _RecordingUpdateHook(priority=20) + orch = TrainingUpdateOrchestrator(hook_lo, hook_passthrough) + ctx = _make_ctx(loss=loss) + orch(ctx, TrainingStage.DO_BACKWARD) + assert param.grad is not None + assert param.grad.item() == pytest.approx(3.0 * x) + + def test_ctx_loss_replaced_post_chain(self) -> None: + param = torch.nn.Parameter(torch.tensor([1.0])) + original = (param * 1.0).sum() + orch = TrainingUpdateOrchestrator( + _LossTransformHook(factor=0.5, priority=10), + _LossTransformHook(factor=4.0, priority=20), + ) + ctx = _make_ctx(loss=original) + orch(ctx, TrainingStage.DO_BACKWARD) + assert ctx.loss is not original + # Final scalar value: 1.0 * 0.5 * 4.0 = 2.0. + assert ctx.loss.item() == pytest.approx(2.0) + + def test_do_backward_rejects_none_loss(self) -> None: + param = torch.nn.Parameter(torch.tensor([1.0])) + loss = (param * 2.0).sum() + hook = _NoneLossHook() + orch = TrainingUpdateOrchestrator(hook) + ctx = _make_ctx(loss=loss) + with pytest.raises(TypeError, match="DO_BACKWARD"): + orch(ctx, TrainingStage.DO_BACKWARD) + assert param.grad is None + + +class TestStrictBoolValidation: + """``_check_veto`` rejects non-bool ``proceed`` returns on gated stages.""" + + @pytest.mark.parametrize( + "bad_value", + [None, 1, 0, "yes", []], + ids=["none", "int_truthy", "int_zero", "str", "list"], + ) + @pytest.mark.parametrize("stage", _GATED_STAGES, ids=lambda s: s.name) + def test_non_bool_proceed_raises_type_error( + self, bad_value: object, stage: TrainingStage + ) -> None: + hook = _BadProceedHook(proceed=bad_value) + orch = TrainingUpdateOrchestrator(hook) + ctx = _make_ctx() + with ( + _patched_update_helpers(), + pytest.raises(TypeError, match=stage.name) as exc_info, + ): + orch(ctx, stage) + assert "_BadProceedHook" in str(exc_info.value) + + @pytest.mark.parametrize("stage", _GATED_STAGES, ids=lambda s: s.name) + def test_true_proceed_does_not_raise(self, stage: TrainingStage) -> None: + hook = _RecordingUpdateHook(priority=10) + orch = TrainingUpdateOrchestrator(hook) + ctx = _make_ctx() + with _patched_update_helpers(): + orch(ctx, stage) # no raise + + def test_non_gated_stages_skip_veto_validation(self) -> None: + param = torch.nn.Parameter(torch.tensor([1.0])) + loss = (param * 2.0).sum() + # proceed=None must not raise on DO_BACKWARD or AFTER_OPTIMIZER_STEP. + bad = _BadProceedHook(proceed=None, priority=10) + orch = TrainingUpdateOrchestrator(bad) + ctx = _make_ctx(loss=loss) + orch(ctx, TrainingStage.DO_BACKWARD) # no raise + assert param.grad is not None + orch(ctx, TrainingStage.AFTER_OPTIMIZER_STEP) # no raise + + def test_check_veto_helper_directly(self) -> None: + sentinel = object() + with pytest.raises(TypeError, match="DO_OPTIMIZER_STEP"): + _check_veto(None, sentinel, TrainingStage.DO_OPTIMIZER_STEP) + # bool decision passes silently. + _check_veto(True, sentinel, TrainingStage.BEFORE_BATCH) + _check_veto(False, sentinel, TrainingStage.BEFORE_BATCH) + + +class TestOrchestratorConstructor: + def test_empty_orchestrator_succeeds(self) -> None: + orch = TrainingUpdateOrchestrator() + assert orch._hooks == [] + + def test_two_hooks_flattened(self) -> None: + a = _RecordingUpdateHook(priority=20) + b = _RecordingUpdateHook(priority=10) + orch = TrainingUpdateOrchestrator(a, b) + assert orch._hooks == [b, a] + + def test_nested_orchestrator_flattened(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + inner = TrainingUpdateOrchestrator(a, b) + c = _RecordingUpdateHook(priority=30) + outer = TrainingUpdateOrchestrator(inner, c) + assert outer._hooks == [a, b, c] + for hook in outer._hooks: + assert not isinstance(hook, TrainingUpdateOrchestrator) + + @pytest.mark.parametrize( + ("bad_value", "type_name"), + [(42, "int"), ("a string", "str"), (object(), "object")], + ) + def test_non_hook_argument_raises_type_error( + self, bad_value: object, type_name: str + ) -> None: + a = _RecordingUpdateHook(priority=10) + with pytest.raises(TypeError, match="argument 1") as exc_info: + TrainingUpdateOrchestrator(a, bad_value) # type: ignore[arg-type] + assert type_name in str(exc_info.value) + assert "*hooks" in str(exc_info.value) + + def test_list_instead_of_varargs_raises_type_error(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + with pytest.raises(TypeError, match="argument 0") as exc_info: + TrainingUpdateOrchestrator([a, b]) # type: ignore[arg-type] + assert "list" in str(exc_info.value) + + +class TestOrchestratorLifecycle: + def test_context_manager_forwards_lifecycle_to_children(self) -> None: + events: list[str] = [] + first = _LifecycleUpdateHook("first", events, priority=10) + second = _LifecycleUpdateHook("second", events, priority=20) + + with TrainingUpdateOrchestrator(second, first): + events.append("inside") + + assert events == [ + "enter:first", + "enter:second", + "inside", + "exit:second", + "exit:first", + ] + + def test_context_manager_closes_children_without_exit(self) -> None: + events: list[str] = [] + first = _CloseOnlyUpdateHook("first", events, priority=10) + second = _CloseOnlyUpdateHook("second", events, priority=20) + + with TrainingUpdateOrchestrator(first, second): + events.append("inside") + + assert events == ["inside", "close:second", "close:first"] + + def test_strategy_context_enters_autowrapped_update_hooks(self) -> None: + events: list[str] = [] + update_hook = _LifecycleUpdateHook("update", events) + strategy = _make_strategy(hooks=[update_hook]) + + with strategy: + events.append("inside") + + assert events == ["enter:update", "inside", "exit:update"] + + +class TestRunsOnStage: + def test_base_hook_claims_only_update_stages(self) -> None: + hook = TrainingUpdateHook() + for stage in TrainingStage: + assert hook._runs_on_stage(stage) is (stage in _UPDATE_STAGES) + + def test_orchestrator_claims_only_update_stages(self) -> None: + orch = TrainingUpdateOrchestrator(_RecordingUpdateHook()) + for stage in TrainingStage: + assert orch._runs_on_stage(stage) is (stage in _UPDATE_STAGES) + + def test_hook_claims_stage_helper_matches_runs_on_stage(self) -> None: + hook = TrainingUpdateHook() + for stage in TrainingStage: + assert _hook_claims_stage(hook, stage) is (stage in _UPDATE_STAGES) + + def test_hook_claims_stage_uses_stage_attr_when_no_runs_on_stage(self) -> None: + hook = _StageOnlyHook(TrainingStage.BEFORE_BACKWARD) + assert _hook_claims_stage(hook, TrainingStage.BEFORE_BACKWARD) is True + for stage in TrainingStage: + if stage is TrainingStage.BEFORE_BACKWARD: + continue + assert _hook_claims_stage(hook, stage) is False + + def test_hybrid_hook_runs_on_stage_takes_precedence(self) -> None: + hook = _HybridStageRunsOnHook(stage=TrainingStage.DO_BACKWARD) + # Even though stage == DO_BACKWARD, _runs_on_stage returns False. + assert _hook_claims_stage(hook, TrainingStage.DO_BACKWARD) is False + + +class TestAutoWrapConstructor: + def test_single_bare_hook_wrapped_in_orchestrator(self) -> None: + bare = _RecordingUpdateHook(priority=10) + strategy = _make_strategy(hooks=[bare]) + assert len(strategy.hooks) == 1 + wrapper = _single_orchestrator(strategy) + assert wrapper._hooks == [bare] + assert strategy._has_update_orchestrator is True + + def test_multiple_bare_hooks_folded_into_one_orchestrator(self) -> None: + a = _RecordingUpdateHook(priority=20) + b = _RecordingUpdateHook(priority=10) + strategy = _make_strategy(hooks=[a, b]) + assert len(strategy.hooks) == 1 + wrapper = _single_orchestrator(strategy) + assert wrapper._hooks == [b, a] + + def test_explicit_orchestrator_kept_as_is(self) -> None: + bare = _RecordingUpdateHook(priority=10) + explicit = TrainingUpdateOrchestrator(bare) + strategy = _make_strategy(hooks=[explicit]) + assert strategy.hooks[0] is explicit + assert strategy._has_update_orchestrator is True + + def test_explicit_orchestrator_uses_base_frequency_validation(self) -> None: + explicit = TrainingUpdateOrchestrator(_RecordingUpdateHook(priority=10)) + explicit.frequency = 0 + with pytest.raises(pydantic.ValidationError, match="Hook frequency"): + _make_strategy(hooks=[explicit]) + + def test_explicit_orchestrator_plus_bare_folded(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + explicit = TrainingUpdateOrchestrator(a) + strategy = _make_strategy(hooks=[explicit, b]) + wrapper = _single_orchestrator(strategy) + assert len(wrapper._hooks) == 2 + assert any(h is a for h in wrapper._hooks) + assert any(h is b for h in wrapper._hooks) + + def test_non_update_hooks_preserved_with_orchestrator_inserted(self) -> None: + non_a = _StageOnlyHook(TrainingStage.AFTER_BATCH) + non_b = _StageOnlyHook(TrainingStage.BEFORE_FORWARD) + update_a = _RecordingUpdateHook(priority=10) + update_b = _RecordingUpdateHook(priority=20) + strategy = _make_strategy(hooks=[non_a, update_a, non_b, update_b]) + assert strategy.hooks[0] is non_a + assert isinstance(strategy.hooks[1], TrainingUpdateOrchestrator) + assert strategy.hooks[2] is non_b + non_update = [ + h for h in strategy.hooks if not isinstance(h, TrainingUpdateOrchestrator) + ] + assert non_update == [non_a, non_b] + assert strategy.hooks[0] is non_a + assert isinstance(strategy.hooks[1], TrainingUpdateOrchestrator) + assert strategy.hooks[2] is non_b + wrapper = _single_orchestrator(strategy) + assert len(wrapper._hooks) == 2 + assert any(h is update_a for h in wrapper._hooks) + assert any(h is update_b for h in wrapper._hooks) + + def test_orchestrator_inserted_at_first_bare_update_hook(self) -> None: + update = _RecordingUpdateHook(priority=10) + non_update = _StageOnlyHook(TrainingStage.AFTER_BATCH) + strategy = _make_strategy(hooks=[update, non_update]) + assert isinstance(strategy.hooks[0], TrainingUpdateOrchestrator) + assert strategy.hooks[1] is non_update + + def test_no_orchestrator_when_no_update_hooks(self) -> None: + # Auto-wrap is keyed off ``TrainingUpdateHook`` type, not stage + # membership; a plain ``Hook``-protocol object on BEFORE_BATCH does + # not trigger orchestrator creation. + plain_stage_hook = _StageOnlyHook(TrainingStage.BEFORE_BATCH) + strategy = _make_strategy(hooks=[plain_stage_hook]) + assert strategy._has_update_orchestrator is False + + +class TestAutoWrapRegisterHook: + def test_register_first_bare_hook_creates_orchestrator(self) -> None: + strategy = _make_strategy(hooks=[]) + bare = _RecordingUpdateHook(priority=10) + strategy.register_hook(bare) + wrapper = _single_orchestrator(strategy) + assert wrapper._hooks == [bare] + assert strategy._has_update_orchestrator is True + + def test_register_second_bare_hook_merges(self) -> None: + a = _RecordingUpdateHook(priority=10) + strategy = _make_strategy(hooks=[a]) + b = _RecordingUpdateHook(priority=20) + strategy.register_hook(b) + wrapper = _single_orchestrator(strategy) + assert len(wrapper._hooks) == 2 + assert any(h is a for h in wrapper._hooks) + assert any(h is b for h in wrapper._hooks) + + def test_register_non_update_hook_skips_autowrap(self) -> None: + strategy = _make_strategy(hooks=[]) + # Auto-wrap is keyed off ``TrainingUpdateHook`` type, not stage. + plain_stage_hook = _StageOnlyHook(TrainingStage.BEFORE_BATCH) + strategy.register_hook(plain_stage_hook) + assert strategy._has_update_orchestrator is False + assert plain_stage_hook in strategy.hooks + + def test_register_update_hook_with_stage_raises_value_error(self) -> None: + strategy = _make_strategy(hooks=[]) + bare = _RecordingUpdateHook(priority=10) + with pytest.raises(ValueError, match="stage=.*TrainingUpdateHook"): + strategy.register_hook(bare, stage=TrainingStage.BEFORE_BATCH) + + def test_register_second_orchestrator_raises_value_error(self) -> None: + a = _RecordingUpdateHook(priority=10) + strategy = _make_strategy(hooks=[TrainingUpdateOrchestrator(a)]) + b = _RecordingUpdateHook(priority=20) + orch_b = TrainingUpdateOrchestrator(b) + with pytest.raises(ValueError, match="Only one TrainingUpdateOrchestrator"): + strategy.register_hook(orch_b) + + def test_register_orchestrator_uses_base_frequency_validation(self) -> None: + strategy = _make_strategy(hooks=[]) + explicit = TrainingUpdateOrchestrator(_RecordingUpdateHook(priority=10)) + explicit.frequency = 0 + with pytest.raises(ValueError, match="Hook frequency"): + strategy.register_hook(explicit) + + def test_claim_flags_refreshed_after_registration(self) -> None: + strategy = _make_strategy(hooks=[]) + assert strategy._has_do_backward_claim is False + assert strategy._has_do_optimizer_step_claim is False + bare = _RecordingUpdateHook(priority=10) + strategy.register_hook(bare) + # Orchestrator claims both DO stages. + assert strategy._has_do_backward_claim is True + assert strategy._has_do_optimizer_step_claim is True + + +class TestTwoOrchestratorRejection: + def test_constructor_two_orchestrators_raises_validation_error(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + orch_a = TrainingUpdateOrchestrator(a) + orch_b = TrainingUpdateOrchestrator(b) + with pytest.raises(pydantic.ValidationError) as exc_info: + _make_strategy(hooks=[orch_a, orch_b]) + assert "Only one TrainingUpdateOrchestrator" in str(exc_info.value) + + def test_register_hook_two_orchestrators_raises_value_error(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + strategy = _make_strategy(hooks=[TrainingUpdateOrchestrator(a)]) + orch_b = TrainingUpdateOrchestrator(b) + with pytest.raises(ValueError, match="Only one TrainingUpdateOrchestrator"): + strategy.register_hook(orch_b) + # Ensure it is NOT a ValidationError subclass at runtime. + with pytest.raises(ValueError) as exc_info: + strategy.register_hook(orch_b) + assert not isinstance(exc_info.value, pydantic.ValidationError) + + def test_message_references_compose_with_plus(self) -> None: + assert "+" in _MULTIPLE_ORCHESTRATOR_MSG + + +class TestDoStageConflict: + @pytest.mark.parametrize("do_stage", _DO_STAGES, ids=lambda s: s.name) + def test_orchestrator_plus_non_update_hook_with_do_stage_constructor( + self, do_stage: TrainingStage + ) -> None: + bare = _RecordingUpdateHook(priority=10) + orch = TrainingUpdateOrchestrator(bare) + non_update = _StageOnlyHook(do_stage) + with pytest.raises(pydantic.ValidationError) as exc_info: + _make_strategy(hooks=[orch, non_update]) + msg = str(exc_info.value) + assert "At most one hook may claim" in msg + assert do_stage.name in msg + + @pytest.mark.parametrize("do_stage", _DO_STAGES, ids=lambda s: s.name) + def test_register_hook_do_stage_collision_raises_value_error( + self, do_stage: TrainingStage + ) -> None: + bare = _RecordingUpdateHook(priority=10) + strategy = _make_strategy(hooks=[bare]) + non_update = _StageOnlyHook(TrainingStage.AFTER_BATCH) + with pytest.raises(ValueError, match=do_stage.name): + strategy.register_hook(non_update, stage=do_stage) + + def test_two_non_update_hooks_with_same_do_stage_rejected(self) -> None: + h1 = _StageOnlyHook(TrainingStage.DO_BACKWARD) + h2 = _StageOnlyHook(TrainingStage.DO_BACKWARD) + with pytest.raises(pydantic.ValidationError) as exc_info: + _make_strategy(hooks=[h1, h2]) + assert "DO_BACKWARD" in str(exc_info.value) + + def test_fake_eq_hook_counted_only_once_via_identity_check(self) -> None: + """``_validate_single_do_claimants`` uses ``is`` for the candidate check. + + A hook whose ``__eq__`` returns ``True`` for any comparison should NOT + be spuriously double-counted when registered once with a DO stage. + This verifies the identity-vs-equality fix. + """ + strategy = _make_strategy(hooks=[]) + fake = _FakeEqHook(stage=TrainingStage.DO_BACKWARD) + # Should succeed: only one claimant of DO_BACKWARD even though + # ``fake == anything`` is True. + strategy.register_hook(fake) + # Use identity, not ``in`` (since fake.__eq__ would return True for + # any peer hook in strategy.hooks). + assert any(h is fake for h in strategy.hooks) + assert strategy._has_do_backward_claim is True + + +class TestTrainContextGradScaler: + def test_default_grad_scaler_is_none(self) -> None: + ctx = _make_ctx() + assert ctx.grad_scaler is None + + def test_grad_scaler_accepts_mocked_instance(self) -> None: + scaler = Mock(spec=torch.amp.GradScaler) + ctx = TrainContext( + batch=_make_batch(), + step_count=0, + grad_scaler=scaler, + ) + assert ctx.grad_scaler is scaler + + def test_grad_scaler_visible_to_later_hook_in_dispatch(self) -> None: + scaler = Mock(spec=torch.amp.GradScaler) + setter = _GradScalerSetHook(scaler) + reader = _GradScalerReadHook() + orch = TrainingUpdateOrchestrator(setter, reader) + param = torch.nn.Parameter(torch.tensor([1.0])) + loss = (param * 1.0).sum() + ctx = _make_ctx(loss=loss) + orch(ctx, TrainingStage.DO_BACKWARD) + assert reader.observed is scaler + + +class TestTrainContextLifecycle: + def test_no_hook_run_does_not_build_train_context(self) -> None: + strategy = _make_strategy(hooks=[]) + with patch.object( + strategy, + "_build_context", + side_effect=AssertionError("_build_context should not run without hooks"), + ) as build_context: + strategy.run([_make_batch()]) + build_context.assert_not_called() + assert strategy._ctx is None + + def test_context_cache_cleared_after_hook_failure_and_retry(self) -> None: + capture = _ContextCaptureHook(TrainingStage.BEFORE_BATCH) + raiser = _RaisingStageHook(TrainingStage.BEFORE_FORWARD) + strategy = _make_strategy(hooks=[capture, raiser]) + + with pytest.raises(RuntimeError, match="forced hook failure"): + strategy.run([_make_batch()]) + + assert strategy._ctx is None + assert len(capture.contexts) == 1 + failed_ctx = capture.contexts[0] + + raiser.enabled = False + strategy.run([_make_batch(seed=10)]) + + assert strategy._ctx is None + assert len(capture.contexts) == 2 + assert capture.contexts[1] is not failed_ctx + assert capture.contexts[1].optimizers is strategy._optimizers + + +class TestPlainDoStageHooks: + def test_plain_do_backward_hook_owns_backward(self) -> None: + hook = _DoBackwardOwnerHook() + strategy = _make_strategy(hooks=[hook]) + strategy.run([_make_batch()]) + assert hook.calls == 1 + assert strategy.step_count == 1 + + def test_plain_do_optimizer_hook_suppresses_default_step_helpers(self) -> None: + hook = _DoOptimizerStepOwnerHook() + strategy = _make_strategy(hooks=[hook]) + with _patched_update_helpers() as m: + strategy.run([_make_batch()]) + assert hook.calls == 1 + assert len(hook.contexts) == 1 + m.strategy_step.assert_not_called() + m.strategy_sched.assert_not_called() + + +# --------------------------------------------------------------------------- +# Integration tests: orchestrator vs. strategy default training-loop paths. +# +# These tests run a single ``strategy.run(...)`` per scenario with all six +# helper functions patched (``_patched_update_helpers``) so we can assert +# which path called which helper. We deliberately keep one canonical +# strategy.run() per (path, gating) combination rather than table-driving +# every assertion through a fixture; this preserves stack-trace clarity if +# the strategy's dispatch contract regresses. +# --------------------------------------------------------------------------- + + +class TestZeroGradSuppression: + def test_veto_suppresses_both_zero_gradient_paths(self) -> None: + hook = _VetoHook(veto_stage=TrainingStage.BEFORE_BATCH, priority=10) + with _run_strategy_with_patched_helpers(hooks=[hook]) as m: + pass + m.strategy_zero.assert_not_called() + m.orch_zero.assert_not_called() + + def test_orchestrator_zero_grad_called_on_proceed(self) -> None: + hook = _RecordingUpdateHook(priority=10) + with _run_strategy_with_patched_helpers(hooks=[hook]) as m: + pass + m.strategy_zero.assert_not_called() + m.orch_zero.assert_called_once() + + def test_default_zero_grad_called_when_no_orchestrator(self) -> None: + with _run_strategy_with_patched_helpers(hooks=[]) as m: + pass + m.strategy_zero.assert_called_once() + m.orch_zero.assert_not_called() + + +class TestOptimizerStepSuppression: + def test_veto_suppresses_step_helpers(self) -> None: + hook = _VetoHook(veto_stage=TrainingStage.DO_OPTIMIZER_STEP, priority=10) + with _run_strategy_with_patched_helpers(hooks=[hook]) as m: + pass + m.orch_step.assert_not_called() + m.orch_sched.assert_not_called() + m.strategy_step.assert_not_called() + m.strategy_sched.assert_not_called() + + def test_orchestrator_step_helpers_called_on_proceed(self) -> None: + hook = _RecordingUpdateHook(priority=10) + with _run_strategy_with_patched_helpers(hooks=[hook]) as m: + pass + m.orch_step.assert_called_once() + m.orch_sched.assert_called_once() + m.strategy_step.assert_not_called() + m.strategy_sched.assert_not_called() + + def test_default_step_helpers_called_without_orchestrator(self) -> None: + with _run_strategy_with_patched_helpers(hooks=[]) as m: + pass + m.strategy_step.assert_called_once() + m.strategy_sched.assert_called_once() + + def test_unpatched_orchestrator_steps_optimizer_and_scheduler(self) -> None: + hook = _RecordingUpdateHook(priority=10) + strategy = _make_strategy( + hooks=[hook], + optimizer_configs=OptimizerConfig( + optimizer_cls=torch.optim.SGD, + optimizer_kwargs={"lr": 0.1}, + scheduler_cls=torch.optim.lr_scheduler.StepLR, + scheduler_kwargs={"step_size": 1, "gamma": 0.5}, + ), + ) + before = [p.detach().clone() for p in strategy.models["main"].parameters()] + + strategy.run([_make_batch()]) + + after = list(strategy.models["main"].parameters()) + assert any(not torch.equal(old, new) for old, new in zip(before, after)) + assert strategy._optimizers[0].param_groups[0]["lr"] == pytest.approx(0.05) + assert hook.calls[-1] == (TrainingStage.AFTER_OPTIMIZER_STEP, False) + + +class TestAfterOptimizerStepAlwaysRuns: + def test_after_optimizer_step_runs_when_step_vetoed(self) -> None: + hook = _VetoHook(veto_stage=TrainingStage.DO_OPTIMIZER_STEP, priority=10) + strategy = _make_strategy(hooks=[hook]) + strategy.train_batch(_make_batch()) + seen_stages = {stage for stage, _will_skip in hook.calls} + assert TrainingStage.AFTER_OPTIMIZER_STEP in seen_stages + assert (TrainingStage.AFTER_OPTIMIZER_STEP, True) in hook.calls + # Sanity: DO_OPTIMIZER_STEP was indeed dispatched (so the veto path ran). + assert TrainingStage.DO_OPTIMIZER_STEP in seen_stages + + +class TestHookProtocolCompliance: + def test_orchestrator_satisfies_hook_protocol(self) -> None: + orch = TrainingUpdateOrchestrator(_RecordingUpdateHook()) + assert isinstance(orch, Hook) + + def test_bare_training_update_hook_does_not_satisfy_protocol(self) -> None: + """Bare ``TrainingUpdateHook`` lacks ``frequency``/``stage`` so it fails the check. + + The base class intentionally omits ``frequency``/``stage`` because it is + not directly registry-compatible; the orchestrator is the registry-facing + wrapper. Auto-wrapping in ``TrainingStrategy`` ensures users do not have + to confront this distinction. + """ + bare = TrainingUpdateHook() + assert not isinstance(bare, Hook) + + +class TestFoldHelper: + def test_no_update_hooks_returns_input_list(self) -> None: + non_a = _StageOnlyHook(TrainingStage.AFTER_BATCH) + non_b = _StageOnlyHook(TrainingStage.BEFORE_FORWARD) + result = _fold_training_update_hooks([non_a, non_b]) + assert result == [non_a, non_b] + assert all(not isinstance(h, TrainingUpdateOrchestrator) for h in result) + + def test_two_orchestrators_raises_value_error(self) -> None: + a = _RecordingUpdateHook(priority=10) + b = _RecordingUpdateHook(priority=20) + orch_a = TrainingUpdateOrchestrator(a) + orch_b = TrainingUpdateOrchestrator(b) + with pytest.raises(ValueError, match="Only one TrainingUpdateOrchestrator"): + _fold_training_update_hooks([orch_a, orch_b]) + + def test_single_bare_hook_wrapped_in_orchestrator(self) -> None: + bare = _RecordingUpdateHook(priority=10) + result = _fold_training_update_hooks([bare]) + assert len(result) == 1 + assert isinstance(result[0], TrainingUpdateOrchestrator) + assert result[0]._hooks == [bare] diff --git a/test/training/test_validation_config.py b/test/training/test_validation_config.py new file mode 100644 index 00000000..88c8d197 --- /dev/null +++ b/test/training/test_validation_config.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for :class:`nvalchemi.training._validation.ValidationConfig`.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from nvalchemi.training import EnergyMSELoss, ForceMSELoss +from nvalchemi.training._validation import ValidationConfig +from nvalchemi.training.losses.composition import ComposedLossFunction + + +class TestValidationConfigConstruction: + """Validate construction defaults, normalization, and rejection.""" + + def test_defaults(self) -> None: + """All optional fields receive sensible defaults.""" + cfg = ValidationConfig(validation_data=[]) + assert cfg.validation_fn is None + assert cfg.loss_fn is None + assert cfg.every_n_epochs is None + assert cfg.every_n_steps is None + assert cfg.grad_mode == "auto" + assert cfg.set_eval is True + assert cfg.use_ema == "auto" + assert cfg.use_mixed_precision == "auto" + assert cfg.batch_callback is None + assert cfg.name == "validation" + + def test_schedule_mutual_exclusion_raises(self) -> None: + """Setting both every_n_epochs and every_n_steps raises.""" + with pytest.raises(ValidationError, match="Only one of"): + ValidationConfig( + validation_data=[], + every_n_epochs=2, + every_n_steps=5, + ) + + def test_every_n_epochs_only(self) -> None: + """Setting only every_n_epochs is accepted.""" + cfg = ValidationConfig(validation_data=[], every_n_epochs=3) + assert cfg.every_n_epochs == 3 + assert cfg.every_n_steps is None + + def test_every_n_steps_only(self) -> None: + """Setting only every_n_steps is accepted.""" + cfg = ValidationConfig(validation_data=[], every_n_steps=10) + assert cfg.every_n_steps == 10 + assert cfg.every_n_epochs is None + + def test_loss_fn_normalization_leaf(self) -> None: + """A leaf loss is normalized to a ComposedLossFunction.""" + cfg = ValidationConfig(validation_data=[], loss_fn=EnergyMSELoss()) + assert isinstance(cfg.loss_fn, ComposedLossFunction) + assert len(cfg.loss_fn.components) == 1 + assert isinstance(cfg.loss_fn.components[0], EnergyMSELoss) + + def test_loss_fn_normalization_composed(self) -> None: + """A ComposedLossFunction passes through unchanged.""" + composed = EnergyMSELoss() + ForceMSELoss() + cfg = ValidationConfig(validation_data=[], loss_fn=composed) + assert isinstance(cfg.loss_fn, ComposedLossFunction) + assert len(cfg.loss_fn.components) == 2 + + def test_loss_fn_none_stays_none(self) -> None: + """loss_fn=None (use strategy default) stays None.""" + cfg = ValidationConfig(validation_data=[]) + assert cfg.loss_fn is None + + def test_extra_fields_rejected(self) -> None: + """Unknown fields are rejected by extra='forbid'.""" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + ValidationConfig(validation_data=[], bogus_field=True) + + def test_every_n_epochs_minimum_one(self) -> None: + """every_n_epochs must be >= 1.""" + with pytest.raises(ValidationError, match="greater than or equal to 1"): + ValidationConfig(validation_data=[], every_n_epochs=0) + + def test_every_n_steps_minimum_one(self) -> None: + """every_n_steps must be >= 1.""" + with pytest.raises(ValidationError, match="greater than or equal to 1"): + ValidationConfig(validation_data=[], every_n_steps=0) + + def test_name_minimum_length(self) -> None: + """name must be non-empty.""" + with pytest.raises(ValidationError): + ValidationConfig(validation_data=[], name="") + + +class TestValidationDataReiterability: + """Ensure validation_data rejects one-shot iterators and preserves re-iterables.""" + + def test_validation_data_list_is_reiterable(self) -> None: + """A list of Batch-like objects survives two full iteration passes.""" + sentinel_a, sentinel_b = object(), object() + cfg = ValidationConfig(validation_data=[sentinel_a, sentinel_b]) + first_pass = list(cfg.validation_data) + second_pass = list(cfg.validation_data) + assert first_pass == [sentinel_a, sentinel_b] + assert second_pass == [sentinel_a, sentinel_b] + + def test_validation_data_generator_rejected(self) -> None: + """A generator expression is rejected as one-shot.""" + with pytest.raises(ValidationError, match="re-iterable"): + ValidationConfig(validation_data=(x for x in [1, 2])) + + def test_validation_data_bare_iterator_rejected(self) -> None: + """A bare list_iterator is rejected as one-shot.""" + with pytest.raises(ValidationError, match="re-iterable"): + ValidationConfig(validation_data=iter([1, 2])) + + def test_validation_data_non_iterable_rejected(self) -> None: + """A non-iterable value is rejected with a clear error.""" + with pytest.raises(ValidationError, match="iterable"): + ValidationConfig(validation_data=42) diff --git a/test/training/test_validation_loop.py b/test/training/test_validation_loop.py new file mode 100644 index 00000000..65315929 --- /dev/null +++ b/test/training/test_validation_loop.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the public standalone :class:`ValidationLoop` API.""" + +from __future__ import annotations + +import math + +import pytest +import torch + +from nvalchemi.training import ( + EnergyMSELoss, + ForceMSELoss, + ValidationConfig, + ValidationLoop, +) +from test.training.conftest import _build_dataset, _build_demo_model +from test.training.test_strategy import demo_training_fn + +device = torch.device("cpu") + + +def _named_validation_fn(models, batch): + """Standalone validation function for the named-model (dict) path.""" + return demo_training_fn(models["main"], batch) + + +def _composed_loss(): + """Return a :class:`ComposedLossFunction` (energy MSE + force MSE).""" + return EnergyMSELoss() + ForceMSELoss(normalize_by_atom_count=True) + + +class TestValidationLoopConstruction: + """Constructor validation for the standalone :class:`ValidationLoop`.""" + + def test_rejects_both_model_and_models(self) -> None: + """Passing both ``model`` and ``models`` raises ``ValueError``.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + with pytest.raises(ValueError, match="Exactly one of"): + ValidationLoop( + validation_data=data, + config=config, + device=device, + model=_build_demo_model(), + models={"main": _build_demo_model()}, + validation_fn=demo_training_fn, + grad_enabled=False, + ) + + def test_rejects_neither_model_nor_models(self) -> None: + """Passing neither ``model`` nor ``models`` raises ``ValueError``.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + with pytest.raises(ValueError, match="Exactly one of"): + ValidationLoop( + validation_data=data, + config=config, + device=device, + validation_fn=demo_training_fn, + grad_enabled=False, + ) + + def test_requires_validation_fn(self) -> None: + """Omitting ``validation_fn`` raises ``ValueError``.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + with pytest.raises(ValueError, match="validation_fn is required"): + ValidationLoop( + validation_data=data, + config=config, + device=device, + model=_build_demo_model(), + grad_enabled=False, + ) + + def test_requires_loss_fn_when_config_has_none(self) -> None: + """No ``loss_fn`` arg and no ``config.loss_fn`` raises ``ValueError``.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data) + with pytest.raises(ValueError, match="loss_fn must be provided"): + ValidationLoop( + validation_data=data, + config=config, + device=device, + model=_build_demo_model(), + validation_fn=demo_training_fn, + grad_enabled=False, + ) + + def test_loss_fn_from_config_used(self) -> None: + """A ``config.loss_fn`` resolves the loss when the arg is omitted.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + loop = ValidationLoop( + validation_data=data, + config=config, + device=device, + model=_build_demo_model(), + validation_fn=demo_training_fn, + grad_enabled=False, + ) + assert isinstance(loop, ValidationLoop) + + +class TestValidationLoopExecuteSingleModel: + """Execution of the single-model standalone path.""" + + def test_execute_returns_summary_with_expected_keys(self) -> None: + """``execute()`` returns a summary with the expected keys and labels.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + loop = ValidationLoop( + validation_data=data, + config=config, + device=device, + model=_build_demo_model(), + validation_fn=demo_training_fn, + grad_enabled=True, + ) + with loop as active_loop: + summary = active_loop.execute() + assert summary is not None + expected = {"name", "total_loss", "model_source", "precision", "num_batches"} + assert expected <= set(summary) + assert summary["model_source"] == "live" + assert summary["precision"] == "float32" + assert "total_loss" in summary + assert math.isfinite(float(summary["total_loss"])) + + def test_execute_outside_context_raises(self) -> None: + """Calling ``execute()`` outside the ``with`` block raises ``RuntimeError``.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + loop = ValidationLoop( + validation_data=data, + config=config, + device=device, + model=_build_demo_model(), + validation_fn=demo_training_fn, + grad_enabled=False, + ) + with pytest.raises(RuntimeError, match="inside a 'with' block"): + loop.execute() + + def test_custom_autocast_labels_precision_mixed(self) -> None: + """A custom ``autocast`` callable labels the precision as ``mixed``.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + # A disabled autocast keeps dtypes at float32 (so the loss dtype check + # passes) while still being a non-``None`` callable, which is what the + # loop uses to label the precision as ``mixed``. + loop = ValidationLoop( + validation_data=data, + config=config, + device=device, + model=_build_demo_model(), + validation_fn=demo_training_fn, + autocast=lambda: torch.autocast( + device_type="cpu", dtype=torch.bfloat16, enabled=False + ), + grad_enabled=True, + ) + with loop as active_loop: + summary = active_loop.execute() + assert summary is not None + assert summary["precision"] == "mixed" + + +class TestValidationLoopNamedModels: + """Execution of the named-model (dict) standalone path.""" + + def test_named_models_execute(self) -> None: + """The named-model path runs and reports a ``live`` source over all batches.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + loop = ValidationLoop( + validation_data=data, + config=config, + device=device, + models={"main": _build_demo_model()}, + validation_fn=_named_validation_fn, + grad_enabled=True, + ) + with loop as active_loop: + summary = active_loop.execute() + assert summary is not None + assert summary["model_source"] == "live" + assert summary["num_batches"] == 2 + + +class TestValidationLoopStateRestoration: + """Restoration of training modes and gradients around the loop.""" + + def test_training_modes_restored(self) -> None: + """Training modes are restored after a successful run.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + model = _build_demo_model() + model.train() + loop = ValidationLoop( + validation_data=data, + config=config, + device=device, + model=model, + validation_fn=demo_training_fn, + grad_enabled=True, + ) + with loop as active_loop: + active_loop.execute() + assert model.training is True + + def test_training_modes_restored_on_exception(self) -> None: + """Training modes are restored even when ``execute()`` raises.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + model = _build_demo_model() + model.train() + + def _raising_validation_fn(model_arg, batch): + raise RuntimeError("boom") + + loop = ValidationLoop( + validation_data=data, + config=config, + device=device, + model=model, + validation_fn=_raising_validation_fn, + grad_enabled=False, + ) + with pytest.raises(RuntimeError): + with loop as active_loop: + active_loop.execute() + assert model.training is True + + def test_grads_restored_after_grad_enabled_run(self) -> None: + """A pre-existing gradient is restored after a grad-enabled run.""" + data = _build_dataset(n_batches=2) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + model = _build_demo_model() + first_param = next(iter(model.parameters())) + saved_grad = torch.ones_like(first_param) + first_param.grad = saved_grad.clone() + loop = ValidationLoop( + validation_data=data, + config=config, + device=device, + model=model, + validation_fn=demo_training_fn, + grad_enabled=True, + ) + with loop as active_loop: + active_loop.execute() + assert first_param.grad is not None + assert torch.equal(first_param.grad, saved_grad) diff --git a/test/training/test_validation_loop_distributed.py b/test/training/test_validation_loop_distributed.py new file mode 100644 index 00000000..28e091e3 --- /dev/null +++ b/test/training/test_validation_loop_distributed.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Distributed (multi-rank, gloo/CPU) tests for the standalone ValidationLoop.""" + +from __future__ import annotations + +import math +import os +import socket +from typing import Any + +import pytest +import torch +from torch import distributed as dist + +from nvalchemi.training import ( + EnergyMSELoss, + ForceMSELoss, + ValidationConfig, + ValidationLoop, +) +from test.training.conftest import _build_dataset, _build_demo_model +from test.training.test_strategy import demo_training_fn + + +def _free_port() -> int: + """Return an available localhost TCP port for process-group setup.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _composed_loss() -> Any: + """Return a composed energy + force MSE loss.""" + return EnergyMSELoss() + ForceMSELoss(normalize_by_atom_count=True) + + +def _run_validation_worker( + rank: int, + world_size: int, + port: int, + result_queue: Any, +) -> None: + """Run a standalone ValidationLoop on one gloo/CPU rank and report its summary. + + With ``distributed_manager=None`` the ValidationLoop falls back to the + raw ``torch.distributed`` primitives, so the all-reduce, rank-0 publish, + and ``__exit__`` barrier all use the initialized process group. + + Parameters + ---------- + rank : int + Global rank of this worker. + world_size : int + Total number of ranks. + port : int + TCP port for the gloo rendezvous. + result_queue : Any + Multiprocessing queue used to send ``(rank, summary)`` to the parent. + """ + os.environ.update( + { + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(port), + "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "LOCAL_RANK": str(rank), + } + ) + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + # Each rank validates an identical, re-iterable shard so the + # all-reduced mean is well-defined and finite. + data = _build_dataset(n_batches=2, base_seed=100 + rank) + config = ValidationConfig(validation_data=data, loss_fn=_composed_loss()) + loop = ValidationLoop( + validation_data=data, + config=config, + device=torch.device("cpu"), + model=_build_demo_model(), + validation_fn=demo_training_fn, + grad_enabled=True, + ) + with loop as active_loop: + summary = active_loop.execute() + if summary is None: + result_queue.put((rank, None)) + else: + result_queue.put((rank, {"total_loss": float(summary["total_loss"])})) + finally: + dist.barrier() + dist.destroy_process_group() + + +@pytest.mark.slow +@pytest.mark.skipif(not dist.is_gloo_available(), reason="gloo backend required") +def test_distributed_validation_rank0_publishes_others_none() -> None: + """Rank 0 publishes a finite reduced summary; non-publishing ranks get None.""" + world_size = 2 + ctx = torch.multiprocessing.get_context("spawn") + result_queue = ctx.Queue() + port = _free_port() + procs = [ + ctx.Process( + target=_run_validation_worker, + args=(rank, world_size, port, result_queue), + ) + for rank in range(world_size) + ] + for proc in procs: + proc.start() + results: dict[int, Any] = {} + for _ in range(world_size): + rank, summary = result_queue.get(timeout=60) + results[rank] = summary + for proc in procs: + proc.join(timeout=60) + for proc in procs: + # A clean exit on every rank proves the __exit__ barrier did not deadlock. + assert proc.exitcode == 0 + + assert set(results) == set(range(world_size)) + # Rank 0 is the publishing rank: it returns the reduced summary. + assert results[0] is not None + assert math.isfinite(results[0]["total_loss"]) + # Every non-zero rank is non-publishing and returns None. + for rank in range(1, world_size): + assert results[rank] is None diff --git a/uv.lock b/uv.lock index b76651bf..59dad96e 100644 --- a/uv.lock +++ b/uv.lock @@ -1510,22 +1510,22 @@ wheels = [ [[package]] name = "cupy-cuda12x" -version = "14.0.1" +version = "13.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cuda-pathfinder" }, + { name = "fastrlock" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" } }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/11/6d089629f44591864bc8a11fa64c9d4fcd1afb4a7217954c806fb47c4fe5/cupy_cuda12x-14.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:31e6a33579a06fde3ff238b8b6b72446384d17554b2a3b14f818c9ee44b0c2e6", size = 146237981, upload-time = "2026-02-20T10:22:29.065Z" }, - { url = "https://files.pythonhosted.org/packages/37/f0/0f1d79c0c7fccbc2ed0c0ff3be1b0562be60b764c729ca8ded1bd6d953aa/cupy_cuda12x-14.0.1-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bfbde2e9f7946021b49414f9c800991163f2a56a1318f3d7d69cbb06001a1585", size = 135080693, upload-time = "2026-02-20T10:22:35.843Z" }, - { url = "https://files.pythonhosted.org/packages/6d/1b/b3a26fd36e066e9bc25d875488468c9a40e8c7a90acadfacc524a17da457/cupy_cuda12x-14.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:c289e78876c6840b3c512868b8c5d43ac76bc3c581eab1a75c4f2f4a88d5b430", size = 96361678, upload-time = "2026-02-20T10:22:41.718Z" }, - { url = "https://files.pythonhosted.org/packages/38/ca/b93ef9fca1471a65f136a73e10819634c0b83427362fc08fc9f29f935bf0/cupy_cuda12x-14.0.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:f244bc14fad6f1ef0c74abd98afa4b82d2534aecdba911197810ec0047f0d1f3", size = 145578614, upload-time = "2026-02-20T10:22:49.108Z" }, - { url = "https://files.pythonhosted.org/packages/5a/a6/944406223a190815d9df156a1d66f3b0352bd8827dc4a8c752196d616dbc/cupy_cuda12x-14.0.1-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:9f0c81c3509f77be3ae8444759d5b314201b2dfcbbf2ae0d0b5fb7a61f20893c", size = 134613763, upload-time = "2026-02-20T10:22:56.792Z" }, - { url = "https://files.pythonhosted.org/packages/11/fd/62e6e3f3c0c9f785b2dbdc2bff01bc375f5c6669d52e5e151f7aeb577801/cupy_cuda12x-14.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:63dc8a3a88d2ffd0386796b915d27acc7f2332c2291efd1ff4f0021b96f02051", size = 96267167, upload-time = "2026-02-20T10:23:02.263Z" }, - { url = "https://files.pythonhosted.org/packages/99/67/f967c5aff77bd6ae6765faf20580db80bb8a7e2574e999166de1d4e50146/cupy_cuda12x-14.0.1-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:9d9b1bdcf9fa777593017867e8733192c071b94639a1b3e8b2ee99eb3f3ea760", size = 145128055, upload-time = "2026-02-20T10:23:08.765Z" }, - { url = "https://files.pythonhosted.org/packages/80/53/037c931731151c504cfc00069eb295c903927c92145115623f13bd2ea076/cupy_cuda12x-14.0.1-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:21fcb4e917e43237edcc5e3a1a1241e2a2946ba9e577ce36fd580bd9856f91e8", size = 134227269, upload-time = "2026-02-20T10:23:16.147Z" }, - { url = "https://files.pythonhosted.org/packages/a3/70/ce8344426effda22152bf30cfb8f9b6477645d0f41df784674369af8f422/cupy_cuda12x-14.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:b7399e7fe4e2be3b5c3974fc892a661e10082836a4c78d0152b39cb483608a89", size = 96250134, upload-time = "2026-02-20T10:23:22.631Z" }, + { url = "https://files.pythonhosted.org/packages/54/64/71c6e08f76c06639e5112f69ee3bc1129be00054ad5f906d7fd3138af579/cupy_cuda12x-13.6.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c790d012fd4d86872b9c89af9f5f15d91c30b8e3a4aa4dd04c2610f45f06ac44", size = 128016458, upload-time = "2025-08-18T08:24:26.394Z" }, + { url = "https://files.pythonhosted.org/packages/fc/d9/5c5077243cd92368c3eccecdbf91d76db15db338169042ffd1647533c6b1/cupy_cuda12x-13.6.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:77ba6745a130d880c962e687e4e146ebbb9014f290b0a80dbc4e4634eb5c3b48", size = 113039337, upload-time = "2025-08-18T08:24:31.814Z" }, + { url = "https://files.pythonhosted.org/packages/88/f5/02bea5cdf108e2a66f98e7d107b4c9a6709e5dbfedf663340e5c11719d83/cupy_cuda12x-13.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:a20b7acdc583643a623c8d8e3efbe0db616fbcf5916e9c99eedf73859b6133af", size = 89885526, upload-time = "2025-08-18T08:24:37.258Z" }, + { url = "https://files.pythonhosted.org/packages/12/c5/7e7fc4816d0de0154e5d9053242c3a08a0ca8b43ee656a6f7b3b95055a7b/cupy_cuda12x-13.6.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:a6970ceefe40f9acbede41d7fe17416bd277b1bd2093adcde457b23b578c5a59", size = 127334633, upload-time = "2025-08-18T08:24:43.065Z" }, + { url = "https://files.pythonhosted.org/packages/e0/95/d7e1295141e7d530674a3cc567e13ed0eb6b81524cb122d797ed996b5bea/cupy_cuda12x-13.6.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:79b0cacb5e8b190ef409f9e03f06ac8de1b021b0c0dda47674d446f5557e0eb1", size = 112886268, upload-time = "2025-08-18T08:24:49.294Z" }, + { url = "https://files.pythonhosted.org/packages/ae/8c/14555b63fd78cfac7b88af0094cea0a3cb845d243661ec7da69f7b3ea0de/cupy_cuda12x-13.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:ca06fede7b8b83ca9ad80062544ef2e5bb8d4762d1c4fc3ac8349376de9c8a5e", size = 89785108, upload-time = "2025-08-18T08:24:54.527Z" }, + { url = "https://files.pythonhosted.org/packages/19/ec/f62cb991f11fb41291c4c15b6936d7b67ffa71ddb344ad6e8894e06ce58d/cupy_cuda12x-13.6.0-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:e5426ae3b1b9cf59927481e457a89e3f0b50a35b114a8034ec9110e7a833434c", size = 126904601, upload-time = "2025-08-18T08:24:59.951Z" }, + { url = "https://files.pythonhosted.org/packages/f8/b8/30127bcdac53a25f94ee201bf4802fcd8d012145567d77c54174d6d01c01/cupy_cuda12x-13.6.0-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:52d9e7f83d920da7d81ec2e791c2c2c747fdaa1d7b811971b34865ce6371e98a", size = 112654824, upload-time = "2025-08-18T08:25:05.944Z" }, + { url = "https://files.pythonhosted.org/packages/72/36/c9e24acb19f039f814faea880b3704a3661edaa6739456b73b27540663e3/cupy_cuda12x-13.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:297b4268f839de67ef7865c2202d3f5a0fb8d20bd43360bc51b6e60cb4406447", size = 89750580, upload-time = "2025-08-18T08:25:10.972Z" }, ] [[package]] @@ -1827,14 +1827,14 @@ wheels = [ [[package]] name = "gitpython" -version = "3.1.46" +version = "3.1.50" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "gitdb" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/df/b5/59d16470a1f0dfe8c793f9ef56fd3826093fc52b3bd96d6b9d6c26c7e27b/gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f", size = 215371, upload-time = "2026-01-01T15:37:32.073Z" } +sdist = { url = "https://files.pythonhosted.org/packages/33/f6/354ae6491228b5eb40e10d89c4d13c651fe1cf7556e35ebdded50cff57ce/gitpython-3.1.50.tar.gz", hash = "sha256:80da2d12504d52e1f998772dc5baf6e553f8d2fcfe1fcc226c9d9a2ee3372dcc", size = 219798, upload-time = "2026-05-06T04:01:26.571Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" }, + { url = "https://files.pythonhosted.org/packages/20/7a/1c6e3562dfd8950adbb11ffbc65d21e7c89d01a6e4f137fa981056de25c5/gitpython-3.1.50-py3-none-any.whl", hash = "sha256:d352abe2908d07355014abdd21ddf798c2a961469239afec4962e9da884858f9", size = 212507, upload-time = "2026-05-06T04:01:23.799Z" }, ] [[package]] @@ -1862,6 +1862,47 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, ] +[[package]] +name = "grpcio" +version = "1.81.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/15/f3/23f47b24f8d8c2028eba501db3acfbb2f592cbb5995eaa6e363a627b74d7/grpcio-1.81.0.tar.gz", hash = "sha256:a5acd7efd3b1fe9b4eb0bcaaa1507eed68a0ad0678b654c3f7b464df9ba9dca5", size = 13032272, upload-time = "2026-06-01T05:56:22.827Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/a8/9916ab10a0201f4c7afb6918125aa2f38a7626ee18ffbc066dd9cb04a74d/grpcio-1.81.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:794e6aa648e8df47d8f908dc8c3b42347d04ec58438f1dcd4e445f09b4f6b0ce", size = 6093557, upload-time = "2026-06-01T05:54:32.64Z" }, + { url = "https://files.pythonhosted.org/packages/a7/43/99e969a048904a65df3129ee53c5f523b7c4e43127786460cac4bee82470/grpcio-1.81.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:cd78145b7f7784661c524624f3526c9c6f891b30a4b54cb93a40806d0d0d61e9", size = 12075345, upload-time = "2026-06-01T05:54:35.77Z" }, + { url = "https://files.pythonhosted.org/packages/83/70/4c3a204e190333768d4f63f4ff56bd0bf405f05b9188f3a59a8bcf161f8b/grpcio-1.81.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:638ccc1b86f7540170a169cb900799b9296a1381e47879ce60b0de9d3db73d33", size = 6640664, upload-time = "2026-06-01T05:54:38.854Z" }, + { url = "https://files.pythonhosted.org/packages/2e/a9/0fa17ac8b4e29cf59b26915be6cab8c0d4583ce24a6208a287b6e5f6d072/grpcio-1.81.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:21ec30b9ea320c8207ea7cd05873ad64aa69fdd0e81b6758b3347983ba20b50a", size = 7332542, upload-time = "2026-06-01T05:54:41.39Z" }, + { url = "https://files.pythonhosted.org/packages/f4/18/7c8e3d0dda2fb7a17076fcd6c9085209eabad3354696c64230f87b3a14eb/grpcio-1.81.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dbdb99986548a7e87f8343805ef315fd4eb50ffaabf4fb1206e42f2542bb805d", size = 6842564, upload-time = "2026-06-01T05:54:43.57Z" }, + { url = "https://files.pythonhosted.org/packages/f6/19/2f1726c2e03ad3f3fe241e6b41534532ad580d595de14a4054ad84999c80/grpcio-1.81.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c36f5d5e97944cbda2d4096b4ae262e6e68506246b61582acf1b8591607f3ccc", size = 7446236, upload-time = "2026-06-01T05:54:46.042Z" }, + { url = "https://files.pythonhosted.org/packages/a7/dc/0321f892212e2c0bfe248cea24c00d7d7111639688ec5ffd8e36b5c02fe6/grpcio-1.81.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:9f355384e5543ab77a755a7085225ecc19f32b76032e851cbd8145715d79dec8", size = 8445633, upload-time = "2026-06-01T05:54:48.809Z" }, + { url = "https://files.pythonhosted.org/packages/e5/20/0e7ea7494955cf1beea3077b2fd2c04c84d4480c2ae85a1e1cfa150c62d7/grpcio-1.81.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:77eb4e9fe61486bd1198cc7236ebb0f70e66234e63c0348f40bc2553ed16a88b", size = 7873958, upload-time = "2026-06-01T05:54:52.135Z" }, + { url = "https://files.pythonhosted.org/packages/d3/9e/6438e226046c2a0778060e2b1d791a4827277bbd9d223013c2c63ee7435e/grpcio-1.81.0-cp311-cp311-win32.whl", hash = "sha256:7915a2e63acdc05264a206e1bddfd8e1fb8a29e406c18d72d30f8c124e021374", size = 4202110, upload-time = "2026-06-01T05:54:54.134Z" }, + { url = "https://files.pythonhosted.org/packages/42/6b/d0895e93d65b186f5f1737fcc186d7faa487e2d9d934eda111a37a309869/grpcio-1.81.0-cp311-cp311-win_amd64.whl", hash = "sha256:5e925a70fe99fe5794f7beca0ea034c75f068afcc356d79047e73f99cdcca34c", size = 4940942, upload-time = "2026-06-01T05:54:56.749Z" }, + { url = "https://files.pythonhosted.org/packages/82/d5/896a3aaf07068d707d88b282a04914b872db4d32d3c7e6d88e43a3b911fa/grpcio-1.81.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:57b3b0e73a518fa286959b40c3eddd02703504ca186e8b7b2945954519bd8b2c", size = 6053538, upload-time = "2026-06-01T05:54:58.965Z" }, + { url = "https://files.pythonhosted.org/packages/68/6a/7e3eafa4727cd405ff917605ed2949e2af162f233f5cbdd773723a5fea7d/grpcio-1.81.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:8bb1789c94322a13336a2b6c58d9c14d68f8628b6e24205a799c69f5bf8516ce", size = 12053447, upload-time = "2026-06-01T05:55:01.862Z" }, + { url = "https://files.pythonhosted.org/packages/16/79/a4302aa82428de48a922421f522b027a1a727ab4d0926368454aa953d36d/grpcio-1.81.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e4d053900a0d24b75d7521139a3872150301b3d6bde3bed5e12318fb25791e4d", size = 6595872, upload-time = "2026-06-01T05:55:04.946Z" }, + { url = "https://files.pythonhosted.org/packages/b4/1f/7ff2850eaefbecf99af3f624dbb28dd1ad6c5fd4c1d8c26909ed6482673b/grpcio-1.81.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:db217c2e52931719f9937bd12082cd4d7b495b35803d5760686975c285924bf8", size = 7303857, upload-time = "2026-06-01T05:55:07.205Z" }, + { url = "https://files.pythonhosted.org/packages/e2/98/1f3896a9baae1f2aedf4e99c55291d6fa1f30ad9603d63bc18bda967b53e/grpcio-1.81.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:19f201da7b4e5c0559198abe5a97157e726f3abe6e8f5e832d4a50740f6dcc22", size = 6809676, upload-time = "2026-06-01T05:55:09.513Z" }, + { url = "https://files.pythonhosted.org/packages/34/8b/3441983718095208c5d797fd3239882e97ea89a629f41c8df94b4eef4df9/grpcio-1.81.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:275144b0115353339dbb8a6f28a9cf8997b5bf40e37f8f66ac0b0ea57e95b43f", size = 7412654, upload-time = "2026-06-01T05:55:12.777Z" }, + { url = "https://files.pythonhosted.org/packages/3c/98/1eddf07df6e4fe85cf67502a793f7b05468b2dca3d1ef35b972cf5d54468/grpcio-1.81.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5192857589f223e5a98ff0e31f6e551b19040e647d17bfe10116c8a2ce3b8696", size = 8408026, upload-time = "2026-06-01T05:55:15.514Z" }, + { url = "https://files.pythonhosted.org/packages/5c/73/3860341e6a1f5347be6ab35c6c0e1e3a8eb59d010388207fd561dcf01a88/grpcio-1.81.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c6ff087cb1f563f47b504b4e29e684129fc5ae4863faf3ebca08a327764ee6cb", size = 7849498, upload-time = "2026-06-01T05:55:18.078Z" }, + { url = "https://files.pythonhosted.org/packages/ae/3f/0ea06bd85c701966aa3f8f37314f2ed83520d2b7590f42d643d445d8bc8b/grpcio-1.81.0-cp312-cp312-win32.whl", hash = "sha256:98c6240f563178fc5877bd50e6ff274463e53e1472128f4110742450739659fa", size = 4184161, upload-time = "2026-06-01T05:55:20.127Z" }, + { url = "https://files.pythonhosted.org/packages/39/e3/a7c387406827a86f99ad7838b995bf9b4a182ffe2d2c439ed2873efec952/grpcio-1.81.0-cp312-cp312-win_amd64.whl", hash = "sha256:87e33b7afcfb3585121b5f007d2c52b8c534104d18f556e840d35193ca2a9141", size = 4929958, upload-time = "2026-06-01T05:55:22.736Z" }, + { url = "https://files.pythonhosted.org/packages/f3/29/779ee53c931d0fd55c1d459fde43e485172caa3ac87cbd43d003a13a0185/grpcio-1.81.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:62bbe463c9f0f2ff24e31bd25f8dd8b4bae78900e315915a3195a0ef1471a855", size = 6054973, upload-time = "2026-06-01T05:55:25.043Z" }, + { url = "https://files.pythonhosted.org/packages/9e/b6/7211807926b5a17f8d9a5d47c739a163d6812fefe3e4714e81cf92945ed7/grpcio-1.81.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:43c121e135ae44d1559b430db2b2dfad7421cbbe40e1deba506c7dc62b439719", size = 12048662, upload-time = "2026-06-01T05:55:28.453Z" }, + { url = "https://files.pythonhosted.org/packages/64/89/b1b93ef6b34bd20bbaf707fa99133bc9cc302139d5ec6f77a165c7169796/grpcio-1.81.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f345de40ef2e65f63645d53d251824e6070e07804827c5b00ec2e44555f9f901", size = 6599116, upload-time = "2026-06-01T05:55:31.185Z" }, + { url = "https://files.pythonhosted.org/packages/eb/bc/c89f9b9d1c22895715356a1e009554dae66319e97826bb4d30bcda7d29e8/grpcio-1.81.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:8c0855a350886f713b9e458e2a10d208009dcaa849f574e39cd6067db1fe1279", size = 7307591, upload-time = "2026-06-01T05:55:33.463Z" }, + { url = "https://files.pythonhosted.org/packages/65/4a/1df2a4cb4a1386e066ab7e4175e34bb884b35ccb60d3621c09c84af6aabb/grpcio-1.81.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a524cd530900bd24511fcb7f2ed144da4ea37711c4b094475d0bceca7a93a170", size = 6811797, upload-time = "2026-06-01T05:55:36.731Z" }, + { url = "https://files.pythonhosted.org/packages/8d/dc/fa189d20601a1be25b08850cfb733879bbb1047b62a8feec3a60e3e1a87b/grpcio-1.81.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e7746ba3e6efc9e2b748eff59470a2b8684d5a9ec607c6580bcaa5be175820bc", size = 7415131, upload-time = "2026-06-01T05:55:39.451Z" }, + { url = "https://files.pythonhosted.org/packages/ad/a3/5625c48cb48d23c6631b3e5294f88e4c751f22a52591ae78859fab96dca1/grpcio-1.81.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:aaaa4f7f2057d795952e4eacf3f342be8b5b156992f6ac85023c8b98794ebd47", size = 8408398, upload-time = "2026-06-01T05:55:42.219Z" }, + { url = "https://files.pythonhosted.org/packages/75/34/0f8202c6809a46c2b4d69125ef3667c40b1c211f8e19930e5fa1f1197039/grpcio-1.81.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0fba53cb96004b2b7fb758b46b2288cb49d0b658316a4e73f3ef67230616ee65", size = 7844481, upload-time = "2026-06-01T05:55:44.849Z" }, + { url = "https://files.pythonhosted.org/packages/c0/95/c3366b5b5edf4c4adc90f2e29ca16e57965a8e56dc8d2ee89565ba1905bb/grpcio-1.81.0-cp313-cp313-win32.whl", hash = "sha256:c197e2ef75a442528072b29e9755da299110e8610e8bcbb59a6b4cf55384f005", size = 4182777, upload-time = "2026-06-01T05:55:47.459Z" }, + { url = "https://files.pythonhosted.org/packages/a9/a7/932f2f748511a32e641a2aba0d30dded3ed6e8bc330e0924e4d5d86853e6/grpcio-1.81.0-cp313-cp313-win_amd64.whl", hash = "sha256:194eddfacc84d80f50512e9fd4ee851d5f2499f18f299c95aa8fb4748f0537e0", size = 4928085, upload-time = "2026-06-01T05:55:50.158Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -2672,6 +2713,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/c0/4bc973defd1270b89ccaae04cef0d5fa3ea85b59b108ad2c08aeea9afb76/makefun-1.16.0-py2.py3-none-any.whl", hash = "sha256:43baa4c3e7ae2b17de9ceac20b669e9a67ceeadff31581007cca20a07bbe42c4", size = 22923, upload-time = "2025-05-09T15:00:41.042Z" }, ] +[[package]] +name = "markdown" +version = "3.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805, upload-time = "2026-02-09T14:57:26.942Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180, upload-time = "2026-02-09T14:57:25.787Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -3616,6 +3666,7 @@ dependencies = [ { name = "nvalchemi-toolkit-ops" }, { name = "nvidia-physicsnemo" }, { name = "periodictable" }, + { name = "plotext" }, { name = "plum-dispatch" }, { name = "pydantic" }, { name = "rich" }, @@ -3639,6 +3690,7 @@ cu12 = [ { name = "nvalchemi-toolkit-ops", extra = ["torch-cu12"], marker = "(sys_platform != 'darwin' and extra == 'extra-17-nvalchemi-toolkit-cu12') or (extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13')" }, { name = "nvidia-physicsnemo", extra = ["cu12"], marker = "(sys_platform != 'darwin' and extra == 'extra-17-nvalchemi-toolkit-cu12') or (extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13')" }, { name = "torch", version = "2.12.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform != 'darwin'" }, + { name = "torchvision", version = "0.27.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "sys_platform != 'darwin'" }, ] cu13 = [ { name = "cuequivariance-ops-torch-cu13", marker = "sys_platform != 'darwin'" }, @@ -3646,6 +3698,7 @@ cu13 = [ { name = "nvalchemi-toolkit-ops", extra = ["torch-cu13"], marker = "(sys_platform != 'darwin' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13')" }, { name = "nvidia-physicsnemo", extra = ["cu13"], marker = "(sys_platform != 'darwin' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13')" }, { name = "torch", version = "2.12.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "sys_platform != 'darwin'" }, + { name = "torchvision", version = "0.27.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "sys_platform != 'darwin'" }, ] mace = [ { name = "cuequivariance-torch" }, @@ -3654,6 +3707,9 @@ mace = [ pymatgen = [ { name = "pymatgen" }, ] +tensorboard = [ + { name = "tensorboard" }, +] [package.dev-dependencies] build = [ @@ -3716,21 +3772,25 @@ requires-dist = [ { name = "nvalchemi-toolkit-ops", git = "https://github.com/NVIDIA/nvalchemi-toolkit-ops.git?rev=7a73c012b7fd5bc649701d2aec802b4b9511a355" }, { name = "nvalchemi-toolkit-ops", extras = ["torch-cu12"], marker = "sys_platform != 'darwin' and extra == 'cu12'", git = "https://github.com/NVIDIA/nvalchemi-toolkit-ops.git?rev=7a73c012b7fd5bc649701d2aec802b4b9511a355" }, { name = "nvalchemi-toolkit-ops", extras = ["torch-cu13"], marker = "sys_platform != 'darwin' and extra == 'cu13'", git = "https://github.com/NVIDIA/nvalchemi-toolkit-ops.git?rev=7a73c012b7fd5bc649701d2aec802b4b9511a355" }, - { name = "nvidia-physicsnemo", specifier = ">=2.0.0" }, - { name = "nvidia-physicsnemo", extras = ["cu12"], marker = "sys_platform != 'darwin' and extra == 'cu12'", specifier = ">=2.0.0" }, - { name = "nvidia-physicsnemo", extras = ["cu13"], marker = "sys_platform != 'darwin' and extra == 'cu13'", specifier = ">=2.0.0" }, + { name = "nvidia-physicsnemo", specifier = ">=2.1.0" }, + { name = "nvidia-physicsnemo", extras = ["cu12"], marker = "sys_platform != 'darwin' and extra == 'cu12'", specifier = ">=2.1.0" }, + { name = "nvidia-physicsnemo", extras = ["cu13"], marker = "sys_platform != 'darwin' and extra == 'cu13'", specifier = ">=2.1.0" }, { name = "periodictable", specifier = "==2.0.2" }, + { name = "plotext" }, { name = "plum-dispatch", specifier = ">=2.5.7" }, { name = "pydantic", specifier = ">=2.11.7" }, { name = "pymatgen", marker = "extra == 'pymatgen'", specifier = ">=2025.10.7" }, { name = "rich", specifier = ">=13.0.0" }, + { name = "tensorboard", marker = "extra == 'tensorboard'" }, { name = "tensordict", specifier = ">=0.11.0" }, { name = "torch", specifier = ">=2.8" }, { name = "torch", marker = "sys_platform != 'darwin' and extra == 'cu12'", index = "https://download.pytorch.org/whl/cu126", conflict = { package = "nvalchemi-toolkit", extra = "cu12" } }, { name = "torch", marker = "sys_platform != 'darwin' and extra == 'cu13'", index = "https://download.pytorch.org/whl/cu130", conflict = { package = "nvalchemi-toolkit", extra = "cu13" } }, + { name = "torchvision", marker = "sys_platform != 'darwin' and extra == 'cu12'", index = "https://download.pytorch.org/whl/cu126", conflict = { package = "nvalchemi-toolkit", extra = "cu12" } }, + { name = "torchvision", marker = "sys_platform != 'darwin' and extra == 'cu13'", index = "https://download.pytorch.org/whl/cu130", conflict = { package = "nvalchemi-toolkit", extra = "cu13" } }, { name = "zarr", specifier = ">=3" }, ] -provides-extras = ["aimnet", "ase", "cu12", "cu13", "mace", "pymatgen"] +provides-extras = ["aimnet", "ase", "cu12", "cu13", "mace", "pymatgen", "tensorboard"] [package.metadata.requires-dev] build = [ @@ -4774,7 +4834,7 @@ wheels = [ [[package]] name = "nvidia-physicsnemo" -version = "2.0.0" +version = "2.1.0" source = { registry = "https://pypi.nvidia.com/" } dependencies = [ { name = "cftime" }, @@ -4800,13 +4860,16 @@ dependencies = [ { name = "torch", version = "2.12.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra != 'extra-17-nvalchemi-toolkit-cu12' and extra != 'extra-17-nvalchemi-toolkit-cu13')" }, { name = "torch", version = "2.12.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, { name = "torch", version = "2.12.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13'" }, - { name = "torchvision" }, + { name = "torchvision", version = "0.27.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra != 'extra-17-nvalchemi-toolkit-cu12' and extra != 'extra-17-nvalchemi-toolkit-cu13')" }, + { name = "torchvision", version = "0.27.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, + { name = "torchvision", version = "0.27.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13'" }, { name = "tqdm" }, { name = "treelib" }, + { name = "urllib3" }, { name = "warp-lang" }, ] wheels = [ - { url = "https://pypi.nvidia.com/nvidia-physicsnemo/nvidia_physicsnemo-2.0.0-py3-none-any.whl", hash = "sha256:fcea6ac198a2925ab81c3f62011225f53b73e1212e5364aac939ab599c0dfd9d" }, + { url = "https://pypi.nvidia.com/nvidia-physicsnemo/nvidia_physicsnemo-2.1.0-py3-none-any.whl", hash = "sha256:2e05dab3d3b4ff4427f37ff2d6802d9817c3f5200b22a8031098d84ff9d6702c" }, ] [package.optional-dependencies] @@ -4816,7 +4879,7 @@ cu12 = [ { name = "nvidia-dali-cuda120" }, { name = "pylibraft-cu12" }, { name = "torch", version = "2.12.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" } }, - { name = "torchvision" }, + { name = "torchvision", version = "0.27.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" } }, ] cu13 = [ { name = "cuml-cu13" }, @@ -4824,7 +4887,7 @@ cu13 = [ { name = "nvidia-dali-cuda130" }, { name = "pylibraft-cu13" }, { name = "torch", version = "2.12.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" } }, - { name = "torchvision" }, + { name = "torchvision", version = "0.27.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" } }, ] [[package]] @@ -5288,6 +5351,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/31/05e764397056194206169869b50cf2fee4dbbbc71b344705b9c0d878d4d8/platformdirs-4.9.2-py3-none-any.whl", hash = "sha256:9170634f126f8efdae22fb58ae8a0eaa86f38365bc57897a6c4f781d1f5875bd", size = 21168, upload-time = "2026-02-16T03:56:08.891Z" }, ] +[[package]] +name = "plotext" +version = "5.3.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/d7/f75f397af966fe252d0d34ffd3cae765317fce2134f925f95e7d6725d1ce/plotext-5.3.2.tar.gz", hash = "sha256:52d1e932e67c177bf357a3f0fe6ce14d1a96f7f7d5679d7b455b929df517068e", size = 61967, upload-time = "2024-09-24T15:13:37.728Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/1e/12fe7c40cd2099a1f454518754ed229b01beaf3bbb343127f0cc13ce6c22/plotext-5.3.2-py3-none-any.whl", hash = "sha256:394362349c1ddbf319548cfac17ca65e6d5dfc03200c40dfdc0503b3e95a2283", size = 64047, upload-time = "2024-09-24T15:13:36.296Z" }, +] + [[package]] name = "plotly" version = "6.7.0" @@ -5361,6 +5433,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/8c/83087ebc47ab0396ce092363001fa37c17153119ee282700c0713a195853/prettytable-3.17.0-py3-none-any.whl", hash = "sha256:aad69b294ddbe3e1f95ef8886a060ed1666a0b83018bbf56295f6f226c43d287", size = 34433, upload-time = "2025-11-14T17:33:19.093Z" }, ] +[[package]] +name = "protobuf" +version = "7.35.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/60/fd/5b1491d9e4b586d621c54f4c36b888714164b6875f8d6afa3f9072906a51/protobuf-7.35.0.tar.gz", hash = "sha256:a2efd84605f41e559f1881b0912b44099d0a2ac9bf46b3474823f10fb393b0e6", size = 458677, upload-time = "2026-05-19T23:02:29.197Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl", hash = "sha256:66be6c513931c794fa92c080ffee41671390da3d79da219cf9c0c0907f035dda", size = 433225, upload-time = "2026-05-19T23:02:19.884Z" }, + { url = "https://files.pythonhosted.org/packages/8b/39/1c76c2da93f3c507e958e0aecee2391cc44d4625de6c728bbc555195b5a8/protobuf-7.35.0-cp310-abi3-manylinux2014_aarch64.whl", hash = "sha256:fcbe42a4ac09d3ec9c987ddfcd956afd0b15f1ff613bd8371bde9405ffd5c8e5", size = 328847, upload-time = "2026-05-19T23:02:22.3Z" }, + { url = "https://files.pythonhosted.org/packages/91/1a/39f7ce90a238c1a987a4d81ec26379e02ca0aff367de68e4a1fa474215b9/protobuf-7.35.0-cp310-abi3-manylinux2014_s390x.whl", hash = "sha256:4cbf5cc286130e06a6c9bbefac442431173906dfcc979712183d4adcc01b37ee", size = 344030, upload-time = "2026-05-19T23:02:23.591Z" }, + { url = "https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl", hash = "sha256:6c0f98f10c8a05ea30f8993dfef2de093d27b490fdae78bb60c8343795d55011", size = 327130, upload-time = "2026-05-19T23:02:24.637Z" }, + { url = "https://files.pythonhosted.org/packages/8e/e5/e46adb0badc388bfb84877a5f9f026aff63f60e611016cf64dbe77e05446/protobuf-7.35.0-cp310-abi3-win32.whl", hash = "sha256:4c4617b83ade0e279d1d2bfe04025a1adb87f9ed657de038620dc0ff959357f6", size = 428946, upload-time = "2026-05-19T23:02:25.741Z" }, + { url = "https://files.pythonhosted.org/packages/a7/ab/547fbd9e16d879dd13c167478f8ae0a83a428008ca07a5e06acdc23ad473/protobuf-7.35.0-cp310-abi3-win_amd64.whl", hash = "sha256:f05bcadf9a2a6b8dda047007075135fb7d08c73d9177aabc067e1be46881a201", size = 439996, upload-time = "2026-05-19T23:02:26.808Z" }, + { url = "https://files.pythonhosted.org/packages/b8/ef/50433d346c56657a70d27f156c7b349ac59a068b01de4eb796e747eecc43/protobuf-7.35.0-py3-none-any.whl", hash = "sha256:c13f325cf242bad135c350629eeb5d54b24228eb472fb3e2e9ebbd4c5dc20ca0", size = 171659, upload-time = "2026-05-19T23:02:27.842Z" }, +] + [[package]] name = "propcache" version = "0.5.2" @@ -6905,6 +6992,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, ] +[[package]] +name = "tensorboard" +version = "2.20.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "grpcio" }, + { name = "markdown" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "protobuf" }, + { name = "setuptools" }, + { name = "tensorboard-data-server" }, + { name = "werkzeug" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/d9/a5db55f88f258ac669a92858b70a714bbbd5acd993820b41ec4a96a4d77f/tensorboard-2.20.0-py3-none-any.whl", hash = "sha256:9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6", size = 5525680, upload-time = "2025-07-17T19:20:49.638Z" }, +] + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", size = 2356, upload-time = "2023-10-23T21:23:32.16Z" }, + { url = "https://files.pythonhosted.org/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", size = 4823598, upload-time = "2023-10-23T21:23:33.714Z" }, + { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363, upload-time = "2023-10-23T21:23:35.583Z" }, +] + [[package]] name = "tensordict" version = "0.11.0" @@ -6969,7 +7086,9 @@ dependencies = [ { name = "torch", version = "2.12.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra != 'extra-17-nvalchemi-toolkit-cu12' and extra != 'extra-17-nvalchemi-toolkit-cu13')" }, { name = "torch", version = "2.12.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, { name = "torch", version = "2.12.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13'" }, - { name = "torchvision" }, + { name = "torchvision", version = "0.27.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra != 'extra-17-nvalchemi-toolkit-cu12' and extra != 'extra-17-nvalchemi-toolkit-cu13')" }, + { name = "torchvision", version = "0.27.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, + { name = "torchvision", version = "0.27.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/08/54/ece85b0eef3700c90db8271a43669b05a0ebbe2edb1962329c34374a297e/timm-1.0.27.tar.gz", hash = "sha256:315dfe63186ca9fb7ff941268941231fd5be259f2b4bb4afa28560ae1015cb9a", size = 2439861, upload-time = "2026-05-08T19:38:36.844Z" } wheels = [ @@ -7264,13 +7383,30 @@ wheels = [ name = "torchvision" version = "0.27.0" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine != 's390x' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine != 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, - { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13' or extra != 'extra-17-nvalchemi-toolkit-cu12'" }, - { name = "pillow" }, + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra != 'extra-17-nvalchemi-toolkit-cu12' and extra != 'extra-17-nvalchemi-toolkit-cu13')" }, + { name = "pillow", marker = "(extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra != 'extra-17-nvalchemi-toolkit-cu12' and extra != 'extra-17-nvalchemi-toolkit-cu13')" }, { name = "torch", version = "2.12.0", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-17-nvalchemi-toolkit-cu12' and extra == 'extra-17-nvalchemi-toolkit-cu13') or (extra != 'extra-17-nvalchemi-toolkit-cu12' and extra != 'extra-17-nvalchemi-toolkit-cu13')" }, - { name = "torch", version = "2.12.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, - { name = "torch", version = "2.12.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/cf/d6/a7e71e981042d5c573e2e61891b9023b190c88adb75b18bed8594371250c/torchvision-0.27.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:df0c166b6bdf7c47f88e81e8b43bc085451d5c50d0c5d1691bc474c1227d6fed", size = 1758812, upload-time = "2026-05-13T14:57:16.662Z" }, @@ -7291,6 +7427,130 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/70/01b6461117a6a94b5af3f8ee166bb0f045056f3cf187750c110dabfdfffa/torchvision-0.27.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a49e55055a39a8506fe7e59850522cab004efb2c3839f6057658889c1d69c815", size = 4141602, upload-time = "2026-05-13T14:56:53.449Z" }, ] +[[package]] +name = "torchvision" +version = "0.27.0+cu126" +source = { registry = "https://download.pytorch.org/whl/cu126" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, + { name = "pillow", marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, + { name = "torch", version = "2.12.0+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, +] +wheels = [ + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:cb9c6377ff8d1716689a58f641a5ccc74e58f7c8c0d1495139d7ca3bc055754d", upload-time = "2026-05-12T16:20:41Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:70e142b5ab5dea7f70dba395f1cee17eb43f58f4c6c625e368b626b41b6f6c3b", upload-time = "2026-05-12T16:20:41Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp311-cp311-win_amd64.whl", hash = "sha256:6cb74e3accf038fb375273f2bc31d6128dfb00824c8ce8264d9d0fce051e9fb7", upload-time = "2026-05-13T02:00:38Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a4bcd3ea7e9124fb40674dd143a3a28cbde63adc8de6d6ffe1d6810cd40032be", upload-time = "2026-05-12T16:20:41Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d4d03bbe04a2a9320554f31e6219638f869fc289c175388525cb49ac589ee027", upload-time = "2026-05-12T16:20:41Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp312-cp312-win_amd64.whl", hash = "sha256:038691814aef031fddb1c654cc514168b375067840ce189f03de0382f6a72c13", upload-time = "2026-05-13T02:00:38Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:313c8fbc1fa7b0e5192752601e91c3c9987f6f5ee1342691b465e0c33653a307", upload-time = "2026-05-12T16:20:42Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:56477bd091009afc733931724d66c56b21c9fa14ba2c3a1ec24c8ddce86b5cd8", upload-time = "2026-05-12T16:20:42Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp313-cp313-win_amd64.whl", hash = "sha256:b92a80f74b638f6e8c29b1319eb69701ceea98c0ac3c166ac8ef3f45a0493400", upload-time = "2026-05-13T02:00:39Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:af5367582f4189ec76b3ec0ef9b3503a55b03e57e05cdea62d44e12cacbf4b8a", upload-time = "2026-05-12T16:20:42Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:836e7bb5c54238cb810bc263529f63b4b6ed8d183d5c764d5368902fb1acab19", upload-time = "2026-05-12T16:20:42Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torchvision-0.27.0%2Bcu126-cp313-cp313t-win_amd64.whl", hash = "sha256:3335086359b4e210ebd6240cb383c63ad265345cb8f8041bf0fe876822a6ab4d", upload-time = "2026-05-13T02:00:40Z" }, +] + +[[package]] +name = "torchvision" +version = "0.27.0+cu130" +source = { registry = "https://download.pytorch.org/whl/cu130" } +resolution-markers = [ + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version == '3.12.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and platform_machine == 'aarch64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine != 'aarch64' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'win32'", +] +dependencies = [ + { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13'" }, + { name = "pillow", marker = "extra == 'extra-17-nvalchemi-toolkit-cu13'" }, + { name = "torch", version = "2.12.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13'" }, +] +wheels = [ + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2aab6d1ce1c476b6e5ddba884d5b65e6819ca3db58ad4d9f863aba102d487a1d", upload-time = "2026-05-12T16:20:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:f90237398efb8ce7001b80e1870c921b3a375d91c892ba8b46415f8085a3711d", upload-time = "2026-05-12T16:20:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp311-cp311-win_amd64.whl", hash = "sha256:cf6b38f3828868962e5469800353be923983ff90a34c9a1ceebc83fafd662e79", upload-time = "2026-05-13T02:00:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0a839a2921410b1135add4c3d90f784c9d1e9e9f3c7b401b216d356ddca23ab2", upload-time = "2026-05-12T16:20:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:664dff46fac97a730c90a976a370ae2cad52780df6ae40fad74be77eee8b4528", upload-time = "2026-05-12T16:20:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp312-cp312-win_amd64.whl", hash = "sha256:a79f78d23557b5299c1a1eceeef846d6799ea0a3afe30c600c80ebd26a80bbf8", upload-time = "2026-05-13T02:00:45Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:da81245777c47f6dfd60e02f510d9778fb7f6e23119e2fc1ea1bb06777aae338", upload-time = "2026-05-12T16:20:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:afa4128f37066b83af9d426841a53147dd3c208efea893c93dc3eb6fa2af2287", upload-time = "2026-05-12T16:20:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp313-cp313-win_amd64.whl", hash = "sha256:31533c28f23bf642989a9ae12caa40a2f8cc9b443d556ba2ffb7a51f759e6a11", upload-time = "2026-05-13T02:00:46Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:bb511f033cd3d6f304dc25753d2a28a1d77aa4dd54a219242d9df7fa57d8dd0a", upload-time = "2026-05-12T16:20:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:0c375ac4e9a1c09308f81b73d111d50b76eec335dc91a1811ae370467db2cf47", upload-time = "2026-05-12T16:20:45Z" }, + { url = "https://download-r2.pytorch.org/whl/cu130/torchvision-0.27.0%2Bcu130-cp313-cp313t-win_amd64.whl", hash = "sha256:34d108e1ce8255e017bf1f732a51ab2e9ddffb443d118db499a0fbbeb0164650", upload-time = "2026-05-13T02:00:47Z" }, +] + [[package]] name = "tqdm" version = "4.67.3" @@ -7434,11 +7694,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.6.3" +version = "2.7.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/0c/06f8b233b8fd13b9e5ee11424ef85419ba0d8ba0b3138bf360be2ff56953/urllib3-2.7.0.tar.gz", hash = "sha256:231e0ec3b63ceb14667c67be60f2f2c40a518cb38b03af60abc813da26505f4c", size = 433602, upload-time = "2026-05-07T16:13:18.596Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, + { url = "https://files.pythonhosted.org/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087, upload-time = "2026-05-07T16:13:17.151Z" }, ] [[package]] @@ -7467,17 +7727,17 @@ wheels = [ [[package]] name = "warp-lang" -version = "1.11.1" +version = "1.14.0" source = { registry = "https://pypi.nvidia.com/" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu12'" }, { name = "numpy", version = "2.4.2", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-17-nvalchemi-toolkit-cu13' or extra != 'extra-17-nvalchemi-toolkit-cu12'" }, ] wheels = [ - { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1ad11f1fa775269e991a3d55039152c8a504baf86701c849b485cb8e66c49d15" }, - { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:8b098f41e71d421d80ee7562e38aa8380ff6b0d3b4c6ee866cfbdef733ac5bdc" }, - { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.1-py3-none-manylinux_2_34_aarch64.whl", hash = "sha256:5d0904b0eefcc81f39ba65375427a3de99006088aa43e24a9011263f07d0cd07" }, - { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.11.1-py3-none-win_amd64.whl", hash = "sha256:15dc10aa51fb0fdbe1ca16d52e5fadca35a47ffd9d0c636826506f96bb2e7c41" }, + { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.14.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:12656050545cc77bf9b9b155399496c1a6279b5b6c59e407507d6858a2beb4a2" }, + { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.14.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:70cd127d0e9109417099649fedf9d00f39f1307ccb7a6e9fb87661337868d7de" }, + { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.14.0-py3-none-manylinux_2_34_aarch64.whl", hash = "sha256:f482787e8da9c9ef045601fde99095e16d604fbcc3cbb4a1e0cef0769388b316" }, + { url = "https://pypi.nvidia.com/warp-lang/warp_lang-1.14.0-py3-none-win_amd64.whl", hash = "sha256:936b49ec78237f9760e58cbe9c46ee6f4244aefbd62071c4fa9fd3b313dfa878" }, ] [[package]] @@ -7489,6 +7749,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" }, ] +[[package]] +name = "werkzeug" +version = "3.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dd/b2/381be8cfdee792dd117872481b6e378f85c957dd7c5bca38897b08f765fd/werkzeug-3.1.8.tar.gz", hash = "sha256:9bad61a4268dac112f1c5cd4630a56ede601b6ed420300677a869083d70a4c44", size = 875852, upload-time = "2026-04-02T18:49:14.268Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/8c/2e650f2afeb7ee576912636c23ddb621c91ac6a98e66dc8d29c3c69446e1/werkzeug-3.1.8-py3-none-any.whl", hash = "sha256:63a77fb8892bf28ebc3178683445222aa500e48ebad5ec77b0ad80f8726b1f50", size = 226459, upload-time = "2026-04-02T18:49:12.72Z" }, +] + [[package]] name = "wheel" version = "0.46.3"