diff --git a/fme/downscaling/data/__init__.py b/fme/downscaling/data/__init__.py index e11d8f662..beea9a3a3 100644 --- a/fme/downscaling/data/__init__.py +++ b/fme/downscaling/data/__init__.py @@ -22,5 +22,6 @@ coords_require_lon_roll, expand_and_fold_tensor, find_roll_anchor, + roll_lon_coords, scale_tuple, ) diff --git a/fme/downscaling/data/test_config.py b/fme/downscaling/data/test_config.py index 487d8b25b..181917c90 100644 --- a/fme/downscaling/data/test_config.py +++ b/fme/downscaling/data/test_config.py @@ -20,7 +20,7 @@ ) from fme.downscaling.data.utils import ClosedInterval from fme.downscaling.requirements import DataRequirements -from fme.downscaling.test_utils import data_paths_helper +from fme.downscaling.test_utils import cell_centered_coordinate, data_paths_helper from fme.downscaling.typing_ import FineResCoarseResPair @@ -34,14 +34,8 @@ def _write_global_nc(path: Path, n_lat: int, n_lon: int, num_timesteps: int) -> cftime.DatetimeProlepticGregorian(2000, 1, 1) + datetime.timedelta(days=i) for i in range(num_timesteps) ] - lon_spacing = 360.0 / n_lon - lat_spacing = 8.0 / n_lat - lons = np.array( - [lon_spacing / 2 + i * lon_spacing for i in range(n_lon)], dtype=np.float32 - ) - lats = np.array( - [lat_spacing / 2 + i * lat_spacing for i in range(n_lat)], dtype=np.float32 - ) + lons = cell_centered_coordinate(0.0, 360.0, n_lon).numpy() + lats = cell_centered_coordinate(0.0, 8.0, n_lat).numpy() data = ( np.broadcast_to(lons[None, None, :], (num_timesteps, n_lat, n_lon)) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 782fe5e06..5f2d90371 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -21,7 +21,10 @@ PairedBatchData, StaticInputs, adjust_fine_coord_range, + coords_require_lon_roll, + find_roll_anchor, load_coords_from_path, + roll_lon_coords, ) from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector @@ -399,6 +402,8 @@ def __init__( self._channel_axis = -3 self.full_fine_coords = full_fine_coords.to(get_device()) self.static_inputs = static_inputs.to_device() if static_inputs else None + # Set True only by with_rolled_lon (inference only); guards train_on_batch. + self._is_longitude_rolled = False self._loss_weight_tensor = _build_variable_loss_weight_tensor( config.loss_weights.output_channels, self.out_names ) @@ -492,6 +497,12 @@ def train_on_batch( optimizer: Optimization | NullOptimization, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" + if self._is_longitude_rolled: + raise RuntimeError( + "Cannot train a longitude-rolled DiffusionModel. with_rolled_lon " + "is intended for inference only; rolled models share weights and " + "would corrupt gradient synchronization under distributed training." + ) _static_inputs = self._subset_static_if_available(batch.coarse) coarse, fine = batch.coarse.data, batch.fine.data inputs_norm = self._get_input_from_coarse(coarse, _static_inputs) @@ -747,6 +758,68 @@ def metadata(self): else 0, ) + def _lon_roll_amount(self, coarse_lon: torch.Tensor) -> tuple[int, float]: + """ + Number of positions to roll the fine grid (and the lon_start it aligns to) + so the fine cells stay aligned to coarse_lon's coarse cells. + + Assumes a uniformly spaced fine grid; validated by roll_lon_coords when + the roll is applied. + """ + lon_start = float(coarse_lon.min()) + fine_lon = self.full_fine_coords.lon + fine_spacing = float(fine_lon[1] - fine_lon[0]) + # Anchor on the western coarse-cell *edge* (not its center, lon_start) so + # the roll is a whole number of coarse cells; anchoring on the center + # would split the boundary coarse cell across the seam. + western_edge = lon_start - self.downscale_factor * fine_spacing / 2.0 + return find_roll_anchor(fine_lon, western_edge), lon_start + + def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel": + """ + Return a new model with full_fine_coords and static_inputs rolled to match + coarse_lon's longitude convention, sharing the network weights. + + Models with rolled longitude are useful when inference region crosses + the prime meridian, where we want to ensure we can grab proper slices + from the static inputs and provide the right coordinates for the outputs. + + Returns self unchanged when coarse_lon does not cross the prime meridian. + + Intended for inference only: rebuilding wraps the module in a second + DistributedDataParallel under torch distributed, which is a hazard for + gradient-synchronized training. + """ + if not coords_require_lon_roll(coarse_lon): + return self + roll_amount, lon_start = self._lon_roll_amount(coarse_lon) + + # The new model is built through the constructor (rather than a shallow copy) + # so its coords are re-validated and derived state is rebuilt fresh; the raw + # module is unwrapped and passed so __init__ re-wraps it exactly once. + rolled = DiffusionModel( + config=self.config, + module=self.module.module, + normalizer=self.normalizer, + loss=self.loss, + coarse_shape=self.coarse_shape, + downscale_factor=self.downscale_factor, + sigma_data=self.sigma_data, + full_fine_coords=LatLonCoordinates( + lat=self.full_fine_coords.lat, + lon=roll_lon_coords(self.full_fine_coords.lon, roll_amount, lon_start), + ), + in_names=self.in_names, + out_names=self.out_names, + static_inputs=( + self.static_inputs.roll(roll_amount, lon_start) + if self.static_inputs is not None + else None + ), + ) + rolled._is_longitude_rolled = True + return rolled + @dataclasses.dataclass class _CheckpointModelConfigSelector: diff --git a/fme/downscaling/predictors/serial_denoising.py b/fme/downscaling/predictors/serial_denoising.py index 18b198e64..f52900b54 100644 --- a/fme/downscaling/predictors/serial_denoising.py +++ b/fme/downscaling/predictors/serial_denoising.py @@ -226,6 +226,26 @@ def static_inputs(self) -> StaticInputs | None: def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates: return self._primary.get_fine_coords_for_batch(batch) + def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DenoisingMoEPredictor": + """New predictor with every expert's coords rolled to match coarse_lon. + + All experts are rolled (not just the primary) so the shared-grid invariant + enforced in __init__ still holds -- nothing relies on the non-primary + experts' coordinates being left unrolled. Rebuilt through __init__ so + _dispatch_module is reconstructed from the rolled experts. Returns self + unchanged when no roll is needed. + """ + rolled = [expert.with_rolled_lon(coarse_lon) for expert in self._experts] + if all(r is e for r, e in zip(rolled, self._experts)): + return self + return DenoisingMoEPredictor( + experts=rolled, + sigma_ranges=self._sigma_ranges, + num_diffusion_generation_steps=self._num_diffusion_generation_steps, + churn=self._churn, + expert_renames=self._expert_renames, + ) + @torch.no_grad() def generate( self, diff --git a/fme/downscaling/predictors/test_serial_denoising.py b/fme/downscaling/predictors/test_serial_denoising.py index 4475f4ebd..78d10a832 100644 --- a/fme/downscaling/predictors/test_serial_denoising.py +++ b/fme/downscaling/predictors/test_serial_denoising.py @@ -5,7 +5,7 @@ import torch from fme.core.coordinates import LatLonCoordinates -from fme.downscaling.data import StaticInput, StaticInputs +from fme.downscaling.data import StaticInputs from fme.downscaling.models import CheckpointModelConfig from fme.downscaling.predictors.serial_denoising import ( DenoisingExpertCheckpointConfig, @@ -18,7 +18,7 @@ ) from fme.downscaling.test_models import ( _get_diffusion_model, - _get_monotonic_coordinate, + _make_global_fine_coords_and_static, make_fine_coords, make_paired_batch_data, ) @@ -323,19 +323,54 @@ def test_save_preserves_rename_applied_by_checkpoint_model_config(tmp_path): assert set(reqs.coarse_names) == {"renamed_x"} -def _make_global_fine_coords_and_static(fine_shape: tuple[int, int]): - """Return a global-covering LatLonCoordinates and matching StaticInputs.""" - step = 360 / fine_shape[1] - global_fine_lon = torch.arange(fine_shape[1]) * step + step / 2 - global_fine_lat = _get_monotonic_coordinate(fine_shape[0], stop=fine_shape[0]) - full_fine_coords = LatLonCoordinates(lat=global_fine_lat, lon=global_fine_lon) - static_field = torch.arange( - fine_shape[0] * fine_shape[1], dtype=torch.float32 - ).reshape(*fine_shape) - static_inputs = StaticInputs( - fields=[StaticInput(static_field)], coords=full_fine_coords +def test_denoising_moe_predictor_with_rolled_lon_rolls_all_experts(): + """with_rolled_lon rolls every expert (keeping the shared-grid invariant).""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + + expert0 = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + expert1 = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, ) - return full_fine_coords, static_inputs + predictor = DenoisingMoEPredictor( + experts=[expert0, expert1], + sigma_ranges=[(0.0, 0.5), (0.5, 1.0)], + num_diffusion_generation_steps=2, + churn=0.0, + ) + + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = predictor.with_rolled_lon(coarse_lon) + + # Every expert is a new (rolled) object; _primary stays _experts[0]. + assert rolled._primary is rolled._experts[0] + for rolled_expert, original, source in zip( + rolled._experts, predictor._experts, [expert0, expert1] + ): + assert rolled_expert is not original + # Coords are rolled... + assert rolled_expert.full_fine_coords.lon[0].item() < 0 + # ...but the raw network weights are still shared (fresh wrapper). + assert next(rolled_expert.module.parameters()) is next( + source.module.parameters() + ) + # The sigma dispatcher is rebuilt from the rolled experts, consistent with + # _experts (not left pointing at any pre-roll module). + for entry, rolled_expert in zip(rolled._dispatch_module._entries, rolled._experts): + assert entry[2] is rolled_expert.module + + # No-roll case returns self + non_neg_lon = torch.tensor([0.0, 5.0, 10.0, 15.0], dtype=torch.float32) + assert predictor.with_rolled_lon(non_neg_lon) is predictor def test_denoising_moe_predictor_rejects_mismatched_expert_grids(): diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 244a3c154..348b0ac2d 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -28,6 +28,7 @@ ) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector from fme.downscaling.noise import LogNormalNoiseDistribution +from fme.downscaling.test_utils import cell_centered_coordinate from fme.downscaling.typing_ import FineResCoarseResPair @@ -99,12 +100,6 @@ def make_batch_data( return BatchData(data=data, time=time, latlon_coordinates=latlon) -def _get_monotonic_coordinate(size: int, stop: float) -> torch.Tensor: - bounds = torch.linspace(0, stop, size + 1) - coord = (bounds[:-1] + bounds[1:]) / 2 - return coord - - def make_paired_batch_data( coarse_shape: tuple[int, int], fine_shape: tuple[int, int], @@ -115,10 +110,10 @@ def make_paired_batch_data( """ lat_c, lon_c = coarse_shape lat_f, lon_f = fine_shape - fine_lat = _get_monotonic_coordinate(lat_f, stop=lat_f) - fine_lon = _get_monotonic_coordinate(lon_f, stop=lon_f) - coarse_lat = _get_monotonic_coordinate(lat_c, stop=lat_f) - coarse_lon = _get_monotonic_coordinate(lon_c, stop=lon_f) + fine_lat = cell_centered_coordinate(0.0, lat_f, n=lat_f) + fine_lon = cell_centered_coordinate(0.0, lon_f, n=lon_f) + coarse_lat = cell_centered_coordinate(0.0, lat_f, n=lat_c) + coarse_lon = cell_centered_coordinate(0.0, lon_f, n=lon_c) fine = make_batch_data((batch_size, lat_f, lon_f), fine_lat, fine_lon) coarse = make_batch_data((batch_size, lat_c, lon_c), coarse_lat, coarse_lon) return PairedBatchData(fine=fine, coarse=coarse) @@ -128,8 +123,8 @@ def make_fine_coords(fine_shape: tuple[int, int]) -> LatLonCoordinates: """Create LatLonCoordinates with proper monotonic coordinates for given shape.""" lat_size, lon_size = fine_shape return LatLonCoordinates( - lat=_get_monotonic_coordinate(lat_size, stop=lat_size), - lon=_get_monotonic_coordinate(lon_size, stop=lon_size), + lat=cell_centered_coordinate(0.0, lat_size, n=lat_size), + lon=cell_centered_coordinate(0.0, lon_size, n=lon_size), ) @@ -377,8 +372,8 @@ def test_DiffusionModel_generate_on_batch_no_target(): batch_size = 2 n_generated_samples = 2 - coarse_lat = _get_monotonic_coordinate(coarse_shape[0], stop=fine_shape[0]) - coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + coarse_lat = cell_centered_coordinate(0.0, fine_shape[0], n=coarse_shape[0]) + coarse_lon = cell_centered_coordinate(0.0, fine_shape[1], n=coarse_shape[1]) coarse_batch = make_batch_data((batch_size, *coarse_shape), coarse_lat, coarse_lon) samples = model.generate_on_batch_no_target( @@ -419,8 +414,8 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): for alternative_input_shape in [(8, 8), (32, 32)]: fine_shape = tuple(dim * downscale_factor for dim in alternative_input_shape) alt_y, alt_x = alternative_input_shape - coarse_lat = _get_monotonic_coordinate(alt_y, stop=alt_y * downscale_factor) - coarse_lon = _get_monotonic_coordinate(alt_x, stop=alt_x * downscale_factor) + coarse_lat = cell_centered_coordinate(0.0, alt_y * downscale_factor, n=alt_y) + coarse_lon = cell_centered_coordinate(0.0, alt_x * downscale_factor, n=alt_x) coarse_batch = make_batch_data( (batch_size, *alternative_input_shape), coarse_lat, coarse_lon ) @@ -518,8 +513,8 @@ def test_get_fine_coords_for_batch(): ) # Build a batch covering a spatial patch: middle 4 coarse lats and 8 coarse lons. - full_coarse_lat = _get_monotonic_coordinate(coarse_shape[0], stop=fine_shape[0]) - full_coarse_lon = _get_monotonic_coordinate(coarse_shape[1], stop=fine_shape[1]) + full_coarse_lat = cell_centered_coordinate(0.0, fine_shape[0], n=coarse_shape[0]) + full_coarse_lon = cell_centered_coordinate(0.0, fine_shape[1], n=coarse_shape[1]) patch_coarse_lat = full_coarse_lat[2:6].tolist() # [5, 7, 9, 11] patch_coarse_lon = full_coarse_lon[4:12].tolist() # [9, 11, ..., 23] batch = make_batch_data((2, 4, 8), patch_coarse_lat, patch_coarse_lon) @@ -532,6 +527,147 @@ def test_get_fine_coords_for_batch(): assert torch.allclose(result.lon, expected_lon) +def _make_global_fine_coords_and_static(fine_shape: tuple[int, int]): + """Return a global-covering LatLonCoordinates and matching StaticInputs.""" + global_fine_lon = cell_centered_coordinate(0.0, 360.0, fine_shape[1]) + global_fine_lat = cell_centered_coordinate(0.0, fine_shape[0], n=fine_shape[0]) + full_fine_coords = LatLonCoordinates(lat=global_fine_lat, lon=global_fine_lon) + static_field = torch.arange( + fine_shape[0] * fine_shape[1], dtype=torch.float32 + ).reshape(*fine_shape) + static_inputs = StaticInputs( + fields=[StaticInput(static_field)], coords=full_fine_coords + ) + return full_fine_coords, static_inputs + + +def test_with_rolled_lon_no_roll_returns_same(): + """with_rolled_lon returns the original model when no roll is needed.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + static_inputs = make_static_inputs(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=static_inputs.coords, + static_inputs=static_inputs, + ) + coarse_lon = cell_centered_coordinate(0.0, fine_shape[1], n=coarse_shape[1]) + assert model._is_longitude_rolled is False + assert model.with_rolled_lon(coarse_lon) is model + + +def test_with_rolled_lon_shifts_coords_and_shares_weights(): + """with_rolled_lon: new model with rolled coords, shared network weights.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = model.with_rolled_lon(coarse_lon) + + # The rolled model is flagged inference-only; the original is left unmarked. + assert model._is_longitude_rolled is False + assert rolled._is_longitude_rolled is True + + # Reconstruction wraps a fresh module around the SAME raw weights. + assert rolled.module is not model.module + assert next(rolled.module.parameters()) is next(model.module.parameters()) + assert torch.all(rolled.full_fine_coords.lon[1:] > rolled.full_fine_coords.lon[:-1]) + assert rolled.full_fine_coords.lon[0].item() < 0 + + # Value-level check that coords and static data rolled together, lon-only. + # The roll amount is recovered from the coords; the static field encodes its + # original flat index, so a coord/data roll mismatch or an accidental lat + # roll changes values. + orig_lon = model.full_fine_coords.lon + rolled_lon = rolled.full_fine_coords.lon + roll = int(torch.argmin(torch.abs(orig_lon - rolled_lon[0] % 360.0)).item()) + assert roll > 0 + assert torch.allclose(rolled_lon % 360.0, torch.roll(orig_lon, -roll) % 360.0) + assert model.static_inputs is not None and rolled.static_inputs is not None + assert torch.equal( + rolled.static_inputs.fields[0].data, + torch.roll(model.static_inputs.fields[0].data, -roll, dims=-1), + ) + + # Guards against accidental double-rolling: the second roll resolves to 0 + # (full rotation), so the twice-rolled model has identical coords and static + # inputs to the once-rolled one. + twice = rolled.with_rolled_lon(coarse_lon) + assert torch.equal(twice.full_fine_coords.lon, rolled.full_fine_coords.lon) + assert twice.static_inputs is not None + assert torch.equal( + twice.static_inputs.fields[0].data, rolled.static_inputs.fields[0].data + ) + + +def test_train_on_batch_raises_for_rolled_model(): + """A longitude-rolled model is inference only and must reject training.""" + coarse_shape = (8, 16) + fine_shape = (16, 32) + batch_size = 2 + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=2, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + + coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32) + rolled = model.with_rolled_lon(coarse_lon) + + batch = make_paired_batch_data(coarse_shape, fine_shape, batch_size) + optimization = OptimizationConfig().build(modules=[rolled.module], max_epochs=2) + with pytest.raises(RuntimeError, match="longitude-rolled"): + rolled.train_on_batch(batch, optimization) + + +def test_roll_diffusion_model_keeps_fine_aligned_to_coarse_cells(): + """A seam-crossing domain must roll the fine grid by whole coarse cells. + + The roll is anchored on the western coarse-cell edge, not its center. If it + anchored on the center it would roll an extra downscale_factor // 2 fine + points, leaving no fine margin below the western coarse cell -- which makes + get_fine_coords_for_batch raise -- and splitting that cell across the seam. + """ + coarse_shape = (4, 8) + fine_shape = (16, 32) + factor = 4 + full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape) + model = _get_diffusion_model( + coarse_shape=coarse_shape, + downscale_factor=factor, + full_fine_coords=full_fine_coords, + static_inputs=static_inputs, + ) + + # Four of the eight global 45-degree coarse cells, crossing the 0/360 seam and + # expressed in negative convention (physically 292.5, 337.5 and 22.5, 67.5). + # Coarse-lat centers [6, 10] are interior, leaving fine margin above and below. + coarse_lat = [6.0, 10.0] + coarse_lon = [-67.5, -22.5, 22.5, 67.5] + batch = make_batch_data( + (1, len(coarse_lat), len(coarse_lon)), coarse_lat, coarse_lon + ) + + rolled = model.with_rolled_lon(torch.tensor(coarse_lon, dtype=torch.float32)) + # Anchoring on the cell center would leave no margin and raise here. + fine_coords = rolled.get_fine_coords_for_batch(batch) + + # Each coarse cell is covered by exactly `factor` fine cells whose mean is the + # coarse-cell center -- i.e. the fine grid stayed aligned to the coarse cells. + recentered = fine_coords.lon.reshape(len(coarse_lon), factor).mean(dim=1).cpu() + assert torch.allclose(recentered, torch.tensor(coarse_lon), atol=1e-3) + + def test_checkpoint_config_topography_raises(): with pytest.raises(ValueError): CheckpointModelConfig( diff --git a/fme/downscaling/test_utils.py b/fme/downscaling/test_utils.py index 4f34a4451..b66c827bd 100644 --- a/fme/downscaling/test_utils.py +++ b/fme/downscaling/test_utils.py @@ -3,14 +3,24 @@ import cftime import numpy as np +import torch import xarray as xr from fme.downscaling.typing_ import FineResCoarseResPair +def cell_centered_coordinate(start: float, end: float, n: int) -> torch.Tensor: + """Centers of ``n`` equal-width cells spanning ``[start, end]``. + + Shared primitive for the downscaling tests' coordinate construction + (longitude rolling in particular relies on cell-centered grids). + """ + bounds = torch.linspace(start, end, n + 1) + return (bounds[:-1] + bounds[1:]) / 2 + + def _midpoints_from_count(start, end, n_mid): - width = (end - start) / n_mid - return np.linspace(start + width / 2, end - width / 2, n_mid, dtype=np.float32) + return cell_centered_coordinate(start, end, n_mid).numpy() def create_test_data_on_disk(