Skip to content

Add eval-only global mean relaxation to SingleModuleStep#1251

Open
mcgibbon wants to merge 2 commits into
mainfrom
feature/global-mean-relaxation
Open

Add eval-only global mean relaxation to SingleModuleStep#1251
mcgibbon wants to merge 2 commits into
mainfrom
feature/global-mean-relaxation

Conversation

@mcgibbon

@mcgibbon mcgibbon commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Adds an eval-only Newtonian relaxation of the global mean of selected output variables toward configured target values, applied between any global-mean-removal inverse transform and the corrector so the corrector operates on the relaxed field. Eval/train state on the step is made explicit via new train()/eval() methods on StepABC, matching the torch.nn.Module API.

Changes:

  • fme.core.step.global_mean_relaxation: new GlobalMeanRelaxationConfig, GlobalMeanRelaxationVariableConfig, and GlobalMeanRelaxation; per variable, subtracts (area_weighted_mean(x) - target) / timescale_steps. target may be a float or "mean" (resolved to the network normalizer's mean at build time).

  • fme.core.step.single_module.SingleModuleStepConfig: new optional global_mean_relaxation field, validated against out_names. Wired into step_with_adjustments via a new kwarg, gated on the step's training flag so it is inert during training.

  • fme.core.step.step.StepABC: new train(mode=True)/eval() methods and a _training flag, mirroring torch.nn.Module. fme.ace.stepper.single_module.SingleModuleStepper.set_eval/set_train now delegate to the step's methods instead of iterating modules directly.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

mcgibbon added 2 commits June 10, 2026 01:13
Adds a Newtonian relaxation of the global mean of selected output
variables toward configured target values, applied after the network
call (and any global-mean-removal inverse) and before the corrector.
Relaxation is gated on the network module's training flag so it is
inert during training and only active during evaluation/inference.

The new SingleModuleStepConfig.global_mean_relaxation field is
optional (defaults to None) so existing checkpoints load unchanged.
Each variable specifies a target — either a float in physical units
or the literal "mean" to use the network normalizer's mean — and a
timescale in steps. Each step subtracts a uniform offset of
(area_weighted_mean(x) - target) / timescale_steps from the field.
The eval-only global-mean-relaxation gate previously read
``self.module.torch_module.training`` to detect eval mode, which only
worked because ``SingleModuleStepper.set_eval`` iterates the step's
modules and toggles their training flag. Make the contract explicit:
``StepABC`` now exposes ``train(mode=True)`` and ``eval()`` matching
the ``torch.nn.Module`` API. The methods toggle a ``_training`` flag
on the step and forward to ``self.modules``. The relaxation gate
reads ``self._training``.

``SingleModuleStepper.set_eval``/``set_train`` now delegate to the
step's new methods instead of iterating modules directly.

@yyexela yyexela left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just some comments/questions/suggestions to clarify.

from fme.core.normalizer import StandardNormalizer
from fme.core.typing_ import TensorDict, TensorMapping

USE_NORMALIZATION_MEAN: Literal["mean"] = "mean"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is this the right name to use? The name USE_NORMALIZATION_MEAN seems like it'd be a bool more than a fixed literal string.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Use NORMALIZATION_MEAN_TARGET instead

Comment thread fme/core/step/step.py
Comment on lines +238 to +241
def __init__(self) -> None:
# Mirrors ``torch.nn.Module.training`` so that step-level eval/train
# state is observable without reaching into the underlying modules.
self._training: bool = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: This new base-class __init__ sets _training, but MockStep in
test_step_registry.py doesn't call super().__init__(), so it won't have
that attribute. Calling .train() or .eval() on a MockStep will raise
AttributeError. Fix: add super().__init__() as the first line of
MockStep.__init__.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alexey: If this is right, also add a test for this

Comment on lines +49 to +53
The relaxation is applied during ``.step`` in eval mode only, after
the network call (and any global-mean-removal inverse transform) and
before the corrector. For each named variable, the step subtracts a
uniform offset of ``(area_weighted_mean(x) - target) / timescale_steps``
from the field.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Missing integration test for global_mean_removal + global_mean_relaxation together

Suggestion: Make sure the ordering is tested since that's a big part of how the PR says the data should flow.

Comment on lines +120 to +130
def __init__(
self,
targets: dict[str, float],
timescales: dict[str, float],
gridded_operations: GriddedOperations,
):
if set(targets) != set(timescales):
raise ValueError(
"targets and timescales must have the same keys: "
f"targets={list(targets)}, timescales={list(timescales)}."
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: This guard is never tested.

Suggestion: Looks like class GlobalMeanRelaxation only ever gets made in GlobalMeanRelaxationConfig.build(), so this guard would never trigger since in that case targets == timescales. Either test this guard or make class GlobalMeanRelaxation private (ie. class _GlobalMeanRelaxation)

Comment on lines +424 to +432
# Newtonian relaxation toward configured target global means is an
# eval-time-only adjustment; gating on the step's training flag
# keeps it from contributing to the loss during training.
if self._global_mean_relaxation is not None and not self._training:
global_mean_relaxation: GlobalMeanRelaxation | None = (
self._global_mean_relaxation
)
else:
global_mean_relaxation = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: This can be simplified to:

global_mean_relaxation = self._global_mean_relaxation if not self._training else None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty ugly otherwise

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants