Refactor step to take an explicit [batch, ensemble] dimension pair#1302
Refactor step to take an explicit [batch, ensemble] dimension pair#1302mcgibbon wants to merge 4 commits into
Conversation
The step (`step_with_adjustments` and all four `StepABC` implementations) now receives tensors with an explicit `[batch, ensemble, *spatial]` leading dimension pair instead of a folded `[batch*ensemble, *spatial]` dimension. The nn.Module calls still operate on a folded `[batch*ensemble, channel, *spatial]` batch: each `network_call` folds the ensemble dimension into the batch immediately before the module and unfolds the module output. This makes the ensemble structure visible inside the step, a prerequisite for input channel dropout that is shared across ensemble members but independent across the batch (a follow-up). - Global mean removal reduces every dimension after the leading sample dim, so it is folded around (each ensemble member is an independent sample for mean removal) — keeping it byte-identical and agnostic to the spatial rank. - The corrector, ocean and normalizer are leading-dim agnostic (the corrector reduces only HORIZONTAL_DIMS) and run on the [batch, ensemble] layout directly. CorrectorState/StepperState gain fold/unfold_ensemble helpers; the corrector state is [batch, ensemble, 1, 1] inside the step and folded to [batch*ensemble, 1, 1] in the externally-threaded state. - predict_generator keeps its external contract folded: it unfolds the per-step inputs, masks and stepper_state, threads [batch, ensemble] through the step, and folds the yielded outputs — so training and inference callers are unchanged. - The input-mask helpers handle the leading [batch, ensemble] pair, and new fold/unfold_ensemble_tensor helpers fold a single tensor. Behavior-preserving: an added ensemble-independence test asserts each member's output equals an independent single-member step, and the spatial-parallel regression baselines reproduce the pre-refactor outputs bit-for-bit (modulo the size-1 ensemble dim).
…emble] The regression input/output fixtures store tensors, so they must reflect the new [batch, ensemble, *spatial] step contract. The values are unchanged (verified bit-for-bit modulo the size-1 ensemble dim); only a size-1 ensemble dimension is added. Verified under both H=2 and W=2 spatial decompositions.
…nnel Per review: the stepper interfaces must not be handed a folded [batch*ensemble, ...] TensorMapping together with an n_ensemble: int that re-interprets its combined leading dimension. Instead the explicit [batch, ensemble, *spatial] layout flows through, folding only at the encapsulated storage/IO boundaries. - BatchData keeps its folded `data` + `n_ensemble` as encapsulated storage and now exposes `ensemble_data` (already present) plus `ensemble_data_mask` and `ensemble_stepper_state`, all deriving the explicit [batch, ensemble, ...] view (ensemble length 1 is valid even before broadcast_ensemble). - Stepper.predict_generator no longer takes n_ensemble: it consumes the explicit ensemble views (time at dim 2), threads [batch, ensemble, ...] through the step, and yields that layout directly. Callers pass `.ensemble_data` / `.ensemble_data_mask` / `.ensemble_stepper_state`; _accumulate_loss drops its post-yield unfold. - process_prediction_generator_list consumes the explicit yields and folds back into BatchData's folded storage at that one boundary (deriving n_ensemble from the shape, no argument). - The coupled stepper drives its components over a flat batch; it folds the components' explicit yields back at the boundary and treats the component initial conditions as a flat (n_ensemble=1) view, matching the freshly built component forcings. A deeper coupled migration is a separate follow-up. Behavior-preserving: ace step/stepper, coupled, inference, aggregator and the spatial-parallel regression suites all pass unchanged.
| return unfold_ensemble_dim(TensorDict(self.data), n_ensemble=self.n_ensemble) | ||
|
|
||
| @property | ||
| def ensemble_data_mask(self) -> TensorMapping | None: |
There was a problem hiding this comment.
Why would we need this? If data is missing from a batch member, it will be missing from all ensemble instances. Can't we just broadcast later?
There was a problem hiding this comment.
Removed ensemble_data_mask. You're right — the data_mask is constant across ensemble members, so it no longer gets a [batch, ensemble] view. The step now folds the [batch, ensemble] data into the batch at entry and applies the per-sample (folded [batch*ensemble]) data_mask on the folded batch, exactly as on main.
| return unfold_ensemble_dim(dict(self.data_mask), n_ensemble=self.n_ensemble) | ||
|
|
||
| @property | ||
| def ensemble_stepper_state(self) -> "StepperState | None": |
There was a problem hiding this comment.
Question: why do we need this? Shouldn't the stepper state be able to always know internally what its ensemble count is?
There was a problem hiding this comment.
Removed ensemble_stepper_state. The step now folds the ensemble into the batch at entry, so the corrector runs on the folded batch and the stepper_state stays in its folded [batch*ensemble, ...] layout throughout — no unfolding, and the CorrectorState/StepperState fold-helpers are gone too.
| self, | ||
| ic_dict: TensorMapping, | ||
| forcing_dict: TensorMapping, | ||
| ic_data: TensorMapping, |
There was a problem hiding this comment.
unnecessary name change, please revert. This change is big enough that we should avoid changelog lines where possible.
There was a problem hiding this comment.
Reverted — back to ic_dict/forcing_dict.
| # the nn.Module call, which folds transiently inside the Step; the | ||
| # generator yields the same explicit layout. The time dimension is thus | ||
| # one past TIME_DIM (the ensemble dim is inserted before it). | ||
| time_dim = self.TIME_DIM + 1 |
There was a problem hiding this comment.
is it reasonable to instead update self.TIME_DIM to one higher? If so we don't need this lengthy comment.
There was a problem hiding this comment.
I tried bumping self.TIME_DIM to 2, but the coupled stepper indexes its folded component data via self.atmosphere.TIME_DIM/self.ocean.TIME_DIM (time at dim 1) — bumping it broke 39 coupled tests. So I kept Stepper.TIME_DIM = 1 and use a local time_dim = self.TIME_DIM + 1 in predict_generator (the ensemble dim sits before time in ensemble_data), and shortened the comment to one line.
Per review comments on #1302: - Remove BatchData.ensemble_data_mask and ensemble_stepper_state. The data_mask is constant across ensemble members, so it stays per-sample and broadcasts at apply; the stepper_state stays in its folded layout. - step_with_adjustments now folds the explicit [batch, ensemble] data into the batch at entry (each ensemble member is an independent sample) and unfolds the output, running the existing per-sample body on the folded data. This reverts the per-op network_call / mask-helper / global-mean-removal / corrector changes back to main, and drops the CorrectorState/StepperState fold-helpers and the single-tensor fold/unfold tensor helpers. - predict_generator consumes BatchData.ensemble_data (self-describing, no n_ensemble argument) but takes data_mask/stepper_state in their folded form. - Revert the predict_generator parameter rename (ic_dict/forcing_dict). - Keep Stepper.TIME_DIM = 1 (the coupled stepper indexes its folded component data via it); predict_generator uses a local time_dim = TIME_DIM + 1. Behavior-preserving: the step body is byte-identical to main on the folded batch. ace step/stepper, coupled, inference, aggregator and the spatial-parallel regression suites pass.
Per review comments on #1302: - Remove BatchData.ensemble_data_mask and ensemble_stepper_state. The data_mask is constant across ensemble members, so it stays per-sample and broadcasts at apply; the stepper_state stays in its folded layout. - step_with_adjustments now folds the explicit [batch, ensemble] data into the batch at entry (each ensemble member is an independent sample) and unfolds the output, running the existing per-sample body on the folded data. This reverts the per-op network_call / mask-helper / global-mean-removal / corrector changes back to main, and drops the CorrectorState/StepperState fold-helpers and the single-tensor fold/unfold tensor helpers. - predict_generator consumes BatchData.ensemble_data (self-describing, no n_ensemble argument) but takes data_mask/stepper_state in their folded form. - Revert the predict_generator parameter rename (ic_dict/forcing_dict). - Keep Stepper.TIME_DIM = 1 (the coupled stepper indexes its folded component data via it); predict_generator uses a local time_dim = TIME_DIM + 1. Behavior-preserving: the step body is byte-identical to main on the folded batch. ace step/stepper, coupled, inference, aggregator and the spatial-parallel regression suites pass.
The
step/stepper now receive data with an explicit[batch, ensemble, *spatial]leading dimension pair (viaBatchData.ensemble_data) instead of a folded[batch*ensemble, *spatial]dimension paired with ann_ensemble: intthat re-interprets it. This makes the data self-describing (n_ensemble = shape[1]) — no integer side-channel — which is the prerequisite for input channel dropout that is shared across ensemble members but independent across the batch (a follow-up, not implemented here).The ensemble dimension is folded back into the batch at the entry of
step_with_adjustments(each ensemble member is an independent sample), so the per-sample body — normalization, global mean removal, the network call, the corrector, the ocean — runs on the folded batch exactly as before, and the output is unfolded back. The change is therefore behavior-preserving (the step body is byte-identical tomainon the folded batch).Changes:
fme.ace.data_loading.BatchData.ensemble_data(existing) returns[batch, ensemble, time, *spatial];n_ensemblelength 1 is valid even beforebroadcast_ensemble. No other accessors are added —data_mask(constant across ensemble) andstepper_statestay in their folded layout.Stepper.predict_generatorconsumesBatchData.ensemble_data(non_ensembleargument);data_mask/stepper_stateare passed in their folded form.Stepper.TIME_DIMstays 1 (the coupled stepper indexes its folded component data through it); predict_generator uses a localtime_dim = TIME_DIM + 1.step_with_adjustmentsfolds the explicit ensemble dimension into the batch at entry and unfolds the output; the per-op step internals are unchanged frommain.process_prediction_generator_liststacks the explicit yields and folds back intoBatchData's folded storage (derivingn_ensemblefrom the shape).fme.coupled.stepperdrives its components over a flat batch; it folds the components' explicit yields at the boundary and treats the component initial conditions as a flat (n_ensemble=1) view matching the freshly built forcings.Verification:
test_step_ensemble_members_are_independent: each ensemble member's output equals an independent single-member step.Spatial-parallel
test_step_regressionpasses underH=2andW=2with regenerated[batch, ensemble]baselines (values reproduce the pre-refactor outputs modulo the size-1 ensemble dim).ace step/stepper, coupled, inference, and aggregator suites pass.
Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated