Skip to content

Add use_single_component_optim flag to CoupledTrainStepperConfig#1317

Draft
jpdunc23 wants to merge 1 commit into
mainfrom
feature/single-component-optim
Draft

Add use_single_component_optim flag to CoupledTrainStepperConfig#1317
jpdunc23 wants to merge 1 commit into
mainfrom
feature/single-component-optim

Conversation

@jpdunc23

@jpdunc23 jpdunc23 commented Jun 25, 2026

Copy link
Copy Markdown
Member

Adds an opt-in use_single_component_optim flag to CoupledTrainStepperConfig. When enabled, each train_on_batch call optimizes exactly one of {ocean, atmosphere}: a per-batch coin owned by CoupledStepperTrainLoss selects one realm (fair 50/50 among realms whose loss is non-null for that batch, the sole non-null realm otherwise), and the non-selected realm contributes zero loss and runs under torch.no_grad(). The selected realm keeps its existing n_steps / optimize_last_step_only / loss_weight schedule — the coin only decides eligibility. The coin is flipped once per batch right after sample_n_steps(), is disabled during evaluation (full per-step metrics preserved), and is seeded distributed-consistently so every rank makes the same decision. Default off reproduces current behavior exactly, and training-config changes are not persisted in inference checkpoints, so checkpoint loading is unaffected.

Changes:

  • CoupledStepperTrainLoss.__call__ now routes its optimized-step check through the
    step_is_optimized(realm, step) wrapper (previously called the per-realm schedule
    directly); compute_loss deliberately still calls the schedule directly to keep
    eval-only diagnostic metrics for the non-selected realm

  • CoupledStepperTrainLoss.__init__ gains a single_component_optim parameter

  • CoupledTrainStepperConfig gains use_single_component_optim: bool = False, validated
    in __post_init__ and plumbed through _build_loss

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Adds an opt-in `use_single_component_optim` flag to `CoupledTrainStepperConfig`.
When enabled, each `train_on_batch` call optimizes exactly one of {ocean, atmosphere}:
a per-batch coin owned by `CoupledStepperTrainLoss` selects one realm (fair 50/50 among
realms whose loss is non-null for that batch, the sole non-null realm otherwise), and the
non-selected realm contributes zero loss and runs under `torch.no_grad()`. The selected
realm keeps its existing `n_steps` / `optimize_last_step_only` / `loss_weight` schedule —
the coin only decides eligibility. The coin is flipped once per batch right after
`sample_n_steps()`, is disabled during evaluation (full per-step metrics preserved), and
is seeded distributed-consistently so every rank makes the same decision. Default off
reproduces current behavior exactly, and training-config changes are not persisted in
inference checkpoints, so checkpoint loading is unaffected.

Changes:
- `CoupledStepperTrainLoss.__call__` now routes its optimized-step check through the
  `step_is_optimized(realm, step)` wrapper (previously called the per-realm schedule
  directly); `compute_loss` deliberately still calls the schedule directly to keep
  eval-only diagnostic metrics for the non-selected realm
- `CoupledStepperTrainLoss.__init__` gains a `single_component_optim` parameter
- `CoupledTrainStepperConfig` gains `use_single_component_optim: bool = False`, validated
  in `__post_init__` and plumbed through `_build_loss`

- [ ] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
@jpdunc23 jpdunc23 requested a review from mcgibbon June 25, 2026 17:38
Comment thread pr-plan.md
...
# NEW coin state:
self._single_component_optim: bool = single_component_optim
self._coin_rng: np.random.RandomState | None = None # lazy, like TimeLengthProbabilities

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe just call this _single_component_optim_rng? The "coin" terminology is somewhat random.

I also think we could probably have a separate, more general object to own the RNG, perhaps defined in fme.core and reused by TimeLengthProbabilities; otherwise, we'll just be repeating a lot of the TimeLengthProbabilities API here. Something like:

class RandomChoice:
    def __init__(
        self, 
        outcomes,  # 1-D array-like or int, same as np.random.RandomState.choice's a argument
        probabilities, # 1-D array-like, same as np.random.RandomState.choice's p argument
    ):
        ...

    def initialize_rng(self):
        ...  # same as TimeLengthProbabilities implementation

    def seed_rng(self):
        ...  # same as TimeLengthProbabilities implementation

    def sample(self) -> int:
        ...  # same as TimeLengthProbabilities implementation

I think it makes sense to include that change as part of the PR since this is maybe the first time we need rng similar to what is in TimeLengthProbabilities, or we could add RandomChoice and the TimeLengthProbabilities refactor in a separate PR first.

@mcgibbon please let me know what you think.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re _coin_rng yes, it implies choice from 2 values and it's possible we'd add more components in the future. Maybe _which_component_rng? Or even just _rng, does it use another one?

I would advise against a RandomChoice class. I don't get what benefit it provides that isn't already provided by np.random.RandomState, all it seems to do is restrict the class to only having one of its sampling methods and moving the "choice" verb from the method to the class name.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or even just _rng, does it use another one?

No, though seed_step_sampler does call ComponentLossSchedule.seed_rng, so I think it is best to disambiguate using _which_component_rng

I would advise against a RandomChoice class.

OK, there will be some repetition but perhaps less than I feared given that I don't think we actually need seed_rng here.

@jpdunc23 jpdunc23 Jun 25, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this _component_choice_rng.

@jpdunc23

Copy link
Copy Markdown
Member Author

@mcgibbon requesting your review of the plan prior to implementation. You can view the plan with nice markdown rendering here: https://github.com/ai2cm/ace/blob/a0a4f16a6ddc74fb9990dc47492b4dd8bccf1859/pr-plan.md

Comment thread pr-plan.md
- When `use_single_component_optim=True` but one component is statically null (zero
`loss_weight`, or `n_steps` whose max is 0), the coin degrades to always selecting the
sole non-null realm — i.e. the flag has no observable effect. Should `__post_init__`
(a) raise a `ValueError` (the flag is meaningless here), (b) emit a warning and proceed,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should raise a ValueError.

Comment thread pr-plan.md
or (c) silently allow it as graceful degradation? The plan currently assumes (c); (b)
seems most useful for catching misconfigured experiments. Note this is distinct from the
always-rejected case where *both* realms are null.
- Flag name: `use_single_component_optim` (proposed) vs `optimize_single_component_per_batch`?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimize_single_component_per_batch is much clearer, I prefer it. It's ambiguous what "single component optim" would be, e.g. arguably FTO and FTA are single component optimization. This resolves that issue.

Comment thread pr-plan.md
seems most useful for catching misconfigured experiments. Note this is distinct from the
always-rejected case where *both* realms are null.
- Flag name: `use_single_component_optim` (proposed) vs `optimize_single_component_per_batch`?
Method name: `flip_optimization_coin()` vs an alternative. Soft proposals only — `Config`

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a different name for this method.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggesting component_optimization_choice() to match the numpy RandomState.choice() convention.

Comment thread pr-plan.md
- Flag name: `use_single_component_optim` (proposed) vs `optimize_single_component_per_batch`?
Method name: `flip_optimization_coin()` vs an alternative. Soft proposals only — `Config`
suffix and `_`-private conventions hold either way.
- Coin policy is fixed at fair 50/50 among non-null realms; loss-weight-proportional

@mcgibbon mcgibbon Jun 25, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a different name than "coin". I do think this should give equal weighting to components, for now - no need for configurability past what we already have in the component loss configs.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

Comment thread pr-plan.md
self._coin_rng: np.random.RandomState | None = None # lazy, like TimeLengthProbabilities
self._selected_realm: Literal["ocean", "atmosphere"] | None = None # None == both eligible

def _initialize_coin_rng(self) -> None: # NEW — private; lazy-init, mirrors TimeLengthProbabilities.initialize_rng

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this _init_component_choice_rng

Comment thread pr-plan.md
# ) # identical across ranks -> identical coin decisions, reproducible from run seed
...

def flip_optimization_coin(self) -> None: # NEW — re-draw selected realm; no-op when disabled

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this component_optimization_choice

Comment thread pr-plan.md
@@ -0,0 +1,228 @@
# Add `use_single_component_optim` flag to `CoupledTrainStepperConfig`

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace use_single_component_optim with optimize_single_component_per_batch throughout.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants