Add with_rolled_lon to downscaling models#1237
Conversation
…1234) First in a 5-PR stack adding support for longitude domains that cross the 0/360 prime meridian in downscaling. This standalone hardening PR moves expert grid-compatibility validation into the predictor constructor so every construction path is protected, not just the config-build path: only the primary expert's coordinates are used for input prep and output coords, so an expert built on a mismatched grid would otherwise silently downscale onto the wrong grid. Changes: - `fme.downscaling.predictors.serial_denoising`: move `_validate_experts_compatible` from `DenoisingMoEConfig.build` into `DenoisingMoEPredictor.__init__`, so it holds for `build`, `from_state`, and future callers (e.g. `with_rolled_lon`). - `fme.downscaling.test_models`: add `test_denoising_moe_predictor_rejects_mismatched_expert_grids`, constructing the predictor directly with mismatched-grid experts and asserting it raises. - [x] Tests added - [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated Base: `main` ### Stack | PR | Head → Base | Title | |----|-------------|-------| | [#1234](#1234) | `refactor/moe-validate-experts-init` → `main` | Validate expert grid compatibility in `DenoisingMoEPredictor.__init__` | | [#1235](#1235) | `feature/lon-roll-primitives` → PR1 | Add longitude roll primitives | | [#1236](#1236) | `feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in the data layer | | [#1237](#1237) | `feature/lon-roll-model` → PR3 | Add with_rolled_lon to models | | [#1238](#1238) | `feature/lon-roll-integration` → PR4 | Roll the model in inference/predict/evaluator |
) PR 2 of 5 in the prime-meridian longitude stack. Adds the pure coordinate/data rolling utilities needed to re-express a global grid in a seam-crossing domain's convention. These have no production callers yet — later PRs wire them into the data and model layers — so they are reviewable in isolation with full unit coverage. The interval-based roll only triggers when an interval actually crosses the seam (`start < 0` or `stop > 360`), so in-range intervals are a no-op and non-global grids are left untouched. Primitives overview (PR #1235) These primitives are always used as a pair: find_roll_anchor (or find_roll_anchor_from_interval) computes the roll amount once; callers pass it to all subsequent roll_lon_coords and roll_lon_data so coordinates and field tensors shift by the same amount. Two downstream pathways use them: - Dataset load — rolls each loaded grid into the user's configured lon_extent convention (PR #1236) - Model setup — rolls the model's fine grid to match the incoming coarse batch's convention (PR #1237) Changes: - `fme.downscaling.data.utils`: add `ClosedInterval.finite_values`, `_requires_lon_roll`, `coords_require_lon_roll`, `find_roll_anchor`, `find_roll_anchor_from_interval`, `roll_lon_coords`, `roll_lon_data`, and private helpers `_validate_rollable_lon` and `_validate_monotonic_lon`. - `roll_lon_coords` (1-D coordinate tensor) and `roll_lon_data` (N-D field tensor) form a parallel pair: both apply the same roll amount, but `roll_lon_coords` also remaps values to keep the result monotonically increasing, while `roll_lon_data` is a pure cyclic shift. Callers pre-compute the roll amount once via `find_roll_anchor` and pass it to both. - `roll_latlon_coords` is not included here; it operates on a `LatLonCoordinates` struct rather than a raw tensor and belongs in the PR that first uses it. - `fme.downscaling.data` (`__init__`): export the new roll helpers. - `fme.downscaling.data.test_utils`: unit tests for roll amounts, seam-crossing conventions, round-trip invertibility, non-global/non-uniform rejection, and invalid input validation. - [x] Tests added - [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated Base: `refactor/moe-validate-experts-init` (PR 1) ### Stack | PR | Head → Base | Title | |----|-------------|-------| | [#1234](#1234) | `refactor/moe-validate-experts-init` → `main` | Validate expert grid compatibility in `DenoisingMoEPredictor.__init__` | | [#1235](#1235) | `feature/lon-roll-primitives` → PR1 | Add longitude roll primitives | | [#1236](#1236) | `feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in the data layer | | [#1237](#1237) | `feature/lon-roll-model` → PR3 | Add with_rolled_lon to models | | [#1238](#1238) | `feature/lon-roll-integration` → PR4 | Roll the model in inference/predict/evaluator |
7455589 to
806b5cd
Compare
PR 3 of 5 in the prime-meridian longitude stack. Applies the roll primitives (PR 2) in the data layer so a longitude interval that crosses the 0/360 seam can be subset instead of raising `NotImplementedError`. In-range intervals resolve to a zero roll and behave exactly as before. Changes: - `fme.downscaling.data.datasets.HorizontalSubsetDataset`: roll data and coordinates into the requested interval's convention rather than raising on wraparound. - `fme.downscaling.data.config`: extract `_build_aligned_subset_pair`, which rolls coarse and fine lon coords into the extent's convention (`_roll_lons_to_extent_convention`) before `adjust_fine_coord_range`, so fine/coarse subselection stays aligned across the seam. - `fme.downscaling.data.static.StaticInputs.roll`: roll static fields and their lon coordinates to match. - `fme.downscaling.data.test_config`, `fme.downscaling.data.test_datasets`, `fme.downscaling.data.test_static`: tests for seam-crossing subsetting (negative and >360 conventions), fine/coarse scale-factor preservation across the seam (even and odd downscale factors), end-to-end paired loader with a seam-crossing extent, and `StaticInputs.roll`. Note: surfacing the coarse grid convention on `GriddedData`/`PairedGriddedData` (`coarse_latlon_coords`) was deferred to the integration PR after review discussion. - [x] Tests added - [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated Base: `feature/lon-roll-primitives` (PR 2) ### Stack | PR | Head → Base | Title | |----|-------------|-------| | [#1234](#1234) | `refactor/moe-validate-experts-init` → `main` | Validate expert grid compatibility in `DenoisingMoEPredictor.__init__` | | [#1235](#1235) | `feature/lon-roll-primitives` → PR1 | Add longitude roll primitives | | [#1236](#1236) | `feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in the data layer | | [#1237](#1237) | `feature/lon-roll-model` → PR3 | Add with_rolled_lon to models | | [#1238](#1238) | `feature/lon-roll-integration` → PR4 | Roll the model in inference/predict/evaluator |
Let models re-express their grid in a seam-crossing coarse domain's longitude convention while sharing network weights: - DiffusionModel.with_rolled_lon rebuilds the model through its constructor with full_fine_coords and static_inputs rolled to match the coarse grid. The roll is anchored on the western coarse-cell edge so the fine grid stays aligned to whole coarse cells. Returns self when no roll is needed. - DenoisingMoEPredictor.with_rolled_lon rolls every expert (preserving the shared-grid invariant) and rebuilds so the sigma dispatcher is reconstructed from the rolled experts. Adds tests for no-roll passthrough, coord shifting with shared weights, idempotency, coarse-cell alignment, and rolling all MoE experts.
b77e2de to
39d11b0
Compare
|
|
||
| Returns self unchanged when coarse_lon does not cross the prime meridian. | ||
|
|
||
| Intended for inference only: rebuilding wraps the module in a second |
There was a problem hiding this comment.
This will have to be tackled if we want to train with patches across the meridian.
There was a problem hiding this comment.
Maybe just do evaluations with patches crossing the meridian first to see if this gap in training distribution is an issue or not.
AnnaKwa
left a comment
There was a problem hiding this comment.
LGTM, minor comment about checking that a rolled model is not used in training
|
|
||
| Returns self unchanged when coarse_lon does not cross the prime meridian. | ||
|
|
||
| Intended for inference only: rebuilding wraps the module in a second |
There was a problem hiding this comment.
Maybe just do evaluations with patches crossing the meridian first to see if this gap in training distribution is an issue or not.
|
|
||
| Intended for inference only: rebuilding wraps the module in a second | ||
| DistributedDataParallel under torch distributed, which is a hazard for | ||
| gradient-synchronized training. |
There was a problem hiding this comment.
Can an attribute _is_longitude_rolled be added to the model so that an assertion at training time would make the source of this error would be clear if someone tried to train with a checkpoint that had rolled coords?
PR 4 of 5 in the prime-meridian longitude stack (PRs 1–3 now merged to main). Lets a model re-express its grid in a seam-crossing coarse domain's longitude convention while sharing the trained network weights, so a single checkpoint can generate over a domain expressed west of 0 or east of 360.
Changes:
fme.downscaling.models.DiffusionModel.with_rolled_lon: rebuild the model through its constructor withfull_fine_coordsandstatic_inputsrolled to match the coarse grid, anchored on the western coarse-cell edge so the fine grid stays aligned to whole coarse cells; returnsselfwhen no roll is needed. Inference-only (rebuilding re-wraps the module under torch distributed).fme.downscaling.predictors.serial_denoising.DenoisingMoEPredictor.with_rolled_lon: roll every expert (preserving the shared-grid invariant) and rebuild so the sigma dispatcher is reconstructed from the rolled experts.fme.downscaling.dataexportsroll_lon_coordsfor the model layer.fme.downscaling.test_models: tests for no-roll passthrough, coord shifting with shared weights (including value-level checks that coords and static data roll together, and that a double roll is a no-op), and coarse-cell alignment for a seam-crossing domain. MoE rolling tests live intest_serial_denoisingnext to the existing grid-validation test.cell_centered_coordinatehelper intest_utilsreplaces per-file midpoint-coordinate constructions (test_models,test_config); removed a test and helper intest_models/test_serial_denoisingduplicated from Validate expert grid compatibility in DenoisingMoEPredictor.__init__ #1234.Base:
main(PRs 1–3 of the stack merged)Stack
refactor/moe-validate-experts-init→mainDenoisingMoEPredictor.__init__feature/lon-roll-primitives→mainfeature/lon-roll-data-layer→mainfeature/lon-roll-model→mainfeature/lon-roll-integration→ PR4