Feature/uncorrected inference metrics [don't merge]#1222
Conversation
CorrectorABC.__call__ now returns a CorrectionResult(corrected, before) instead of a plain TensorDict. The `before` field holds the pre-correction values of exactly the variables the corrector modified, detected by tensor identity (correctors apply changes out-of-place). This is a prerequisite for computing metrics on uncorrected model outputs to quantify corrector reliance. Behavior-preserving: the single production call site in step_with_adjustments unpacks .corrected. Adds a shared captured_before helper and per-corrector tests for the before-capture semantics. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
StepABC.step now returns StepOutput(output, uncorrected) instead of a plain TensorDict. The `uncorrected` field is a sparse, detached snapshot of the pre-correction values of exactly the variables the corrector modified (sourced from CorrectionResult.before). This will let inference evaluate uncorrected outputs to quantify corrector reliance. step_with_adjustments builds the shadow at the corrector boundary; ocean and prescribed-prognostic adjustments run after and are intentionally excluded. All step implementations (single/secondary/radiation/fcn3) inherit the new return type; MultiCallStep composes its wrapped step's shadow and MultiCall unpacks .output. The rollout in predict_generator always feeds the corrected output forward; Stepper.step, TrainStepper._accumulate_loss and the coupled stepper unpack .output to preserve existing behavior. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Stepper.predict gains an opt-in return_uncorrected flag (via @overload) that additionally returns a BatchData holding the pre-correction values of corrector-modified variables, sharing the prediction's time coordinate and with no derived variables computed. predict_paired gains the same flag and carries the shadow in a new optional PairedData.uncorrected_prediction field, paired against the same reference; this rides through the generic inference Looper unchanged so existing callers and the default path are untouched. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Adds compute_uncorrected_metrics (default off) to InferenceEvaluatorAggregatorConfig. When enabled, the evaluator builds a second, time-mean-only aggregator and wraps both in InferenceEvaluatorAggregatorWithUncorrected, which records the stepper's uncorrected (pre-correction) outputs from PairedData.uncorrected_prediction and logs them under an "uncorrected/" prefix (e.g. inference/uncorrected/time_mean/rmse/<var>), with diagnostics written to an uncorrected/ subdirectory. The evaluator opts the stepper into producing the shadow via predict_paired(return_uncorrected=True). This quantifies how much the stepper relies on its corrector for skill. No effect when disabled or when the stepper has no corrector. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Flip the default on so uncorrected (pre-correction) inference metrics are computed by default. Guard the composite aggregator so the shadow is summarized and flushed only when it actually recorded a batch, since a corrector-less stepper produces an empty shadow whose aggregators would otherwise raise "No data recorded." on flush. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…inference-metrics
Resolve conflicts from main's configurable global-mean-removal / LossSchedule refactor overlapping this branch's StepOutput and uncorrected-shadow plumbing. - single_module.py train-loss loop: keep this branch's evolved loop (StepOutput capture + optimize_precorrected loss_gen) as a superset of main's renamed loop. Adopt main's LossSchedule-driven n_loss_steps and grad_context, and feed loss_gen into main's extracted _accumulate_step_loss helper. No main-only logic dropped. - single_module.py low-level step: union already auto-merged — keeps the new global_mean_removal/data_mask plumbing AND the StepOutput uncorrected capture / output-process-func handling. - test_train.py: keep both newly added tests — this branch's test_inline_inference_logs_uncorrected_metrics and main's skipif(cuda) decorator on the insolation parametrized test. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Move the prefix-aware name/prefix matching logic out of StaticSpatialMasking into a new shared NameAndPrefixMatcher in fme/core/name_and_prefix_matcher.py, exposing a single matches(name) -> bool method. The matching semantics are preserved exactly: a bare name matches the 2D variable and its 3D levels, a trailing-underscore prefix matches all levels, and an explicit name_<level> matches exactly. Spatial masking now constructs and uses NameAndPrefixMatcher in place of its private _build_regex. Pure refactor, no behavior change. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ted-inference-metrics
Replace the all-or-nothing TrainStepperConfig.optimize_precorrected bool with an optional precorrector_optimization: PreCorrectorOptimizationConfig. Presence of the config enables pre-corrector optimization; its exclude_names_and_prefixes list excludes specific variables (using the shared NameAndPrefixMatcher convention) so they are optimized against their corrected values while other corrector-modified variables stay pre-corrected. This keeps the sea-ice force-positive correction in the loss (removing a mid-latitude positive sea-ice bias) while retaining the pre-corrector temperature gains. The config owns the matcher and exposes select_precorrected; the stepper never constructs or holds it. detach_uncorrected is driven by precorrector_optimization is None. Because the corrector-modified variable set is only known at runtime, the "exclusion matched zero variables" check is a warn-once on the first training step. Rollout state and returned predictions are unchanged (always fully corrected). Migrate the ocean-train integration fixture to the new config, excluding sea_ice_fraction. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
|
||
| Returns: | ||
| The denormalized output data at the next time step. | ||
| The output at the next timestep and the pre-correction values of any |
There was a problem hiding this comment.
Let's be generic here and just say it's the output including next-timestep data, so this can't go out of date with further updates to StepOutput.
| """ | ||
|
|
||
| output: TensorDict | ||
| uncorrected: TensorDict |
There was a problem hiding this comment.
As long as you put this as corrections: TensorDict | None (or the singular correction) in your re-implementation this should be fine.
The | None is important to support steps that don't define this.
mcgibbon
left a comment
There was a problem hiding this comment.
We had a slack discussion and converged on reporting metrics on the corrections instead of the pre-corrected state, since it's hard to derive the mean effect of the corrector and the variability of that effect based on metrics like RMSE applied to pre-corrector fields. Also talked about having a Loss like object that takes in StepOutput, and being able to swap it between computing regularization on the corrections, and applying the loss to de-corrected state (would be nice to be able to configure both options).
…ection values StepABC.step now returns a StepOutput(output, stepper_state, uncorrected) dataclass instead of a tuple[TensorDict, StepperState | None]. The new uncorrected field is a sparse, detached snapshot of the pre-correction values of exactly the variables a corrector modified, so downstream features can derive the correction (output - uncorrected) or use the raw pre-correction values without re-running the model. It is an empty dict when no corrector ran, so consumers need no None checks. stepper_state keeps its existing passthrough semantics. step_with_adjustments captures the shadow at the corrector boundary via a new captured_before helper (tensor-identity detection of out-of-place edits), detaching unconditionally; ocean and prescribed-prognostic adjustments run after the corrector and are intentionally excluded. The corrector ABC is left unchanged. All step implementations (single/secondary/radiation/fcn3) inherit the new return type; MultiCallStep composes its wrapped step's shadow and the MultiCall helper returns an empty shadow. The rollout in predict_generator always feeds the corrected output forward as state; Stepper.step applies the name-preserving output process func to the shadow too. This PR is pure StepOutput-through-step plumbing: the per-step StepOutput.uncorrected is computed at the corrector boundary but discarded at the Stepper.predict boundary, so predict returns its existing corrected-only BatchData and no BatchData/PairedData surface changes. Carrying the uncorrected series on the prediction is deferred to the correction-metrics PR (#1284), which introduces an encapsulated, time-aware container for it. Pure plumbing: no user-visible behavior change, and existing checkpoints load unchanged. Adds step- and stepper-seam tests plus captured_before unit tests; the spatial-parallel step regression matrix passes unchanged under torchrun. Part of #1271 (PR 1 of the #1218/#1222 split). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Adds normalized-space metrics of the corrector's correction (output - uncorrected) to the inference aggregators, plus an optional denormalized correction netCDF, on by default behind aggregator config flags. This is PR 2 of the 3-PR split of #1218/#1222 and builds on the StepOutput plumbing from #1271. Carriage: the pre-correction ``uncorrected`` series is carried from Stepper.predict to the consumers through a new opaque, time-aware container, StepDiagnostics (fme/ace/data_loading/step_diagnostics.py), instead of a raw public field on BatchData. It follows the StepperState encapsulation pattern: BatchData and PairedData each hold a single opaque step_diagnostics field (default None) and never inspect its contents. Unlike StepperState (terminal per-sample state), the payload is a per-timestep diagnostic series, so the container is time-aware: it is forwarded by reference through every structure-preserving method and time-sliced/padded alongside data by the time-touching ones (select_time_slice, remove_initial_condition, get_start/get_end, prepend), scattered by scatter_spatial, broadcast by broadcast_ensemble, and moved by to_device/to_cpu/pin_memory. __post_init__ validates its leading sample dim like stepper_state. This fixes the silently-dropped path the reviewer flagged on #1283: compute_derived_variables and PairedData.from_batch_data now preserve the series, so it survives the real inference loop. Stepper.predict builds the container from the stacked per-step StepOutput shadows and attaches it to the prediction. The correction aggregator and netCDF writer reach into it via the single get_uncorrected accessor. Metrics (computed as normalize(output) - normalize(uncorrected) per corrected key, using the network normalizer the existing *_norm metrics use): - inference/time_mean_norm/correction_magnitude/{var}: area-weighted global mean of the time-mean of |normalized correction|, plus a channel_mean over the corrected variables only. - inference/time_mean_norm/correction_map/{var}: signed time-mean map, logged as an image and flushed to time_mean_norm_correction_diagnostics.nc. - inference/mean_norm/weighted_correction_magnitude/{var}: per-step area-weighted global mean of |normalized correction|. - inference/mean_norm/weighted_correction_std/{var}: per-step area-weighted spatial std of the signed normalized correction (mirrors weighted_std_gen). These live in a new fme/ace/aggregator/inference/correction.py with dedicated CorrectionTimeMeanAggregator / CorrectionMeanAggregator and a CorrectionRecorder shared by both inference aggregators. They are kept in a separate group merged into the existing time_mean_norm / mean_norm label groups, so the time-series table uses a distinct "correction_series" key that to_inference_logs resolves to the same prefix without colliding with the main series table. Availability and gating: - Time-mean metrics in all inference types; time-series metrics only in standalone evaluator and no-target inference (inline training drops them via the existing enable_time_series path). - The no-target inference aggregator now receives the stepper's network normalizer (plumbed through InferenceAggregatorConfig.build and the inference job), introducing mean_norm / time_mean_norm groups there containing only correction metrics. Correction metrics are skipped when no normalizer is available, preserving backward compatibility for callers that omit it. - log_correction_metrics: bool = True on both the evaluator and no-target aggregator configs. No effect when the stepper has no corrector: the container's uncorrected mapping is empty and the correction aggregators stay silent. Disk output: - save_correction_files: bool = False on DataWriterConfig writes autoregressive_corrections.nc with the denormalized correction time series (output - uncorrected, physical units, with variable metadata) for the sparse corrected variables, respecting the save-names subset and time-coarsening, via a single-source RawDataWriter in PairedDataWriter. The uncorrected/-prefixed error metrics from #1222 are intentionally dropped. Adds a shared parametrized round-trip test asserting the container survives (and stays time-aligned through) every structure-preserving method on BatchData and PairedData, so a future method that forgets to thread it fails CI; aggregator unit tests asserting exact magnitude/std/map/channel_mean values for a constant-offset correction and the flag-off/no-corrector silence paths; writer tests for the sparse denormalized file (incl. time-coarsening); config validation/defaults tests; and an end-to-end train+inference test asserting time-mean correction metrics on the inline inference-loop path (series dropped), per-step series in standalone inference, and the corrections netCDF. Part of #1272 (PR 2 of the #1218/#1222 split). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Short description of why the PR is needed and how it satisfies those requirements, in sentence form.
Changes:
symbol (e.g.
fme.core.my_function) or script and concise description of changes or added featureCan group multiple related symbols on a single bullet
Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
Resolves # (delete if none)