Skip to content

Inline +4K SST perturbation-response evaluation during training#1310

Draft
mcgibbon wants to merge 8 commits into
mainfrom
feature/inline-perturbation-response-eval
Draft

Inline +4K SST perturbation-response evaluation during training#1310
mcgibbon wants to merge 8 commits into
mainfrom
feature/inline-perturbation-response-eval

Conversation

@mcgibbon

Copy link
Copy Markdown
Contributor

Make the +4 K uniform-SST warming response visible during training by
running an unperturbed baseline and a perturbed climate together as separate
batch members of a single inline rollout at evaluated checkpoints, differencing
their time-means, and logging the structural diagnostics the land-amplification
goal is judged on. This is a diagnostic eval only — it never contributes to
checkpoint selection — and involves no external (c96-SHiELD) reference; it
reports the model's own response structure.

This is a new inline-inference type (peer to inline inference), not a change to
the inference evaluator: the response needs a 0 K baseline climatology, so its
top-level operation differs (run + difference two climates). It reuses the
existing InferenceTask/InferenceCallback plumbing with weight=0.

Changes:

  • fme.ace.aggregator.inference.perturbation_response — new PerturbationResponseAggregator + PerturbationResponseAggregatorConfig: accumulates a per-group time-mean over baseline/perturbed batch members (one-hot group encoding, group 0 = baseline; >1 perturbation raises NotImplementedError), differences each perturbation against the baseline, and logs near-surface land/ocean warming ratio by latitude band, free-troposphere/surface warming ratio over tropical ocean, and the global-mean column warming profile. NaN-guarded ratios; cross-rank reduction via reduce_sum.

  • fme.ace.data_loading.perturbation_pair.build_perturbation_pair_data — turns an unperturbed inference IC + forcing loader into a paired baseline/perturbed rollout. Perturbs every forcing window (ocean-masked) and the prognostic IC when surface temperature is prognostic, with a configurable IC-perturbation scope (whole-field vs ocean-masked). Returns the [2*n_local, 2] one-hot group encoding.

  • fme.ace.train.train_configLabeledPerturbation, PerturbationResponseInferenceConfig, the TrainConfig.perturbation_inference field, name/label validation, epoch sets, and TrainBuilders.get_perturbation_inference_data.

  • fme.ace.train.trainget_inference_callback also builds weight-0 perturbation-response tasks; _make_perturbation_response_task does the pairing; build_trainer wiring.

  • fme.ace.stepper.single_module.TrainStepper — expose surface_temperature_name/ocean_fraction_name (delegating to the inner stepper).

  • Tests added (aggregator unit tests, paired-data unit tests, end-to-end training smoke test asserting the response log keys + diagnostics appear and that the eval does not affect checkpoint selection)

  • If dependencies changed, "deps only" image rebuilt — n/a

Design note (research repo): notes/2026-06-23-inline-perturbation-response-eval-design.md

🤖 Generated with Claude Code

mcgibbon added 3 commits June 23, 2026 22:23
Add PerturbationResponseAggregator + config: accumulates a per-group
time-mean over baseline/perturbed batch members (one-hot group encoding,
group 0 = baseline), differences each perturbation against the baseline,
and logs the land-amplification structural diagnostics — near-surface
land/ocean warming ratio by latitude band, free-troposphere/surface
warming ratio over tropical ocean, and the global-mean column warming
profile. No external reference; ratios are NaN-guarded; >1 perturbation
raises NotImplementedError.
Add build_perturbation_pair_data: turns an unperturbed inference initial
condition + forcing loader into a paired version that runs a baseline and
a perturbed climate as separate batch members of one rollout. The SST
perturbation is applied to every forcing window (ocean-masked) and to the
prognostic initial condition when surface temperature is prognostic, with
a configurable initial-condition scope (whole-field vs ocean-masked).
Returns a [2*n_local, 2] one-hot group encoding (group 0 = baseline).
Add PerturbationResponseInferenceConfig (peer to InlineInferenceConfig)
and a TrainConfig.perturbation_inference field, a builder that constructs
the unperturbed base inference data, and build_trainer wiring that pairs
it (build_perturbation_pair_data) and runs it as a weight-0 InferenceTask
through the existing inference callback. The response aggregator is keyed
by perturbation label only (the callback adds the task-name prefix).
Expose surface_temperature_name/ocean_fraction_name on TrainStepper.
End-to-end smoke test in test_train.py asserts the response log keys and
diagnostics appear and that the eval does not affect checkpoint selection.
@mcgibbon mcgibbon marked this pull request as draft June 24, 2026 00:49
mcgibbon added 5 commits June 24, 2026 00:51
test_symbols requires every nested TrainConfig dataclass to be in
fme.ace.__all__. Export LabeledPerturbation, PerturbationResponseInferenceConfig,
PerturbationResponseAggregatorConfig, and LatitudeBand.
- flush_diagnostics: run the collective reduce on all ranks but only write
  the netCDF on the root rank (avoid a multi-writer race on the shared path)
- drop the unused n_timesteps argument from the aggregator build signature
- document that IC ensembles are unsupported and that masks are localized
…arallelism

Build the latitude band and land/ocean masks from rank-localized coordinates
so they match the (spatially scattered) recorded data and localized area
weights under model parallelism. Identity when no spatial parallelism.
The ENSO (nino34) and IPO (TPI) regional aggregators built LatLonRegion
weights from global coordinates while inference data is spatially scattered,
which would shape-mismatch under model parallelism. Localize the coordinates
before constructing the regions (identity without spatial parallelism).

Standalone fix to pre-existing aggregators, separable from the
perturbation-response feature.
self._counts was a CPU tensor passed to reduce_sum; NCCL all_reduce
rejects CPU tensors, so the diagnostic would crash on multi-GPU runs
(masked by the CPU-only test suite). Move the counts clone to the device
for the reduce, matching the group sums (which are already on device).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant