Skip to content
1 change: 1 addition & 0 deletions fme/downscaling/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
coords_require_lon_roll,
expand_and_fold_tensor,
find_roll_anchor,
roll_lon_coords,
scale_tuple,
)
12 changes: 3 additions & 9 deletions fme/downscaling/data/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from fme.downscaling.data.utils import ClosedInterval
from fme.downscaling.requirements import DataRequirements
from fme.downscaling.test_utils import data_paths_helper
from fme.downscaling.test_utils import cell_centered_coordinate, data_paths_helper
from fme.downscaling.typing_ import FineResCoarseResPair


Expand All @@ -34,14 +34,8 @@ def _write_global_nc(path: Path, n_lat: int, n_lon: int, num_timesteps: int) ->
cftime.DatetimeProlepticGregorian(2000, 1, 1) + datetime.timedelta(days=i)
for i in range(num_timesteps)
]
lon_spacing = 360.0 / n_lon
lat_spacing = 8.0 / n_lat
lons = np.array(
[lon_spacing / 2 + i * lon_spacing for i in range(n_lon)], dtype=np.float32
)
lats = np.array(
[lat_spacing / 2 + i * lat_spacing for i in range(n_lat)], dtype=np.float32
)
lons = cell_centered_coordinate(0.0, 360.0, n_lon).numpy()
lats = cell_centered_coordinate(0.0, 8.0, n_lat).numpy()

data = (
np.broadcast_to(lons[None, None, :], (num_timesteps, n_lat, n_lon))
Expand Down
73 changes: 73 additions & 0 deletions fme/downscaling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
PairedBatchData,
StaticInputs,
adjust_fine_coord_range,
coords_require_lon_roll,
find_roll_anchor,
load_coords_from_path,
roll_lon_coords,
)
from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate
from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector
Expand Down Expand Up @@ -399,6 +402,8 @@ def __init__(
self._channel_axis = -3
self.full_fine_coords = full_fine_coords.to(get_device())
self.static_inputs = static_inputs.to_device() if static_inputs else None
# Set True only by with_rolled_lon (inference only); guards train_on_batch.
self._is_longitude_rolled = False
self._loss_weight_tensor = _build_variable_loss_weight_tensor(
config.loss_weights.output_channels, self.out_names
)
Expand Down Expand Up @@ -492,6 +497,12 @@ def train_on_batch(
optimizer: Optimization | NullOptimization,
) -> ModelOutputs:
"""Performs a denoising training step on a batch of data."""
if self._is_longitude_rolled:
raise RuntimeError(
"Cannot train a longitude-rolled DiffusionModel. with_rolled_lon "
"is intended for inference only; rolled models share weights and "
"would corrupt gradient synchronization under distributed training."
)
_static_inputs = self._subset_static_if_available(batch.coarse)
coarse, fine = batch.coarse.data, batch.fine.data
inputs_norm = self._get_input_from_coarse(coarse, _static_inputs)
Expand Down Expand Up @@ -747,6 +758,68 @@ def metadata(self):
else 0,
)

def _lon_roll_amount(self, coarse_lon: torch.Tensor) -> tuple[int, float]:
"""
Number of positions to roll the fine grid (and the lon_start it aligns to)
so the fine cells stay aligned to coarse_lon's coarse cells.

Assumes a uniformly spaced fine grid; validated by roll_lon_coords when
the roll is applied.
"""
lon_start = float(coarse_lon.min())
fine_lon = self.full_fine_coords.lon
fine_spacing = float(fine_lon[1] - fine_lon[0])
# Anchor on the western coarse-cell *edge* (not its center, lon_start) so
# the roll is a whole number of coarse cells; anchoring on the center
# would split the boundary coarse cell across the seam.
western_edge = lon_start - self.downscale_factor * fine_spacing / 2.0
return find_roll_anchor(fine_lon, western_edge), lon_start

def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DiffusionModel":
"""
Return a new model with full_fine_coords and static_inputs rolled to match
coarse_lon's longitude convention, sharing the network weights.

Models with rolled longitude are useful when inference region crosses
the prime meridian, where we want to ensure we can grab proper slices
from the static inputs and provide the right coordinates for the outputs.

Returns self unchanged when coarse_lon does not cross the prime meridian.

Intended for inference only: rebuilding wraps the module in a second

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have to be tackled if we want to train with patches across the meridian.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just do evaluations with patches crossing the meridian first to see if this gap in training distribution is an issue or not.

DistributedDataParallel under torch distributed, which is a hazard for
gradient-synchronized training.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can an attribute _is_longitude_rolled be added to the model so that an assertion at training time would make the source of this error would be clear if someone tried to train with a checkpoint that had rolled coords?

"""
if not coords_require_lon_roll(coarse_lon):
return self
roll_amount, lon_start = self._lon_roll_amount(coarse_lon)

# The new model is built through the constructor (rather than a shallow copy)
# so its coords are re-validated and derived state is rebuilt fresh; the raw
# module is unwrapped and passed so __init__ re-wraps it exactly once.
rolled = DiffusionModel(
config=self.config,
module=self.module.module,
normalizer=self.normalizer,
loss=self.loss,
coarse_shape=self.coarse_shape,
downscale_factor=self.downscale_factor,
sigma_data=self.sigma_data,
full_fine_coords=LatLonCoordinates(
lat=self.full_fine_coords.lat,
lon=roll_lon_coords(self.full_fine_coords.lon, roll_amount, lon_start),
),
in_names=self.in_names,
out_names=self.out_names,
static_inputs=(
self.static_inputs.roll(roll_amount, lon_start)
if self.static_inputs is not None
else None
),
)
rolled._is_longitude_rolled = True
return rolled


@dataclasses.dataclass
class _CheckpointModelConfigSelector:
Expand Down
20 changes: 20 additions & 0 deletions fme/downscaling/predictors/serial_denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,26 @@ def static_inputs(self) -> StaticInputs | None:
def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates:
return self._primary.get_fine_coords_for_batch(batch)

def with_rolled_lon(self, coarse_lon: torch.Tensor) -> "DenoisingMoEPredictor":
"""New predictor with every expert's coords rolled to match coarse_lon.

All experts are rolled (not just the primary) so the shared-grid invariant
enforced in __init__ still holds -- nothing relies on the non-primary
experts' coordinates being left unrolled. Rebuilt through __init__ so
_dispatch_module is reconstructed from the rolled experts. Returns self
unchanged when no roll is needed.
"""
rolled = [expert.with_rolled_lon(coarse_lon) for expert in self._experts]
if all(r is e for r, e in zip(rolled, self._experts)):
return self
return DenoisingMoEPredictor(
experts=rolled,
sigma_ranges=self._sigma_ranges,
num_diffusion_generation_steps=self._num_diffusion_generation_steps,
churn=self._churn,
expert_renames=self._expert_renames,
)

@torch.no_grad()
def generate(
self,
Expand Down
63 changes: 49 additions & 14 deletions fme/downscaling/predictors/test_serial_denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from fme.core.coordinates import LatLonCoordinates
from fme.downscaling.data import StaticInput, StaticInputs
from fme.downscaling.data import StaticInputs
from fme.downscaling.models import CheckpointModelConfig
from fme.downscaling.predictors.serial_denoising import (
DenoisingExpertCheckpointConfig,
Expand All @@ -18,7 +18,7 @@
)
from fme.downscaling.test_models import (
_get_diffusion_model,
_get_monotonic_coordinate,
_make_global_fine_coords_and_static,
make_fine_coords,
make_paired_batch_data,
)
Expand Down Expand Up @@ -323,19 +323,54 @@ def test_save_preserves_rename_applied_by_checkpoint_model_config(tmp_path):
assert set(reqs.coarse_names) == {"renamed_x"}


def _make_global_fine_coords_and_static(fine_shape: tuple[int, int]):
"""Return a global-covering LatLonCoordinates and matching StaticInputs."""
step = 360 / fine_shape[1]
global_fine_lon = torch.arange(fine_shape[1]) * step + step / 2
global_fine_lat = _get_monotonic_coordinate(fine_shape[0], stop=fine_shape[0])
full_fine_coords = LatLonCoordinates(lat=global_fine_lat, lon=global_fine_lon)
static_field = torch.arange(
fine_shape[0] * fine_shape[1], dtype=torch.float32
).reshape(*fine_shape)
static_inputs = StaticInputs(
fields=[StaticInput(static_field)], coords=full_fine_coords
def test_denoising_moe_predictor_with_rolled_lon_rolls_all_experts():
"""with_rolled_lon rolls every expert (keeping the shared-grid invariant)."""
coarse_shape = (8, 16)
fine_shape = (16, 32)
full_fine_coords, static_inputs = _make_global_fine_coords_and_static(fine_shape)

expert0 = _get_diffusion_model(
coarse_shape=coarse_shape,
downscale_factor=2,
full_fine_coords=full_fine_coords,
static_inputs=static_inputs,
)
expert1 = _get_diffusion_model(
coarse_shape=coarse_shape,
downscale_factor=2,
full_fine_coords=full_fine_coords,
static_inputs=static_inputs,
)
return full_fine_coords, static_inputs
predictor = DenoisingMoEPredictor(
experts=[expert0, expert1],
sigma_ranges=[(0.0, 0.5), (0.5, 1.0)],
num_diffusion_generation_steps=2,
churn=0.0,
)

coarse_lon = torch.tensor([-10.0, -5.0, 0.0, 5.0], dtype=torch.float32)
rolled = predictor.with_rolled_lon(coarse_lon)

# Every expert is a new (rolled) object; _primary stays _experts[0].
assert rolled._primary is rolled._experts[0]
for rolled_expert, original, source in zip(
rolled._experts, predictor._experts, [expert0, expert1]
):
assert rolled_expert is not original
# Coords are rolled...
assert rolled_expert.full_fine_coords.lon[0].item() < 0
# ...but the raw network weights are still shared (fresh wrapper).
assert next(rolled_expert.module.parameters()) is next(
source.module.parameters()
)
# The sigma dispatcher is rebuilt from the rolled experts, consistent with
# _experts (not left pointing at any pre-roll module).
for entry, rolled_expert in zip(rolled._dispatch_module._entries, rolled._experts):
assert entry[2] is rolled_expert.module

# No-roll case returns self
non_neg_lon = torch.tensor([0.0, 5.0, 10.0, 15.0], dtype=torch.float32)
assert predictor.with_rolled_lon(non_neg_lon) is predictor


def test_denoising_moe_predictor_rejects_mismatched_expert_grids():
Expand Down
Loading
Loading