From 5d00f3a670a1b726aa1b27b29fc19c96638b2ec8 Mon Sep 17 00:00:00 2001 From: Kaan Baloglu Date: Wed, 3 Jun 2026 11:13:12 -0700 Subject: [PATCH] DeferrableMetrics: emit INFO event on Future-backed deferred failure (#4275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Closes an observability gap on the Future-backed path of `DeferrableMetrics`. `RecMetricModule.compute` is decorated with `EventLoggingHandler.event_logger(REC_METRICS)`, which emits SUCCESS as soon as `compute()` returns. When `compute()` returns a `DeferrableMetrics` wrapping a Future, the actual metric computation runs on a worker thread and may raise long after the decorator already logged SUCCESS. The consumer's `resolve()` or `subscribe()` then raises outside the decorator's scope, so the failure never reaches `torchrec_event_logging`. From a 7d analysis of silent APS DEAD attempts (joined torchrec_event_logging ∩ mast_hpc_job_run_status), ~33% last logged `RecMetricModule.compute` as SUCCESS before dying — roughly 1,100 APS attempts/week with no torchrec-side attribution. `DeferrableMetrics.__init__` now registers an internal `add_done_callback` on Future-backed instances. The callback runs `f.result()` in a try/except; on exception it synchronously calls `EventLoggingHandler.log_event` with `event_type=INFO`, `event_name='DeferrableMetrics.deferred_failure'`, and metadata carrying `exception_type`, truncated `error_message`, and truncated `stack_trace`. Telemetry exceptions are swallowed so the caller is never affected. The done-callback fires once per Future regardless of whether `resolve()` / `subscribe()` is ever called, so unobserved futures still surface their failures. Pre-resolved (non-Future) instances are not instrumented. Mirrors the per-site capture pattern from `DataLoadingThread` (D105462584): instrument at the boundary where the Future actually resolves, not where the wrapping call returned. **Design choices**: - **Synchronous emit, no background worker / queue / daemon thread**. Matches the existing pattern in `torchrec/metrics/cpu_offloaded_metric_module.py:557` (the metric worker thread itself runs synchronous `EventLoggingHandler.event_logger` calls) and `metric_module.py:339,412`. The FB shim adds to an in-memory Scuba sample buffer; per-call cost is microseconds. - **`event_type=INFO`, not `FAILURE`**. Per Nipun's review on D105462584, the dataset invariant `#START == #SUCCESS + #FAILURE` requires every FAILURE to pair with a START. Standalone failure records should use INFO. The START/SUCCESS for the enclosing `RecMetricModule.compute` call is already emitted by the existing decorator; we don't need to duplicate it. - **No JK gate**. Pure additive observability — no behavior change for any caller. `resolve()` / `subscribe()` / `update()` return the same values and raise the same exceptions. The only new artifact is rows in `torchrec_event_logging`. Existing RecMetrics decorators are also unconditional; this matches the local convention. - **Defensive import block for `EventLoggingHandler` / `TorchrecComponent` / `EventType`**. Mirrors the guarded import in `metric_module.py:30-65` and `cpu_offloaded_metric_module.py:27-65`. `DeferrableMetrics` is imported by `metric_module.py`, which gets packaged into inference paths without the logging handler shim; an unconditional import here would break those packages. The fallback substitutes a no-op `log_event`. **Known tradeoff**: emit is unsampled. Per-process volume is bounded by `Future-backed metric computations × ranks`. Typical APS eval job (~10 compute calls × ~128 ranks) caps at ~1,280 events for a persistent-failure run — small relative to the 200K+ events / 28d already on `torchrec_event_logging`. If canary shows a single job emitting tens of thousands, the right follow-up is a process-local cap (first N per `exception_type` then drop with one warning). Sampling tools (`n_batch_log_event`) lose signal in non-APF contexts where the batch counter isn't updated. Scope: only `DeferrableMetrics` Future-backed instances. `RecMetricModule.update` / `compute` decorators unchanged; per-metric implementations unchanged. Anything that wraps a Future in a different container (raw `Future`, custom awaitable) is not covered. Differential Revision: D107334098 --- torchrec/metrics/deferrable_metrics.py | 82 +++++++++++++++++++ .../metrics/tests/test_deferrable_metrics.py | 65 +++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/torchrec/metrics/deferrable_metrics.py b/torchrec/metrics/deferrable_metrics.py index 7264b72ae..d842a9cd2 100644 --- a/torchrec/metrics/deferrable_metrics.py +++ b/torchrec/metrics/deferrable_metrics.py @@ -19,14 +19,62 @@ """ import logging +import traceback from collections.abc import Iterator, Mapping from concurrent.futures import Future from typing import Any, Callable import torch +try: + # Guarded: TorchRec is packaged into inference paths without the logging + # handler shim; an unconditional import would break those packages. + from torchrec.distributed.logging_handlers import ( + EventLoggingHandler, + TorchrecComponent, + ) + from torchrec.distributed.logging_utils import EventType +except Exception: + torch._C._log_api_usage_once( + "torchrec.metrics.deferrable_metrics.import_failure.logging_handlers" + ) + + from enum import Enum as _Enum + from typing import TYPE_CHECKING + + if TYPE_CHECKING: + from torchrec.distributed.logging_handlers import ( + EventLoggingHandler, + TorchrecComponent, + ) + from torchrec.distributed.logging_utils import EventType + else: + + class TorchrecComponent(_Enum): + REC_METRICS = "rec_metrics" + + class EventType(_Enum): + INFO = "INFO" + + class EventLoggingHandler: + @staticmethod + def log_event(*args: object, **kwargs: object) -> None: + pass + + logger: logging.Logger = logging.getLogger(__name__) +_EVENT_NAME: str = "DeferrableMetrics.deferred_failure" +_ERROR_MESSAGE_MAX_LEN: int = 4096 +_STACK_TRACE_MAX_LEN: int = 8192 +_TRUNCATION_MARKER: str = "...[truncated]" + + +def _truncate(s: str, n: int) -> str: + if len(s) <= n: + return s + return s[: max(0, n - len(_TRUNCATION_MARKER))] + _TRUNCATION_MARKER + def device_supports_async(device: torch.device) -> bool: """Check if a device supports non-blocking async transfers (CUDA events).""" @@ -79,6 +127,13 @@ class DeferrableMetrics(Mapping[str, Any]): Implements Mapping[str, Any] so it is a drop-in replacement for Dict[str, MetricValue] at both type and runtime level. Dict-style access (__getitem__, __iter__, __len__) calls resolve() internally. + + Future-backed instances emit a `DeferrableMetrics.deferred_failure` + INFO event to `torchrec_event_logging` if the Future raises. Success + is already captured by the enclosing `RecMetricModule.compute` decorator, + so no SUCCESS counterpart is emitted here. The done-callback fires once + per Future regardless of whether resolve/subscribe is called, so + unobserved futures still surface their failures. """ _warned: bool = False @@ -96,6 +151,33 @@ def __init__( self._data = dict(inner) self._resolved = True + if self._future is not None: + self._future.add_done_callback(self._emit_failure_if_raised) + + def _emit_failure_if_raised(self, f: Future[dict[str, Any]]) -> None: + # Runs on the Future's done-callback thread. Emit is synchronous to + # match the local @event_logger convention; telemetry must never + # raise into the caller. + try: + f.result() + except BaseException as e: # noqa: B036 + try: + EventLoggingHandler.log_event( + component=TorchrecComponent.REC_METRICS.value, + event_name=_EVENT_NAME, + event_type=EventType.INFO, + metadata={"exception_type": type(e).__name__}, + error_message=_truncate(str(e), _ERROR_MESSAGE_MAX_LEN), + stack_trace=_truncate( + "".join( + traceback.format_exception(type(e), e, e.__traceback__) + ), + _STACK_TRACE_MAX_LEN, + ), + ) + except BaseException: # noqa: B036 + pass + def _warn_sync_access(self) -> None: """Log a warning once per process when dict-style access triggers resolve() on a Future-backed instance.""" diff --git a/torchrec/metrics/tests/test_deferrable_metrics.py b/torchrec/metrics/tests/test_deferrable_metrics.py index 896ff19ce..65fc4c2eb 100644 --- a/torchrec/metrics/tests/test_deferrable_metrics.py +++ b/torchrec/metrics/tests/test_deferrable_metrics.py @@ -16,6 +16,7 @@ from torchrec.metrics.deferrable_metrics import ( DeferrableMetrics, device_supports_async, + EventType, transfer_tensors_to_cpu, ) @@ -273,6 +274,70 @@ def test_future_backed_update_dict_merges(self) -> None: self.assertEqual(received[0]["cpu_val"], "extra") +class FailureCaptureTest(unittest.TestCase): + """Verifies INFO-event capture for deferred failures from Future-backed instances.""" + + def setUp(self) -> None: + DeferrableMetrics._warned = False + self._log = patch( + "torchrec.metrics.deferrable_metrics.EventLoggingHandler.log_event" + ) + self.mock_log = self._log.start() + self.addCleanup(self._log.stop) + + def test_pre_resolved_dict_emits_nothing(self) -> None: + DeferrableMetrics({"a": 1}).resolve() + self.assertEqual(self.mock_log.call_count, 0) + + def test_future_success_emits_nothing(self) -> None: + # Success is captured by the enclosing RecMetricModule.compute + # decorator's SUCCESS event; we only emit on the failure path. + f: Future[dict] = Future() + DeferrableMetrics(f) + f.set_result({"a": 1}) + self.assertEqual(self.mock_log.call_count, 0) + + def test_future_failure_emits_info_event_with_payload(self) -> None: + f: Future[dict] = Future() + dm = DeferrableMetrics(f) + f.set_exception(RuntimeError("metric blew up")) + with self.assertRaisesRegex(RuntimeError, "metric blew up"): + dm.resolve() + self.assertEqual(self.mock_log.call_count, 1) + kwargs = self.mock_log.call_args.kwargs + self.assertEqual(kwargs["event_name"], "DeferrableMetrics.deferred_failure") + self.assertEqual(kwargs["event_type"], EventType.INFO) + self.assertEqual(kwargs["metadata"]["exception_type"], "RuntimeError") + self.assertEqual(kwargs["error_message"], "metric blew up") + self.assertTrue(kwargs["stack_trace"]) + + def test_future_failure_via_subscribe_path(self) -> None: + f: Future[dict] = Future() + dm = DeferrableMetrics(f) + errors: list[Exception] = [] + dm.subscribe(lambda d: None, on_error=errors.append) + f.set_exception(RuntimeError("subscribe path")) + self.assertEqual(len(errors), 1) + self.assertEqual(self.mock_log.call_count, 1) + + def test_unobserved_future_failure_still_emits(self) -> None: + # Future-backed DeferrableMetrics constructed and the reference + # dropped; Future raises. The done-callback runs unconditionally. + f: Future[dict] = Future() + DeferrableMetrics(f) + f.set_exception(RuntimeError("orphaned")) + self.assertEqual(self.mock_log.call_count, 1) + + def test_telemetry_failure_does_not_raise(self) -> None: + # log_event itself raising must never propagate into the caller. + self.mock_log.side_effect = RuntimeError("scuba down") + f: Future[dict] = Future() + dm = DeferrableMetrics(f) + f.set_exception(ValueError("metric error")) + with self.assertRaisesRegex(ValueError, "metric error"): + dm.resolve() + + class TransferTensorsToCpuTest(unittest.TestCase): def test_cpu_tensors_passthrough(self) -> None: