Skip to content

Optimize old_per_token_logps recomputation in GRPOTrainer: per-rollout window check instead of static modulo #5770

@wengeezhang

Description

@wengeezhang

Feature request

Replace the static modulo check that gates old_per_token_logps recomputation in GRPOTrainer with a precise per-rollout window check.

Currently, when gradient_accumulation_steps is not a multiple of steps_per_generation × num_iterations, the trainer recomputes old_per_token_logps for every rollout — even when the specific rollout's consumption window does not cross an optimizer.step() boundary.

The proposed change replaces:

if self.args.gradient_accumulation_steps % generate_every != 0:
    old_per_token_logps = recompute(...)

with a per-rollout time-domain check:

a = self._step
b = ((a // generate_every) + 1) * generate_every - 1
weights_change_mid_rollout = (a // grad_accum) != (b // grad_accum)
if weights_change_mid_rollout:
    old_per_token_logps = recompute(...)

This skips unnecessary forward passes when the rollout window stays within a single grad-accum bucket, while preserving identical behaviour in the default (aligned) configuration.

Related PR: #5757

Motivation

In non-aligned configurations (where gradient_accumulation_steps is not a multiple of generate_every), the current static modulo check is over-conservative — it triggers a full _get_per_token_logps_and_entropies forward pass for every rollout, regardless of whether the model weights actually changed mid-window.

Concrete example (spg=2, num_iter=1, grad_accum=3, generate_every=2):

  • Only 2 out of every 6 rollouts truly need recomputation
  • The current check recomputes for all 6 — wasting 4 forward passes

This adds up to significant unnecessary compute in large-scale GRPO training runs with non-default accumulation settings. The fix is correctness-preserving and introduces zero regression for the default config path.

Your contribution

I have a working implementation ready and will submit a PR shortly #5757 . The change is minimal (~10 lines in grpo_trainer.py) and correctness-preserving — identical behaviour in the default config, with fewer unnecessary forward passes in non-aligned configurations. I will also include a unit test to verify the recomputation decisions for representative non-aligned settings (e.g. spg=2, num_iter=1, grad_accum=3).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions