Skip to content

Add with_rolled_lon to downscaling models#1237

Merged
frodre merged 11 commits into
mainfrom
feature/lon-roll-model
Jun 15, 2026
Merged

Add with_rolled_lon to downscaling models#1237
frodre merged 11 commits into
mainfrom
feature/lon-roll-model

Conversation

@frodre

@frodre frodre commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

PR 4 of 5 in the prime-meridian longitude stack (PRs 1–3 now merged to main). Lets a model re-express its grid in a seam-crossing coarse domain's longitude convention while sharing the trained network weights, so a single checkpoint can generate over a domain expressed west of 0 or east of 360.

Changes:

  • fme.downscaling.models.DiffusionModel.with_rolled_lon: rebuild the model through its constructor with full_fine_coords and static_inputs rolled to match the coarse grid, anchored on the western coarse-cell edge so the fine grid stays aligned to whole coarse cells; returns self when no roll is needed. Inference-only (rebuilding re-wraps the module under torch distributed).
  • fme.downscaling.predictors.serial_denoising.DenoisingMoEPredictor.with_rolled_lon: roll every expert (preserving the shared-grid invariant) and rebuild so the sigma dispatcher is reconstructed from the rolled experts.
  • fme.downscaling.data exports roll_lon_coords for the model layer.
  • fme.downscaling.test_models: tests for no-roll passthrough, coord shifting with shared weights (including value-level checks that coords and static data roll together, and that a double roll is a no-op), and coarse-cell alignment for a seam-crossing domain. MoE rolling tests live in test_serial_denoising next to the existing grid-validation test.
  • Test cleanup: shared cell_centered_coordinate helper in test_utils replaces per-file midpoint-coordinate constructions (test_models, test_config); removed a test and helper in test_models/test_serial_denoising duplicated from Validate expert grid compatibility in DenoisingMoEPredictor.__init__ #1234.
  • Tests added
  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Base: main (PRs 1–3 of the stack merged)

Stack

PR Head → Base Title Status
#1234 refactor/moe-validate-experts-initmain Validate expert grid compatibility in DenoisingMoEPredictor.__init__ merged
#1235 feature/lon-roll-primitivesmain Add longitude roll primitives merged
#1236 feature/lon-roll-data-layermain Roll seam-crossing longitudes in the data layer merged
#1237 feature/lon-roll-modelmain Add with_rolled_lon to models this PR
#1238 feature/lon-roll-integration → PR4 Roll the model in inference/predict/evaluator open

frodre added a commit that referenced this pull request Jun 8, 2026
…1234)

First in a 5-PR stack adding support for longitude domains that cross
the 0/360 prime meridian in downscaling. This standalone hardening PR
moves expert grid-compatibility validation into the predictor
constructor so every construction path is protected, not just the
config-build path: only the primary expert's coordinates are used for
input prep and output coords, so an expert built on a mismatched grid
would otherwise silently downscale onto the wrong grid.

Changes:
- `fme.downscaling.predictors.serial_denoising`: move
`_validate_experts_compatible` from `DenoisingMoEConfig.build` into
`DenoisingMoEPredictor.__init__`, so it holds for `build`, `from_state`,
and future callers (e.g. `with_rolled_lon`).
- `fme.downscaling.test_models`: add
`test_denoising_moe_predictor_rejects_mismatched_expert_grids`,
constructing the predictor directly with mismatched-grid experts and
asserting it raises.
- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Base: `main`

### Stack

| PR | Head → Base | Title |
|----|-------------|-------|
| [#1234](#1234) |
`refactor/moe-validate-experts-init` → `main` | Validate expert grid
compatibility in `DenoisingMoEPredictor.__init__` |
| [#1235](#1235) |
`feature/lon-roll-primitives` → PR1 | Add longitude roll primitives |
| [#1236](#1236) |
`feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in
the data layer |
| [#1237](#1237) |
`feature/lon-roll-model` → PR3 | Add with_rolled_lon to models |
| [#1238](#1238) |
`feature/lon-roll-integration` → PR4 | Roll the model in
inference/predict/evaluator |
frodre added a commit that referenced this pull request Jun 9, 2026
)

PR 2 of 5 in the prime-meridian longitude stack. Adds the pure
coordinate/data rolling utilities needed to re-express a global grid in
a seam-crossing domain's convention. These have no production callers
yet — later PRs wire them into the data and model layers — so they are
reviewable in isolation with full unit coverage. The interval-based roll
only triggers when an interval actually crosses the seam (`start < 0` or
`stop > 360`), so in-range intervals are a no-op and non-global grids
are left untouched.

Primitives overview (PR #1235)

These primitives are always used as a pair: find_roll_anchor (or
find_roll_anchor_from_interval) computes the roll amount once; callers
pass it to all subsequent roll_lon_coords and roll_lon_data so
coordinates and field tensors shift by the same amount.

Two downstream pathways use them:
- Dataset load — rolls each loaded grid into the user's configured
lon_extent convention (PR #1236)
- Model setup — rolls the model's fine grid to match the incoming coarse
batch's convention (PR #1237)

Changes:
- `fme.downscaling.data.utils`: add `ClosedInterval.finite_values`,
`_requires_lon_roll`, `coords_require_lon_roll`, `find_roll_anchor`,
`find_roll_anchor_from_interval`, `roll_lon_coords`, `roll_lon_data`,
and private helpers `_validate_rollable_lon` and
`_validate_monotonic_lon`.
- `roll_lon_coords` (1-D coordinate tensor) and `roll_lon_data` (N-D
field tensor) form a parallel pair: both apply the same roll amount, but
`roll_lon_coords` also remaps values to keep the result monotonically
increasing, while `roll_lon_data` is a pure cyclic shift. Callers
pre-compute the roll amount once via `find_roll_anchor` and pass it to
both.
- `roll_latlon_coords` is not included here; it operates on a
`LatLonCoordinates` struct rather than a raw tensor and belongs in the
PR that first uses it.
- `fme.downscaling.data` (`__init__`): export the new roll helpers.
- `fme.downscaling.data.test_utils`: unit tests for roll amounts,
seam-crossing conventions, round-trip invertibility,
non-global/non-uniform rejection, and invalid input validation.
- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Base: `refactor/moe-validate-experts-init` (PR 1)

### Stack

| PR | Head → Base | Title |
|----|-------------|-------|
| [#1234](#1234) |
`refactor/moe-validate-experts-init` → `main` | Validate expert grid
compatibility in `DenoisingMoEPredictor.__init__` |
| [#1235](#1235) |
`feature/lon-roll-primitives` → PR1 | Add longitude roll primitives |
| [#1236](#1236) |
`feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in
the data layer |
| [#1237](#1237) |
`feature/lon-roll-model` → PR3 | Add with_rolled_lon to models |
| [#1238](#1238) |
`feature/lon-roll-integration` → PR4 | Roll the model in
inference/predict/evaluator |
@frodre frodre force-pushed the feature/lon-roll-data-layer branch from 7455589 to 806b5cd Compare June 9, 2026 21:25
Base automatically changed from feature/lon-roll-data-layer to main June 12, 2026 20:17
frodre added a commit that referenced this pull request Jun 12, 2026
PR 3 of 5 in the prime-meridian longitude stack. Applies the roll
primitives (PR 2) in the data layer so a longitude interval that crosses
the 0/360 seam can be subset instead of raising `NotImplementedError`.
In-range intervals resolve to a zero roll and behave exactly as before.

Changes:
- `fme.downscaling.data.datasets.HorizontalSubsetDataset`: roll data and
coordinates into the requested interval's convention rather than raising
on wraparound.
- `fme.downscaling.data.config`: extract `_build_aligned_subset_pair`,
which rolls coarse and fine lon coords into the extent's convention
(`_roll_lons_to_extent_convention`) before `adjust_fine_coord_range`, so
fine/coarse subselection stays aligned across the seam.
- `fme.downscaling.data.static.StaticInputs.roll`: roll static fields
and their lon coordinates to match.
- `fme.downscaling.data.test_config`,
`fme.downscaling.data.test_datasets`,
`fme.downscaling.data.test_static`: tests for seam-crossing subsetting
(negative and >360 conventions), fine/coarse scale-factor preservation
across the seam (even and odd downscale factors), end-to-end paired
loader with a seam-crossing extent, and `StaticInputs.roll`.

Note: surfacing the coarse grid convention on
`GriddedData`/`PairedGriddedData` (`coarse_latlon_coords`) was deferred
to the integration PR after review discussion.

- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Base: `feature/lon-roll-primitives` (PR 2)

### Stack

| PR | Head → Base | Title |
|----|-------------|-------|
| [#1234](#1234) |
`refactor/moe-validate-experts-init` → `main` | Validate expert grid
compatibility in `DenoisingMoEPredictor.__init__` |
| [#1235](#1235) |
`feature/lon-roll-primitives` → PR1 | Add longitude roll primitives |
| [#1236](#1236) |
`feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in
the data layer |
| [#1237](#1237) |
`feature/lon-roll-model` → PR3 | Add with_rolled_lon to models |
| [#1238](#1238) |
`feature/lon-roll-integration` → PR4 | Roll the model in
inference/predict/evaluator |
Let models re-express their grid in a seam-crossing coarse domain's
longitude convention while sharing network weights:

- DiffusionModel.with_rolled_lon rebuilds the model through its constructor
  with full_fine_coords and static_inputs rolled to match the coarse grid.
  The roll is anchored on the western coarse-cell edge so the fine grid
  stays aligned to whole coarse cells. Returns self when no roll is needed.
- DenoisingMoEPredictor.with_rolled_lon rolls every expert (preserving the
  shared-grid invariant) and rebuilds so the sigma dispatcher is
  reconstructed from the rolled experts.

Adds tests for no-roll passthrough, coord shifting with shared weights,
idempotency, coarse-cell alignment, and rolling all MoE experts.
@frodre frodre force-pushed the feature/lon-roll-model branch from b77e2de to 39d11b0 Compare June 12, 2026 21:24
Comment thread fme/downscaling/models.py

Returns self unchanged when coarse_lon does not cross the prime meridian.

Intended for inference only: rebuilding wraps the module in a second

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have to be tackled if we want to train with patches across the meridian.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just do evaluations with patches crossing the meridian first to see if this gap in training distribution is an issue or not.

@frodre frodre marked this pull request as ready for review June 12, 2026 23:38

@AnnaKwa AnnaKwa left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minor comment about checking that a rolled model is not used in training

Comment thread fme/downscaling/models.py

Returns self unchanged when coarse_lon does not cross the prime meridian.

Intended for inference only: rebuilding wraps the module in a second

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just do evaluations with patches crossing the meridian first to see if this gap in training distribution is an issue or not.

Comment thread fme/downscaling/models.py

Intended for inference only: rebuilding wraps the module in a second
DistributedDataParallel under torch distributed, which is a hazard for
gradient-synchronized training.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can an attribute _is_longitude_rolled be added to the model so that an assertion at training time would make the source of this error would be clear if someone tried to train with a checkpoint that had rolled coords?

@frodre frodre enabled auto-merge (squash) June 15, 2026 18:38
@frodre frodre merged commit 6bb72fa into main Jun 15, 2026
7 checks passed
@frodre frodre deleted the feature/lon-roll-model branch June 15, 2026 19:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants