Skip to content

Feature/uncorrected inference metrics [don't merge]#1222

Draft
jpdunc23 wants to merge 14 commits into
mainfrom
feature/uncorrected-inference-metrics
Draft

Feature/uncorrected inference metrics [don't merge]#1222
jpdunc23 wants to merge 14 commits into
mainfrom
feature/uncorrected-inference-metrics

Conversation

@jpdunc23

@jpdunc23 jpdunc23 commented Jun 3, 2026

Copy link
Copy Markdown
Member

Short description of why the PR is needed and how it satisfies those requirements, in sentence form.

Changes:

  • symbol (e.g. fme.core.my_function) or script and concise description of changes or added feature

  • Can group multiple related symbols on a single bullet

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

jpdunc23 and others added 14 commits May 28, 2026 19:27
CorrectorABC.__call__ now returns a CorrectionResult(corrected, before)
instead of a plain TensorDict. The `before` field holds the pre-correction
values of exactly the variables the corrector modified, detected by tensor
identity (correctors apply changes out-of-place). This is a prerequisite for
computing metrics on uncorrected model outputs to quantify corrector reliance.

Behavior-preserving: the single production call site in step_with_adjustments
unpacks .corrected. Adds a shared captured_before helper and per-corrector
tests for the before-capture semantics.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
StepABC.step now returns StepOutput(output, uncorrected) instead of a plain
TensorDict. The `uncorrected` field is a sparse, detached snapshot of the
pre-correction values of exactly the variables the corrector modified (sourced
from CorrectionResult.before). This will let inference evaluate uncorrected
outputs to quantify corrector reliance.

step_with_adjustments builds the shadow at the corrector boundary; ocean and
prescribed-prognostic adjustments run after and are intentionally excluded.
All step implementations (single/secondary/radiation/fcn3) inherit the new
return type; MultiCallStep composes its wrapped step's shadow and MultiCall
unpacks .output. The rollout in predict_generator always feeds the corrected
output forward; Stepper.step, TrainStepper._accumulate_loss and the coupled
stepper unpack .output to preserve existing behavior.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Stepper.predict gains an opt-in return_uncorrected flag (via @overload) that
additionally returns a BatchData holding the pre-correction values of
corrector-modified variables, sharing the prediction's time coordinate and
with no derived variables computed. predict_paired gains the same flag and
carries the shadow in a new optional PairedData.uncorrected_prediction field,
paired against the same reference; this rides through the generic inference
Looper unchanged so existing callers and the default path are untouched.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Adds compute_uncorrected_metrics (default off) to
InferenceEvaluatorAggregatorConfig. When enabled, the evaluator builds a
second, time-mean-only aggregator and wraps both in
InferenceEvaluatorAggregatorWithUncorrected, which records the stepper's
uncorrected (pre-correction) outputs from PairedData.uncorrected_prediction and
logs them under an "uncorrected/" prefix (e.g.
inference/uncorrected/time_mean/rmse/<var>), with diagnostics written to an
uncorrected/ subdirectory. The evaluator opts the stepper into producing the
shadow via predict_paired(return_uncorrected=True). This quantifies how much
the stepper relies on its corrector for skill. No effect when disabled or when
the stepper has no corrector.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Flip the default on so uncorrected (pre-correction) inference metrics are
computed by default. Guard the composite aggregator so the shadow is
summarized and flushed only when it actually recorded a batch, since a
corrector-less stepper produces an empty shadow whose aggregators would
otherwise raise "No data recorded." on flush.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Resolve conflicts from main's configurable global-mean-removal /
LossSchedule refactor overlapping this branch's StepOutput and
uncorrected-shadow plumbing.

- single_module.py train-loss loop: keep this branch's evolved loop
  (StepOutput capture + optimize_precorrected loss_gen) as a superset of
  main's renamed loop. Adopt main's LossSchedule-driven n_loss_steps and
  grad_context, and feed loss_gen into main's extracted
  _accumulate_step_loss helper. No main-only logic dropped.
- single_module.py low-level step: union already auto-merged — keeps the
  new global_mean_removal/data_mask plumbing AND the StepOutput
  uncorrected capture / output-process-func handling.
- test_train.py: keep both newly added tests — this branch's
  test_inline_inference_logs_uncorrected_metrics and main's skipif(cuda)
  decorator on the insolation parametrized test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Move the prefix-aware name/prefix matching logic out of
StaticSpatialMasking into a new shared NameAndPrefixMatcher in
fme/core/name_and_prefix_matcher.py, exposing a single matches(name) ->
bool method. The matching semantics are preserved exactly: a bare name
matches the 2D variable and its 3D levels, a trailing-underscore prefix
matches all levels, and an explicit name_<level> matches exactly.

Spatial masking now constructs and uses NameAndPrefixMatcher in place of
its private _build_regex. Pure refactor, no behavior change.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the all-or-nothing TrainStepperConfig.optimize_precorrected bool
with an optional precorrector_optimization: PreCorrectorOptimizationConfig.
Presence of the config enables pre-corrector optimization; its
exclude_names_and_prefixes list excludes specific variables (using the
shared NameAndPrefixMatcher convention) so they are optimized against
their corrected values while other corrector-modified variables stay
pre-corrected. This keeps the sea-ice force-positive correction in the
loss (removing a mid-latitude positive sea-ice bias) while retaining the
pre-corrector temperature gains.

The config owns the matcher and exposes select_precorrected; the stepper
never constructs or holds it. detach_uncorrected is driven by
precorrector_optimization is None. Because the corrector-modified
variable set is only known at runtime, the "exclusion matched zero
variables" check is a warn-once on the first training step. Rollout state
and returned predictions are unchanged (always fully corrected).

Migrate the ocean-train integration fixture to the new config, excluding
sea_ice_fraction.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Comment thread fme/ace/step/fcn3.py

Returns:
The denormalized output data at the next time step.
The output at the next timestep and the pre-correction values of any

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be generic here and just say it's the output including next-timestep data, so this can't go out of date with further updates to StepOutput.

Comment thread fme/core/step/step.py
"""

output: TensorDict
uncorrected: TensorDict

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as you put this as corrections: TensorDict | None (or the singular correction) in your re-implementation this should be fine.

The | None is important to support steps that don't define this.

@mcgibbon mcgibbon left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had a slack discussion and converged on reporting metrics on the corrections instead of the pre-corrected state, since it's hard to derive the mean effect of the corrector and the variability of that effect based on metrics like RMSE applied to pre-corrector fields. Also talked about having a Loss like object that takes in StepOutput, and being able to swap it between computing regularization on the corrections, and applying the loss to de-corrected state (would be nice to be able to configure both options).

jpdunc23 added a commit that referenced this pull request Jun 17, 2026
…ection values

StepABC.step now returns a StepOutput(output, stepper_state, uncorrected)
dataclass instead of a tuple[TensorDict, StepperState | None]. The new
uncorrected field is a sparse, detached snapshot of the pre-correction values
of exactly the variables a corrector modified, so downstream features can
derive the correction (output - uncorrected) or use the raw pre-correction
values without re-running the model. It is an empty dict when no corrector ran,
so consumers need no None checks. stepper_state keeps its existing
passthrough semantics.

step_with_adjustments captures the shadow at the corrector boundary via a new
captured_before helper (tensor-identity detection of out-of-place edits),
detaching unconditionally; ocean and prescribed-prognostic adjustments run
after the corrector and are intentionally excluded. The corrector ABC is left
unchanged. All step implementations (single/secondary/radiation/fcn3) inherit
the new return type; MultiCallStep composes its wrapped step's shadow and the
MultiCall helper returns an empty shadow. The rollout in predict_generator
always feeds the corrected output forward as state; Stepper.step applies the
name-preserving output process func to the shadow too.

This PR is pure StepOutput-through-step plumbing: the per-step
StepOutput.uncorrected is computed at the corrector boundary but discarded at
the Stepper.predict boundary, so predict returns its existing corrected-only
BatchData and no BatchData/PairedData surface changes. Carrying the
uncorrected series on the prediction is deferred to the correction-metrics PR
(#1284), which introduces an encapsulated, time-aware container for it.

Pure plumbing: no user-visible behavior change, and existing checkpoints load
unchanged. Adds step- and stepper-seam tests plus captured_before unit tests;
the spatial-parallel step regression matrix passes unchanged under torchrun.

Part of #1271 (PR 1 of the #1218/#1222 split).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
jpdunc23 added a commit that referenced this pull request Jun 17, 2026
Adds normalized-space metrics of the corrector's correction
(output - uncorrected) to the inference aggregators, plus an optional
denormalized correction netCDF, on by default behind aggregator config flags.
This is PR 2 of the 3-PR split of #1218/#1222 and builds on the StepOutput
plumbing from #1271.

Carriage: the pre-correction ``uncorrected`` series is carried from
Stepper.predict to the consumers through a new opaque, time-aware container,
StepDiagnostics (fme/ace/data_loading/step_diagnostics.py), instead of a raw
public field on BatchData. It follows the StepperState encapsulation pattern:
BatchData and PairedData each hold a single opaque step_diagnostics field
(default None) and never inspect its contents. Unlike StepperState (terminal
per-sample state), the payload is a per-timestep diagnostic series, so the
container is time-aware: it is forwarded by reference through every
structure-preserving method and time-sliced/padded alongside data by the
time-touching ones (select_time_slice, remove_initial_condition,
get_start/get_end, prepend), scattered by scatter_spatial, broadcast by
broadcast_ensemble, and moved by to_device/to_cpu/pin_memory. __post_init__
validates its leading sample dim like stepper_state. This fixes the
silently-dropped path the reviewer flagged on #1283: compute_derived_variables
and PairedData.from_batch_data now preserve the series, so it survives the real
inference loop. Stepper.predict builds the container from the stacked per-step
StepOutput shadows and attaches it to the prediction. The correction aggregator
and netCDF writer reach into it via the single get_uncorrected accessor.

Metrics (computed as normalize(output) - normalize(uncorrected) per corrected
key, using the network normalizer the existing *_norm metrics use):

- inference/time_mean_norm/correction_magnitude/{var}: area-weighted global mean
  of the time-mean of |normalized correction|, plus a channel_mean over the
  corrected variables only.
- inference/time_mean_norm/correction_map/{var}: signed time-mean map, logged as
  an image and flushed to time_mean_norm_correction_diagnostics.nc.
- inference/mean_norm/weighted_correction_magnitude/{var}: per-step area-weighted
  global mean of |normalized correction|.
- inference/mean_norm/weighted_correction_std/{var}: per-step area-weighted
  spatial std of the signed normalized correction (mirrors weighted_std_gen).

These live in a new fme/ace/aggregator/inference/correction.py with dedicated
CorrectionTimeMeanAggregator / CorrectionMeanAggregator and a CorrectionRecorder
shared by both inference aggregators. They are kept in a separate group merged
into the existing time_mean_norm / mean_norm label groups, so the time-series
table uses a distinct "correction_series" key that to_inference_logs resolves to
the same prefix without colliding with the main series table.

Availability and gating:

- Time-mean metrics in all inference types; time-series metrics only in
  standalone evaluator and no-target inference (inline training drops them via
  the existing enable_time_series path).
- The no-target inference aggregator now receives the stepper's network
  normalizer (plumbed through InferenceAggregatorConfig.build and the inference
  job), introducing mean_norm / time_mean_norm groups there containing only
  correction metrics. Correction metrics are skipped when no normalizer is
  available, preserving backward compatibility for callers that omit it.
- log_correction_metrics: bool = True on both the evaluator and no-target
  aggregator configs. No effect when the stepper has no corrector: the
  container's uncorrected mapping is empty and the correction aggregators stay
  silent.

Disk output:

- save_correction_files: bool = False on DataWriterConfig writes
  autoregressive_corrections.nc with the denormalized correction time series
  (output - uncorrected, physical units, with variable metadata) for the sparse
  corrected variables, respecting the save-names subset and time-coarsening, via
  a single-source RawDataWriter in PairedDataWriter.

The uncorrected/-prefixed error metrics from #1222 are intentionally dropped.

Adds a shared parametrized round-trip test asserting the container survives (and
stays time-aligned through) every structure-preserving method on BatchData and
PairedData, so a future method that forgets to thread it fails CI; aggregator
unit tests asserting exact magnitude/std/map/channel_mean values for a
constant-offset correction and the flag-off/no-corrector silence paths; writer
tests for the sparse denormalized file (incl. time-coarsening); config
validation/defaults tests; and an end-to-end train+inference test asserting
time-mean correction metrics on the inline inference-loop path (series dropped),
per-step series in standalone inference, and the corrections netCDF.

Part of #1272 (PR 2 of the #1218/#1222 split).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants