fix(rollout): isolate per-trajectory exceptions in generate_and_rm_group#200
fix(rollout): isolate per-trajectory exceptions in generate_and_rm_group#200aoshen02 wants to merge 1 commit into
Conversation
asyncio.gather(*tasks) in generate_and_rm_group had no return_exceptions=True, so a single trajectory raising an unhandled exception cancelled the whole gather and crashed the entire rollout via CancelledError (which also swallows the root cause). This is benign for plain RLVR rollouts where generate_and_rm never raises, but agentic rollouts can raise after the custom generate() returns (e.g. trajectory token-merge / prefix-drift edge cases outside generate()'s own try/except). One bad sample took down an entire 500-instance batch. Catch per-trajectory exceptions, log the real traceback (exc_info), and substitute an ABORTED resolved=False placeholder (same fan-out list shape) so the batch completes. Mirrors the existing _abort() sample contract; ABORTED is already skipped by the reward-model and routing-replay paths. Same latent gap exists upstream in slime (sglang_rollout.py) and miles (inference_rollout_common.py). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces per-trajectory exception isolation in generate_and_rm_group using asyncio.gather(..., return_exceptions=True) to prevent a single trajectory failure from crashing the entire rollout batch. However, when an exception is caught, the code always appends the aborted sample wrapped in a list ([sample]), which can lead to a mixed-type list and downstream crashes if fan-out is not active. The reviewer suggested dynamically detecting whether fan-out is active and wrapping the aborted sample accordingly, while also simplifying the loop unpacking.
| results = await asyncio.gather(*[t for _, t in pairs], return_exceptions=True) | ||
| group = [] | ||
| for sample, res in zip([s for s, _ in pairs], results): | ||
| if isinstance(res, BaseException): | ||
| logger.error( | ||
| "[generate_and_rm_group] trajectory crashed, isolating idx=%s: %r", | ||
| getattr(sample, "index", "?"), res, exc_info=res, | ||
| ) | ||
| sample.tokens = [0, 0] | ||
| sample.response = "" | ||
| sample.response_length = 1 | ||
| sample.loss_mask = [0] | ||
| sample.rollout_log_probs = [0.0] | ||
| sample.reward = 0.0 | ||
| sample.status = Sample.Status.ABORTED | ||
| group.append([sample]) | ||
| else: | ||
| group.append(res) |
There was a problem hiding this comment.
In standard rollouts (without fan-out), generate_and_rm returns a single Sample object, meaning group is expected to be a flat list[Sample]. However, when a trajectory raises an exception, the current implementation always appends [sample] (a list containing the sample) to group on line 516. This results in a mixed-type list (e.g., [Sample, [Sample], Sample]), which will cause downstream components (such as reward models or filters) to crash with AttributeError or TypeError when they attempt to process the group.
To prevent this, we should dynamically detect whether fan-out is being used by checking if any of the successful results are lists, and wrap the aborted sample in a list only if fan-out is active. Additionally, we can simplify the loop iteration by directly unpacking pairs instead of creating an intermediate list with list comprehension.
| results = await asyncio.gather(*[t for _, t in pairs], return_exceptions=True) | |
| group = [] | |
| for sample, res in zip([s for s, _ in pairs], results): | |
| if isinstance(res, BaseException): | |
| logger.error( | |
| "[generate_and_rm_group] trajectory crashed, isolating idx=%s: %r", | |
| getattr(sample, "index", "?"), res, exc_info=res, | |
| ) | |
| sample.tokens = [0, 0] | |
| sample.response = "" | |
| sample.response_length = 1 | |
| sample.loss_mask = [0] | |
| sample.rollout_log_probs = [0.0] | |
| sample.reward = 0.0 | |
| sample.status = Sample.Status.ABORTED | |
| group.append([sample]) | |
| else: | |
| group.append(res) | |
| results = await asyncio.gather(*[t for _, t in pairs], return_exceptions=True) | |
| is_fanout = any(isinstance(res, list) for res in results if not isinstance(res, BaseException)) | |
| group = [] | |
| for (sample, _), res in zip(pairs, results): | |
| if isinstance(res, BaseException): | |
| logger.error( | |
| "[generate_and_rm_group] trajectory crashed, isolating idx=%s: %r", | |
| getattr(sample, "index", "?"), res, exc_info=res, | |
| ) | |
| sample.tokens = [0, 0] | |
| sample.response = "" | |
| sample.response_length = 1 | |
| sample.loss_mask = [0] | |
| sample.rollout_log_probs = [0.0] | |
| sample.reward = 0.0 | |
| sample.status = Sample.Status.ABORTED | |
| group.append([sample] if is_fanout else sample) | |
| else: | |
| group.append(res) |
Problem
generate_and_rm_groupgathers per-trajectory tasks with a bareasyncio.gather(*tasks)(noreturn_exceptions=True). If any single trajectory raises an unhandled exception, gather cancels the siblings and propagates, crashing the entire rollout viaCancelledError— which also swallows the root exception (logs show only CancelledError, not what actually failed).This is benign for plain RLVR rollouts (where
generate_and_rmreliably catches its own errors), which is why it has not surfaced before. But agentic rollouts can raise after the customgenerate()returns — e.g. trajectory token-merge / prefix-drift edge cases that run outsidegenerate()'s own try/except. Observed on a 500-instance SWE-bench eval: a single bad trajectory (~1 in 350-400) reproducibly took down all 500 (crashed ~321 and ~373 on two runs), with only aCancelledErrorin the logs.Fix
Catch per-trajectory exceptions at the group gather:
return_exceptions=Trueso one failure no longer cancels the batch.logger.error(..., exc_info=res)to surface the real traceback (currently swallowed by CancelledError).ABORTED/resolved=Falseplaceholder with the same fan-out list shape, reusing the existing_abort()sample contract (tokens=[0,0],loss_mask=[0],status=ABORTED,reward=0.0).ABORTEDis already a first-class status that downstream short-circuits (reward-model skip atgenerate_and_rm, routing-replay skip), so the placeholder introduces no new sample shape — it is identical to what every timeout / missing-image abort already produces.Notes
slime/rollout/sglang_rollout.py) and miles (inference_rollout_common.py) — both have the identical bare gather; worth porting there too.