Skip to content

Add correction inference metrics and optional correction netCDF output#1284

Open
jpdunc23 wants to merge 3 commits into
feature/step-outputfrom
feature/correction-metrics
Open

Add correction inference metrics and optional correction netCDF output#1284
jpdunc23 wants to merge 3 commits into
feature/step-outputfrom
feature/correction-metrics

Conversation

@jpdunc23

@jpdunc23 jpdunc23 commented Jun 16, 2026

Copy link
Copy Markdown
Member

Steppers can apply correctors (e.g. ocean heat content, global dryness) that adjust the network's raw output each step, but today there is no way to see how much a trained model relies on its corrector. This PR adds normalized-space metrics of the correction (output − uncorrected) to the inference aggregators and an optional denormalized correction netCDF, on by default behind aggregator config flags. It builds on the StepOutput.uncorrected plumbing from #1283 (PR1, base branch) and additionally carries that shadow up to the aggregators on a time-aware diagnostics carriage. The uncorrected/-prefixed error metrics from #1222 are intentionally dropped in favor of metrics of the correction itself.

Per review on #1283, the consumer-facing plumbing that surfaces the pre-correction shadow on prediction data was moved out of that PR into this one, and reworked into an opaque, encapsulated carriage (rather than a raw uncorrected field) so it cannot be silently dropped by BatchData/PairedData methods.

Changes:

  • fme.ace.data_loading.step_diagnostics.StepDiagnostics (new) — an opaque, time-aware carriage holding the per-step uncorrected series (today the only diagnostic). It validates its leading sample dim and exposes the structure-preserving operations (device movement, time slicing, ensemble broadcast, spatial scatter, pin-memory) so the container can be forwarded and time-sliced alongside data without BatchData inspecting its contents.
  • fme.ace.data_loading.batch_dataBatchData.step_diagnostics / PairedData.step_diagnostics optional fields (following the data_mask/stepper_state precedent), threaded through every structure-preserving method via the container protocol and time-sliced by the time-touching methods, so compute_derived_variables and friends cannot silently drop or misalign it.
  • fme.ace.stepper.single_module.Stepper.predict — stacks the per-step StepOutput.uncorrected shadows into a forward-step-aligned series, wraps it in StepDiagnostics, and attaches it to the final prediction after the IC-windowing pipeline.
  • fme.ace.aggregator.inference.correction (new) — CorrectionTimeMeanAggregator (time_mean_norm/correction_magnitude/{var} area-weighted global mean of the time-mean of |normalized correction|, a channel_mean over corrected variables only, and the signed time_mean_norm/correction_map/{var} image), CorrectionMeanAggregator (mean_norm/weighted_correction_magnitude/{var} and mean_norm/weighted_correction_std/{var} per-step series), and a shared CorrectionRecorder. Corrections are normalize(output) − normalize(uncorrected) per corrected key using the network normalizer.
  • fme.ace.aggregator.inference.mainlog_correction_metrics: bool = True on both InferenceEvaluatorAggregatorConfig and InferenceAggregatorConfig; the recorder is wired into both aggregators' record/summary/time-series/diagnostics paths. Correction metrics merge into the existing time_mean_norm / mean_norm groups (and are the only members of those groups for no-target inference). The no-target aggregator now accepts the stepper's network normalizer; correction metrics are skipped when no normalizer is supplied (backward compatible).
  • fme.ace.inference.inference — passes stepper.normalizer.normalize into the no-target aggregator build.
  • fme.ace.inference.data_writer.mainsave_correction_files: bool = False on DataWriterConfig writes autoregressive_corrections.nc with the denormalized correction time series for the sparse corrected variables, respecting the save-names subset and time-coarsening, via a single-source writer in PairedDataWriter.

Availability: time-mean metrics in all inference types; time-series metrics only in standalone evaluator and no-target inference (inline training drops them via the existing enable_time_series path). Defaults preserve current behavior except the newly logged metrics for corrector-equipped runs; correction-file saving is off by default.

  • Tests added
  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Tests: test_batch_data parametrized tests asserting step_diagnostics survives every structure-preserving method (including compute_derived_variables) on both BatchData and PairedData, the absent/None pass-through, post-init sample-dim validation, scatter, and pin-memory; aggregator unit tests asserting exact magnitude/std/map/channel_mean values for a constant-offset correction plus the flag-off and no-corrector silence paths; writer tests for the sparse denormalized file including time-coarsening and the empty-without-corrector case; config validation/defaults tests; and an end-to-end train+inference test (test_train_and_inference_correction_metrics) asserting time-mean correction metrics appear in inline training inference with the per-step series dropped, the per-step series appear in standalone inference, and autoregressive_corrections.nc is written.

Stacked on #1283 (PR1) — base branch is feature/step-output; retarget to main once that merges.

Resolves #1272

🤖 Generated with Claude Code

Adds normalized-space metrics of the corrector's correction
(output - uncorrected) to the inference aggregators, plus an optional
denormalized correction netCDF, on by default behind aggregator config flags.
This is PR 2 of the 3-PR split of #1218/#1222 and builds on the StepOutput
plumbing from #1271.

Carriage: the pre-correction ``uncorrected`` series is carried from
Stepper.predict to the consumers through a new opaque, time-aware container,
StepDiagnostics (fme/ace/data_loading/step_diagnostics.py), instead of a raw
public field on BatchData. It follows the StepperState encapsulation pattern:
BatchData and PairedData each hold a single opaque step_diagnostics field
(default None) and never inspect its contents. Unlike StepperState (terminal
per-sample state), the payload is a per-timestep diagnostic series, so the
container is time-aware: it is forwarded by reference through every
structure-preserving method and time-sliced/padded alongside data by the
time-touching ones (select_time_slice, remove_initial_condition,
get_start/get_end, prepend), scattered by scatter_spatial, broadcast by
broadcast_ensemble, and moved by to_device/to_cpu/pin_memory. __post_init__
validates its leading sample dim like stepper_state. This fixes the
silently-dropped path the reviewer flagged on #1283: compute_derived_variables
and PairedData.from_batch_data now preserve the series, so it survives the real
inference loop. Stepper.predict builds the container from the stacked per-step
StepOutput shadows and attaches it to the prediction. The correction aggregator
and netCDF writer reach into it via the single get_uncorrected accessor.

Metrics (computed as normalize(output) - normalize(uncorrected) per corrected
key, using the network normalizer the existing *_norm metrics use):

- inference/time_mean_norm/correction_magnitude/{var}: area-weighted global mean
  of the time-mean of |normalized correction|, plus a channel_mean over the
  corrected variables only.
- inference/time_mean_norm/correction_map/{var}: signed time-mean map, logged as
  an image and flushed to time_mean_norm_correction_diagnostics.nc.
- inference/mean_norm/weighted_correction_magnitude/{var}: per-step area-weighted
  global mean of |normalized correction|.
- inference/mean_norm/weighted_correction_std/{var}: per-step area-weighted
  spatial std of the signed normalized correction (mirrors weighted_std_gen).

These live in a new fme/ace/aggregator/inference/correction.py with dedicated
CorrectionTimeMeanAggregator / CorrectionMeanAggregator and a CorrectionRecorder
shared by both inference aggregators. They are kept in a separate group merged
into the existing time_mean_norm / mean_norm label groups, so the time-series
table uses a distinct "correction_series" key that to_inference_logs resolves to
the same prefix without colliding with the main series table.

Availability and gating:

- Time-mean metrics in all inference types; time-series metrics only in
  standalone evaluator and no-target inference (inline training drops them via
  the existing enable_time_series path).
- The no-target inference aggregator now receives the stepper's network
  normalizer (plumbed through InferenceAggregatorConfig.build and the inference
  job), introducing mean_norm / time_mean_norm groups there containing only
  correction metrics. Correction metrics are skipped when no normalizer is
  available, preserving backward compatibility for callers that omit it.
- log_correction_metrics: bool = True on both the evaluator and no-target
  aggregator configs. No effect when the stepper has no corrector: the
  container's uncorrected mapping is empty and the correction aggregators stay
  silent.

Disk output:

- save_correction_files: bool = False on DataWriterConfig writes
  autoregressive_corrections.nc with the denormalized correction time series
  (output - uncorrected, physical units, with variable metadata) for the sparse
  corrected variables, respecting the save-names subset and time-coarsening, via
  a single-source RawDataWriter in PairedDataWriter.

The uncorrected/-prefixed error metrics from #1222 are intentionally dropped.

Adds a shared parametrized round-trip test asserting the container survives (and
stays time-aligned through) every structure-preserving method on BatchData and
PairedData, so a future method that forgets to thread it fails CI; aggregator
unit tests asserting exact magnitude/std/map/channel_mean values for a
constant-offset correction and the flag-off/no-corrector silence paths; writer
tests for the sparse denormalized file (incl. time-coarsening); config
validation/defaults tests; and an end-to-end train+inference test asserting
time-mean correction metrics on the inline inference-loop path (series dropped),
per-step series in standalone inference, and the corrections netCDF.

Part of #1272 (PR 2 of the #1218/#1222 split).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
jpdunc23 added a commit that referenced this pull request Jun 17, 2026
…ection values

StepABC.step now returns a StepOutput(output, stepper_state, uncorrected)
dataclass instead of a tuple[TensorDict, StepperState | None]. The new
uncorrected field is a sparse, detached snapshot of the pre-correction values
of exactly the variables a corrector modified, so downstream features can
derive the correction (output - uncorrected) or use the raw pre-correction
values without re-running the model. It is an empty dict when no corrector ran,
so consumers need no None checks. stepper_state keeps its existing
passthrough semantics.

step_with_adjustments captures the shadow at the corrector boundary via a new
captured_before helper (tensor-identity detection of out-of-place edits),
detaching unconditionally; ocean and prescribed-prognostic adjustments run
after the corrector and are intentionally excluded. The corrector ABC is left
unchanged. All step implementations (single/secondary/radiation/fcn3) inherit
the new return type; MultiCallStep composes its wrapped step's shadow and the
MultiCall helper returns an empty shadow. The rollout in predict_generator
always feeds the corrected output forward as state; Stepper.step applies the
name-preserving output process func to the shadow too.

This PR is pure StepOutput-through-step plumbing: the per-step
StepOutput.uncorrected is computed at the corrector boundary but discarded at
the Stepper.predict boundary, so predict returns its existing corrected-only
BatchData and no BatchData/PairedData surface changes. Carrying the
uncorrected series on the prediction is deferred to the correction-metrics PR
(#1284), which introduces an encapsulated, time-aware container for it.

Pure plumbing: no user-visible behavior change, and existing checkpoints load
unchanged. Adds step- and stepper-seam tests plus captured_before unit tests;
the spatial-parallel step regression matrix passes unchanged under torchrun.

Part of #1271 (PR 1 of the #1218/#1222 split).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jpdunc23 jpdunc23 force-pushed the feature/step-output branch from d6e1d56 to cb8029d Compare June 17, 2026 23:39
@jpdunc23 jpdunc23 force-pushed the feature/correction-metrics branch from 07cb0f4 to 1a13e76 Compare June 17, 2026 23:39


@dataclasses.dataclass
class StepDiagnostics:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agent chose StepDiagnostic as opposed to @mcgibbon 's original StepMetrics, reasoning "since the payload is diagnostic input, not a metric".

from .reduced import AreaWeightedSingleTargetReducedMetric, _SeriesData, data_to_table


def compute_correction_norm(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: I think this can't distinguish between a value with a zero-correction from the corrector, and a value that isn't affected by the corrector. If true, this could easily lead to bugs like under-estimation of means by dropping zero entries, if we use this as a sentinel for variables the corrector doesn't touch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants