Add use_single_component_optim flag to CoupledTrainStepperConfig#1317
Add use_single_component_optim flag to CoupledTrainStepperConfig#1317jpdunc23 wants to merge 1 commit into
use_single_component_optim flag to CoupledTrainStepperConfig#1317Conversation
Adds an opt-in `use_single_component_optim` flag to `CoupledTrainStepperConfig`.
When enabled, each `train_on_batch` call optimizes exactly one of {ocean, atmosphere}:
a per-batch coin owned by `CoupledStepperTrainLoss` selects one realm (fair 50/50 among
realms whose loss is non-null for that batch, the sole non-null realm otherwise), and the
non-selected realm contributes zero loss and runs under `torch.no_grad()`. The selected
realm keeps its existing `n_steps` / `optimize_last_step_only` / `loss_weight` schedule —
the coin only decides eligibility. The coin is flipped once per batch right after
`sample_n_steps()`, is disabled during evaluation (full per-step metrics preserved), and
is seeded distributed-consistently so every rank makes the same decision. Default off
reproduces current behavior exactly, and training-config changes are not persisted in
inference checkpoints, so checkpoint loading is unaffected.
Changes:
- `CoupledStepperTrainLoss.__call__` now routes its optimized-step check through the
`step_is_optimized(realm, step)` wrapper (previously called the per-realm schedule
directly); `compute_loss` deliberately still calls the schedule directly to keep
eval-only diagnostic metrics for the non-selected realm
- `CoupledStepperTrainLoss.__init__` gains a `single_component_optim` parameter
- `CoupledTrainStepperConfig` gains `use_single_component_optim: bool = False`, validated
in `__post_init__` and plumbed through `_build_loss`
- [ ] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
| ... | ||
| # NEW coin state: | ||
| self._single_component_optim: bool = single_component_optim | ||
| self._coin_rng: np.random.RandomState | None = None # lazy, like TimeLengthProbabilities |
There was a problem hiding this comment.
Can we maybe just call this _single_component_optim_rng? The "coin" terminology is somewhat random.
I also think we could probably have a separate, more general object to own the RNG, perhaps defined in fme.core and reused by TimeLengthProbabilities; otherwise, we'll just be repeating a lot of the TimeLengthProbabilities API here. Something like:
class RandomChoice:
def __init__(
self,
outcomes, # 1-D array-like or int, same as np.random.RandomState.choice's a argument
probabilities, # 1-D array-like, same as np.random.RandomState.choice's p argument
):
...
def initialize_rng(self):
... # same as TimeLengthProbabilities implementation
def seed_rng(self):
... # same as TimeLengthProbabilities implementation
def sample(self) -> int:
... # same as TimeLengthProbabilities implementationI think it makes sense to include that change as part of the PR since this is maybe the first time we need rng similar to what is in TimeLengthProbabilities, or we could add RandomChoice and the TimeLengthProbabilities refactor in a separate PR first.
@mcgibbon please let me know what you think.
There was a problem hiding this comment.
Re _coin_rng yes, it implies choice from 2 values and it's possible we'd add more components in the future. Maybe _which_component_rng? Or even just _rng, does it use another one?
I would advise against a RandomChoice class. I don't get what benefit it provides that isn't already provided by np.random.RandomState, all it seems to do is restrict the class to only having one of its sampling methods and moving the "choice" verb from the method to the class name.
There was a problem hiding this comment.
Or even just _rng, does it use another one?
No, though seed_step_sampler does call ComponentLossSchedule.seed_rng, so I think it is best to disambiguate using _which_component_rng
I would advise against a RandomChoice class.
OK, there will be some repetition but perhaps less than I feared given that I don't think we actually need seed_rng here.
There was a problem hiding this comment.
Let's call this _component_choice_rng.
|
@mcgibbon requesting your review of the plan prior to implementation. You can view the plan with nice markdown rendering here: https://github.com/ai2cm/ace/blob/a0a4f16a6ddc74fb9990dc47492b4dd8bccf1859/pr-plan.md |
| - When `use_single_component_optim=True` but one component is statically null (zero | ||
| `loss_weight`, or `n_steps` whose max is 0), the coin degrades to always selecting the | ||
| sole non-null realm — i.e. the flag has no observable effect. Should `__post_init__` | ||
| (a) raise a `ValueError` (the flag is meaningless here), (b) emit a warning and proceed, |
There was a problem hiding this comment.
It should raise a ValueError.
| or (c) silently allow it as graceful degradation? The plan currently assumes (c); (b) | ||
| seems most useful for catching misconfigured experiments. Note this is distinct from the | ||
| always-rejected case where *both* realms are null. | ||
| - Flag name: `use_single_component_optim` (proposed) vs `optimize_single_component_per_batch`? |
There was a problem hiding this comment.
optimize_single_component_per_batch is much clearer, I prefer it. It's ambiguous what "single component optim" would be, e.g. arguably FTO and FTA are single component optimization. This resolves that issue.
| seems most useful for catching misconfigured experiments. Note this is distinct from the | ||
| always-rejected case where *both* realms are null. | ||
| - Flag name: `use_single_component_optim` (proposed) vs `optimize_single_component_per_batch`? | ||
| Method name: `flip_optimization_coin()` vs an alternative. Soft proposals only — `Config` |
There was a problem hiding this comment.
Use a different name for this method.
There was a problem hiding this comment.
Suggesting component_optimization_choice() to match the numpy RandomState.choice() convention.
| - Flag name: `use_single_component_optim` (proposed) vs `optimize_single_component_per_batch`? | ||
| Method name: `flip_optimization_coin()` vs an alternative. Soft proposals only — `Config` | ||
| suffix and `_`-private conventions hold either way. | ||
| - Coin policy is fixed at fair 50/50 among non-null realms; loss-weight-proportional |
There was a problem hiding this comment.
We should use a different name than "coin". I do think this should give equal weighting to components, for now - no need for configurability past what we already have in the component loss configs.
| self._coin_rng: np.random.RandomState | None = None # lazy, like TimeLengthProbabilities | ||
| self._selected_realm: Literal["ocean", "atmosphere"] | None = None # None == both eligible | ||
|
|
||
| def _initialize_coin_rng(self) -> None: # NEW — private; lazy-init, mirrors TimeLengthProbabilities.initialize_rng |
There was a problem hiding this comment.
Call this _init_component_choice_rng
| # ) # identical across ranks -> identical coin decisions, reproducible from run seed | ||
| ... | ||
|
|
||
| def flip_optimization_coin(self) -> None: # NEW — re-draw selected realm; no-op when disabled |
There was a problem hiding this comment.
Call this component_optimization_choice
| @@ -0,0 +1,228 @@ | |||
| # Add `use_single_component_optim` flag to `CoupledTrainStepperConfig` | |||
There was a problem hiding this comment.
Replace use_single_component_optim with optimize_single_component_per_batch throughout.
Adds an opt-in
use_single_component_optimflag toCoupledTrainStepperConfig. When enabled, eachtrain_on_batchcall optimizes exactly one of {ocean, atmosphere}: a per-batch coin owned byCoupledStepperTrainLossselects one realm (fair 50/50 among realms whose loss is non-null for that batch, the sole non-null realm otherwise), and the non-selected realm contributes zero loss and runs undertorch.no_grad(). The selected realm keeps its existingn_steps/optimize_last_step_only/loss_weightschedule — the coin only decides eligibility. The coin is flipped once per batch right aftersample_n_steps(), is disabled during evaluation (full per-step metrics preserved), and is seeded distributed-consistently so every rank makes the same decision. Default off reproduces current behavior exactly, and training-config changes are not persisted in inference checkpoints, so checkpoint loading is unaffected.Changes:
CoupledStepperTrainLoss.__call__now routes its optimized-step check through thestep_is_optimized(realm, step)wrapper (previously called the per-realm scheduledirectly);
compute_lossdeliberately still calls the schedule directly to keepeval-only diagnostic metrics for the non-selected realm
CoupledStepperTrainLoss.__init__gains asingle_component_optimparameterCoupledTrainStepperConfiggainsuse_single_component_optim: bool = False, validatedin
__post_init__and plumbed through_build_lossTests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated