From 39d11b0ed0794e1fcafe41e03cc9c80b623afbce Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Sat, 6 Jun 2026 14:13:29 -0700 Subject: [PATCH 1/9] Add with_rolled_lon to downscaling models Let models re-express their grid in a seam-crossing coarse domain's longitude convention while sharing network weights: - DiffusionModel.with_rolled_lon rebuilds the model through its constructor with full_fine_coords and static_inputs rolled to match the coarse grid. The roll is anchored on the western coarse-cell edge so the fine grid stays aligned to whole coarse cells. Returns self when no roll is needed. - DenoisingMoEPredictor.with_rolled_lon rolls every expert (preserving the shared-grid invariant) and rebuilds so the sigma dispatcher is reconstructed from the rolled experts. Adds tests for no-roll passthrough, coord shifting with shared weights, idempotency, coarse-cell alignment, and rolling all MoE experts. --- fme/downscaling/models.py | 55 +++++ .../predictors/serial_denoising.py | 20 ++ fme/downscaling/test_models.py | 203 ++++++++++++++++++ 3 files changed, 278 insertions(+) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 782fe5e06..733419c7a 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -21,7 +21,10 @@ PairedBatchData, StaticInputs, adjust_fine_coord_range, + coords_require_lon_roll, + find_roll_anchor, load_coords_from_path, + roll_latlon_coords, ) from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector @@ -747,6 +750,58 @@ def metadata(self): else 0, ) + def _lon_roll_amount(self, coarse_lon: torch.Tensor) -> tuple[int, float]: + """ + Number of positions to roll the fine grid (and the lon_start it aligns to) + so the fine cells stay aligned to coarse_lon's coarse cells. + + coarse_lon is the actual coarse domain grid, so it already carries the + convention to align to. The roll is anchored on the western coarse-cell + *edge* (half a coarse cell below coarse_lon.min(), which is a cell *center*) + so the fine grid rolls by a whole number of coarse cells and its cells stay + aligned to the coarse cells. Anchoring on the center instead would roll by an + extra downscale_factor // 2 fine points, splitting the boundary coarse cell + across the seam. + """ + lon_start = float(coarse_lon.min()) + fine_lon = self.full_fine_coords.lon + fine_spacing = float(fine_lon[1] - fine_lon[0]) + western_edge = lon_start - self.downscale_factor * fine_spacing / 2.0 + return find_roll_anchor(fine_lon, western_edge), lon_start + + def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel": + """ + Return a new model with full_fine_coords and static_inputs rolled to match + coarse_lon's longitude convention, sharing the network weights. + + Returns self unchanged when coarse_lon does not cross the prime meridian. + The new model is built through the constructor (rather than a shallow copy) + so its coords are re-validated and derived state is rebuilt fresh; the raw + module is unwrapped and passed so __init__ re-wraps it exactly once. + """ + if not coords_require_lon_roll(coarse_lon): + return self + roll_amount, lon_start = self._lon_roll_amount(coarse_lon) + return DiffusionModel( + config=self.config, + module=self.module.module, + normalizer=self.normalizer, + loss=self.loss, + coarse_shape=self.coarse_shape, + downscale_factor=self.downscale_factor, + sigma_data=self.sigma_data, + full_fine_coords=roll_latlon_coords( + self.full_fine_coords, roll_amount, lon_start + ), + in_names=self.in_names, + out_names=self.out_names, + static_inputs=( + self.static_inputs.roll(roll_amount, lon_start) + if self.static_inputs is not None + else None + ), + ) + @dataclasses.dataclass class _CheckpointModelConfigSelector: diff --git a/fme/downscaling/predictors/serial_denoising.py b/fme/downscaling/predictors/serial_denoising.py index 18b198e64..f52900b54 100644 --- a/fme/downscaling/predictors/serial_denoising.py +++ b/fme/downscaling/predictors/serial_denoising.py @@ -226,6 +226,26 @@ def static_inputs(self) -> StaticInputs | None: def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: return self._primary.get_fine_coords_for_batch(batch) + def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DenoisingMoEPredictor": + """New predictor with every expert's coords rolled to match coarse_lon. + + All experts are rolled (not just the primary) so the shared-grid invariant + enforced in __init__ still holds -- nothing relies on the non-primary + experts' coordinates being left unrolled. Rebuilt through __init__ so + _dispatch_module is reconstructed from the rolled experts. Returns self + unchanged when no roll is needed. + """ + rolled = [expert.with_rolled_lon(coarse_lon) for expert in self._experts] + if all(r is e for r, e in zip(rolled, self._experts)): + return self + return DenoisingMoEPredictor( + experts=rolled, + sigma_ranges=self._sigma_ranges, + num_diffusion_generation_steps=self._num_diffusion_generation_steps, + churn=self._churn, + expert_renames=self._expert_renames, + ) + @torch.no_grad() def generate( self, diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 244a3c154..5cb079e67 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -532,6 +532,209 @@ def test_get_fine_coords_for_batch(): assert torch.allclose(result.lon, expected_lon) +def _make_global_fine_coords_and_static(fine_shape: tuple[int, int]): + """Return a global-covering LatLonCoordinates and matching StaticInputs.""" + step = 360 / fine_shape[1] + global_fine_lon = torch.arange(fine_shape[1]) * step + step / 2 + global_fine_lat = _get_monotonic_coordinate(fine_shape[0], stop=fine_shape[0]) + full_fine_coords = LatLonCoordinates(lat=global_fine_lat, lon=global_fine_lon) + static_field = torch.arange( + fine_shape[0] * fine_shape[1], dtype=torch.float32 + ).reshape(*fine_shape) + static_inputs = StaticInputs( + fields=[StaticInput(static_field)], coords=full_fine_coords + ) + return full_fine_coords, static_inputs + + +def test_with_rolled_lon_no_roll_returns_same(): + """with_rolled_lon returns the original model when no roll is needed.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + static_inputs = make_static_inputs(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=static_inputs.coords, + static_inputs=static_inputs, + ) + coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + assert model.with_rolled_lon(coarse_lon) is model + + +def test_with_rolled_lon_shifts_coords_and_shares_weights(): + """with_rolled_lon: new model with rolled coords, shared network weights.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = model.with_rolled_lon(coarse_lon) + + # Reconstruction wraps a fresh module around the SAME raw weights. + assert rolled.module is not model.module + assert next(rolled.module.parameters()) is next(model.module.parameters()) + assert not torch.equal(rolled.full_fine_coords.lon, model.full_fine_coords.lon) + assert torch.all(rolled.full_fine_coords.lon[1:] > rolled.full_fine_coords.lon[:-1]) + assert rolled.full_fine_coords.lon[0].item() < 0 + assert rolled.static_inputs is not None + # Compare against model.static_inputs (on-device) rather than the CPU-side original + assert not torch.equal( + rolled.static_inputs.fields[0].data, model.static_inputs.fields[0].data + ) + + +def test_with_rolled_lon_is_idempotent(): + """Rolling an already-rolled model with the same domain is a no-op. + + Guards against accidental double-rolling: the second roll resolves to 0 + (full rotation), so the twice-rolled model has identical coords and static + inputs to the once-rolled one. + """ + coarse_shape = (8, 16) + fine_shape = (16, 32) + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = model.with_rolled_lon(coarse_lon) + twice = rolled.with_rolled_lon(coarse_lon) + + assert torch.equal(twice.full_fine_coords.lon, rolled.full_fine_coords.lon) + assert rolled.static_inputs is not None and twice.static_inputs is not None + assert torch.equal( + twice.static_inputs.fields[0].data, rolled.static_inputs.fields[0].data + ) + + +def test_roll_diffusion_model_keeps_fine_aligned_to_coarse_cells(): + """A seam-crossing domain must roll the fine grid by whole coarse cells. + + The roll is anchored on the western coarse-cell edge, not its center. If it + anchored on the center it would roll an extra downscale_factor // 2 fine + points, leaving no fine margin below the western coarse cell -- which makes + get_fine_coords_for_batch raise -- and splitting that cell across the seam. + """ + coarse_shape = (4, 8) + fine_shape = (16, 32) + factor = 4 + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=factor, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + + # Four of the eight global 45-degree coarse cells, crossing the 0/360 seam and + # expressed in negative convention (physically 292.5, 337.5 and 22.5, 67.5). + # Coarse-lat centers [6, 10] are interior, leaving fine margin above and below. + coarse_lat = [6.0, 10.0] + coarse_lon = [-67.5, -22.5, 22.5, 67.5] + batch = make_batch_data( + (1, len(coarse_lat), len(coarse_lon)), coarse_lat, coarse_lon + ) + + rolled = model.with_rolled_lon(torch.tensor(coarse_lon, dtype=torch.float32)) + # Anchoring on the cell center would leave no margin and raise here. + fine_coords = rolled.get_fine_coords_for_batch(batch) + + # Each coarse cell is covered by exactly `factor` fine cells whose mean is the + # coarse-cell center -- i.e. the fine grid stayed aligned to the coarse cells. + recentered = fine_coords.lon.reshape(len(coarse_lon), factor).mean(dim=1).cpu() + assert torch.allclose(recentered, torch.tensor(coarse_lon), atol=1e-3) + + +def test_denoising_moe_predictor_with_rolled_lon_rolls_all_experts(): + """with_rolled_lon rolls every expert (keeping the shared-grid invariant).""" + from fme.downscaling.predictors.serial_denoising import DenoisingMoEPredictor + + coarse_shape = (8, 16) + fine_shape = (16, 32) + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + + expert0 = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + expert1 = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + predictor = DenoisingMoEPredictor( + experts=[expert0, expert1], + sigma_ranges=[(0.0, 0.5), (0.5, 1.0)], + num_diffusion_generation_steps=2, + churn=0.0, + ) + + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = predictor.with_rolled_lon(coarse_lon) + + # Every expert is a new (rolled) object; _primary stays _experts[0]. + assert rolled._primary is rolled._experts[0] + for rolled_expert, original, source in zip( + rolled._experts, predictor._experts, [expert0, expert1] + ): + assert rolled_expert is not original + # Coords are rolled... + assert rolled_expert.full_fine_coords.lon[0].item() < 0 + # ...but the raw network weights are still shared (fresh wrapper). + assert next(rolled_expert.module.parameters()) is next( + source.module.parameters() + ) + # The sigma dispatcher is rebuilt from the rolled experts, consistent with + # _experts (not left pointing at any pre-roll module). + for entry, rolled_expert in zip(rolled._dispatch_module._entries, rolled._experts): + assert entry[2] is rolled_expert.module + + # No-roll case returns self + non_neg_lon = torch.tensor([0.0, 5.0, 10.0, 15.0], dtype=torch.float32) + assert predictor.with_rolled_lon(non_neg_lon) is predictor + + +def test_denoising_moe_predictor_rejects_mismatched_expert_grids(): + """Experts on different grids are rejected at construction (shared-grid).""" + from fme.downscaling.predictors.serial_denoising import DenoisingMoEPredictor + + fine_coords_a, static_a = _make_global_fine_coords_and_static((16, 32)) + fine_coords_b, static_b = _make_global_fine_coords_and_static((16, 16)) + expert_a = _get_diffusion_model( + coarse_shape=(8, 16), + downscale_factor=2, + full_fine_coords=fine_coords_a, + static_inputs=static_a, + ) + expert_b = _get_diffusion_model( + coarse_shape=(8, 8), + downscale_factor=2, + full_fine_coords=fine_coords_b, + static_inputs=static_b, + ) + with pytest.raises(ValueError, match="metadata"): + DenoisingMoEPredictor( + experts=[expert_a, expert_b], + sigma_ranges=[(0.0, 0.5), (0.5, 1.0)], + num_diffusion_generation_steps=2, + churn=0.0, + ) + + def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( From 5e0ca29dbe67e6cb558cb0bacca8a8952050b7fe Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 12 Jun 2026 15:21:58 -0700 Subject: [PATCH 2/9] Fix utils riname use in models, consolidate coordinate creation --- fme/downscaling/data/__init__.py | 1 + fme/downscaling/data/test_config.py | 12 +++--------- fme/downscaling/models.py | 7 ++++--- fme/downscaling/test_models.py | 8 +++----- fme/downscaling/test_utils.py | 14 ++++++++++++-- 5 files changed, 23 insertions(+), 19 deletions(-) diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index e11d8f662..beea9a3a3 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -22,5 +22,6 @@ coords_require_lon_roll, expand_and_fold_tensor, find_roll_anchor, + roll_lon_coords, scale_tuple, ) diff --git a/fme/downscaling/data/test_config.py b/fme/downscaling/data/test_config.py index 487d8b25b..181917c90 100644 --- a/fme/downscaling/data/test_config.py +++ b/fme/downscaling/data/test_config.py @@ -20,7 +20,7 @@ ) from fme.downscaling.data.utils import ClosedInterval from fme.downscaling.requirements import DataRequirements -from fme.downscaling.test_utils import data_paths_helper +from fme.downscaling.test_utils import cell_centered_coordinate, data_paths_helper from fme.downscaling.typing_ import FineResCoarseResPair @@ -34,14 +34,8 @@ def _write_global_nc(path: Path, n_lat: int, n_lon: int, num_timesteps: int) -> cftime.DatetimeProlepticGregorian(2000, 1, 1) + datetime.timedelta(days=i) for i in range(num_timesteps) ] - lon_spacing = 360.0 / n_lon - lat_spacing = 8.0 / n_lat - lons = np.array( - [lon_spacing / 2 + i * lon_spacing for i in range(n_lon)], dtype=np.float32 - ) - lats = np.array( - [lat_spacing / 2 + i * lat_spacing for i in range(n_lat)], dtype=np.float32 - ) + lons = cell_centered_coordinate(0.0, 360.0, n_lon).numpy() + lats = cell_centered_coordinate(0.0, 8.0, n_lat).numpy() data = ( np.broadcast_to(lons[None, None, :], (num_timesteps, n_lat, n_lon)) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 733419c7a..67aca5e46 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -24,7 +24,7 @@ coords_require_lon_roll, find_roll_anchor, load_coords_from_path, - roll_latlon_coords, + roll_lon_coords, ) from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector @@ -790,8 +790,9 @@ def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel": coarse_shape=self.coarse_shape, downscale_factor=self.downscale_factor, sigma_data=self.sigma_data, - full_fine_coords=roll_latlon_coords( - self.full_fine_coords, roll_amount, lon_start + full_fine_coords=LatLonCoordinates( + lat=self.full_fine_coords.lat, + lon=roll_lon_coords(self.full_fine_coords.lon, roll_amount, lon_start), ), in_names=self.in_names, out_names=self.out_names, diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 5cb079e67..e2d46fc79 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -28,6 +28,7 @@ ) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.noise import LogNormalNoiseDistribution +from fme.downscaling.test_utils import cell_centered_coordinate from fme.downscaling.typing_ import FineResCoarseResPair @@ -100,9 +101,7 @@ def make_batch_data( def _get_monotonic_coordinate(size: int, stop: float) -> torch.Tensor: - bounds = torch.linspace(0, stop, size + 1) - coord = (bounds[:-1] + bounds[1:]) / 2 - return coord + return cell_centered_coordinate(0.0, stop, size) def make_paired_batch_data( @@ -534,8 +533,7 @@ def test_get_fine_coords_for_batch(): def _make_global_fine_coords_and_static(fine_shape: tuple[int, int]): """Return a global-covering LatLonCoordinates and matching StaticInputs.""" - step = 360 / fine_shape[1] - global_fine_lon = torch.arange(fine_shape[1]) * step + step / 2 + global_fine_lon = cell_centered_coordinate(0.0, 360.0, fine_shape[1]) global_fine_lat = _get_monotonic_coordinate(fine_shape[0], stop=fine_shape[0]) full_fine_coords = LatLonCoordinates(lat=global_fine_lat, lon=global_fine_lon) static_field = torch.arange( diff --git a/fme/downscaling/test_utils.py b/fme/downscaling/test_utils.py index 4f34a4451..b66c827bd 100644 --- a/fme/downscaling/test_utils.py +++ b/fme/downscaling/test_utils.py @@ -3,14 +3,24 @@ import cftime import numpy as np +import torch import xarray as xr from fme.downscaling.typing_ import FineResCoarseResPair +def cell_centered_coordinate(start: float, end: float, n: int) -> torch.Tensor: + """Centers of ``n`` equal-width cells spanning ``[start, end]``. + + Shared primitive for the downscaling tests' coordinate construction + (longitude rolling in particular relies on cell-centered grids). + """ + bounds = torch.linspace(start, end, n + 1) + return (bounds[:-1] + bounds[1:]) / 2 + + def _midpoints_from_count(start, end, n_mid): - width = (end - start) / n_mid - return np.linspace(start + width / 2, end - width / 2, n_mid, dtype=np.float32) + return cell_centered_coordinate(start, end, n_mid).numpy() def create_test_data_on_disk( From d4028ae231edf8b2d0c2e0e41af16b2f03970d88 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 12 Jun 2026 15:51:25 -0700 Subject: [PATCH 3/9] Fix comments --- fme/downscaling/models.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 67aca5e46..0c9b1fc46 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -774,14 +774,19 @@ def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel": Return a new model with full_fine_coords and static_inputs rolled to match coarse_lon's longitude convention, sharing the network weights. + Models with rolled longitude are useful when inference region crosses + the prime meridian, where we want to ensure we can grab proper slices + from the static inputs and provide the right coordinates for the outputs. + Returns self unchanged when coarse_lon does not cross the prime meridian. - The new model is built through the constructor (rather than a shallow copy) - so its coords are re-validated and derived state is rebuilt fresh; the raw - module is unwrapped and passed so __init__ re-wraps it exactly once. """ if not coords_require_lon_roll(coarse_lon): return self roll_amount, lon_start = self._lon_roll_amount(coarse_lon) + + # The new model is built through the constructor (rather than a shallow copy) + # so its coords are re-validated and derived state is rebuilt fresh; the raw + # module is unwrapped and passed so __init__ re-wraps it exactly once. return DiffusionModel( config=self.config, module=self.module.module, From 3ee4ee62df54c50f76e8cf22550a37b36ffd293a Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 12 Jun 2026 15:52:24 -0700 Subject: [PATCH 4/9] Remove redundant moe test from test_models and use shared coordinate --- .../predictors/test_serial_denoising.py | 63 ++++++-- fme/downscaling/test_models.py | 137 +++--------------- 2 files changed, 67 insertions(+), 133 deletions(-) diff --git a/fme/downscaling/predictors/test_serial_denoising.py b/fme/downscaling/predictors/test_serial_denoising.py index 4475f4ebd..78d10a832 100644 --- a/fme/downscaling/predictors/test_serial_denoising.py +++ b/fme/downscaling/predictors/test_serial_denoising.py @@ -5,7 +5,7 @@ import torch from fme.core.coordinates import LatLonCoordinates -from fme.downscaling.data import StaticInput, StaticInputs +from fme.downscaling.data import StaticInputs from fme.downscaling.models import CheckpointModelConfig from fme.downscaling.predictors.serial_denoising import ( DenoisingExpertCheckpointConfig, @@ -18,7 +18,7 @@ ) from fme.downscaling.test_models import ( _get_diffusion_model, - _get_monotonic_coordinate, + _make_global_fine_coords_and_static, make_fine_coords, make_paired_batch_data, ) @@ -323,19 +323,54 @@ def test_save_preserves_rename_applied_by_checkpoint_model_config(tmp_path): assert set(reqs.coarse_names) == {"renamed_x"} -def _make_global_fine_coords_and_static(fine_shape: tuple[int, int]): - """Return a global-covering LatLonCoordinates and matching StaticInputs.""" - step = 360 / fine_shape[1] - global_fine_lon = torch.arange(fine_shape[1]) * step + step / 2 - global_fine_lat = _get_monotonic_coordinate(fine_shape[0], stop=fine_shape[0]) - full_fine_coords = LatLonCoordinates(lat=global_fine_lat, lon=global_fine_lon) - static_field = torch.arange( - fine_shape[0] * fine_shape[1], dtype=torch.float32 - ).reshape(*fine_shape) - static_inputs = StaticInputs( - fields=[StaticInput(static_field)], coords=full_fine_coords +def test_denoising_moe_predictor_with_rolled_lon_rolls_all_experts(): + """with_rolled_lon rolls every expert (keeping the shared-grid invariant).""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + + expert0 = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + expert1 = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, ) - return full_fine_coords, static_inputs + predictor = DenoisingMoEPredictor( + experts=[expert0, expert1], + sigma_ranges=[(0.0, 0.5), (0.5, 1.0)], + num_diffusion_generation_steps=2, + churn=0.0, + ) + + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = predictor.with_rolled_lon(coarse_lon) + + # Every expert is a new (rolled) object; _primary stays _experts[0]. + assert rolled._primary is rolled._experts[0] + for rolled_expert, original, source in zip( + rolled._experts, predictor._experts, [expert0, expert1] + ): + assert rolled_expert is not original + # Coords are rolled... + assert rolled_expert.full_fine_coords.lon[0].item() < 0 + # ...but the raw network weights are still shared (fresh wrapper). + assert next(rolled_expert.module.parameters()) is next( + source.module.parameters() + ) + # The sigma dispatcher is rebuilt from the rolled experts, consistent with + # _experts (not left pointing at any pre-roll module). + for entry, rolled_expert in zip(rolled._dispatch_module._entries, rolled._experts): + assert entry[2] is rolled_expert.module + + # No-roll case returns self + non_neg_lon = torch.tensor([0.0, 5.0, 10.0, 15.0], dtype=torch.float32) + assert predictor.with_rolled_lon(non_neg_lon) is predictor def test_denoising_moe_predictor_rejects_mismatched_expert_grids(): diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index e2d46fc79..a609a6c8e 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -100,10 +100,6 @@ def make_batch_data( return BatchData(data=data, time=time, latlon_coordinates=latlon) -def _get_monotonic_coordinate(size: int, stop: float) -> torch.Tensor: - return cell_centered_coordinate(0.0, stop, size) - - def make_paired_batch_data( coarse_shape: tuple[int, int], fine_shape: tuple[int, int], @@ -114,10 +110,10 @@ def make_paired_batch_data( """ lat_c, lon_c = coarse_shape lat_f, lon_f = fine_shape - fine_lat = _get_monotonic_coordinate(lat_f, stop=lat_f) - fine_lon = _get_monotonic_coordinate(lon_f, stop=lon_f) - coarse_lat = _get_monotonic_coordinate(lat_c, stop=lat_f) - coarse_lon = _get_monotonic_coordinate(lon_c, stop=lon_f) + fine_lat = cell_centered_coordinate(0.0, lat_f, n=lat_f) + fine_lon = cell_centered_coordinate(0.0, lon_f, n=lon_f) + coarse_lat = cell_centered_coordinate(0.0, lat_f, n=lat_c) + coarse_lon = cell_centered_coordinate(0.0, lon_f, n=lon_c) fine = make_batch_data((batch_size, lat_f, lon_f), fine_lat, fine_lon) coarse = make_batch_data((batch_size, lat_c, lon_c), coarse_lat, coarse_lon) return PairedBatchData(fine=fine, coarse=coarse) @@ -127,8 +123,8 @@ def make_fine_coords(fine_shape: tuple[int, int]) -> LatLonCoordinates: """Create LatLonCoordinates with proper monotonic coordinates for given shape.""" lat_size, lon_size = fine_shape return LatLonCoordinates( - lat=_get_monotonic_coordinate(lat_size, stop=lat_size), - lon=_get_monotonic_coordinate(lon_size, stop=lon_size), + lat=cell_centered_coordinate(0.0, lat_size, n=lat_size), + lon=cell_centered_coordinate(0.0, lon_size, n=lon_size), ) @@ -376,8 +372,8 @@ def test_DiffusionModel_generate_on_batch_no_target(): batch_size = 2 n_generated_samples = 2 - coarse_lat = _get_monotonic_coordinate(coarse_shape[0], stop=fine_shape[0]) - coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + coarse_lat = cell_centered_coordinate(0.0, fine_shape[0], n=coarse_shape[0]) + coarse_lon = cell_centered_coordinate(0.0, fine_shape[1], n=coarse_shape[1]) coarse_batch = make_batch_data((batch_size, *coarse_shape), coarse_lat, coarse_lon) samples = model.generate_on_batch_no_target( @@ -418,8 +414,8 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): for alternative_input_shape in [(8, 8), (32, 32)]: fine_shape = tuple(dim * downscale_factor for dim in alternative_input_shape) alt_y, alt_x = alternative_input_shape - coarse_lat = _get_monotonic_coordinate(alt_y, stop=alt_y * downscale_factor) - coarse_lon = _get_monotonic_coordinate(alt_x, stop=alt_x * downscale_factor) + coarse_lat = cell_centered_coordinate(0.0, alt_y * downscale_factor, n=alt_y) + coarse_lon = cell_centered_coordinate(0.0, alt_x * downscale_factor, n=alt_x) coarse_batch = make_batch_data( (batch_size, *alternative_input_shape), coarse_lat, coarse_lon ) @@ -517,8 +513,8 @@ def test_get_fine_coords_for_batch(): ) # Build a batch covering a spatial patch: middle 4 coarse lats and 8 coarse lons. - full_coarse_lat = _get_monotonic_coordinate(coarse_shape[0], stop=fine_shape[0]) - full_coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + full_coarse_lat = cell_centered_coordinate(0.0, fine_shape[0], n=coarse_shape[0]) + full_coarse_lon = cell_centered_coordinate(0.0, fine_shape[1], n=coarse_shape[1]) patch_coarse_lat = full_coarse_lat[2:6].tolist() # [5, 7, 9, 11] patch_coarse_lon = full_coarse_lon[4:12].tolist() # [9, 11, ..., 23] batch = make_batch_data((2, 4, 8), patch_coarse_lat, patch_coarse_lon) @@ -534,7 +530,7 @@ def test_get_fine_coords_for_batch(): def _make_global_fine_coords_and_static(fine_shape: tuple[int, int]): """Return a global-covering LatLonCoordinates and matching StaticInputs.""" global_fine_lon = cell_centered_coordinate(0.0, 360.0, fine_shape[1]) - global_fine_lat = _get_monotonic_coordinate(fine_shape[0], stop=fine_shape[0]) + global_fine_lat = cell_centered_coordinate(0.0, fine_shape[0], n=fine_shape[0]) full_fine_coords = LatLonCoordinates(lat=global_fine_lat, lon=global_fine_lon) static_field = torch.arange( fine_shape[0] * fine_shape[1], dtype=torch.float32 @@ -556,7 +552,7 @@ def test_with_rolled_lon_no_roll_returns_same(): full_fine_coords=static_inputs.coords, static_inputs=static_inputs, ) - coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + coarse_lon = cell_centered_coordinate(0.0, fine_shape[1], n=coarse_shape[1]) assert model.with_rolled_lon(coarse_lon) is model @@ -587,30 +583,12 @@ def test_with_rolled_lon_shifts_coords_and_shares_weights(): rolled.static_inputs.fields[0].data, model.static_inputs.fields[0].data ) - -def test_with_rolled_lon_is_idempotent(): - """Rolling an already-rolled model with the same domain is a no-op. - - Guards against accidental double-rolling: the second roll resolves to 0 - (full rotation), so the twice-rolled model has identical coords and static - inputs to the once-rolled one. - """ - coarse_shape = (8, 16) - fine_shape = (16, 32) - full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) - model = _get_diffusion_model( - coarse_shape=coarse_shape, - downscale_factor=2, - full_fine_coords=full_fine_coords, - static_inputs=static_inputs, - ) - - coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) - rolled = model.with_rolled_lon(coarse_lon) + # Guards against accidental double-rolling: the second roll resolves to 0 + # (full rotation), so the twice-rolled model has identical coords and static + # inputs to the once-rolled one. twice = rolled.with_rolled_lon(coarse_lon) - assert torch.equal(twice.full_fine_coords.lon, rolled.full_fine_coords.lon) - assert rolled.static_inputs is not None and twice.static_inputs is not None + assert twice.static_inputs is not None assert torch.equal( twice.static_inputs.fields[0].data, rolled.static_inputs.fields[0].data ) @@ -654,85 +632,6 @@ def test_roll_diffusion_model_keeps_fine_aligned_to_coarse_cells(): assert torch.allclose(recentered, torch.tensor(coarse_lon), atol=1e-3) -def test_denoising_moe_predictor_with_rolled_lon_rolls_all_experts(): - """with_rolled_lon rolls every expert (keeping the shared-grid invariant).""" - from fme.downscaling.predictors.serial_denoising import DenoisingMoEPredictor - - coarse_shape = (8, 16) - fine_shape = (16, 32) - full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) - - expert0 = _get_diffusion_model( - coarse_shape=coarse_shape, - downscale_factor=2, - full_fine_coords=full_fine_coords, - static_inputs=static_inputs, - ) - expert1 = _get_diffusion_model( - coarse_shape=coarse_shape, - downscale_factor=2, - full_fine_coords=full_fine_coords, - static_inputs=static_inputs, - ) - predictor = DenoisingMoEPredictor( - experts=[expert0, expert1], - sigma_ranges=[(0.0, 0.5), (0.5, 1.0)], - num_diffusion_generation_steps=2, - churn=0.0, - ) - - coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) - rolled = predictor.with_rolled_lon(coarse_lon) - - # Every expert is a new (rolled) object; _primary stays _experts[0]. - assert rolled._primary is rolled._experts[0] - for rolled_expert, original, source in zip( - rolled._experts, predictor._experts, [expert0, expert1] - ): - assert rolled_expert is not original - # Coords are rolled... - assert rolled_expert.full_fine_coords.lon[0].item() < 0 - # ...but the raw network weights are still shared (fresh wrapper). - assert next(rolled_expert.module.parameters()) is next( - source.module.parameters() - ) - # The sigma dispatcher is rebuilt from the rolled experts, consistent with - # _experts (not left pointing at any pre-roll module). - for entry, rolled_expert in zip(rolled._dispatch_module._entries, rolled._experts): - assert entry[2] is rolled_expert.module - - # No-roll case returns self - non_neg_lon = torch.tensor([0.0, 5.0, 10.0, 15.0], dtype=torch.float32) - assert predictor.with_rolled_lon(non_neg_lon) is predictor - - -def test_denoising_moe_predictor_rejects_mismatched_expert_grids(): - """Experts on different grids are rejected at construction (shared-grid).""" - from fme.downscaling.predictors.serial_denoising import DenoisingMoEPredictor - - fine_coords_a, static_a = _make_global_fine_coords_and_static((16, 32)) - fine_coords_b, static_b = _make_global_fine_coords_and_static((16, 16)) - expert_a = _get_diffusion_model( - coarse_shape=(8, 16), - downscale_factor=2, - full_fine_coords=fine_coords_a, - static_inputs=static_a, - ) - expert_b = _get_diffusion_model( - coarse_shape=(8, 8), - downscale_factor=2, - full_fine_coords=fine_coords_b, - static_inputs=static_b, - ) - with pytest.raises(ValueError, match="metadata"): - DenoisingMoEPredictor( - experts=[expert_a, expert_b], - sigma_ranges=[(0.0, 0.5), (0.5, 1.0)], - num_diffusion_generation_steps=2, - churn=0.0, - ) - - def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( From 4326050d135e03cae90f8ea6fa0fe1996aea6253 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 12 Jun 2026 16:00:10 -0700 Subject: [PATCH 5/9] comment clean up --- fme/downscaling/models.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 0c9b1fc46..680030549 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -754,18 +754,13 @@ def _lon_roll_amount(self, coarse_lon: torch.Tensor) -> tuple[int, float]: """ Number of positions to roll the fine grid (and the lon_start it aligns to) so the fine cells stay aligned to coarse_lon's coarse cells. - - coarse_lon is the actual coarse domain grid, so it already carries the - convention to align to. The roll is anchored on the western coarse-cell - *edge* (half a coarse cell below coarse_lon.min(), which is a cell *center*) - so the fine grid rolls by a whole number of coarse cells and its cells stay - aligned to the coarse cells. Anchoring on the center instead would roll by an - extra downscale_factor // 2 fine points, splitting the boundary coarse cell - across the seam. """ lon_start = float(coarse_lon.min()) fine_lon = self.full_fine_coords.lon fine_spacing = float(fine_lon[1] - fine_lon[0]) + # Anchor on the western coarse-cell *edge* (not its center, lon_start) so + # the roll is a whole number of coarse cells; anchoring on the center + # would split the boundary coarse cell across the seam. western_edge = lon_start - self.downscale_factor * fine_spacing / 2.0 return find_roll_anchor(fine_lon, western_edge), lon_start From 5d3d75b4bce01f48d6afd28389a2ca804acb19f0 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 12 Jun 2026 16:07:28 -0700 Subject: [PATCH 6/9] Add value check for static input data --- fme/downscaling/test_models.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index a609a6c8e..2bf464fe8 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -574,13 +574,23 @@ def test_with_rolled_lon_shifts_coords_and_shares_weights(): # Reconstruction wraps a fresh module around the SAME raw weights. assert rolled.module is not model.module assert next(rolled.module.parameters()) is next(model.module.parameters()) - assert not torch.equal(rolled.full_fine_coords.lon, model.full_fine_coords.lon) assert torch.all(rolled.full_fine_coords.lon[1:] > rolled.full_fine_coords.lon[:-1]) assert rolled.full_fine_coords.lon[0].item() < 0 - assert rolled.static_inputs is not None - # Compare against model.static_inputs (on-device) rather than the CPU-side original - assert not torch.equal( - rolled.static_inputs.fields[0].data, model.static_inputs.fields[0].data + + # Value-level check that coords and static data rolled together, lon-only. + # The roll amount is recovered from the coords; the static field encodes its + # original flat index, so a coord/data roll mismatch or an accidental lat + # roll changes values. Compare against model.static_inputs (on-device) + # rather than the CPU-side original. + orig_lon = model.full_fine_coords.lon + rolled_lon = rolled.full_fine_coords.lon + roll = int(torch.argmin(torch.abs(orig_lon - rolled_lon[0] % 360.0)).item()) + assert roll > 0 + assert torch.allclose(rolled_lon % 360.0, torch.roll(orig_lon, -roll) % 360.0) + assert model.static_inputs is not None and rolled.static_inputs is not None + assert torch.equal( + rolled.static_inputs.fields[0].data, + torch.roll(model.static_inputs.fields[0].data, -roll, dims=-1), ) # Guards against accidental double-rolling: the second roll resolves to 0 From d3fc740b0d07bf3a6f2c1d6f5e2282855426d461 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 12 Jun 2026 16:22:13 -0700 Subject: [PATCH 7/9] More documentation tweaks --- fme/downscaling/models.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 680030549..aa2d20f11 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -754,6 +754,9 @@ def _lon_roll_amount(self, coarse_lon: torch.Tensor) -> tuple[int, float]: """ Number of positions to roll the fine grid (and the lon_start it aligns to) so the fine cells stay aligned to coarse_lon's coarse cells. + + Assumes a uniformly spaced fine grid; validated by roll_lon_coords when + the roll is applied. """ lon_start = float(coarse_lon.min()) fine_lon = self.full_fine_coords.lon @@ -774,6 +777,10 @@ def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel": from the static inputs and provide the right coordinates for the outputs. Returns self unchanged when coarse_lon does not cross the prime meridian. + + Intended for inference only: rebuilding wraps the module in a second + DistributedDataParallel under torch distributed, which is a hazard for + gradient-synchronized training. """ if not coords_require_lon_roll(coarse_lon): return self From c0603989b71970cf9349026ea389bec35de335f0 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Fri, 12 Jun 2026 16:39:33 -0700 Subject: [PATCH 8/9] Comment clenaup --- fme/downscaling/test_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 2bf464fe8..1c4fd75f9 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -580,8 +580,7 @@ def test_with_rolled_lon_shifts_coords_and_shares_weights(): # Value-level check that coords and static data rolled together, lon-only. # The roll amount is recovered from the coords; the static field encodes its # original flat index, so a coord/data roll mismatch or an accidental lat - # roll changes values. Compare against model.static_inputs (on-device) - # rather than the CPU-side original. + # roll changes values. orig_lon = model.full_fine_coords.lon rolled_lon = rolled.full_fine_coords.lon roll = int(torch.argmin(torch.abs(orig_lon - rolled_lon[0] % 360.0)).item()) From c2b6aaac41e8db5a46bce4ac12dad8004432cb39 Mon Sep 17 00:00:00 2001 From: Andre Perkins Date: Mon, 15 Jun 2026 11:34:50 -0700 Subject: [PATCH 9/9] Add flag for rolled model, checked at train time to ensure false --- fme/downscaling/models.py | 12 +++++++++++- fme/downscaling/test_models.py | 27 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index aa2d20f11..5f2d90371 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -402,6 +402,8 @@ def __init__( self._channel_axis = -3 self.full_fine_coords = full_fine_coords.to(get_device()) self.static_inputs = static_inputs.to_device() if static_inputs else None + # Set True only by with_rolled_lon (inference only); guards train_on_batch. + self._is_longitude_rolled = False self._loss_weight_tensor = _build_variable_loss_weight_tensor( config.loss_weights.output_channels, self.out_names ) @@ -495,6 +497,12 @@ def train_on_batch( optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" + if self._is_longitude_rolled: + raise RuntimeError( + "Cannot train a longitude-rolled DiffusionModel. with_rolled_lon " + "is intended for inference only; rolled models share weights and " + "would corrupt gradient synchronization under distributed training." + ) _static_inputs = self._subset_static_if_available(batch.coarse) coarse, fine = batch.coarse.data, batch.fine.data inputs_norm = self._get_input_from_coarse(coarse, _static_inputs) @@ -789,7 +797,7 @@ def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel": # The new model is built through the constructor (rather than a shallow copy) # so its coords are re-validated and derived state is rebuilt fresh; the raw # module is unwrapped and passed so __init__ re-wraps it exactly once. - return DiffusionModel( + rolled = DiffusionModel( config=self.config, module=self.module.module, normalizer=self.normalizer, @@ -809,6 +817,8 @@ def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel": else None ), ) + rolled._is_longitude_rolled = True + return rolled @dataclasses.dataclass diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 1c4fd75f9..348b0ac2d 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -553,6 +553,7 @@ def test_with_rolled_lon_no_roll_returns_same(): static_inputs=static_inputs, ) coarse_lon = cell_centered_coordinate(0.0, fine_shape[1], n=coarse_shape[1]) + assert model._is_longitude_rolled is False assert model.with_rolled_lon(coarse_lon) is model @@ -571,6 +572,10 @@ def test_with_rolled_lon_shifts_coords_and_shares_weights(): coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) rolled = model.with_rolled_lon(coarse_lon) + # The rolled model is flagged inference-only; the original is left unmarked. + assert model._is_longitude_rolled is False + assert rolled._is_longitude_rolled is True + # Reconstruction wraps a fresh module around the SAME raw weights. assert rolled.module is not model.module assert next(rolled.module.parameters()) is next(model.module.parameters()) @@ -603,6 +608,28 @@ def test_with_rolled_lon_shifts_coords_and_shares_weights(): ) +def test_train_on_batch_raises_for_rolled_model(): + """A longitude-rolled model is inference only and must reject training.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + batch_size = 2 + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = model.with_rolled_lon(coarse_lon) + + batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) + optimization = OptimizationConfig().build(modules=[rolled.module], max_epochs=2) + with pytest.raises(RuntimeError, match="longitude-rolled"): + rolled.train_on_batch(batch, optimization) + + def test_roll_diffusion_model_keeps_fine_aligned_to_coarse_cells(): """A seam-crossing domain must roll the fine grid by whole coarse cells.