Feature request
Replace the static modulo check that gates old_per_token_logps recomputation in GRPOTrainer with a precise per-rollout window check.
Currently, when gradient_accumulation_steps is not a multiple of steps_per_generation × num_iterations, the trainer recomputes old_per_token_logps for every rollout — even when the specific rollout's consumption window does not cross an optimizer.step() boundary.
The proposed change replaces:
if self.args.gradient_accumulation_steps % generate_every != 0:
old_per_token_logps = recompute(...)
with a per-rollout time-domain check:
a = self._step
b = ((a // generate_every) + 1) * generate_every - 1
weights_change_mid_rollout = (a // grad_accum) != (b // grad_accum)
if weights_change_mid_rollout:
old_per_token_logps = recompute(...)
This skips unnecessary forward passes when the rollout window stays within a single grad-accum bucket, while preserving identical behaviour in the default (aligned) configuration.
Related PR: #5757
Motivation
In non-aligned configurations (where gradient_accumulation_steps is not a multiple of generate_every), the current static modulo check is over-conservative — it triggers a full _get_per_token_logps_and_entropies forward pass for every rollout, regardless of whether the model weights actually changed mid-window.
Concrete example (spg=2, num_iter=1, grad_accum=3, generate_every=2):
- Only 2 out of every 6 rollouts truly need recomputation
- The current check recomputes for all 6 — wasting 4 forward passes
This adds up to significant unnecessary compute in large-scale GRPO training runs with non-default accumulation settings. The fix is correctness-preserving and introduces zero regression for the default config path.
Your contribution
I have a working implementation ready and will submit a PR shortly #5757 . The change is minimal (~10 lines in grpo_trainer.py) and correctness-preserving — identical behaviour in the default config, with fewer unnecessary forward passes in non-aligned configurations. I will also include a unit test to verify the recomputation decisions for representative non-aligned settings (e.g. spg=2, num_iter=1, grad_accum=3).
Feature request
Replace the static modulo check that gates
old_per_token_logpsrecomputation inGRPOTrainerwith a precise per-rollout window check.Currently, when
gradient_accumulation_stepsis not a multiple ofsteps_per_generation × num_iterations, the trainer recomputesold_per_token_logpsfor every rollout — even when the specific rollout's consumption window does not cross anoptimizer.step()boundary.The proposed change replaces:
with a per-rollout time-domain check:
This skips unnecessary forward passes when the rollout window stays within a single grad-accum bucket, while preserving identical behaviour in the default (aligned) configuration.
Related PR: #5757
Motivation
In non-aligned configurations (where
gradient_accumulation_stepsis not a multiple ofgenerate_every), the current static modulo check is over-conservative — it triggers a full_get_per_token_logps_and_entropiesforward pass for every rollout, regardless of whether the model weights actually changed mid-window.Concrete example (
spg=2,num_iter=1,grad_accum=3,generate_every=2):This adds up to significant unnecessary compute in large-scale GRPO training runs with non-default accumulation settings. The fix is correctness-preserving and introduces zero regression for the default config path.
Your contribution
I have a working implementation ready and will submit a PR shortly #5757 . The change is minimal (~10 lines in
grpo_trainer.py) and correctness-preserving — identical behaviour in the default config, with fewer unnecessary forward passes in non-aligned configurations. I will also include a unit test to verify the recomputation decisions for representative non-aligned settings (e.g.spg=2, num_iter=1, grad_accum=3).