Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,21 +496,26 @@ def process_ensemble_prediction_generator_list(
def process_prediction_generator_list(
output_list: list[tuple[TensorDict, StepperState | None]],
time: xr.DataArray,
n_ensemble: int,
labels: BatchLabels | None = None,
horizontal_dims: list[str] | None = None,
) -> BatchData:
"""Stack per-step outputs into a single BatchData.

Attaches the terminal stepper_state (from the last entry in
``output_list``) to the returned BatchData so it can propagate to the
next ``Stepper.predict`` call.
The generator yields the explicit ``[batch, ensemble, *spatial]`` layout;
this stacks over time into ``[batch, ensemble, time, *spatial]`` and folds
the ensemble back into the batch only at this storage boundary (BatchData
holds the folded ``data`` + ``n_ensemble``). Attaches the terminal
stepper_state (from the last entry in ``output_list``, already in the folded
layout) so it can propagate to the next ``Stepper.predict`` call.
"""
output_dicts = [item[0] for item in output_list]
terminal_state = output_list[-1][1] if output_list else None
output_timeseries = stack_list_of_tensor_dicts(output_dicts, time_dim=1)
output_timeseries = EnsembleTensorDict(
stack_list_of_tensor_dicts(output_dicts, time_dim=2)
)
folded_data, n_ensemble = fold_ensemble_dim(output_timeseries)
return BatchData.new_on_device(
data=output_timeseries,
data=folded_data,
time=time,
horizontal_dims=horizontal_dims,
labels=labels,
Expand Down Expand Up @@ -1088,11 +1093,9 @@ def get_prediction_generator(
"Initial condition and forcing data must have the same labels, "
f"got {ic_batch_data.labels} and {forcing_data.labels}."
)
ic_dict = ic_batch_data.data
forcing_dict = forcing_data.data
return self.predict_generator(
ic_dict,
forcing_dict,
ic_batch_data.ensemble_data,
forcing_data.ensemble_data,
n_forward_steps,
optimizer,
forcing_data.labels,
Expand All @@ -1116,18 +1119,24 @@ def predict_generator(
data_mask: TensorMapping | None = None,
stepper_state: StepperState | None = None,
) -> Generator[tuple[TensorDict, StepperState | None], None, None]:
state = {k: ic_dict[k].squeeze(self.TIME_DIM) for k in ic_dict}
# ic/forcing carry an explicit [batch, ensemble, time, *spatial] layout
# (BatchData.ensemble_data); the Step folds the ensemble into the batch
# internally and yields the same explicit layout. data_mask and
# stepper_state stay in their folded [batch*ensemble, ...] layout. The
# ensemble dim sits before time here, so time is at TIME_DIM + 1.
time_dim = self.TIME_DIM + 1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it reasonable to instead update self.TIME_DIM to one higher? If so we don't need this lengthy comment.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried bumping self.TIME_DIM to 2, but the coupled stepper indexes its folded component data via self.atmosphere.TIME_DIM/self.ocean.TIME_DIM (time at dim 1) — bumping it broke 39 coupled tests. So I kept Stepper.TIME_DIM = 1 and use a local time_dim = self.TIME_DIM + 1 in predict_generator (the ensemble dim sits before time in ensemble_data), and shortened the comment to one line.

state = {k: ic_dict[k].squeeze(time_dim) for k in ic_dict}
for step in range(n_forward_steps):
input_forcing = {
k: (
forcing_dict[k][:, step]
forcing_dict[k].select(time_dim, step)
if k not in self._step_obj.next_step_forcing_names
else forcing_dict[k][:, step + 1]
else forcing_dict[k].select(time_dim, step + 1)
)
for k in self._input_only_names
}
next_step_input_dict = {
k: forcing_dict[k][:, step + 1]
k: forcing_dict[k].select(time_dim, step + 1)
for k in self._step_obj.next_step_input_names
}
input_data = {**state, **input_forcing}
Expand Down Expand Up @@ -1212,7 +1221,6 @@ def predict(
time=forcing_data.time[:, self.n_ic_timesteps :],
horizontal_dims=forcing_data.horizontal_dims,
labels=forcing.labels,
n_ensemble=forcing.n_ensemble,
)
if compute_derived_variables:
with timer.context("compute_derived_variables"):
Expand Down Expand Up @@ -1667,8 +1675,8 @@ def _accumulate_loss(
input_ensemble_data = input_data.as_batch_data().broadcast_ensemble(n_ensemble)
forcing_ensemble_data = data.broadcast_ensemble(n_ensemble)
output_generator = self._stepper.predict_generator(
input_ensemble_data.data,
forcing_ensemble_data.data,
input_ensemble_data.ensemble_data,
forcing_ensemble_data.ensemble_data,
n_forward_steps,
optimization,
labels=input_ensemble_data.labels,
Expand All @@ -1688,8 +1696,10 @@ def _accumulate_loss(
contextlib.nullcontext() if optimize_step else torch.no_grad()
)
with grad_context:
gen_step, _ = next(output_iterator)
gen_step = unfold_ensemble_dim(gen_step, n_ensemble=n_ensemble)
# The generator yields the explicit [batch, ensemble, *spatial]
# layout directly.
raw_gen_step, _ = next(output_iterator)
gen_step = EnsembleTensorDict(raw_gen_step)
output_list.append(gen_step)
target_step = add_ensemble_dim(
{
Expand Down
52 changes: 42 additions & 10 deletions fme/ace/stepper/test_single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ def _get_train_stepper(
def test_step():
stepper = _get_stepper(["a", "b"], ["a", "b"])
n_samples = 3
input_data = {x: torch.rand(n_samples, 5, 5).to(DEVICE) for x in ["a", "b"]}
input_data = {x: torch.rand(n_samples, 1, 5, 5).to(DEVICE) for x in ["a", "b"]}

output, _ = stepper.step(
StepArgs(input=input_data, next_step_input_data={}, labels=None)
Expand All @@ -1091,7 +1091,7 @@ def test_step():
def test_step_with_diagnostic():
stepper = _get_stepper(["a"], ["a", "c"], module_name="RepeatChannel")
n_samples = 3
input_data = {"a": torch.rand(n_samples, 5, 5).to(DEVICE)}
input_data = {"a": torch.rand(n_samples, 1, 5, 5).to(DEVICE)}
output, _ = stepper.step(
StepArgs(input=input_data, next_step_input_data={}, labels=None)
)
Expand All @@ -1109,7 +1109,7 @@ def test_step_with_forcing_and_diagnostic(residual_prediction):
residual_prediction=residual_prediction,
)
n_samples = 3
input_data = {x: torch.rand(n_samples, 5, 5).to(DEVICE) for x in ["a", "b"]}
input_data = {x: torch.rand(n_samples, 1, 5, 5).to(DEVICE) for x in ["a", "b"]}
output, _ = stepper.step(
StepArgs(input=input_data, next_step_input_data={}, labels=None)
)
Expand All @@ -1122,12 +1122,44 @@ def test_step_with_forcing_and_diagnostic(residual_prediction):
assert "c" in output


@pytest.mark.parametrize(
"module_name, in_names, out_names",
[
("ChannelSum", ["a", "b"], ["a"]),
("RepeatChannel", ["a"], ["a", "c"]),
],
)
def test_step_ensemble_members_are_independent(module_name, in_names, out_names):
"""A ``[batch, ensemble]`` step must produce, for each ensemble member,
exactly what an independent single-member step on that member's data would
produce. This pins behavior preservation of the fold/unfold around the
network: ensemble members must not be mixed, and the channel dimension must
not be mistaken for the ensemble dimension.
"""
stepper = _get_stepper(in_names, out_names, module_name=module_name)
n_batch, n_ensemble = 2, 3
input_data = {x: torch.rand(n_batch, n_ensemble, 5, 5).to(DEVICE) for x in in_names}
out, _ = stepper.step(
StepArgs(input=input_data, next_step_input_data={}, labels=None)
)
for name in out_names:
assert out[name].shape == (n_batch, n_ensemble, 5, 5)
for e in range(n_ensemble):
member_in = {k: v[:, e : e + 1] for k, v in input_data.items()}
member_out, _ = stepper.step(
StepArgs(input=member_in, next_step_input_data={}, labels=None)
)
assert set(member_out) == set(out)
for name in out:
torch.testing.assert_close(out[name][:, e : e + 1], member_out[name])


def test_step_with_prescribed_ocean():
stepper = _get_stepper(
["a", "b"], ["a", "b"], ocean_config=OceanConfig("a", "mask")
)
input_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "b"]}
ocean_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "mask"]}
input_data = {x: torch.rand(3, 1, 5, 5).to(DEVICE) for x in ["a", "b"]}
ocean_data = {x: torch.rand(3, 1, 5, 5).to(DEVICE) for x in ["a", "mask"]}
output, _ = stepper.step(
StepArgs(input=input_data, next_step_input_data=ocean_data, labels=None)
)
Expand All @@ -1146,8 +1178,8 @@ def test_prescribe_sst_integration():
stepper = _get_stepper(
["a", "b"], ["a", "b"], ocean_config=OceanConfig("a", "mask")
)
input_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "b"]}
ocean_data = {x: torch.rand(3, 5, 5).to(DEVICE) for x in ["a", "mask"]}
input_data = {x: torch.rand(3, 1, 5, 5).to(DEVICE) for x in ["a", "b"]}
ocean_data = {x: torch.rand(3, 1, 5, 5).to(DEVICE) for x in ["a", "mask"]}
prescribed_data = stepper.prescribe_sst(
mask_data=ocean_data,
gen_data=input_data,
Expand Down Expand Up @@ -2078,7 +2110,7 @@ def _step_negative_input(stepper: Stepper) -> tuple[torch.Tensor, torch.Tensor]:
The stepper adds one to its input, so the raw prediction is negative
everywhere and the force-positive corrector clamps it to zero.
"""
input_data = {"a": torch.full((3, 5, 5), -5.0, device=DEVICE)}
input_data = {"a": torch.full((3, 1, 5, 5), -5.0, device=DEVICE)}
output, _ = stepper.step(
StepArgs(input=input_data, next_step_input_data={}, labels=None)
)
Expand Down Expand Up @@ -2599,7 +2631,7 @@ def test_step_unmasked_nan_input_raises():
"""A NaN input not covered by data_mask raises a located error in step()."""
stepper = _get_stepper(["a", "b"], ["a"], module_name="ChannelSum")
n_samples = 2
input_data = {x: torch.rand(n_samples, 5, 5).to(DEVICE) for x in ["a", "b"]}
input_data = {x: torch.rand(n_samples, 1, 5, 5).to(DEVICE) for x in ["a", "b"]}
input_data["b"][1] = torch.nan
with pytest.raises(ValueError, match=r"NaN found in network input.*\bb\b"):
stepper.step(StepArgs(input=input_data, next_step_input_data={}, labels=None))
Expand All @@ -2609,7 +2641,7 @@ def test_step_masked_nan_input_does_not_raise():
"""A NaN input covered by data_mask is zeroed, so the guard does not fire."""
stepper = _get_stepper(["a", "b"], ["a"], module_name="ChannelSum")
n_samples = 2
input_data = {x: torch.rand(n_samples, 5, 5).to(DEVICE) for x in ["a", "b"]}
input_data = {x: torch.rand(n_samples, 1, 5, 5).to(DEVICE) for x in ["a", "b"]}
input_data["b"][1] = torch.nan
data_mask = {"b": torch.tensor([True, False], dtype=torch.bool, device=DEVICE)}
output, _ = stepper.step(
Expand Down
8 changes: 7 additions & 1 deletion fme/core/distributed/parallel_tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,19 @@ def get_multi_call_selector(


def get_tensor_dict(
names: list[str], img_shape: tuple[int, int], n_samples: int
names: list[str],
img_shape: tuple[int, int],
n_samples: int,
n_ensemble: int = 1,
) -> TensorDict:
# The step operates on an explicit [batch, ensemble, *spatial] leading
# pair; build inputs with that ensemble dimension (size 1 by default).
data_dict = {}
device = fme.get_device()
for name in names:
data_dict[name] = torch.rand(
n_samples,
n_ensemble,
*img_shape,
device=device,
)
Expand Down
Binary file modified fme/core/distributed/parallel_tests/testdata/multi_call_input.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
31 changes: 26 additions & 5 deletions fme/core/step/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
)
from fme.core.step.step import StepABC, StepConfigABC, StepSelector
from fme.core.stepper_state import StepperState
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.tensors import fold_sized_ensemble_dim, unfold_ensemble_dim
from fme.core.typing_ import EnsembleTensorDict, TensorDict, TensorMapping

DEFAULT_TIMESTEP = datetime.timedelta(hours=6)
DEFAULT_ENCODED_TIMESTEP = encode_timestep(DEFAULT_TIMESTEP)
Expand Down Expand Up @@ -531,6 +532,15 @@ def _build_channel_mask_dict(
return result


def _ensemble_size(data: TensorMapping) -> int:
"""Size of the explicit ensemble dimension (dim 1) of a
``[batch, ensemble, *spatial]`` tensor mapping; 1 for an empty mapping.
"""
for v in data.values():
return v.shape[1]
return 1


def step_with_adjustments(
input: TensorMapping,
next_step_input_data: TensorMapping,
Expand All @@ -550,10 +560,12 @@ def step_with_adjustments(

Args:
input: Mapping from variable name to tensor of shape
[n_batch, n_lat, n_lon] containing denormalized data from the
initial timestep. In practice this contains the ML inputs.
[n_batch, n_ensemble, n_lat, n_lon] containing denormalized data
from the initial timestep. In practice this contains the ML inputs.
The ensemble dimension is folded into the batch for the per-sample
transforms and the network call, and the output is unfolded back.
next_step_input_data: Mapping from variable name to tensor of shape
[n_batch, n_lat, n_lon] containing denormalized data from
[n_batch, n_ensemble, n_lat, n_lon] containing denormalized data from
the output timestep. In practice this contains the necessary data
at the output timestep for the ocean model and corrector.
network_calls: Callable[[TensorMapping], TensorDict] that takes a
Expand Down Expand Up @@ -584,6 +596,15 @@ def step_with_adjustments(
"""
if prescribed_prognostic_names is None:
prescribed_prognostic_names = []
# Fold the explicit ensemble dimension into the batch so the per-sample
# transforms and the network operate on a single sample dimension (each
# ensemble member is an independent sample); the output is unfolded back at
# the end. data_mask and stepper_state are already in the folded layout.
n_ensemble = _ensemble_size(input)
input = fold_sized_ensemble_dim(EnsembleTensorDict(dict(input)), n_ensemble)
next_step_input_data = fold_sized_ensemble_dim(
EnsembleTensorDict(dict(next_step_input_data)), n_ensemble
)
gmr_state: GlobalMeanRemovalState | None = None
if global_mean_removal is not None:
network_input, gmr_state = global_mean_removal.forward_transform(
Expand Down Expand Up @@ -621,4 +642,4 @@ def step_with_adjustments(
raise ValueError(
f"prescribed_prognostic_name '{name}' not in next_step_input_data"
)
return output, stepper_state
return unfold_ensemble_dim(output, n_ensemble), stepper_state
8 changes: 7 additions & 1 deletion fme/core/step/test_radiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@ def forward(self, x):


def get_tensor_dict(
names: list[str], img_shape: tuple[int, int], n_samples: int
names: list[str],
img_shape: tuple[int, int],
n_samples: int,
n_ensemble: int = 1,
) -> TensorDict:
# The step operates on an explicit [batch, ensemble, *spatial] leading
# pair; build inputs with that ensemble dimension (size 1 by default).
data_dict = {}
device = fme.get_device()
for name in names:
data_dict[name] = torch.rand(
n_samples,
n_ensemble,
*img_shape,
device=device,
)
Expand Down
Loading
Loading