Add correction inference metrics and optional correction netCDF output#1284
Open
jpdunc23 wants to merge 3 commits into
Open
Add correction inference metrics and optional correction netCDF output#1284jpdunc23 wants to merge 3 commits into
jpdunc23 wants to merge 3 commits into
Conversation
2 tasks
This was referenced Jun 17, 2026
Adds normalized-space metrics of the corrector's correction (output - uncorrected) to the inference aggregators, plus an optional denormalized correction netCDF, on by default behind aggregator config flags. This is PR 2 of the 3-PR split of #1218/#1222 and builds on the StepOutput plumbing from #1271. Carriage: the pre-correction ``uncorrected`` series is carried from Stepper.predict to the consumers through a new opaque, time-aware container, StepDiagnostics (fme/ace/data_loading/step_diagnostics.py), instead of a raw public field on BatchData. It follows the StepperState encapsulation pattern: BatchData and PairedData each hold a single opaque step_diagnostics field (default None) and never inspect its contents. Unlike StepperState (terminal per-sample state), the payload is a per-timestep diagnostic series, so the container is time-aware: it is forwarded by reference through every structure-preserving method and time-sliced/padded alongside data by the time-touching ones (select_time_slice, remove_initial_condition, get_start/get_end, prepend), scattered by scatter_spatial, broadcast by broadcast_ensemble, and moved by to_device/to_cpu/pin_memory. __post_init__ validates its leading sample dim like stepper_state. This fixes the silently-dropped path the reviewer flagged on #1283: compute_derived_variables and PairedData.from_batch_data now preserve the series, so it survives the real inference loop. Stepper.predict builds the container from the stacked per-step StepOutput shadows and attaches it to the prediction. The correction aggregator and netCDF writer reach into it via the single get_uncorrected accessor. Metrics (computed as normalize(output) - normalize(uncorrected) per corrected key, using the network normalizer the existing *_norm metrics use): - inference/time_mean_norm/correction_magnitude/{var}: area-weighted global mean of the time-mean of |normalized correction|, plus a channel_mean over the corrected variables only. - inference/time_mean_norm/correction_map/{var}: signed time-mean map, logged as an image and flushed to time_mean_norm_correction_diagnostics.nc. - inference/mean_norm/weighted_correction_magnitude/{var}: per-step area-weighted global mean of |normalized correction|. - inference/mean_norm/weighted_correction_std/{var}: per-step area-weighted spatial std of the signed normalized correction (mirrors weighted_std_gen). These live in a new fme/ace/aggregator/inference/correction.py with dedicated CorrectionTimeMeanAggregator / CorrectionMeanAggregator and a CorrectionRecorder shared by both inference aggregators. They are kept in a separate group merged into the existing time_mean_norm / mean_norm label groups, so the time-series table uses a distinct "correction_series" key that to_inference_logs resolves to the same prefix without colliding with the main series table. Availability and gating: - Time-mean metrics in all inference types; time-series metrics only in standalone evaluator and no-target inference (inline training drops them via the existing enable_time_series path). - The no-target inference aggregator now receives the stepper's network normalizer (plumbed through InferenceAggregatorConfig.build and the inference job), introducing mean_norm / time_mean_norm groups there containing only correction metrics. Correction metrics are skipped when no normalizer is available, preserving backward compatibility for callers that omit it. - log_correction_metrics: bool = True on both the evaluator and no-target aggregator configs. No effect when the stepper has no corrector: the container's uncorrected mapping is empty and the correction aggregators stay silent. Disk output: - save_correction_files: bool = False on DataWriterConfig writes autoregressive_corrections.nc with the denormalized correction time series (output - uncorrected, physical units, with variable metadata) for the sparse corrected variables, respecting the save-names subset and time-coarsening, via a single-source RawDataWriter in PairedDataWriter. The uncorrected/-prefixed error metrics from #1222 are intentionally dropped. Adds a shared parametrized round-trip test asserting the container survives (and stays time-aligned through) every structure-preserving method on BatchData and PairedData, so a future method that forgets to thread it fails CI; aggregator unit tests asserting exact magnitude/std/map/channel_mean values for a constant-offset correction and the flag-off/no-corrector silence paths; writer tests for the sparse denormalized file (incl. time-coarsening); config validation/defaults tests; and an end-to-end train+inference test asserting time-mean correction metrics on the inline inference-loop path (series dropped), per-step series in standalone inference, and the corrections netCDF. Part of #1272 (PR 2 of the #1218/#1222 split). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
jpdunc23
added a commit
that referenced
this pull request
Jun 17, 2026
…ection values StepABC.step now returns a StepOutput(output, stepper_state, uncorrected) dataclass instead of a tuple[TensorDict, StepperState | None]. The new uncorrected field is a sparse, detached snapshot of the pre-correction values of exactly the variables a corrector modified, so downstream features can derive the correction (output - uncorrected) or use the raw pre-correction values without re-running the model. It is an empty dict when no corrector ran, so consumers need no None checks. stepper_state keeps its existing passthrough semantics. step_with_adjustments captures the shadow at the corrector boundary via a new captured_before helper (tensor-identity detection of out-of-place edits), detaching unconditionally; ocean and prescribed-prognostic adjustments run after the corrector and are intentionally excluded. The corrector ABC is left unchanged. All step implementations (single/secondary/radiation/fcn3) inherit the new return type; MultiCallStep composes its wrapped step's shadow and the MultiCall helper returns an empty shadow. The rollout in predict_generator always feeds the corrected output forward as state; Stepper.step applies the name-preserving output process func to the shadow too. This PR is pure StepOutput-through-step plumbing: the per-step StepOutput.uncorrected is computed at the corrector boundary but discarded at the Stepper.predict boundary, so predict returns its existing corrected-only BatchData and no BatchData/PairedData surface changes. Carrying the uncorrected series on the prediction is deferred to the correction-metrics PR (#1284), which introduces an encapsulated, time-aware container for it. Pure plumbing: no user-visible behavior change, and existing checkpoints load unchanged. Adds step- and stepper-seam tests plus captured_before unit tests; the spatial-parallel step regression matrix passes unchanged under torchrun. Part of #1271 (PR 1 of the #1218/#1222 split). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
d6e1d56 to
cb8029d
Compare
07cb0f4 to
1a13e76
Compare
# Conflicts: # fme/ace/test_train.py
2 tasks
jpdunc23
commented
Jun 18, 2026
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class StepDiagnostics: |
Member
Author
There was a problem hiding this comment.
Agent chose StepDiagnostic as opposed to @mcgibbon 's original StepMetrics, reasoning "since the payload is diagnostic input, not a metric".
mcgibbon
reviewed
Jun 18, 2026
| from .reduced import AreaWeightedSingleTargetReducedMetric, _SeriesData, data_to_table | ||
|
|
||
|
|
||
| def compute_correction_norm( |
Contributor
There was a problem hiding this comment.
Issue: I think this can't distinguish between a value with a zero-correction from the corrector, and a value that isn't affected by the corrector. If true, this could easily lead to bugs like under-estimation of means by dropping zero entries, if we use this as a sentinel for variables the corrector doesn't touch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Steppers can apply correctors (e.g. ocean heat content, global dryness) that adjust the network's raw output each step, but today there is no way to see how much a trained model relies on its corrector. This PR adds normalized-space metrics of the correction (
output − uncorrected) to the inference aggregators and an optional denormalized correction netCDF, on by default behind aggregator config flags. It builds on theStepOutput.uncorrectedplumbing from #1283 (PR1, base branch) and additionally carries that shadow up to the aggregators on a time-aware diagnostics carriage. Theuncorrected/-prefixed error metrics from #1222 are intentionally dropped in favor of metrics of the correction itself.Per review on #1283, the consumer-facing plumbing that surfaces the pre-correction shadow on prediction data was moved out of that PR into this one, and reworked into an opaque, encapsulated carriage (rather than a raw
uncorrectedfield) so it cannot be silently dropped byBatchData/PairedDatamethods.Changes:
fme.ace.data_loading.step_diagnostics.StepDiagnostics(new) — an opaque, time-aware carriage holding the per-stepuncorrectedseries (today the only diagnostic). It validates its leading sample dim and exposes the structure-preserving operations (device movement, time slicing, ensemble broadcast, spatial scatter, pin-memory) so the container can be forwarded and time-sliced alongsidedatawithoutBatchDatainspecting its contents.fme.ace.data_loading.batch_data—BatchData.step_diagnostics/PairedData.step_diagnosticsoptional fields (following thedata_mask/stepper_stateprecedent), threaded through every structure-preserving method via the container protocol and time-sliced by the time-touching methods, socompute_derived_variablesand friends cannot silently drop or misalign it.fme.ace.stepper.single_module.Stepper.predict— stacks the per-stepStepOutput.uncorrectedshadows into a forward-step-aligned series, wraps it inStepDiagnostics, and attaches it to the final prediction after the IC-windowing pipeline.fme.ace.aggregator.inference.correction(new) —CorrectionTimeMeanAggregator(time_mean_norm/correction_magnitude/{var}area-weighted global mean of the time-mean of |normalized correction|, achannel_meanover corrected variables only, and the signedtime_mean_norm/correction_map/{var}image),CorrectionMeanAggregator(mean_norm/weighted_correction_magnitude/{var}andmean_norm/weighted_correction_std/{var}per-step series), and a sharedCorrectionRecorder. Corrections arenormalize(output) − normalize(uncorrected)per corrected key using the network normalizer.fme.ace.aggregator.inference.main—log_correction_metrics: bool = Trueon bothInferenceEvaluatorAggregatorConfigandInferenceAggregatorConfig; the recorder is wired into both aggregators' record/summary/time-series/diagnostics paths. Correction metrics merge into the existingtime_mean_norm/mean_normgroups (and are the only members of those groups for no-target inference). The no-target aggregator now accepts the stepper's network normalizer; correction metrics are skipped when no normalizer is supplied (backward compatible).fme.ace.inference.inference— passesstepper.normalizer.normalizeinto the no-target aggregator build.fme.ace.inference.data_writer.main—save_correction_files: bool = FalseonDataWriterConfigwritesautoregressive_corrections.ncwith the denormalized correction time series for the sparse corrected variables, respecting the save-names subset and time-coarsening, via a single-source writer inPairedDataWriter.Availability: time-mean metrics in all inference types; time-series metrics only in standalone evaluator and no-target inference (inline training drops them via the existing
enable_time_seriespath). Defaults preserve current behavior except the newly logged metrics for corrector-equipped runs; correction-file saving is off by default.Tests:
test_batch_dataparametrized tests assertingstep_diagnosticssurvives every structure-preserving method (includingcompute_derived_variables) on bothBatchDataandPairedData, the absent/None pass-through, post-init sample-dim validation, scatter, and pin-memory; aggregator unit tests asserting exact magnitude/std/map/channel_mean values for a constant-offset correction plus the flag-off and no-corrector silence paths; writer tests for the sparse denormalized file including time-coarsening and the empty-without-corrector case; config validation/defaults tests; and an end-to-end train+inference test (test_train_and_inference_correction_metrics) asserting time-mean correction metrics appear in inline training inference with the per-step series dropped, the per-step series appear in standalone inference, andautoregressive_corrections.ncis written.Stacked on #1283 (PR1) — base branch is
feature/step-output; retarget tomainonce that merges.Resolves #1272
🤖 Generated with Claude Code