Skip to content

Model Training Example with MACE#109

Draft
ys-teh wants to merge 193 commits into
NVIDIA:mainfrom
ys-teh:feature/mace-training-ex
Draft

Model Training Example with MACE#109
ys-teh wants to merge 193 commits into
NVIDIA:mainfrom
ys-teh:feature/mace-training-ex

Conversation

@ys-teh

@ys-teh ys-teh commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

ALCHEMI Toolkit Pull Request

Description

This PR adds an advanced training example for a charged MACE model and the supporting code modifications needed to train it with available ALCHEMI tools.

Note: This cannot be merged until training-epic is merged.

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Performance improvement
  • Documentation update
  • Refactoring (no functional changes)
  • CI/CD or infrastructure change

Related Issues

Changes Made

  • Adds per_component_unweighted loss to composition.py and strategy.py so that it can be tracked and logged. This is especially useful when the component loss weights are varied during model training.
  • Implements a core cache invalidation fix for Ewald and PME cell caches.
    • Added _cell_cache_needs_update() to both EwaldModelWrapper and PMEModelWrapper.
    • The helper now treats missing cache, shape mismatch, device mismatch, dtype mismatch, or changed cell values as stale cache conditions.
    • Updated the forward paths to call this helper before recomputing Ewald/PME cache state.
    • Adds regression tests for both wrappers covering: missing cached cell, identical cached cell reuse, train/validation batch-size shape mismatch, same-shape changed cell values, dtype mismatch
  • Adds transforms function to dataset (NEEDS IMPROVEMENT)
    • Adds a reusable AtomicData transform pipeline for datapipes.
    • Introduces a common Transform base with call, to(), set_epoch(), state serialization hooks, and repr support.
    • Adds Compose so callers can pass either one transform or an ordered sequence of transforms.
    • Adds built-in transforms for: (1) dtype casting of floating tensors while preserving integer graph fields, (2) per-field scaling for labels, unit conversion, or target normalization
    • Adds tests for invalid transforms, ordered composition, dtype preservation, field scaling, prefetch behavior, and epoch forwarding.
    • Caveat: Currently not compatible with skip_validation=True since this skips the creation of AtomicData. Potential solution: directly acts on the raw dicts.
  • Adds periodic checkpoint support for runtime hook state. (NEEDS IMPROVEMENT)
    • EMAHook now restores averaged weights and num_updates after restart.
    • Added TrainingStrategy.restore_checkpoint(path) for restoring into an already-built strategy.
    • Periodic checkpoints now work with complex wrappers like MACE without requiring model specs.
    • Existing TrainingStrategy.load_checkpoint(path) behavior is preserved for self-contained/spec-based checkpoints.
    • Potential caveat: periodic saves use include_model_specs=False.
    • Adds tests for live strategy restore, periodic checkpoint restore, EMA restore, and DDP checkpointing.
  • Adds a model training example script using a charged MACE model.
    • Adds an advanced MACE training example using nvalchemi data pipes, TrainingStrategy, composed losses, NeighborListHook, DDP, EMA, checkpointing, and distributed validation.
    • Adds _mace_training_helpers.py with MACE-style Huber losses for energy/forces/stress, optional charge MSE, staged loss weighting, stress unit conversion, training loss logging, validation, and parameter counting.
    • Adds _mace_models.py with builders for standard MACE and charged MACE models, including cuEquivariance config support.
  • Adds charge MACE training docs and user guide. (WIP)

Testing

  • Unit tests pass locally (make pytest)
  • Linting passes (make lint)
  • New tests added for new functionality meets coverage expectations?

Checklist

  • I have read and understand the Contributing Guidelines
  • I have updated the CHANGELOG.md
  • I have performed a self-review of my code
  • I have added docstrings to new functions/classes
  • I have updated the documentation (if applicable)

Additional Notes

Tip

This repository uses Greptile, an AI code review service, to help conduct
pull request reviews. We encourage contributors to read and consider suggestions
made by Greptile, but note that human maintainers will provide the necessary
reviews for merging: Greptile's comments are not a qualitative judgement
of your code, nor is it an indication that the PR will be accepted/rejected.
We encourage the use of emoji reactions to Greptile comments, depending on
their usefulness and accuracy.

laserkelvin and others added 30 commits April 21, 2026 20:14
Introduces the nvalchemi.training subpackage with a TrainingStage enum
whose BEFORE_*/AFTER_* members parallel DynamicsStage. Lays the
foundation for training hooks without committing to any dispatch host
(TrainingStrategy lands in a later feature).
Adds pytest coverage for the TrainingStage enum shape (14 members,
BEFORE_*/AFTER_* uniformity) and for HookRegistryMixin behaviour on
TrainingStage-typed hosts: registration and dispatch succeed, foreign
DynamicsStage hooks are rejected, and the _runs_on_stage bypass
continues to permit cross-category hooks. Mirrors the existing
test/hooks/test_registry.py style.
Adds a minimal docs entry under docs/modules/training/ that renders
TrainingStage via autoclass and cross-links the user guide, core hook
framework, and sibling dynamics hooks. Wires the new page into the
modules toctree. Mirrors the narrative + seealso pattern used by
docs/modules/hooks.rst.
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
feat(training): add TrainingStage enum and training module skeleton
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
…pec JSON diff

The signature hash used repr() of default values including floats, which produces
fragile false positives across float representation quirks. The value it added
(an early warning signal on __init__ signature drift) is not worth the noise —
the spec JSON consistency check in save_checkpoint already catches kwarg changes,
and build() now wraps TypeError with the spec's cls_path and timestamp for
actionable diagnostics when a signature genuinely changed.

Addresses PR NVIDIA#2 review feedback from R. Zubatyuk.
…oint

- map_location is forwarded to torch.load() for models, optimizers, and
  schedulers, and applied via model.to() post-load to ensure live modules
  sit on the target device.
- model_name filters the load to a single model plus its associated
  optimizers/schedulers (via manifest associations). Raises KeyError with
  the sorted available names when unknown.

Addresses PR NVIDIA#2 review feedback from A. Thakur (map_location) and
R. Zubatyuk (model_name).
…izer probe

When __init__ has no type annotation for a parameter whose value is a
serialized custom type (e.g. nn.Linear(..., dtype=torch.float32)), the
rehydrated spec previously carried a str instead of the typed value.
_resolve_annotation would infer type(value) == str at rehydrate time,
skip the registered BeforeValidator, and spec.build() would then hand
the string to torch.empty(), raising an opaque TypeError.

create_model_spec_from_json now probes registered deserializers on
str/dict values before passing them to create_model_spec. The first
deserializer that accepts the value wins; otherwise the raw value
passes through unchanged. This covers nn.Linear, nn.Conv2d, nn.LayerNorm,
and other core PyTorch modules that leave dtype/device unannotated.

Also promotes the existing xfail-strict test to asserting the fixed
behaviour.

Addresses PR NVIDIA#2 review feedback from A. Thakur (torch.float32 string
handling).
Adds eight regression tests covering PR NVIDIA#2 reviewer concerns:

- TestAssociations::test_scheduler_attaches_to_second_optimizer
  Proves that with two optimizers, the scheduler round-trips onto the
  optimizer it was saved with (responds to A. Thakur's A1).

- TestDtypeRoundtrip::test_dtype_kwarg_roundtrip
  Full checkpoint round-trip for nn.Linear(dtype=torch.float32), which
  exercises the dtype rehydration fix in the prior commit (A. Thakur A3).

- TestEMACheckpoint::test_ema_base_model_roundtrip
  Documents the best-practice pattern for saving and restoring an
  swa_utils.AveragedModel: save the base model and the inner averaged
  model as separate entries, reconstruct the wrapper in user code.

- TestLoadCheckpointKwargs (five tests): map_location='cpu', CUDA
  variant (skipif), model_name= single-model load, model_name= with
  associated optimizer, unknown model_name= raises KeyError.

Found by glad-vale.
…reductions

Establish the Pydantic+BaseSpec-compatible loss abstraction so downstream
TrainingStrategy work can plug a loss_fn in. Introduces an abstract
BaseLossFunction whose __call__ applies a weight schedule to a subclass-
provided compute(), four serializable weight schedules discriminated on
a type literal, and scatter-based graph-aware reduction primitives
structured for a future swap to nvalchemi.math.segment_ops.
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
… fix device transfer order

- Rename model_name:str to model_names:Iterable[str] so callers can
  select a subset of models in one load call; None loads everything.
- Validate all names upfront against manifest.models and raise KeyError
  listing every unknown name together with available names.
- Collapse the single-vs-multi code paths into one filter-then-iterate
  pipeline driven by a set membership check.
- Load weights directly onto map_location: call model.to(map_location)
  before load_state_dict so the copy is device-local and avoids a
  redundant GPU->CPU->GPU transfer per parameter.
…eduction hot path

Rework graph-aware reductions to avoid blocking host-device syncs and
redundant allocations on the per-step training path. When num_graphs is
supplied (the common hot-path case), skip batch_idx min/max scans and
the associated .item() sync entirely. Replace the torch.ones(V) +
scatter node-count path with torch.bincount(batch_idx, minlength=...).
Collapse repeated shape validation and batch-index resolution into a
single _prep_reduction helper. Make empty-batch errors actionable by
naming num_graphs as the recovery handle, and document the CUDA-long
batch_idx contract in the module docstring.
…nsolidation

Shift loss-weight scheduling from a Pydantic discriminated-union to a
protocol-based design so loss round-trip is not coupled to the schedule
registry. BaseLossFunction.weight is now typed as SerializeAsAny over
the LossWeightSchedule protocol, and concrete schedules drop their
schedule_type Literal field. Losses serialize with an identity
ConstantWeight by default; upstream TrainingStrategy will reconstruct
per-loss schedules from their (instance, spec) pairs, matching the
existing model/optimizer/checkpoint convention.

Alongside the type change: factor a _RampSchedule base for the shared
linear/cosine ramp window, add _is_spec_dict / _is_spec_dict_sequence /
_build_sequence_of_specs helpers in _spec.py to centralize nested spec
classification across build and JSON rehydration paths, re-export
BaseLossFunction/ComposedLossFunction from base.py for discoverability,
consolidate arithmetic and composition semantics in a single module
docstring, raise an actionable ValueError when a per_epoch=True
schedule sees ctx.epoch is None, and wrap weight-call failures with a
message naming the expected (step, epoch) -> float contract. Update
affected tests and clarify the package docstring to call out that only
components and static weights round-trip through BaseSpec.
…tion

feat(training): add BaseSpec, checkpoint system with multi-model and optimizer/scheduler support
Loss functions now take a Batch (plus keyword-only step/epoch) instead of
a HookContext, decoupling the losses module from nvalchemi.hooks. Callers
pass targets and predictions on the batch via standard attribute access,
which matches how TrainingStrategy will wire predictions into the batch
before invoking the loss.
laserkelvin and others added 29 commits June 3, 2026 21:54
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
…ion cap

Address review findings from multi-agent code review:

- P1-1: True double-buffered _iter_prefetch using deque[Future] queue
  so next Zarr read overlaps consumer batch processing
- P1-2: Extract _merge_physical_runs helper with max_amplification=8
  cap to bound gap-merge read amplification
- P1-3: _DefaultRoundtripGroup preserves backward compat for
  nvalchemi-io-test CLI (auto-inserts 'roundtrip' subcommand)
- P2-4: Factor _build_batch_storage shared by from_data_list and
  from_raw_dicts, eliminating batch-construction duplication
- P2-5: Unify _MegaPrefetchResult and load functions into single
  _load_mega dispatching on skip_validation
- P2-6: Batch-then-move device transfer (one .to() per field per
  batch instead of per sample)
- P2-7: Precompute per-position atom/edge pointer offsets once per
  run before field loop
- P2-8: Custom keys preserved as system-level in from_raw_dicts
- P2-10: read PATH validates dir-only via file_okay=False
- P3-11: Tests for read_many([]), single-element, mega error
  propagation, custom key roundtrip
- P3-14: Improved get_mega_batches error message
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Batch.from_raw_dicts misclassified custom per-atom/per-edge Zarr fields
as system-level because it only checked AtomicData._default_*_keys.
UniformLevelStorage then crashed on variable-length tensors.

Add Reader.field_levels property (empty default, overridden by
AtomicDataZarrReader to expose store metadata).  Dataset caches it
at init and forwards to from_raw_dicts.  _build_batch_storage now
checks field_levels before the system-level fallback.
…roundtrip

- Batch: atom-level, edge-level, and fallback classification via field_levels
- Zarr: custom atom/edge fields survive skip_validation Dataset path
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
…rofiling

Improve Zarr I/O profiling and batched data loading
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants