Feature: Variable Masking#1246
Conversation
The synthetic input channels produced by GlobalMeanRemoval used to be
concatenated separately in network_call, which meant the channel-mask
machinery did not produce mask channels for them — when both
include_channel_mask_inputs and append_as_input were enabled, the
extras silently lacked their accompanying mask channels.
Refactor so each GlobalMeanRemoval exposes the synthetic channels as
named tensors in normalized space: extra_channel_names returns sentinel
names (prefixed __gmr_extra__) and extras_normalized() returns
{name: tensor} after forward_transform. The stepper extends its
in_packer with the sentinel names, step_with_adjustments merges the
extras into the normalized input dict, and network_call no longer
mentions GMR at all. The channel-mask path naturally produces mask
channels for the extras (defaulting to ones via the existing fallback
in _build_channel_mask_dict).
This is a backwards-incompatible change to the network input channel
count for the combined include_channel_mask_inputs + append_as_input
setting: the count goes from 2*n_in + n_extra to 2*(n_in + n_extra),
and the channel ordering changes from [in | mask(in) | extras] to
[in | extras | mask(in) | mask(extras)]. Checkpoints using only one of
the two features are unaffected.
Changes:
- fme.core.step.global_mean_removal: add extra_channel_names and
extras_normalized(); drop get_extra_channels (replaced by the named
dict API). The cached forward/inverse state is preserved.
- fme.core.step.single_module.SingleModuleStep.__init__: extend the
input packer with the GMR's extra_channel_names so packing and the
channel-mask routines treat extras as ordinary input channels.
- fme.core.step.single_module.step_with_adjustments: merge normalized
GMR extras into the input dict before network_call.
- Tests cover the new combined include_channel_mask_inputs + GMR
extras path.
GlobalMeanRemoval used to cache per-step state (offsets / shifts and
the normalized extras dict) on the instance between forward_transform
and inverse_transform / extras_normalized. Correctness then depended on
call order — interleaved or skipped calls silently produced wrong
outputs.
Refactor so forward_transform returns an opaque GlobalMeanRemovalState
that callers thread through to inverse_transform and
extras_normalized. The instance carries no per-step state; the
RuntimeError("called before forward_transform") guard is no longer
needed because the contract is enforced by the type signature.
Changes:
- fme.core.step.global_mean_removal: add GlobalMeanRemovalState
dataclass; forward_transform returns (TensorDict, state);
inverse_transform and extras_normalized take state. Remove the
per-instance cached attributes.
- fme.core.step.single_module.step_with_adjustments: capture state
from forward_transform and pass it back to extras_normalized and
inverse_transform.
- Tests cover the new state-threading contract with interleaved
forward/forward/inverse/inverse round-trips for both shared and
per-channel variants.
| total number of input channels. | ||
| """ | ||
|
|
||
| min_vars: int | str = "min" |
There was a problem hiding this comment.
I think 1 is a more reasonable default. Also, no need for "min" when we have 0.
| min_vars: int | str = "min" | |
| min_vars: int = 1 |
There was a problem hiding this comment.
Done — changed to min_vars: int = 1 and removed the "min" string handling entirely (updated __post_init__ and _sample_uniform).
| per_variable: Independent per-channel Bernoulli masking config. | ||
| """ | ||
|
|
||
| uniform: UniformMaskingConfig | None = None |
There was a problem hiding this comment.
Issue: this configuration typing doesn't imply only one type is allowed.
Suggestion: Put a kind: Literal["uniform"] = "uniform" attribute on UniformMaskingConfig, and then use it directly as the type. Similarly put a kind: Literal["per_variable"] = "per_variable" on the PerVariableMaskingConfig.
This will have a secondary benefit that if for some reason we want to remove support for one of these masking types, it will only invalidate checkpoints that use that type, instead of all checkpoints trained with that code.
There was a problem hiding this comment.
Done — added kind: Literal["uniform"] = "uniform" to UniformMaskingConfig and kind: Literal["per_variable"] = "per_variable" to PerVariableMaskingConfig. Moved sample_mask onto each sub-config and replaced the VariableMaskingConfig wrapper dataclass with a type alias UniformMaskingConfig | PerVariableMaskingConfig. Call sites now use the sub-configs directly (e.g. PerVariableMaskingConfig(rate=0.5)), and dacite discriminates correctly in strict mode.
| ) | ||
| assert ( | ||
| differs | ||
| ), "rate=1.0 should zero inputs and mask indicators, producing different outputs" |
There was a problem hiding this comment.
Claude: this test (and test_input_dropout_with_channel_mask_inputs_zeroes_gmr_indicators below) does not actually isolate the mask-indicator zeroing branch. With data_mask=None, _build_channel_mask_dict returns all-ones regardless, so the input-channel zeroing alone is sufficient to make rate=1.0 differ from rate=0.0 — the assertion would still pass if the mask_tensor = mask_tensor * channel_mask... line were deleted. Either pass a non-trivial data_mask, inspect the packed network input directly, or trim the docstring to not overclaim.
There was a problem hiding this comment.
Done — trimmed the docstrings to not overclaim. Both tests now just assert that rate=1.0 produces different outputs from rate=0.0 (without claiming to isolate the mask-indicator zeroing path). Also renamed test_input_dropout_with_channel_mask_inputs_zeroes_gmr_indicators → test_input_dropout_with_channel_mask_inputs_and_gmr.
|
Claude: PR description is misleading — it credits this PR with the GMR stateless refactor and |
|
Re the PR description, note I'll merge that branch into main before this PR gets merged, and then this PR will get rebased from main and merged. So the PR description should only be for the part on this branch. |
|
Updated the PR description to only describe what's on this branch: variable masking support in |
|
Will do a full code review once I have the base PRs reviewed and merged, and then this PR merges any of those updates in. |
86f6a88 to
816dd67
Compare
| min_vars: int = 1 | ||
| max_vars: int | str = "max" |
There was a problem hiding this comment.
Issue (required): When I looked at this code, I immediately filled in that min_vars would mean the minimum number of variables we're working with, not the minimum number of masked variables. Also, I kind of expect that we should always be including the option of not masking a variable, meaning min_vars doesn't need to be configurable (and should always be set to zero, not 1).
Suggestion (optional): rename max_vars to max_masked_vars, and remove min_vars (set it to 0 always).
| forcing_data: BatchData, | ||
| n_forward_steps: int, | ||
| optimizer: OptimizationABC, | ||
| n_ensemble: int | None = None, |
There was a problem hiding this comment.
Question: n_ensemble is already an attribute on both initial_condition and forcing_data, why does it need to be passed separately? Can we revert the changes to this file?
| labels: BatchLabels | None = None, | ||
| data_mask: TensorMapping | None = None, | ||
| stepper_state: StepperState | None = None, | ||
| n_ensemble: int = 1, |
There was a problem hiding this comment.
Issue: We've been intentionally hiding the ensemble-ness of the input from the Step object up to this point. I hadn't realized that this feature is going to require breaking that contract.
Option 1: Find a way to avoid passing n_ensemble to a step, for example by refactoring the data channel dropout concern into a different location that already has access to the BatchData containing the n_ensemble (hard), or by deciding we're just going to always dropout the channels uniformly across all batch members (easy, but requires scientific re-evaluation at least for your "best" case).
Option 2: If we are going to pass this data, it should be done by passing TensorMapping that contain an ensemble dimension, instead of adding a low-level contract that the batch dimension has a certain striding pattern and this integer indicates its interpretation. This is a fundamental refactor to all of our Step concepts, and I don't love it because it's training-specific - even when we need this in training, it's not something we need during post-training inference.
If using uniform-across-batch dropout isn't feasible, I'd explore whether we can more fundamentally refactor out this step-specific training concern into a layer that knows about ensembles before refactoring all Step TensorMappings to have an ensemble dimension.
There was a problem hiding this comment.
If we go uniform-across-batch, make sure that we coordinate using dist cross-gpu so that the behavior doesn't change depending on GPU decomposition, which requires a @pytest.mark.parallel test.
Introduce a synthetic input-dropout presence mask threaded through StepArgs and applied late inside SingleModuleStep.network_call, separate from the real data_mask. Adds the make_input_dropout_mask sampling hook (train-mode guarded, keyed by packed input channel names incl. GMR extras) and the non-random has_input_dropout introspection hook to StepABC, delegated through MultiCallStep and Stepper. The late application and AND-combined channel-mask indicator are inert until a caller supplies StepArgs.input_dropout_mask; the existing in-step sampler is left in place and removed in a follow-up commit.
Generate the synthetic input-dropout mask in TrainStepper._accumulate_loss on the pre-broadcast batch via the make_input_dropout_mask hook, repeat-interleave it to the folded ensemble batch, and pass it through predict_generator as StepArgs.input_dropout_mask. Ensemble members of a base sample share one mask without exposing n_ensemble to the Step. Delete the in-step dropout sampler and the n_ensemble plumbing that threaded through StepArgs, predict_generator/get_prediction_generator, and the coupled stepper (keeping BatchData/PrognosticState ensemble metadata). Dropout is now applied late and deterministically; the loss still scores predictions made from dropped inputs. Coupled training fails loud if a component configures input_dropout, since that route never samples the mask. Resolves the args.py:43 and single_module.py:1066 review comments.
Now that no Step-facing caller needs ensemble-aware sampling, sample_mask returns a plain [batch, n_channels] mask; drop the n_ensemble parameter and the _repeat_ensemble_mask / _get_base_batch_size helpers. Rename UniformMaskingConfig.max_vars to max_masked_vars and drop min_vars, so the minimum masked-variable count is now 0 (a window may have none masked). This training-only config has no inference backwards-compatibility concern, so no alias/deprecation shims are added. Docstrings consistently say 'masked variables'.
The per-sample dropout mask has no spatial dimension, so under spatial/model parallelism each tile of a sample must receive an identical mask; otherwise co-ranks' independent torch.rand draws would mask different channels and corrupt the sample. Add a Distributed.broadcast_spatial primitive (identity for non_distributed and data-parallel torch_distributed; broadcasts over the spatial group only for model_torch_distributed) and call it inside the make_input_dropout_mask hook before returning. Add parallel tests for the primitive and for mask agreement across spatial tiles under RNG divergence.
# Conflicts: # fme/core/step/multi_call.py # fme/core/step/test_step.py
Adds training-time input channel dropout (variable masking) to
SingleModuleStep.The motivation is to study how randomly withholding input channels during training affects model robustness and generalization — particularly whether a model trained with masking can perform inference when some inputs are unavailable.
The masking is sampled in the ensemble-aware training layer and applied as a concrete, late input mask inside the Step. This keeps the Step ensemble-agnostic (no
n_ensembleis threaded intoStepArgs/Step, resolving the PR #1246 review comments atargs.py:43andsingle_module.py:1066). Synthetic dropout is kept separate from the realdata_mask(which marks genuinely-absent variables consumed by preprocessing and loss masking): predictions made from dropped inputs are still scored by the loss; only truly-absent target variables are excluded.Changes:
fme.core.var_masking.UniformMaskingConfig/PerVariableMaskingConfig: two masking strategies — uniform count-based and per-channel Bernoulli, selectable via akinddiscriminator.sample_maskreturns a plain[batch, n_channels]mask (no ensemble awareness).fme.core.step.SingleModuleStepConfig.input_dropout: new optional field; when set, randomly zeros a subset of packed input channels per sample during training (no-op at eval time). Added only toSingleModuleStepConfig; sibling configs (FCN3StepConfig,SecondaryModuleStepConfig,MultiCallStepConfig,SeparateRadiationStepConfig) are untouched. The config is serialized but inert at inference.fme.core.step.args.StepArgs.input_dropout_mask: synthetic training-only input presence mask, keyed by the Step's packed input channel names; preserved through input processing and applied late insidenetwork_call(after GMR/normalization), separate fromdata_mask.fme.core.step.StepABC.make_input_dropout_mask(mode-guarded sampling hook) andhas_input_dropout(non-random introspection); delegated throughMultiCallStepandfme.ace.stepper.Stepper.SingleModuleStepsamples overin_packer.names(incl. GMR extra channels, which remain independently maskable).fme.ace.stepper.TrainStepper._accumulate_loss: samples the mask on the pre-broadcast batch via the hook, repeat-interleaves it to the folded ensemble batch (ensemble members of a base sample share one mask), and passes it throughpredict_generator.fme.coupled.CoupledTrainStepper: raises a clear error if any component step configuresinput_dropout(unsupported for coupled training — fail loud rather than silently no-op).fme.core.distributed.Distributed.broadcast_spatial: new primitive; the sampling hook broadcasts the mask across the spatial/model-parallel group so all tiles of a sample agree (identity for non-distributed and data-parallel backends).Example configs:
Bernoulli per-variable masking at 5% rate:
Uniform masking of up to 3 channels:
Behavior notes:
Dropout is now per-window (sampled once per rollout, held across forward steps), matching inference reality and the real absent-variable
data_mask— was per-step.UniformMaskingConfigminimum masked count is now 0 (min_varsdropped,max_varsrenamedmax_masked_vars); a window may have no channels dropped. This is a training-only config with no inference back-compat concern, so no alias/deprecation shims were added.Channel-mask indicator channels (
include_channel_mask_inputs) are not maskable; an indicator reflects its data channel's AND-combinedreal & syntheticpresence.Synthetic dropout no longer uses
data_mask; only real missing-variable masks exclude targets from the loss.Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated