Close DataLoadingThread silent-death observability gap (#4270)#4270
Open
kaanbaloglu wants to merge 1 commit into
Open
Close DataLoadingThread silent-death observability gap (#4270)#4270kaanbaloglu wants to merge 1 commit into
kaanbaloglu wants to merge 1 commit into
Conversation
Contributor
|
@kaanbaloglu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105462584. |
kaanbaloglu
added a commit
to kaanbaloglu/torchrec
that referenced
this pull request
May 18, 2026
…4270) Summary: Behind JK `pytorch/torchrec:enable_data_loading_thread_failure_capture` (default off). `DataLoadingThread.run()` at `torchrec/distributed/train_pipeline/utils.py` currently only catches `StopIteration`. Any other exception — most commonly CUDA OOM on `batch.to(device)` or `OSError` on the Hive/Manifold-backed `next(self._dataloader_iter)` — kills the daemon thread silently. The consumer in `TrainPipelineFusedSparseDist.progress()` then blocks forever on `_buffer_filled_event.wait()`, because no producer is left to set it. Symptom: the training job appears stuck, gets SIGABRT'd by DPP starvation kill at 1500s, and surfaces in MAST as the generic `DPP_WORKER_STUCK_FULL_OUTPUT_QUEUE` — root cause is invisible to investigators. This is gap meta-pytorch#2 in the torchrec Scuba audit and shares the misclassification class with the dataloader hang fixed in D103494006 / D105399948. When the JK is on, the safe path captures non-`StopIteration` exceptions, emits a `FAILURE` event to `torchrec_event_logging`, and wakes the consumer with the captured error so `get_next_batch()` re-raises it promptly instead of hanging. Design details: 1. `_captured_exception_event` is a NEW `threading.Event`, separate from the existing `_buffer_filled_event`. Overloading `_buffer_filled_event` would muddy the existing "buffer filled vs end-of-stream vs stop" invariant — that event already has three legitimate setters (normal fill, `StopIteration` exit, `stop()`). A dedicated event keeps the failure-signal channel orthogonal and matches the gold-standard pattern in `torchrec/metrics/cpu_offloaded_metric_module.py:193-194, 300-302, 601-660`. 2. `stage` in the FAILURE metadata distinguishes `next_iterator` vs `copy_to_device` so investigators can see whether the exception came from the dataloader source (Hive/Manifold/etc.) or the host-to-device copy (CUDA OOM, MTIA OOM). One event name (`DataLoadingThread.fetch_failure`) keeps dashboards/alerts single-rooted; the stage axis is queryable via `metadata.stage`. 3. Captured exception is terminal — `get_next_batch()` re-raises the same captured exception on every subsequent call. Recovery is `TrainPipelineFusedSparseDist.reset()`, which already drops and rebuilds `_batch_loader`. Matches the contract in `cpu_offloaded_metric_module.py:300-302`. 4. `default=False` is passed explicitly to `torch._utils_internal.justknobs_check` to dodge the wrapper's default-True trap. A pinning test guards the regression that bit D105399948 round 2. 5. The defensive `EventLoggingHandler` / `TorchrecComponent` import block copies the template from `cpu_offloaded_metric_module.py:23-60` — handles torch-package contexts where even the OSS shim at `torchrec/distributed/logging_handlers.py` is unavailable. 6. `StopIteration` remains a separate `except` arm before `except Exception`, with its original behavior preserved exactly — end-of-epoch is a normal-termination path, not a failure. A regression test asserts no FAILURE event fires on natural iterator exhaustion. Off-path: when the JK is off, `run()` executes the original `try/except StopIteration` block unchanged. The new `_captured_exception` / `_captured_exception_event` state is initialized but never written; `get_next_batch()`'s new check gates on `_capture_failures_enabled` so it's a no-op when the JK is off. Bit-exact preservation per the killswitch-fallback rule. Scope: only `TrainPipelineFusedSparseDist` uses `DataLoadingThread` in production today (`train_pipelines.py:1436, 2321, 2338`). `EvalPipelineFusedSparseDist` inherits the field but doesn't exercise it. Differential Revision: D105462584
26dc555 to
61d9192
Compare
61d9192 to
e6c616b
Compare
a94f5e4 to
57fdc9c
Compare
kaanbaloglu
added a commit
to kaanbaloglu/torchrec
that referenced
this pull request
May 22, 2026
…4270) Summary: Behind JK `pytorch/torchrec:enable_data_loading_thread_failure_capture` (default off). `DataLoadingThread.run()` at `torchrec/distributed/train_pipeline/utils.py` currently only catches `StopIteration`. Any other exception — most commonly CUDA OOM on `batch.to(device)` or `OSError` on the Hive/Manifold-backed `next(self._dataloader_iter)` — kills the daemon thread silently. The consumer in `TrainPipelineFusedSparseDist.progress()` then blocks forever on `_buffer_filled_event.wait()`, because no producer is left to set it. Symptom: the training job appears stuck, gets SIGABRT'd by DPP starvation kill at 1500s, and surfaces in MAST as the generic `DPP_WORKER_STUCK_FULL_OUTPUT_QUEUE` — root cause is invisible to investigators. Shares the misclassification class with the dataloader hang fixed in D103494006 / D105399948: a real failure (CUDA OOM, dataloader timeout) gets reported to MAST as a generic stuck-job kill, hiding the actual error type and stack from investigators. Closing this site removes one source of those misclassifications. When the JK is on, the safe path captures non-`StopIteration` exceptions, emits a `FAILURE` event to `torchrec_event_logging`, and wakes the consumer with the captured error so `get_next_batch()` re-raises it promptly instead of hanging. Design details: 1. `_captured_exception_event` is a NEW `threading.Event`, separate from the existing `_buffer_filled_event`. Overloading `_buffer_filled_event` would muddy the existing "buffer filled vs end-of-stream vs stop" invariant — that event already has three legitimate setters (normal fill, `StopIteration` exit, `stop()`). A dedicated event keeps the failure-signal channel orthogonal and matches the gold-standard pattern in `torchrec/metrics/cpu_offloaded_metric_module.py:193-194, 300-302, 601-660`. 2. `stage` in the FAILURE metadata distinguishes `next_iterator` vs `copy_to_device` so investigators can see whether the exception came from the dataloader source (Hive/Manifold/etc.) or the host-to-device copy (CUDA OOM, MTIA OOM). One event name (`DataLoadingThread.fetch_failure`) keeps dashboards/alerts single-rooted; the stage axis is queryable via `metadata.stage`. 3. Captured exception is terminal — `get_next_batch()` re-raises the same captured exception on every subsequent call. Recovery is `TrainPipelineFusedSparseDist.reset()`, which already drops and rebuilds `_batch_loader`. Matches the contract in `cpu_offloaded_metric_module.py:300-302`. 4. `default=False` is passed explicitly to `torch._utils_internal.justknobs_check` to dodge the wrapper's default-True trap. A pinning test guards the regression that bit D105399948 round 2. 5. The defensive `EventLoggingHandler` / `TorchrecComponent` import block copies the template from `cpu_offloaded_metric_module.py:23-60` — handles torch-package contexts where even the OSS shim at `torchrec/distributed/logging_handlers.py` is unavailable. 6. `StopIteration` remains a separate `except` arm before `except Exception`, with its original behavior preserved exactly — end-of-epoch is a normal-termination path, not a failure. A regression test asserts no FAILURE event fires on natural iterator exhaustion. Off-path: when the JK is off, `run()` executes the original `try/except StopIteration` block unchanged. The new `_captured_exception` / `_captured_exception_event` state is initialized but never written; `get_next_batch()`'s new check gates on `_capture_failures_enabled` so it's a no-op when the JK is off. Bit-exact preservation per the killswitch-fallback rule. Scope: only `TrainPipelineFusedSparseDist` uses `DataLoadingThread` in production today (`train_pipelines.py:1436, 2321, 2338`). `EvalPipelineFusedSparseDist` inherits the field but doesn't exercise it. Differential Revision: D105462584
57fdc9c to
0e171f9
Compare
…4270) Summary: Behind JK `pytorch/torchrec:enable_data_loading_thread_failure_capture` (default off). `DataLoadingThread.run()` at `torchrec/distributed/train_pipeline/utils.py` currently only catches `StopIteration`. Any other exception — most commonly CUDA OOM on `batch.to(device)` or `OSError` on the Hive/Manifold-backed `next(self._dataloader_iter)` — kills the daemon thread silently. The consumer in `TrainPipelineFusedSparseDist.progress()` then blocks forever on `_buffer_filled_event.wait()`, because no producer is left to set it. Symptom: the training job appears stuck, gets SIGABRT'd by DPP starvation kill at 1500s, and surfaces in MAST as the generic `DPP_WORKER_STUCK_FULL_OUTPUT_QUEUE` — root cause is invisible to investigators. Shares the misclassification class with the dataloader hang fixed in D103494006 / D105399948: a real failure (CUDA OOM, dataloader timeout) gets reported to MAST as a generic stuck-job kill, hiding the actual error type and stack from investigators. Closing this site removes one source of those misclassifications. When the JK is on, the safe path captures non-`StopIteration` exceptions, emits a `FAILURE` event to `torchrec_event_logging`, and wakes the consumer with the captured error so `get_next_batch()` re-raises it promptly instead of hanging. Design details: 1. `_captured_exception_event` is a NEW `threading.Event`, separate from the existing `_buffer_filled_event`. Overloading `_buffer_filled_event` would muddy the existing "buffer filled vs end-of-stream vs stop" invariant — that event already has three legitimate setters (normal fill, `StopIteration` exit, `stop()`). A dedicated event keeps the failure-signal channel orthogonal and matches the gold-standard pattern in `torchrec/metrics/cpu_offloaded_metric_module.py:193-194, 300-302, 601-660`. 2. `stage` in the FAILURE metadata distinguishes `next_iterator` vs `copy_to_device` so investigators can see whether the exception came from the dataloader source (Hive/Manifold/etc.) or the host-to-device copy (CUDA OOM, MTIA OOM). One event name (`DataLoadingThread.fetch_failure`) keeps dashboards/alerts single-rooted; the stage axis is queryable via `metadata.stage`. 3. Captured exception is terminal — `get_next_batch()` re-raises the same captured exception on every subsequent call. Recovery is `TrainPipelineFusedSparseDist.reset()`, which already drops and rebuilds `_batch_loader`. Matches the contract in `cpu_offloaded_metric_module.py:300-302`. 4. `default=False` is passed explicitly to `torch._utils_internal.justknobs_check` to dodge the wrapper's default-True trap. A pinning test guards the regression that bit D105399948 round 2. 5. The defensive `EventLoggingHandler` / `TorchrecComponent` import block copies the template from `cpu_offloaded_metric_module.py:23-60` — handles torch-package contexts where even the OSS shim at `torchrec/distributed/logging_handlers.py` is unavailable. 6. `StopIteration` remains a separate `except` arm before `except Exception`, with its original behavior preserved exactly — end-of-epoch is a normal-termination path, not a failure. A regression test asserts no FAILURE event fires on natural iterator exhaustion. Off-path: when the JK is off, `run()` executes the original `try/except StopIteration` block unchanged. The new `_captured_exception` / `_captured_exception_event` state is initialized but never written; `get_next_batch()`'s new check gates on `_capture_failures_enabled` so it's a no-op when the JK is off. Bit-exact preservation per the killswitch-fallback rule. Scope: only `TrainPipelineFusedSparseDist` uses `DataLoadingThread` in production today (`train_pipelines.py:1436, 2321, 2338`). `EvalPipelineFusedSparseDist` inherits the field but doesn't exercise it. Differential Revision: D105462584
0e171f9 to
b496366
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
Behind JK
pytorch/torchrec:enable_data_loading_thread_failure_capture(default off).DataLoadingThread.run()attorchrec/distributed/train_pipeline/utils.pycurrently only catchesStopIteration. Any other exception — most commonly CUDA OOM onbatch.to(device)orOSErroron the Hive/Manifold-backednext(self._dataloader_iter)— kills the daemon thread silently. The consumer inTrainPipelineFusedSparseDist.progress()then blocks forever on_buffer_filled_event.wait(), because no producer is left to set it. Symptom: the training job appears stuck, gets SIGABRT'd by DPP starvation kill at 1500s, and surfaces in MAST as the genericDPP_WORKER_STUCK_FULL_OUTPUT_QUEUE— root cause is invisible to investigators.Shares the misclassification class with the dataloader hang fixed in D103494006 / D105399948: a real failure (CUDA OOM, dataloader timeout) gets reported to MAST as a generic stuck-job kill, hiding the actual error type and stack from investigators. Closing this site removes one source of those misclassifications.
When the JK is on, the safe path captures non-
StopIterationexceptions, emits aFAILUREevent totorchrec_event_logging, and wakes the consumer with the captured error soget_next_batch()re-raises it promptly instead of hanging.Design details:
_captured_exception_eventis a NEWthreading.Event, separate from the existing_buffer_filled_event. Overloading_buffer_filled_eventwould muddy the existing "buffer filled vs end-of-stream vs stop" invariant — that event already has three legitimate setters (normal fill,StopIterationexit,stop()). A dedicated event keeps the failure-signal channel orthogonal and matches the gold-standard pattern intorchrec/metrics/cpu_offloaded_metric_module.py:193-194, 300-302, 601-660.stagein the FAILURE metadata distinguishesnext_iteratorvscopy_to_deviceso investigators can see whether the exception came from the dataloader source (Hive/Manifold/etc.) or the host-to-device copy (CUDA OOM, MTIA OOM). One event name (DataLoadingThread.fetch_failure) keeps dashboards/alerts single-rooted; the stage axis is queryable viametadata.stage.Captured exception is terminal —
get_next_batch()re-raises the same captured exception on every subsequent call. Recovery isTrainPipelineFusedSparseDist.reset(), which already drops and rebuilds_batch_loader. Matches the contract incpu_offloaded_metric_module.py:300-302.default=Falseis passed explicitly totorch._utils_internal.justknobs_checkto dodge the wrapper's default-True trap. A pinning test guards the regression that bit D105399948 round 2.The defensive
EventLoggingHandler/TorchrecComponentimport block copies the template fromcpu_offloaded_metric_module.py:23-60— handles torch-package contexts where even the OSS shim attorchrec/distributed/logging_handlers.pyis unavailable.StopIterationremains a separateexceptarm beforeexcept Exception, with its original behavior preserved exactly — end-of-epoch is a normal-termination path, not a failure. A regression test asserts no FAILURE event fires on natural iterator exhaustion.Off-path: when the JK is off,
run()executes the originaltry/except StopIterationblock unchanged. The new_captured_exception/_captured_exception_eventstate is initialized but never written;get_next_batch()'s new check gates on_capture_failures_enabledso it's a no-op when the JK is off. Bit-exact preservation per the killswitch-fallback rule.Scope: only
TrainPipelineFusedSparseDistusesDataLoadingThreadin production today (train_pipelines.py:1436, 2321, 2338).EvalPipelineFusedSparseDistinherits the field but doesn't exercise it.Differential Revision: D105462584