Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
193 commits
Select commit Hold shift + click to select a range
3c2a036
feat(training): add TrainingStage enum for hook-lifecycle dispatch
laserkelvin Apr 22, 2026
47eb1fc
test(training): cover TrainingStage registration and stage isolation
laserkelvin Apr 22, 2026
6ee340f
docs(training): add sphinx stub for nvalchemi.training module
laserkelvin Apr 22, 2026
820c16c
docs: removing unused documentation
laserkelvin Apr 22, 2026
6c77b35
feat(training): add BaseSpec and create_model_spec factories for no-p…
laserkelvin Apr 23, 2026
c0afd01
feat(training): add strict-mode build() and FromSpecMixin opt-in cons…
laserkelvin Apr 23, 2026
08e5824
feat(training): add no-pickle save/load_checkpoint with path-footgun …
laserkelvin Apr 23, 2026
631d3f2
test(training): add BaseSpec, FromSpecMixin, and checkpoint test suit…
laserkelvin Apr 23, 2026
6aa587d
feat(training): re-export BaseSpec, FromSpecMixin, and checkpoint API…
laserkelvin Apr 23, 2026
8f8bdb7
fix(training): resolve nested-class qualnames via greedy module-prefi…
laserkelvin Apr 23, 2026
c3b00db
Merge pull request #1 from laserkelvin/feat-training-stage-enum
laserkelvin Apr 23, 2026
d0254e6
docs: rewriting docs for clarity
laserkelvin Apr 24, 2026
4d28a7f
refactor & fix: adding validator method back and removing mixin class
laserkelvin Apr 24, 2026
15b70cb
refactor: not allowing spec.json to walk and improving error messages
laserkelvin Apr 24, 2026
da8c1a0
test: adding more unit test cases
laserkelvin Apr 24, 2026
89da5d9
feat: adding checkpoint save load workflow
laserkelvin Apr 27, 2026
d317e99
Merge remote-tracking branch 'fork/training-epic' into feat-base-spec…
laserkelvin Apr 28, 2026
857a4d5
refactor(training): drop init_hash from BaseSpec in favor of stable s…
laserkelvin Apr 28, 2026
e83c48c
feat(training): add map_location and model_name kwargs to load_checkp…
laserkelvin Apr 28, 2026
ec20338
fix(training): rehydrate unannotated dtype/device kwargs via deserial…
laserkelvin Apr 28, 2026
9dcc927
test(training): add PR #2 review regression tests
laserkelvin Apr 28, 2026
01ac71e
feat(training/losses): add loss-function foundation with graph-aware …
laserkelvin Apr 27, 2026
0037777
refactor: modifying schedule abstraction
laserkelvin Apr 28, 2026
e11cd49
refactor(training): generalize load_checkpoint to model_names set and…
laserkelvin Apr 29, 2026
9aa4e59
refactor(training/losses): organize loss utilities and reductions
laserkelvin Apr 29, 2026
717a18a
perf(training/losses): remove GPU sync and scratch allocations from r…
laserkelvin Apr 29, 2026
d923522
refactor(training/losses): protocol-based weight schedule and docs co…
laserkelvin Apr 29, 2026
965c3c9
Merge branch 'feat-base-spec-serialization' into feat-loss-functions
laserkelvin Apr 29, 2026
080b9ba
Merge pull request #2 from laserkelvin/feat-base-spec-serialization
laserkelvin Apr 29, 2026
bd8dcfd
refactor(training/losses): adopt batch-first signature, drop HookContext
laserkelvin Apr 29, 2026
8085a73
feat(training/losses): add EnergyLoss, ForceLoss, StressLoss concrete…
laserkelvin Apr 29, 2026
9745099
Merge branch 'training-epic' into feat-loss-functions
laserkelvin Apr 29, 2026
e98d2b8
refactor(training/losses): simplify to user-defined _forward with opt…
laserkelvin Apr 30, 2026
5788741
feat(training/losses): add opt-in ignore_nan masking for missing targ…
laserkelvin Apr 30, 2026
430cfc4
test(training/losses): verify create_model_spec round-trips concrete …
laserkelvin Apr 30, 2026
c651187
refactor: reaching desired abstraction for loss and composed loss
laserkelvin Apr 30, 2026
320267c
refactor: jaxtyping on terms
laserkelvin Apr 30, 2026
41fdfd3
docs: made docstrings for base classes more human
laserkelvin Apr 30, 2026
265c5ef
docs(training): add user guide for loss functions
laserkelvin Apr 30, 2026
0d07bea
refactor(training/losses): opt-in shape validation via assert_same_shape
laserkelvin May 1, 2026
a170225
refactor(training/losses): move weight scheduling from leaves to comp…
laserkelvin May 1, 2026
2b9cb29
refactor(training/losses): accept Batch metadata in tensor losses
laserkelvin May 1, 2026
e2c73c1
docs(training/losses): update user guide for tensor-first API
laserkelvin May 4, 2026
7613f3d
feat(training/losses): expose per-sample loss tensors for diagnostics
laserkelvin May 4, 2026
9e3f785
refactor(training/losses): tighten jaxtyping hints on reductions helpers
laserkelvin May 5, 2026
203ba59
docs: improving validation and docstrings in reductions
laserkelvin May 5, 2026
8271295
refactor(training/losses): align reduction output shape conventions
laserkelvin May 5, 2026
dbfb074
fix: making per-atom normalization match intention
laserkelvin May 5, 2026
25f7c74
refactor: removing step and epoch from loss signature
laserkelvin May 5, 2026
d707547
docs: revising documentation on loss implementation
laserkelvin May 5, 2026
507041c
feat(training): add runtime primitives for strategies
laserkelvin May 7, 2026
bc31a4c
feat(training): add TrainingStrategy orchestration
laserkelvin May 7, 2026
edbfe75
refactor: adding strict shape testing
laserkelvin May 7, 2026
3bd36cc
feat(training): add MAE energy and L2-norm force loss terms
laserkelvin May 8, 2026
15bf2ee
test(training): cover MAE energy and L2-norm force loss terms
laserkelvin May 8, 2026
1c764cc
docs(training): document MAE energy and L2-norm force loss terms
laserkelvin May 8, 2026
a02882c
Merge pull request #3 from laserkelvin/feat-loss-functions
laserkelvin May 8, 2026
9396caa
Merge remote-tracking branch 'fork/training-epic' into feat-mae-l2-lo…
laserkelvin May 8, 2026
26a2f94
Merge remote-tracking branch 'fork/training-epic' into feat-training-…
laserkelvin May 8, 2026
6586fc7
refactor(training): cache HookContext per batch and add optimizer plu…
laserkelvin May 11, 2026
f9b7a2a
feat(training): add DO_BACKWARD and DO_OPTIMIZER_STEP stages with exc…
laserkelvin May 11, 2026
84f1bc6
feat(training/hooks): add TrainingUpdateHook framework with plum-disp…
laserkelvin May 13, 2026
68bce43
feat(training/strategy): integrate TrainingUpdateOrchestrator with au…
laserkelvin May 13, 2026
b5f2ef3
test(training/hooks): cover TrainingUpdateHook framework and orchestr…
laserkelvin May 13, 2026
0d347bd
fix(training): harden serialization primitives
laserkelvin May 13, 2026
b14cce1
docs: clarifying docstring for model in hook context
laserkelvin May 13, 2026
dbf837f
feat(training): add TrainingStrategy orchestration
laserkelvin May 7, 2026
f3e719d
Merge origin/main into training-epic
laserkelvin May 17, 2026
9f79849
Merge training-epic into feat-training-runtime-primitives
laserkelvin May 18, 2026
15dcd2c
Merge pull request #4 from laserkelvin/feat-training-runtime-primitives
laserkelvin May 18, 2026
498e138
Merge branch 'training-epic' into feat-training-update-orchestrator
laserkelvin May 18, 2026
efaf180
Merge remote-tracking branch 'fork/training-epic' into feat-training-…
laserkelvin May 18, 2026
dda8374
Address training strategy review feedback
laserkelvin May 18, 2026
ea53486
Harden restored model specs
laserkelvin May 18, 2026
04f03b9
Preserve composed loss weights in specs
laserkelvin May 18, 2026
85b2f63
Preserve training model call mode in specs
laserkelvin May 18, 2026
ba81a7c
Support ModuleDict in optimizer setup
laserkelvin May 18, 2026
a091c31
Reject empty optimizer configs
laserkelvin May 18, 2026
6344265
Cover training strategy validation gaps
laserkelvin May 18, 2026
108ebcb
Restore strategy validation messages
laserkelvin May 18, 2026
368c856
Cache constructor serialization introspection
laserkelvin May 18, 2026
7d0f6b2
Avoid duplicate freeze parameter traversal
laserkelvin May 19, 2026
372ebbb
feat(training): add MixedPrecisionHook
laserkelvin May 11, 2026
8271a08
test(training): extract shared training test fixtures to conftest
laserkelvin May 12, 2026
2a5fe45
fix(training/hooks): skip post-backward update veto validation
laserkelvin May 13, 2026
45f2b5d
fix(training): align AMP unscale with optimizer steps
laserkelvin May 20, 2026
d149a42
docs(training): document mixed precision hooks
laserkelvin May 20, 2026
083dadd
fix(training): narrow AMP autocast scope
laserkelvin May 20, 2026
35ec6b2
refactor(training): dispatch mixed precision hook stages
laserkelvin May 20, 2026
0b80ce0
feat(training): add EMAHook core for exponential moving average
laserkelvin May 12, 2026
00a8fce
feat(training): add EMAHook state_dict and load_state_dict for checkp…
laserkelvin May 12, 2026
48930b4
test(training): add EMAHook unit tests
laserkelvin May 13, 2026
ca211ea
docs(training): document EMAHook checkpoint recipe
laserkelvin May 13, 2026
3362e01
Merge pull request #5 from laserkelvin/feat-training-strategy-orchest…
laserkelvin May 20, 2026
f13df53
Merge remote-tracking branch 'fork/training-epic' into feat-training-…
laserkelvin May 20, 2026
ec9bb80
docs: improving docstrings for training update hook
laserkelvin May 21, 2026
73b9995
fix(training): clear train context after batch failures
laserkelvin May 21, 2026
99c35be
fix(training): expose optimizer step skip state
laserkelvin May 21, 2026
5be961f
fix(training): preserve update hook insertion order
laserkelvin May 21, 2026
c08544b
refactor(training): dispatch update stages directly
laserkelvin May 21, 2026
d71d283
docs(training): clarify update stage context
laserkelvin May 21, 2026
083d03d
test(training): cover update stage ownership
laserkelvin May 21, 2026
96e0df9
refactor: removing skipping attributes from training context
laserkelvin May 22, 2026
c06d985
refactor(training): use match for update stage dispatch
laserkelvin May 22, 2026
0f85a21
feat(training): expose single-batch training flow
laserkelvin May 22, 2026
bc1bd00
docs(training): document update hook constraints
laserkelvin May 22, 2026
ef5831d
refactor(training): clarify optimizer lifecycle boundaries
laserkelvin May 22, 2026
3fb554f
fix(training): run to target step count
laserkelvin May 23, 2026
18f79cf
fix(training): resume dataloader epochs deterministically
laserkelvin May 23, 2026
294b905
fix(training): align step targets with optimizer updates
laserkelvin May 23, 2026
4cbbd33
feat(training): add MixedPrecisionHook
laserkelvin May 11, 2026
9316934
test(training): extract shared training test fixtures to conftest
laserkelvin May 12, 2026
cebaa1b
fix(training): align AMP unscale with optimizer steps
laserkelvin May 20, 2026
627b16a
docs(training): document mixed precision hooks
laserkelvin May 20, 2026
ee9c3e0
fix(training): narrow AMP autocast scope
laserkelvin May 20, 2026
75ea950
refactor(training): dispatch mixed precision hook stages
laserkelvin May 20, 2026
c63e3c1
test(training): align mixed precision tests with train batch helper
laserkelvin May 22, 2026
8b75aad
fix(training): prevent duplicate mixed precision hooks
laserkelvin May 22, 2026
2396c0c
test(training): align update hook API expectations
laserkelvin May 26, 2026
14580ac
Merge pull request #9 from laserkelvin/feat-training-update-orchestrator
laserkelvin May 27, 2026
b9aa80b
Merge remote-tracking branch 'fork/training-epic' into feat-mixed-pre…
laserkelvin May 27, 2026
ad7ba4c
test: consolidating and using existing device fixture
laserkelvin May 28, 2026
92eaa19
Merge pull request #7 from laserkelvin/feat-mixed-precision-hook
laserkelvin May 28, 2026
c85c44f
Merge branch 'feat-training-update-orchestrator' into feat-ema-hook
laserkelvin May 28, 2026
03b307b
Merge remote-tracking branch 'fork/training-epic' into feat-ema-hook
laserkelvin May 28, 2026
cdbe62e
Merge remote-tracking branch 'fork/training-epic' into feat-ema-hook
laserkelvin May 28, 2026
6b810c1
feat(training): add strategy checkpoint restart loading
laserkelvin May 28, 2026
e441123
fix(training): restore checkpoint restart consistency
laserkelvin May 28, 2026
690b5d3
docs(training): note checkpoint restart workflow
laserkelvin May 28, 2026
8acb8b2
Merge pull request #8 from laserkelvin/feat-ema-hook
laserkelvin May 29, 2026
9720c5b
fix(data): generate edge rows in io benchmark
laserkelvin May 29, 2026
af3095d
refactor(data): profile io benchmark readback
laserkelvin May 29, 2026
a424720
refactor(data): batch zarr dataloader reads
laserkelvin May 29, 2026
01bc4f3
feat(data): compare zarr readback modes
laserkelvin May 30, 2026
e1a23e8
docs(data): document zarr readback modes
laserkelvin May 30, 2026
849adf0
docs(data): refresh zarr benchmark examples
laserkelvin May 30, 2026
eb51e24
Merge remote-tracking branch 'origin/main' into training-epic
laserkelvin May 30, 2026
e774991
Merge remote-tracking branch 'fork/training-epic' into feat-checkpoin…
laserkelvin May 31, 2026
35f76ee
feat(training): add periodic checkpoint hook
laserkelvin Jun 2, 2026
def6893
fix(training): respect checkpoint hook lifecycle
laserkelvin Jun 2, 2026
2b64eac
fix(training): make checkpoint hook cadence explicit
laserkelvin Jun 2, 2026
03b5b8e
refactor: simplifying mutual exclusion
laserkelvin Jun 2, 2026
5263c85
test(training): cover checkpoint hook restart cycles
laserkelvin Jun 2, 2026
007f473
feat(training): add strategy checkpoint helpers
laserkelvin Jun 2, 2026
ff64226
docs: adding explicit note about hook state persistence
laserkelvin Jun 2, 2026
63429bb
feat(data): benchmark shuffled zarr readback
laserkelvin Jun 2, 2026
6bf3e79
refactor(training): rename loss classes and harmonize ignore_nonfinite
laserkelvin Jun 3, 2026
144e0e8
fix(training): align EnergyMAELoss per_atom reduction with atom-count…
laserkelvin Jun 3, 2026
a8a115a
refactor(training): extract template-method pattern from BaseLossFunc…
laserkelvin Jun 3, 2026
2dfb7a2
docs(training): document custom mask, reduce, and plum dispatch patterns
laserkelvin Jun 3, 2026
67ee763
docs(skills): add nvalchemi-loss-api agent skill
laserkelvin Jun 3, 2026
cd1e9d5
docs(training): document distributed checkpoint semantics
laserkelvin Jun 3, 2026
6159191
Merge pull request #12 from laserkelvin/feat-checkpoint-loading
laserkelvin Jun 3, 2026
36bee78
Merge remote-tracking branch 'origin/main' into training-epic
laserkelvin Jun 3, 2026
46ed09a
Merge remote-tracking branch 'fork/training-epic' into training-epic
laserkelvin Jun 3, 2026
782d3dd
Merge remote-tracking branch 'fork/training-epic' into feat-mae-l2-lo…
laserkelvin Jun 3, 2026
b80e5fb
fix(training): update merged test files with renamed loss classes
laserkelvin Jun 3, 2026
b2165c5
Merge pull request #6 from laserkelvin/feat-mae-l2-loss-terms
laserkelvin Jun 3, 2026
fe7124e
feat(data): add read-only subcommand to nvalchemi-io-test CLI
laserkelvin Jun 4, 2026
d11c304
feat(training): add distributed manager DDP support
laserkelvin Jun 4, 2026
99e3a70
docs(training): add distributed manager guide
laserkelvin Jun 4, 2026
b901f9e
fix(training): unwrap DDP models for checkpoints
laserkelvin Jun 4, 2026
8f82416
fix(training): avoid duplicating manager in train context
laserkelvin Jun 4, 2026
16c01a0
fix(training): keep dataloader on strategy workflow
laserkelvin Jun 4, 2026
85891dc
feat(training): generalize DDP sampler configuration
laserkelvin Jun 4, 2026
c92b754
refactor: adding batch method from raw dicts
laserkelvin Jun 4, 2026
88ac068
refactor: adding batch method from raw dicts
laserkelvin Jun 4, 2026
938c253
test: adding unit tests for mega prefetch
laserkelvin Jun 4, 2026
30be8e0
refactor: modifying dataset and dataloader to work with megaprefetching
laserkelvin Jun 4, 2026
891f88f
docs(training): add DDP MLP example
laserkelvin Jun 4, 2026
3d1e779
fix(training): initialize DDP example from env
laserkelvin Jun 4, 2026
351fb2e
fix(training): avoid env reads in DDP example
laserkelvin Jun 4, 2026
59cb7c3
refactor(data): review fixes, double-buffer prefetch, read amplificat…
laserkelvin Jun 4, 2026
4736df6
docs(training): improve DDP example pedagogy
laserkelvin Jun 4, 2026
aa3bee4
docs: adding documentation on zarr perf tuning
laserkelvin Jun 4, 2026
5bef612
docs: updating agent skills to include zarr perf tuning
laserkelvin Jun 4, 2026
fafc2ed
fix(training): simplify DDP sampler injection
laserkelvin Jun 5, 2026
9946cb2
fix(data): propagate field-level metadata through skip_validation path
laserkelvin Jun 5, 2026
f0cbd6b
test(data): add coverage for field_levels in from_raw_dicts and Zarr …
laserkelvin Jun 5, 2026
03d695c
Merge pull request #14 from laserkelvin/feat-distributed-manager
laserkelvin Jun 5, 2026
e39dad8
perf(data): optimize shuffled Zarr reads
laserkelvin Jun 5, 2026
78a1900
refactor(data): simplify fused prefetch loader API
laserkelvin Jun 5, 2026
a17edd8
refactor(data): clarify reader batch loading hooks
laserkelvin Jun 6, 2026
21c3d23
docs(data): explain reader batch loading pipeline
laserkelvin Jun 6, 2026
2e608fe
docs(data): refresh Zarr read tuning guide
laserkelvin Jun 6, 2026
563bd60
docs(data): update Zarr performance agent skill
laserkelvin Jun 6, 2026
1a0a8ff
docs(data): refresh datapipes API guide
laserkelvin Jun 6, 2026
f47ef03
Merge pull request #13 from laserkelvin/fix-io-edge-roundtrip-profiling
laserkelvin Jun 6, 2026
1f07d31
add unweighted component loss to allow monitoring during validation
ys-teh Jun 6, 2026
16dbfe0
resolve ewald cache shape issue
ys-teh Jun 6, 2026
23bceaf
add transforms function to dataset
ys-teh Jun 6, 2026
b2f91b2
support runtime strategy restore for periodic checkpoints
ys-teh Jun 6, 2026
dab08ad
adds mace training example
ys-teh Jun 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .claude/skills/nvalchemi-data-storage/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,23 @@ Iterates over a `Dataset` in batches, producing `Batch` objects.
```python
from nvalchemi.data.datapipes import AtomicDataZarrReader, Dataset, DataLoader

reader = AtomicDataZarrReader("dataset.zarr")
ds = Dataset(reader, device="cuda", num_workers=4)
reader = AtomicDataZarrReader("dataset.zarr", pin_memory=True)
ds = Dataset(reader, device="cuda", num_workers=1)

loader = DataLoader(
ds,
batch_size=32,
shuffle=True,
drop_last=False,
sampler=None, # optional torch Sampler
prefetch_factor=2, # batches to prefetch ahead
num_streams=4, # CUDA streams for prefetching
prefetch_factor=16, # fuse 16 batches per read_many call
num_streams=2, # CUDA streams for prefetching
use_streams=True, # enable stream prefetching
)

# For throughput tuning (skip_validation, prefetch_factor, chunk/shard
# sizing), load the nvalchemi-zarr-perf agent skill.

for batch in loader:
# batch is a Batch with concatenated tensors on target device
print(batch.num_graphs, batch.num_nodes)
Expand Down
221 changes: 221 additions & 0 deletions .claude/skills/nvalchemi-loss-api/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
---
name: nvalchemi-loss-api
description: How to use built-in loss functions and implement custom losses using the BaseLossFunction template-method pattern — residual types, per-atom normalization, masking, and graph-balanced reductions.
---

# nvalchemi Loss API

## Overview

Loss functions are `torch.nn.Module` subclasses rooted at `BaseLossFunction`.
Each leaf consumes `(pred, target, **kwargs)` and returns a scalar.
`ComposedLossFunction` routes keyed prediction/target mappings to leaves,
applies per-component weights (float or `LossWeightSchedule`), and returns
a `ComposedLossOutput` TypedDict.

```python
from nvalchemi.training import (
BaseLossFunction,
ComposedLossFunction,
ReductionContext,
EnergyMSELoss,
EnergyMAELoss,
ForceMSELoss,
ForceL2NormLoss,
StressMSELoss,
)
```

---

## Built-in losses

| Class | Target shape | Residual | Key defaults | Extra knobs |
|---|---|---|---|---|
| `EnergyMSELoss` | `(B, 1)` | squared | `energy` / `predicted_energy` | `per_atom`, `ignore_nonfinite` |
| `EnergyMAELoss` | `(B, 1)` or `(B,)` | absolute | `energy` / `predicted_energy` | `per_atom`, `ignore_nonfinite` |
| `ForceMSELoss` | `(V, 3)` or `(B, V_max, 3)` | squared component | `forces` / `predicted_forces` | `normalize_by_atom_count`, `ignore_nonfinite` |
| `ForceL2NormLoss` | `(V, 3)` or `(B, V_max, 3)` | vector L2 norm | `forces` / `predicted_forces` | `normalize_by_atom_count`, `ignore_nonfinite` |
| `StressMSELoss` | `(B, 3, 3)` | squared (Frobenius) | `stress` / `predicted_stress` | `ignore_nonfinite` |

**Composition sugar:**

```python
loss_fn = 1.0 * EnergyMSELoss() + 10.0 * ForceMSELoss() + 0.1 * StressMSELoss()
out = loss_fn(predictions, targets, step=step, epoch=epoch, batch=batch)
out["total_loss"].backward()
```

**Graph metadata:** losses that need graph structure (`per_atom=True`,
`normalize_by_atom_count=True`, or padded layouts) accept `batch=`
(pulls `batch_idx`, `num_graphs`, `num_nodes_per_graph` automatically)
or explicit kwargs.

---

## Template-method pattern

`BaseLossFunction.forward` orchestrates five hooks:

```
forward(pred, target, **kwargs)
1. validate(pred, target) # shape checks
2. pred, target, ctx = normalize(pred, target, **kwargs) # pre-processing
3. valid = mask(pred, target, ctx, **kwargs) # boolean validity mask
4. residual = compute_residual(pred, target, valid) # ABSTRACT — must override
5. scalar = reduce(residual, valid, ctx, **kwargs) # collapse to scalar
```

**Minimum implementation:** override `compute_residual` only. Defaults
handle shape validation, all-True masking, and validity-weighted mean reduction.

---

## Writing a custom loss

### Minimal: compute_residual only

```python
class HuberEnergyLoss(BaseLossFunction):
def __init__(self, *, target_key="energy", prediction_key="predicted_energy", delta=1.0):
super().__init__()
self.target_key = target_key
self.prediction_key = prediction_key
self.delta = delta

def compute_residual(self, pred, target, valid):
residual = torch.where(valid, pred - target, torch.zeros_like(pred))
abs_r = residual.abs()
return torch.where(
abs_r < self.delta,
0.5 * residual.pow(2),
self.delta * (abs_r - 0.5 * self.delta),
)
```

### Per-atom normalization (normalize override)

Override `normalize` to divide by atom counts and pass weights via
`ReductionContext["weights"]`. The base `reduce` picks up weights
automatically.

```python
class PerAtomEnergyMSE(BaseLossFunction):
target_key = "energy"
prediction_key = "predicted_energy"

def normalize(self, pred, target, **kwargs):
ctx = ReductionContext()
counts = kwargs["num_nodes_per_graph"].to(dtype=pred.dtype).unsqueeze(-1).clamp_min(1.0)
ctx["weights"] = counts # base reduce uses this for atom-count-weighted mean
return pred / counts, target / counts, ctx

def compute_residual(self, pred, target, valid):
residual = torch.where(valid, pred - target, torch.zeros_like(pred))
return residual.pow(2)
```

### Custom masking (mask override)

Override `mask` to exclude non-finite targets, padding, or other invalid entries.
Return a boolean tensor broadcast-compatible with `pred`/`target`.

```python
def mask(self, pred, target, ctx, **kwargs):
if self.ignore_nonfinite:
return torch.isfinite(target)
return torch.ones_like(target, dtype=torch.bool)
```

For padded force layouts `(B, V_max, 3)`, combine a node mask with nonfinite check:

```python
def mask(self, pred, target, ctx, **kwargs):
num_nodes_per_graph = kwargs.get("num_nodes_per_graph")
node_mask = _padded_node_mask(num_nodes_per_graph, pred, pred.shape[1])
valid = node_mask.unsqueeze(-1).expand_as(pred)
if self.ignore_nonfinite:
valid = valid & torch.isfinite(target)
return valid
```

The `valid` tensor flows into `compute_residual` as the third argument.
Zero invalid entries with `torch.where(valid, ..., torch.zeros_like(...))`.

### Custom reduction (reduce override)

Override `reduce` for graph-balanced or other non-mean reductions.
Populate `self.per_sample_loss` with a detached `(B,)` tensor for diagnostics.

```python
from nvalchemi.training.losses.reductions import per_graph_sum

def reduce(self, residual, valid, ctx, **kwargs):
batch_idx = kwargs["batch_idx"]
num_graphs = kwargs["num_graphs"]
valid_f = valid.to(dtype=residual.dtype)
# Per-atom SE summed over xyz, then per-graph mean, then mean over graphs
per_atom_se = residual.sum(dim=-1)
per_atom_valid = valid_f.sum(dim=-1)
per_graph_num = per_graph_sum(per_atom_se, batch_idx, num_graphs)
per_graph_den = per_graph_sum(per_atom_valid, batch_idx, num_graphs)
per_sample = per_graph_num / per_graph_den.clamp_min(1.0)
self.per_sample_loss = per_sample.detach()
return per_sample.mean()
```

### Layout dispatch with plum (dense vs padded forces)

`ForceMSELoss` and `ForceL2NormLoss` use `plum-dispatch` to handle both
dense `(V, 3)` and padded `(B, V_max, 3)` layouts without `if/else` on
`ndim`. Their `mask` and `reduce` hooks delegate to `@overload`/`@dispatch`
helper methods — one overload per layout. See these implementations in
`nvalchemi/training/losses/terms.py` as the reference pattern for
multi-layout losses.

```python
from plum import dispatch, overload

@overload
def _my_helper(self, pred: Forces, target: Forces, ...):
"""Dense (V, 3) path."""
...

@overload
def _my_helper(self, pred: _PaddedForces, target: _PaddedForces, ...):
"""Padded (B, V_max, 3) path."""
...

@dispatch
def _my_helper(self, pred, target, ...):
pass # plum routes to matching overload at runtime
```

---

## Conventions

1. **Define `target_key` and `prediction_key`** on any loss that participates
in `ComposedLossFunction` — these route tensors from the prediction/target
mappings.
2. **Accept `**kwargs`** in hooks that receive them — `ComposedLossFunction`
forwards metadata kwargs to every component.
3. **`compute_residual` must zero invalid entries** using the `valid` mask
argument — the base `reduce` handles weighting but not masking.
4. **`ReductionContext`** is a `dict` subclass (not TypedDict) for
`torch.compile` compatibility. Conventional key: `"weights"` for
atom-count weights consumed by the base `reduce`.

---

## Key files

| File | Contents |
|---|---|
| `nvalchemi/training/losses/composition.py` | `BaseLossFunction`, `ComposedLossFunction`, `ReductionContext` |
| `nvalchemi/training/losses/terms.py` | All 5 built-in leaf losses |
| `nvalchemi/training/losses/reductions.py` | `per_graph_sum`, `per_graph_mean`, `frobenius_mse` |
| `nvalchemi/training/losses/schedules.py` | `ConstantWeight`, `LinearWeight`, `CosineWeight`, `PiecewiseWeight` |
| `nvalchemi/training/losses/base.py` | `LossWeightSchedule` protocol, re-exports |
| `test/training/test_losses.py` | Comprehensive tests for all loss terms |
| `docs/userguide/losses.md` | Full user guide with examples |
Loading