Model Training Example with MACE#109
Draft
ys-teh wants to merge 193 commits into
Draft
Conversation
Introduces the nvalchemi.training subpackage with a TrainingStage enum whose BEFORE_*/AFTER_* members parallel DynamicsStage. Lays the foundation for training hooks without committing to any dispatch host (TrainingStrategy lands in a later feature).
Adds pytest coverage for the TrainingStage enum shape (14 members, BEFORE_*/AFTER_* uniformity) and for HookRegistryMixin behaviour on TrainingStage-typed hosts: registration and dispatch succeed, foreign DynamicsStage hooks are rejected, and the _runs_on_stage bypass continues to permit cross-category hooks. Mirrors the existing test/hooks/test_registry.py style.
Adds a minimal docs entry under docs/modules/training/ that renders TrainingStage via autoclass and cross-links the user guide, core hook framework, and sibling dynamics hooks. Wires the new page into the modules toctree. Mirrors the narrative + seealso pattern used by docs/modules/hooks.rst.
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
…ickle hyperparameter serialization
…e with security invariants
… at package level
feat(training): add TrainingStage enum and training module skeleton
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
…pec JSON diff The signature hash used repr() of default values including floats, which produces fragile false positives across float representation quirks. The value it added (an early warning signal on __init__ signature drift) is not worth the noise — the spec JSON consistency check in save_checkpoint already catches kwarg changes, and build() now wraps TypeError with the spec's cls_path and timestamp for actionable diagnostics when a signature genuinely changed. Addresses PR NVIDIA#2 review feedback from R. Zubatyuk.
…oint - map_location is forwarded to torch.load() for models, optimizers, and schedulers, and applied via model.to() post-load to ensure live modules sit on the target device. - model_name filters the load to a single model plus its associated optimizers/schedulers (via manifest associations). Raises KeyError with the sorted available names when unknown. Addresses PR NVIDIA#2 review feedback from A. Thakur (map_location) and R. Zubatyuk (model_name).
…izer probe When __init__ has no type annotation for a parameter whose value is a serialized custom type (e.g. nn.Linear(..., dtype=torch.float32)), the rehydrated spec previously carried a str instead of the typed value. _resolve_annotation would infer type(value) == str at rehydrate time, skip the registered BeforeValidator, and spec.build() would then hand the string to torch.empty(), raising an opaque TypeError. create_model_spec_from_json now probes registered deserializers on str/dict values before passing them to create_model_spec. The first deserializer that accepts the value wins; otherwise the raw value passes through unchanged. This covers nn.Linear, nn.Conv2d, nn.LayerNorm, and other core PyTorch modules that leave dtype/device unannotated. Also promotes the existing xfail-strict test to asserting the fixed behaviour. Addresses PR NVIDIA#2 review feedback from A. Thakur (torch.float32 string handling).
Adds eight regression tests covering PR NVIDIA#2 reviewer concerns: - TestAssociations::test_scheduler_attaches_to_second_optimizer Proves that with two optimizers, the scheduler round-trips onto the optimizer it was saved with (responds to A. Thakur's A1). - TestDtypeRoundtrip::test_dtype_kwarg_roundtrip Full checkpoint round-trip for nn.Linear(dtype=torch.float32), which exercises the dtype rehydration fix in the prior commit (A. Thakur A3). - TestEMACheckpoint::test_ema_base_model_roundtrip Documents the best-practice pattern for saving and restoring an swa_utils.AveragedModel: save the base model and the inner averaged model as separate entries, reconstruct the wrapper in user code. - TestLoadCheckpointKwargs (five tests): map_location='cpu', CUDA variant (skipif), model_name= single-model load, model_name= with associated optimizer, unknown model_name= raises KeyError. Found by glad-vale.
…reductions Establish the Pydantic+BaseSpec-compatible loss abstraction so downstream TrainingStrategy work can plug a loss_fn in. Introduces an abstract BaseLossFunction whose __call__ applies a weight schedule to a subclass- provided compute(), four serializable weight schedules discriminated on a type literal, and scatter-based graph-aware reduction primitives structured for a future swap to nvalchemi.math.segment_ops.
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
… fix device transfer order - Rename model_name:str to model_names:Iterable[str] so callers can select a subset of models in one load call; None loads everything. - Validate all names upfront against manifest.models and raise KeyError listing every unknown name together with available names. - Collapse the single-vs-multi code paths into one filter-then-iterate pipeline driven by a set membership check. - Load weights directly onto map_location: call model.to(map_location) before load_state_dict so the copy is device-local and avoids a redundant GPU->CPU->GPU transfer per parameter.
…eduction hot path Rework graph-aware reductions to avoid blocking host-device syncs and redundant allocations on the per-step training path. When num_graphs is supplied (the common hot-path case), skip batch_idx min/max scans and the associated .item() sync entirely. Replace the torch.ones(V) + scatter node-count path with torch.bincount(batch_idx, minlength=...). Collapse repeated shape validation and batch-index resolution into a single _prep_reduction helper. Make empty-batch errors actionable by naming num_graphs as the recovery handle, and document the CUDA-long batch_idx contract in the module docstring.
…nsolidation Shift loss-weight scheduling from a Pydantic discriminated-union to a protocol-based design so loss round-trip is not coupled to the schedule registry. BaseLossFunction.weight is now typed as SerializeAsAny over the LossWeightSchedule protocol, and concrete schedules drop their schedule_type Literal field. Losses serialize with an identity ConstantWeight by default; upstream TrainingStrategy will reconstruct per-loss schedules from their (instance, spec) pairs, matching the existing model/optimizer/checkpoint convention. Alongside the type change: factor a _RampSchedule base for the shared linear/cosine ramp window, add _is_spec_dict / _is_spec_dict_sequence / _build_sequence_of_specs helpers in _spec.py to centralize nested spec classification across build and JSON rehydration paths, re-export BaseLossFunction/ComposedLossFunction from base.py for discoverability, consolidate arithmetic and composition semantics in a single module docstring, raise an actionable ValueError when a per_epoch=True schedule sees ctx.epoch is None, and wrap weight-call failures with a message naming the expected (step, epoch) -> float contract. Update affected tests and clarify the package docstring to call out that only components and static weights round-trip through BaseSpec.
…tion feat(training): add BaseSpec, checkpoint system with multi-model and optimizer/scheduler support
Loss functions now take a Batch (plus keyword-only step/epoch) instead of a HookContext, decoupling the losses module from nvalchemi.hooks. Callers pass targets and predictions on the batch via standard attribute access, which matches how TrainingStrategy will wire predictions into the batch before invoking the loss.
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
…ion cap Address review findings from multi-agent code review: - P1-1: True double-buffered _iter_prefetch using deque[Future] queue so next Zarr read overlaps consumer batch processing - P1-2: Extract _merge_physical_runs helper with max_amplification=8 cap to bound gap-merge read amplification - P1-3: _DefaultRoundtripGroup preserves backward compat for nvalchemi-io-test CLI (auto-inserts 'roundtrip' subcommand) - P2-4: Factor _build_batch_storage shared by from_data_list and from_raw_dicts, eliminating batch-construction duplication - P2-5: Unify _MegaPrefetchResult and load functions into single _load_mega dispatching on skip_validation - P2-6: Batch-then-move device transfer (one .to() per field per batch instead of per sample) - P2-7: Precompute per-position atom/edge pointer offsets once per run before field loop - P2-8: Custom keys preserved as system-level in from_raw_dicts - P2-10: read PATH validates dir-only via file_okay=False - P3-11: Tests for read_many([]), single-element, mega error propagation, custom key roundtrip - P3-14: Improved get_mega_batches error message
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Batch.from_raw_dicts misclassified custom per-atom/per-edge Zarr fields as system-level because it only checked AtomicData._default_*_keys. UniformLevelStorage then crashed on variable-length tensors. Add Reader.field_levels property (empty default, overridden by AtomicDataZarrReader to expose store metadata). Dataset caches it at init and forwards to from_raw_dicts. _build_batch_storage now checks field_levels before the system-level fallback.
…roundtrip - Batch: atom-level, edge-level, and fallback classification via field_levels - Zarr: custom atom/edge fields survive skip_validation Dataset path
Add distributed manager DDP support
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
Signed-off-by: Kelvin Lee <kinlongkelvi@nvidia.com>
…rofiling Improve Zarr I/O profiling and batched data loading
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
Signed-off-by: Ying Shi Teh <yteh@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
ALCHEMI Toolkit Pull Request
Description
This PR adds an advanced training example for a charged MACE model and the supporting code modifications needed to train it with available ALCHEMI tools.
Note: This cannot be merged until training-epic is merged.
Type of Change
Related Issues
Changes Made
composition.pyandstrategy.pyso that it can be tracked and logged. This is especially useful when the component loss weights are varied during model training.Composeso callers can pass either one transform or an ordered sequence of transforms._mace_training_helpers.pywith MACE-style Huber losses for energy/forces/stress, optional charge MSE, staged loss weighting, stress unit conversion, training loss logging, validation, and parameter counting._mace_models.pywith builders for standard MACE and charged MACE models, including cuEquivariance config support.Testing
make pytest)make lint)Checklist
Additional Notes
Tip
This repository uses Greptile, an AI code review service, to help conduct
pull request reviews. We encourage contributors to read and consider suggestions
made by Greptile, but note that human maintainers will provide the necessary
reviews for merging: Greptile's comments are not a qualitative judgement
of your code, nor is it an indication that the PR will be accepted/rejected.
We encourage the use of emoji reactions to Greptile comments, depending on
their usefulness and accuracy.