Skip to content

Refactor step to take an explicit [batch, ensemble] dimension pair#1302

Draft
mcgibbon wants to merge 4 commits into
mainfrom
refactor/step-batch-ensemble-dims
Draft

Refactor step to take an explicit [batch, ensemble] dimension pair#1302
mcgibbon wants to merge 4 commits into
mainfrom
refactor/step-batch-ensemble-dims

Conversation

@mcgibbon

@mcgibbon mcgibbon commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

The step/stepper now receive data with an explicit [batch, ensemble, *spatial] leading dimension pair (via BatchData.ensemble_data) instead of a folded [batch*ensemble, *spatial] dimension paired with an n_ensemble: int that re-interprets it. This makes the data self-describing (n_ensemble = shape[1]) — no integer side-channel — which is the prerequisite for input channel dropout that is shared across ensemble members but independent across the batch (a follow-up, not implemented here).

The ensemble dimension is folded back into the batch at the entry of step_with_adjustments (each ensemble member is an independent sample), so the per-sample body — normalization, global mean removal, the network call, the corrector, the ocean — runs on the folded batch exactly as before, and the output is unfolded back. The change is therefore behavior-preserving (the step body is byte-identical to main on the folded batch).

Changes:

  • fme.ace.data_loading.BatchData.ensemble_data (existing) returns [batch, ensemble, time, *spatial]; n_ensemble length 1 is valid even before broadcast_ensemble. No other accessors are added — data_mask (constant across ensemble) and stepper_state stay in their folded layout.
  • Stepper.predict_generator consumes BatchData.ensemble_data (no n_ensemble argument); data_mask/stepper_state are passed in their folded form. Stepper.TIME_DIM stays 1 (the coupled stepper indexes its folded component data through it); predict_generator uses a local time_dim = TIME_DIM + 1.
  • step_with_adjustments folds the explicit ensemble dimension into the batch at entry and unfolds the output; the per-op step internals are unchanged from main.
  • process_prediction_generator_list stacks the explicit yields and folds back into BatchData's folded storage (deriving n_ensemble from the shape).
  • fme.coupled.stepper drives its components over a flat batch; it folds the components' explicit yields at the boundary and treats the component initial conditions as a flat (n_ensemble=1) view matching the freshly built forcings.

Verification:

  • test_step_ensemble_members_are_independent: each ensemble member's output equals an independent single-member step.

  • Spatial-parallel test_step_regression passes under H=2 and W=2 with regenerated [batch, ensemble] baselines (values reproduce the pre-refactor outputs modulo the size-1 ensemble dim).

  • ace step/stepper, coupled, inference, and aggregator suites pass.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

mcgibbon added 2 commits June 22, 2026 15:59
The step (`step_with_adjustments` and all four `StepABC` implementations)
now receives tensors with an explicit `[batch, ensemble, *spatial]` leading
dimension pair instead of a folded `[batch*ensemble, *spatial]` dimension.
The nn.Module calls still operate on a folded `[batch*ensemble, channel,
*spatial]` batch: each `network_call` folds the ensemble dimension into the
batch immediately before the module and unfolds the module output.

This makes the ensemble structure visible inside the step, a prerequisite
for input channel dropout that is shared across ensemble members but
independent across the batch (a follow-up).

- Global mean removal reduces every dimension after the leading sample
  dim, so it is folded around (each ensemble member is an independent
  sample for mean removal) — keeping it byte-identical and agnostic to the
  spatial rank.
- The corrector, ocean and normalizer are leading-dim agnostic (the
  corrector reduces only HORIZONTAL_DIMS) and run on the [batch, ensemble]
  layout directly. CorrectorState/StepperState gain fold/unfold_ensemble
  helpers; the corrector state is [batch, ensemble, 1, 1] inside the step
  and folded to [batch*ensemble, 1, 1] in the externally-threaded state.
- predict_generator keeps its external contract folded: it unfolds the
  per-step inputs, masks and stepper_state, threads [batch, ensemble]
  through the step, and folds the yielded outputs — so training and
  inference callers are unchanged.
- The input-mask helpers handle the leading [batch, ensemble] pair, and
  new fold/unfold_ensemble_tensor helpers fold a single tensor.

Behavior-preserving: an added ensemble-independence test asserts each
member's output equals an independent single-member step, and the
spatial-parallel regression baselines reproduce the pre-refactor outputs
bit-for-bit (modulo the size-1 ensemble dim).
…emble]

The regression input/output fixtures store tensors, so they must reflect the
new [batch, ensemble, *spatial] step contract. The values are unchanged
(verified bit-for-bit modulo the size-1 ensemble dim); only a size-1 ensemble
dimension is added. Verified under both H=2 and W=2 spatial decompositions.
…nnel

Per review: the stepper interfaces must not be handed a folded
[batch*ensemble, ...] TensorMapping together with an n_ensemble: int that
re-interprets its combined leading dimension. Instead the explicit
[batch, ensemble, *spatial] layout flows through, folding only at the
encapsulated storage/IO boundaries.

- BatchData keeps its folded `data` + `n_ensemble` as encapsulated storage
  and now exposes `ensemble_data` (already present) plus `ensemble_data_mask`
  and `ensemble_stepper_state`, all deriving the explicit [batch, ensemble,
  ...] view (ensemble length 1 is valid even before broadcast_ensemble).
- Stepper.predict_generator no longer takes n_ensemble: it consumes the
  explicit ensemble views (time at dim 2), threads [batch, ensemble, ...]
  through the step, and yields that layout directly. Callers pass
  `.ensemble_data` / `.ensemble_data_mask` / `.ensemble_stepper_state`;
  _accumulate_loss drops its post-yield unfold.
- process_prediction_generator_list consumes the explicit yields and folds
  back into BatchData's folded storage at that one boundary (deriving
  n_ensemble from the shape, no argument).
- The coupled stepper drives its components over a flat batch; it folds the
  components' explicit yields back at the boundary and treats the component
  initial conditions as a flat (n_ensemble=1) view, matching the freshly
  built component forcings. A deeper coupled migration is a separate follow-up.

Behavior-preserving: ace step/stepper, coupled, inference, aggregator and the
spatial-parallel regression suites all pass unchanged.
Comment thread fme/ace/data_loading/batch_data.py Outdated
return unfold_ensemble_dim(TensorDict(self.data), n_ensemble=self.n_ensemble)

@property
def ensemble_data_mask(self) -> TensorMapping | None:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we need this? If data is missing from a batch member, it will be missing from all ensemble instances. Can't we just broadcast later?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed ensemble_data_mask. You're right — the data_mask is constant across ensemble members, so it no longer gets a [batch, ensemble] view. The step now folds the [batch, ensemble] data into the batch at entry and applies the per-sample (folded [batch*ensemble]) data_mask on the folded batch, exactly as on main.

Comment thread fme/ace/data_loading/batch_data.py Outdated
return unfold_ensemble_dim(dict(self.data_mask), n_ensemble=self.n_ensemble)

@property
def ensemble_stepper_state(self) -> "StepperState | None":

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: why do we need this? Shouldn't the stepper state be able to always know internally what its ensemble count is?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed ensemble_stepper_state. The step now folds the ensemble into the batch at entry, so the corrector runs on the folded batch and the stepper_state stays in its folded [batch*ensemble, ...] layout throughout — no unfolding, and the CorrectorState/StepperState fold-helpers are gone too.

Comment thread fme/ace/stepper/single_module.py Outdated
self,
ic_dict: TensorMapping,
forcing_dict: TensorMapping,
ic_data: TensorMapping,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary name change, please revert. This change is big enough that we should avoid changelog lines where possible.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted — back to ic_dict/forcing_dict.

# the nn.Module call, which folds transiently inside the Step; the
# generator yields the same explicit layout. The time dimension is thus
# one past TIME_DIM (the ensemble dim is inserted before it).
time_dim = self.TIME_DIM + 1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it reasonable to instead update self.TIME_DIM to one higher? If so we don't need this lengthy comment.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried bumping self.TIME_DIM to 2, but the coupled stepper indexes its folded component data via self.atmosphere.TIME_DIM/self.ocean.TIME_DIM (time at dim 1) — bumping it broke 39 coupled tests. So I kept Stepper.TIME_DIM = 1 and use a local time_dim = self.TIME_DIM + 1 in predict_generator (the ensemble dim sits before time in ensemble_data), and shortened the comment to one line.

Per review comments on #1302:
- Remove BatchData.ensemble_data_mask and ensemble_stepper_state. The
  data_mask is constant across ensemble members, so it stays per-sample and
  broadcasts at apply; the stepper_state stays in its folded layout.
- step_with_adjustments now folds the explicit [batch, ensemble] data into the
  batch at entry (each ensemble member is an independent sample) and unfolds
  the output, running the existing per-sample body on the folded data. This
  reverts the per-op network_call / mask-helper / global-mean-removal /
  corrector changes back to main, and drops the CorrectorState/StepperState
  fold-helpers and the single-tensor fold/unfold tensor helpers.
- predict_generator consumes BatchData.ensemble_data (self-describing, no
  n_ensemble argument) but takes data_mask/stepper_state in their folded form.
- Revert the predict_generator parameter rename (ic_dict/forcing_dict).
- Keep Stepper.TIME_DIM = 1 (the coupled stepper indexes its folded component
  data via it); predict_generator uses a local time_dim = TIME_DIM + 1.

Behavior-preserving: the step body is byte-identical to main on the folded
batch. ace step/stepper, coupled, inference, aggregator and the spatial-parallel
regression suites pass.
mcgibbon added a commit that referenced this pull request Jun 22, 2026
Per review comments on #1302:
- Remove BatchData.ensemble_data_mask and ensemble_stepper_state. The
  data_mask is constant across ensemble members, so it stays per-sample and
  broadcasts at apply; the stepper_state stays in its folded layout.
- step_with_adjustments now folds the explicit [batch, ensemble] data into the
  batch at entry (each ensemble member is an independent sample) and unfolds
  the output, running the existing per-sample body on the folded data. This
  reverts the per-op network_call / mask-helper / global-mean-removal /
  corrector changes back to main, and drops the CorrectorState/StepperState
  fold-helpers and the single-tensor fold/unfold tensor helpers.
- predict_generator consumes BatchData.ensemble_data (self-describing, no
  n_ensemble argument) but takes data_mask/stepper_state in their folded form.
- Revert the predict_generator parameter rename (ic_dict/forcing_dict).
- Keep Stepper.TIME_DIM = 1 (the coupled stepper indexes its folded component
  data via it); predict_generator uses a local time_dim = TIME_DIM + 1.

Behavior-preserving: the step body is byte-identical to main on the folded
batch. ace step/stepper, coupled, inference, aggregator and the spatial-parallel
regression suites pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant