Skip to content

Feature: Variable Masking#1246

Open
yyexela wants to merge 22 commits into
mainfrom
feature/var_masking_simple
Open

Feature: Variable Masking#1246
yyexela wants to merge 22 commits into
mainfrom
feature/var_masking_simple

Conversation

@yyexela

@yyexela yyexela commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Adds training-time input channel dropout (variable masking) to SingleModuleStep.

The motivation is to study how randomly withholding input channels during training affects model robustness and generalization — particularly whether a model trained with masking can perform inference when some inputs are unavailable.

The masking is sampled in the ensemble-aware training layer and applied as a concrete, late input mask inside the Step. This keeps the Step ensemble-agnostic (no n_ensemble is threaded into StepArgs/Step, resolving the PR #1246 review comments at args.py:43 and single_module.py:1066). Synthetic dropout is kept separate from the real data_mask (which marks genuinely-absent variables consumed by preprocessing and loss masking): predictions made from dropped inputs are still scored by the loss; only truly-absent target variables are excluded.

Changes:

  • fme.core.var_masking.UniformMaskingConfig / PerVariableMaskingConfig: two masking strategies — uniform count-based and per-channel Bernoulli, selectable via a kind discriminator. sample_mask returns a plain [batch, n_channels] mask (no ensemble awareness).
  • fme.core.step.SingleModuleStepConfig.input_dropout: new optional field; when set, randomly zeros a subset of packed input channels per sample during training (no-op at eval time). Added only to SingleModuleStepConfig; sibling configs (FCN3StepConfig, SecondaryModuleStepConfig, MultiCallStepConfig, SeparateRadiationStepConfig) are untouched. The config is serialized but inert at inference.
  • fme.core.step.args.StepArgs.input_dropout_mask: synthetic training-only input presence mask, keyed by the Step's packed input channel names; preserved through input processing and applied late inside network_call (after GMR/normalization), separate from data_mask.
  • fme.core.step.StepABC.make_input_dropout_mask (mode-guarded sampling hook) and has_input_dropout (non-random introspection); delegated through MultiCallStep and fme.ace.stepper.Stepper. SingleModuleStep samples over in_packer.names (incl. GMR extra channels, which remain independently maskable).
  • fme.ace.stepper.TrainStepper._accumulate_loss: samples the mask on the pre-broadcast batch via the hook, repeat-interleaves it to the folded ensemble batch (ensemble members of a base sample share one mask), and passes it through predict_generator.
  • fme.coupled.CoupledTrainStepper: raises a clear error if any component step configures input_dropout (unsupported for coupled training — fail loud rather than silently no-op).
  • fme.core.distributed.Distributed.broadcast_spatial: new primitive; the sampling hook broadcasts the mask across the spatial/model-parallel group so all tiles of a sample agree (identity for non-distributed and data-parallel backends).

Example configs:

Bernoulli per-variable masking at 5% rate:

stepper:
  input_dropout:
    kind: per_variable
    rate: 0.05
  include_channel_mask_inputs: true

Uniform masking of up to 3 channels:

stepper:
  input_dropout:
    kind: uniform
    max_masked_vars: 3
  include_channel_mask_inputs: true

Behavior notes:

  • Dropout is now per-window (sampled once per rollout, held across forward steps), matching inference reality and the real absent-variable data_mask — was per-step.

  • UniformMaskingConfig minimum masked count is now 0 (min_vars dropped, max_vars renamed max_masked_vars); a window may have no channels dropped. This is a training-only config with no inference back-compat concern, so no alias/deprecation shims were added.

  • Channel-mask indicator channels (include_channel_mask_inputs) are not maskable; an indicator reflects its data channel's AND-combined real & synthetic presence.

  • Synthetic dropout no longer uses data_mask; only real missing-variable masks exclude targets from the loss.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

mcgibbon and others added 6 commits June 9, 2026 16:37
The synthetic input channels produced by GlobalMeanRemoval used to be
concatenated separately in network_call, which meant the channel-mask
machinery did not produce mask channels for them — when both
include_channel_mask_inputs and append_as_input were enabled, the
extras silently lacked their accompanying mask channels.

Refactor so each GlobalMeanRemoval exposes the synthetic channels as
named tensors in normalized space: extra_channel_names returns sentinel
names (prefixed __gmr_extra__) and extras_normalized() returns
{name: tensor} after forward_transform. The stepper extends its
in_packer with the sentinel names, step_with_adjustments merges the
extras into the normalized input dict, and network_call no longer
mentions GMR at all. The channel-mask path naturally produces mask
channels for the extras (defaulting to ones via the existing fallback
in _build_channel_mask_dict).

This is a backwards-incompatible change to the network input channel
count for the combined include_channel_mask_inputs + append_as_input
setting: the count goes from 2*n_in + n_extra to 2*(n_in + n_extra),
and the channel ordering changes from [in | mask(in) | extras] to
[in | extras | mask(in) | mask(extras)]. Checkpoints using only one of
the two features are unaffected.

Changes:
- fme.core.step.global_mean_removal: add extra_channel_names and
  extras_normalized(); drop get_extra_channels (replaced by the named
  dict API). The cached forward/inverse state is preserved.
- fme.core.step.single_module.SingleModuleStep.__init__: extend the
  input packer with the GMR's extra_channel_names so packing and the
  channel-mask routines treat extras as ordinary input channels.
- fme.core.step.single_module.step_with_adjustments: merge normalized
  GMR extras into the input dict before network_call.
- Tests cover the new combined include_channel_mask_inputs + GMR
  extras path.
GlobalMeanRemoval used to cache per-step state (offsets / shifts and
the normalized extras dict) on the instance between forward_transform
and inverse_transform / extras_normalized. Correctness then depended on
call order — interleaved or skipped calls silently produced wrong
outputs.

Refactor so forward_transform returns an opaque GlobalMeanRemovalState
that callers thread through to inverse_transform and
extras_normalized. The instance carries no per-step state; the
RuntimeError("called before forward_transform") guard is no longer
needed because the contract is enforced by the type signature.

Changes:
- fme.core.step.global_mean_removal: add GlobalMeanRemovalState
  dataclass; forward_transform returns (TensorDict, state);
  inverse_transform and extras_normalized take state. Remove the
  per-instance cached attributes.
- fme.core.step.single_module.step_with_adjustments: capture state
  from forward_transform and pass it back to extras_normalized and
  inverse_transform.
- Tests cover the new state-threading contract with interleaved
  forward/forward/inverse/inverse round-trips for both shared and
  per-channel variants.
@yyexela yyexela changed the title [WIP] Feature/var masking simple Feature/var masking simple Jun 9, 2026
@yyexela yyexela changed the title Feature/var masking simple Feature: Variable Masking Jun 9, 2026
@yyexela yyexela requested review from Arcomano1234 and mcgibbon June 9, 2026 18:44
@mcgibbon mcgibbon changed the base branch from main to refactor/global-mean-removal-stateless June 9, 2026 18:56
Comment thread fme/core/var_masking.py Outdated
total number of input channels.
"""

min_vars: int | str = "min"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 1 is a more reasonable default. Also, no need for "min" when we have 0.

Suggested change
min_vars: int | str = "min"
min_vars: int = 1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — changed to min_vars: int = 1 and removed the "min" string handling entirely (updated __post_init__ and _sample_uniform).

Comment thread fme/core/var_masking.py Outdated
per_variable: Independent per-channel Bernoulli masking config.
"""

uniform: UniformMaskingConfig | None = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: this configuration typing doesn't imply only one type is allowed.

Suggestion: Put a kind: Literal["uniform"] = "uniform" attribute on UniformMaskingConfig, and then use it directly as the type. Similarly put a kind: Literal["per_variable"] = "per_variable" on the PerVariableMaskingConfig.

This will have a secondary benefit that if for some reason we want to remove support for one of these masking types, it will only invalidate checkpoints that use that type, instead of all checkpoints trained with that code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — added kind: Literal["uniform"] = "uniform" to UniformMaskingConfig and kind: Literal["per_variable"] = "per_variable" to PerVariableMaskingConfig. Moved sample_mask onto each sub-config and replaced the VariableMaskingConfig wrapper dataclass with a type alias UniformMaskingConfig | PerVariableMaskingConfig. Call sites now use the sub-configs directly (e.g. PerVariableMaskingConfig(rate=0.5)), and dacite discriminates correctly in strict mode.

Comment thread fme/core/step/test_step.py
Comment thread fme/core/step/test_step.py Outdated
)
assert (
differs
), "rate=1.0 should zero inputs and mask indicators, producing different outputs"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: this test (and test_input_dropout_with_channel_mask_inputs_zeroes_gmr_indicators below) does not actually isolate the mask-indicator zeroing branch. With data_mask=None, _build_channel_mask_dict returns all-ones regardless, so the input-channel zeroing alone is sufficient to make rate=1.0 differ from rate=0.0 — the assertion would still pass if the mask_tensor = mask_tensor * channel_mask... line were deleted. Either pass a non-trivial data_mask, inspect the packed network input directly, or trim the docstring to not overclaim.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — trimmed the docstrings to not overclaim. Both tests now just assert that rate=1.0 produces different outputs from rate=0.0 (without claiming to isolate the mask-indicator zeroing path). Also renamed test_input_dropout_with_channel_mask_inputs_zeroes_gmr_indicatorstest_input_dropout_with_channel_mask_inputs_and_gmr.

@mcgibbon

mcgibbon commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Claude: PR description is misleading — it credits this PR with the GMR stateless refactor and PerChannelGlobalMeanRemoval additions, but those land via the stacked base PR #1244 (refactor/global-mean-removal-stateless). This PR only adds variable masking, a one-line docstring on n_extra_input_channels, and a gmr_state is not None assert tightening in step_with_adjustments. A reviewer reading the description will look for GMR code that is not here. Also worth noting that input_dropout is added only to SingleModuleStepConfig; siblings (FCN3StepConfig, SecondaryModuleStepConfig, MultiCallStepConfig, SeparateRadiationStepConfig) are intentionally untouched — a one-line note would save the reviewer asking.

@mcgibbon

mcgibbon commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Re the PR description, note I'll merge that branch into main before this PR gets merged, and then this PR will get rebased from main and merged. So the PR description should only be for the part on this branch.

@yyexela

yyexela commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Updated the PR description to only describe what's on this branch: variable masking support in SingleModuleStep (the GMR refactor and PerChannelGlobalMeanRemoval additions belong to the base PR #1244). Also added a note that input_dropout is intentionally added only to SingleModuleStepConfig.

@yyexela yyexela requested a review from mcgibbon June 9, 2026 19:58
@mcgibbon

mcgibbon commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Will do a full code review once I have the base PRs reviewed and merged, and then this PR merges any of those updates in.

@mcgibbon mcgibbon force-pushed the refactor/global-mean-removal-stateless branch 2 times, most recently from 86f6a88 to 816dd67 Compare June 9, 2026 22:44
Base automatically changed from refactor/global-mean-removal-stateless to main June 10, 2026 01:19
Comment thread fme/core/var_masking.py Outdated
Comment on lines +19 to +20
min_vars: int = 1
max_vars: int | str = "max"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue (required): When I looked at this code, I immediately filled in that min_vars would mean the minimum number of variables we're working with, not the minimum number of masked variables. Also, I kind of expect that we should always be including the option of not masking a variable, meaning min_vars doesn't need to be configurable (and should always be set to zero, not 1).

Suggestion (optional): rename max_vars to max_masked_vars, and remove min_vars (set it to 0 always).

Comment thread fme/ace/stepper/single_module.py Outdated
forcing_data: BatchData,
n_forward_steps: int,
optimizer: OptimizationABC,
n_ensemble: int | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: n_ensemble is already an attribute on both initial_condition and forcing_data, why does it need to be passed separately? Can we revert the changes to this file?

Comment thread fme/core/step/args.py Outdated
labels: BatchLabels | None = None,
data_mask: TensorMapping | None = None,
stepper_state: StepperState | None = None,
n_ensemble: int = 1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: We've been intentionally hiding the ensemble-ness of the input from the Step object up to this point. I hadn't realized that this feature is going to require breaking that contract.

Option 1: Find a way to avoid passing n_ensemble to a step, for example by refactoring the data channel dropout concern into a different location that already has access to the BatchData containing the n_ensemble (hard), or by deciding we're just going to always dropout the channels uniformly across all batch members (easy, but requires scientific re-evaluation at least for your "best" case).

Option 2: If we are going to pass this data, it should be done by passing TensorMapping that contain an ensemble dimension, instead of adding a low-level contract that the batch dimension has a certain striding pattern and this integer indicates its interpretation. This is a fundamental refactor to all of our Step concepts, and I don't love it because it's training-specific - even when we need this in training, it's not something we need during post-training inference.

If using uniform-across-batch dropout isn't feasible, I'd explore whether we can more fundamentally refactor out this step-specific training concern into a layer that knows about ensembles before refactoring all Step TensorMappings to have an ensemble dimension.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go uniform-across-batch, make sure that we coordinate using dist cross-gpu so that the behavior doesn't change depending on GPU decomposition, which requires a @pytest.mark.parallel test.

yyexela added 7 commits June 18, 2026 13:31
Introduce a synthetic input-dropout presence mask threaded through StepArgs
and applied late inside SingleModuleStep.network_call, separate from the real
data_mask. Adds the make_input_dropout_mask sampling hook (train-mode guarded,
keyed by packed input channel names incl. GMR extras) and the non-random
has_input_dropout introspection hook to StepABC, delegated through
MultiCallStep and Stepper.

The late application and AND-combined channel-mask indicator are inert until
a caller supplies StepArgs.input_dropout_mask; the existing in-step sampler is
left in place and removed in a follow-up commit.
Generate the synthetic input-dropout mask in TrainStepper._accumulate_loss on
the pre-broadcast batch via the make_input_dropout_mask hook, repeat-interleave
it to the folded ensemble batch, and pass it through predict_generator as
StepArgs.input_dropout_mask. Ensemble members of a base sample share one mask
without exposing n_ensemble to the Step.

Delete the in-step dropout sampler and the n_ensemble plumbing that threaded
through StepArgs, predict_generator/get_prediction_generator, and the coupled
stepper (keeping BatchData/PrognosticState ensemble metadata). Dropout is now
applied late and deterministically; the loss still scores predictions made from
dropped inputs. Coupled training fails loud if a component configures
input_dropout, since that route never samples the mask.

Resolves the args.py:43 and single_module.py:1066 review comments.
Now that no Step-facing caller needs ensemble-aware sampling, sample_mask
returns a plain [batch, n_channels] mask; drop the n_ensemble parameter and the
_repeat_ensemble_mask / _get_base_batch_size helpers.

Rename UniformMaskingConfig.max_vars to max_masked_vars and drop min_vars, so
the minimum masked-variable count is now 0 (a window may have none masked).
This training-only config has no inference backwards-compatibility concern, so
no alias/deprecation shims are added. Docstrings consistently say 'masked
variables'.
The per-sample dropout mask has no spatial dimension, so under spatial/model
parallelism each tile of a sample must receive an identical mask; otherwise
co-ranks' independent torch.rand draws would mask different channels and
corrupt the sample. Add a Distributed.broadcast_spatial primitive (identity for
non_distributed and data-parallel torch_distributed; broadcasts over the
spatial group only for model_torch_distributed) and call it inside the
make_input_dropout_mask hook before returning. Add parallel tests for the
primitive and for mask agreement across spatial tiles under RNG divergence.
# Conflicts:
#	fme/core/step/multi_call.py
#	fme/core/step/test_step.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants