Add eval-only global mean relaxation to SingleModuleStep#1251
Conversation
Adds a Newtonian relaxation of the global mean of selected output variables toward configured target values, applied after the network call (and any global-mean-removal inverse) and before the corrector. Relaxation is gated on the network module's training flag so it is inert during training and only active during evaluation/inference. The new SingleModuleStepConfig.global_mean_relaxation field is optional (defaults to None) so existing checkpoints load unchanged. Each variable specifies a target — either a float in physical units or the literal "mean" to use the network normalizer's mean — and a timescale in steps. Each step subtracts a uniform offset of (area_weighted_mean(x) - target) / timescale_steps from the field.
The eval-only global-mean-relaxation gate previously read ``self.module.torch_module.training`` to detect eval mode, which only worked because ``SingleModuleStepper.set_eval`` iterates the step's modules and toggles their training flag. Make the contract explicit: ``StepABC`` now exposes ``train(mode=True)`` and ``eval()`` matching the ``torch.nn.Module`` API. The methods toggle a ``_training`` flag on the step and forward to ``self.modules``. The relaxation gate reads ``self._training``. ``SingleModuleStepper.set_eval``/``set_train`` now delegate to the step's new methods instead of iterating modules directly.
yyexela
left a comment
There was a problem hiding this comment.
Looks good overall, just some comments/questions/suggestions to clarify.
| from fme.core.normalizer import StandardNormalizer | ||
| from fme.core.typing_ import TensorDict, TensorMapping | ||
|
|
||
| USE_NORMALIZATION_MEAN: Literal["mean"] = "mean" |
There was a problem hiding this comment.
Question: Is this the right name to use? The name USE_NORMALIZATION_MEAN seems like it'd be a bool more than a fixed literal string.
There was a problem hiding this comment.
Suggestion: Use NORMALIZATION_MEAN_TARGET instead
| def __init__(self) -> None: | ||
| # Mirrors ``torch.nn.Module.training`` so that step-level eval/train | ||
| # state is observable without reaching into the underlying modules. | ||
| self._training: bool = True |
There was a problem hiding this comment.
Claude: This new base-class __init__ sets _training, but MockStep in
test_step_registry.py doesn't call super().__init__(), so it won't have
that attribute. Calling .train() or .eval() on a MockStep will raise
AttributeError. Fix: add super().__init__() as the first line of
MockStep.__init__.
There was a problem hiding this comment.
Alexey: If this is right, also add a test for this
| The relaxation is applied during ``.step`` in eval mode only, after | ||
| the network call (and any global-mean-removal inverse transform) and | ||
| before the corrector. For each named variable, the step subtracts a | ||
| uniform offset of ``(area_weighted_mean(x) - target) / timescale_steps`` | ||
| from the field. |
There was a problem hiding this comment.
Issue: Missing integration test for global_mean_removal + global_mean_relaxation together
Suggestion: Make sure the ordering is tested since that's a big part of how the PR says the data should flow.
| def __init__( | ||
| self, | ||
| targets: dict[str, float], | ||
| timescales: dict[str, float], | ||
| gridded_operations: GriddedOperations, | ||
| ): | ||
| if set(targets) != set(timescales): | ||
| raise ValueError( | ||
| "targets and timescales must have the same keys: " | ||
| f"targets={list(targets)}, timescales={list(timescales)}." | ||
| ) |
There was a problem hiding this comment.
Issue: This guard is never tested.
Suggestion: Looks like class GlobalMeanRelaxation only ever gets made in GlobalMeanRelaxationConfig.build(), so this guard would never trigger since in that case targets == timescales. Either test this guard or make class GlobalMeanRelaxation private (ie. class _GlobalMeanRelaxation)
| # Newtonian relaxation toward configured target global means is an | ||
| # eval-time-only adjustment; gating on the step's training flag | ||
| # keeps it from contributing to the loss during training. | ||
| if self._global_mean_relaxation is not None and not self._training: | ||
| global_mean_relaxation: GlobalMeanRelaxation | None = ( | ||
| self._global_mean_relaxation | ||
| ) | ||
| else: | ||
| global_mean_relaxation = None |
There was a problem hiding this comment.
Nitpick: This can be simplified to:
global_mean_relaxation = self._global_mean_relaxation if not self._training else None
Adds an eval-only Newtonian relaxation of the global mean of selected output variables toward configured target values, applied between any global-mean-removal inverse transform and the corrector so the corrector operates on the relaxed field. Eval/train state on the step is made explicit via new
train()/eval()methods onStepABC, matching thetorch.nn.ModuleAPI.Changes:
fme.core.step.global_mean_relaxation: newGlobalMeanRelaxationConfig,GlobalMeanRelaxationVariableConfig, andGlobalMeanRelaxation; per variable, subtracts(area_weighted_mean(x) - target) / timescale_steps.targetmay be a float or"mean"(resolved to the network normalizer's mean at build time).fme.core.step.single_module.SingleModuleStepConfig: new optionalglobal_mean_relaxationfield, validated againstout_names. Wired intostep_with_adjustmentsvia a new kwarg, gated on the step's training flag so it is inert during training.fme.core.step.step.StepABC: newtrain(mode=True)/eval()methods and a_trainingflag, mirroringtorch.nn.Module.fme.ace.stepper.single_module.SingleModuleStepper.set_eval/set_trainnow delegate to the step's methods instead of iterating modules directly.Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated