Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions torchrec/metrics/deferrable_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,62 @@
"""

import logging
import traceback
from collections.abc import Iterator, Mapping
from concurrent.futures import Future
from typing import Any, Callable

import torch

try:
# Guarded: TorchRec is packaged into inference paths without the logging
# handler shim; an unconditional import would break those packages.
from torchrec.distributed.logging_handlers import (
EventLoggingHandler,
TorchrecComponent,
)
from torchrec.distributed.logging_utils import EventType
except Exception:
torch._C._log_api_usage_once(
"torchrec.metrics.deferrable_metrics.import_failure.logging_handlers"
)

from enum import Enum as _Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from torchrec.distributed.logging_handlers import (
EventLoggingHandler,
TorchrecComponent,
)
from torchrec.distributed.logging_utils import EventType
else:

class TorchrecComponent(_Enum):
REC_METRICS = "rec_metrics"

class EventType(_Enum):
INFO = "INFO"

class EventLoggingHandler:
@staticmethod
def log_event(*args: object, **kwargs: object) -> None:
pass


logger: logging.Logger = logging.getLogger(__name__)

_EVENT_NAME: str = "DeferrableMetrics.deferred_failure"
_ERROR_MESSAGE_MAX_LEN: int = 4096
_STACK_TRACE_MAX_LEN: int = 8192
_TRUNCATION_MARKER: str = "...[truncated]"


def _truncate(s: str, n: int) -> str:
if len(s) <= n:
return s
return s[: max(0, n - len(_TRUNCATION_MARKER))] + _TRUNCATION_MARKER


def device_supports_async(device: torch.device) -> bool:
"""Check if a device supports non-blocking async transfers (CUDA events)."""
Expand Down Expand Up @@ -79,6 +127,13 @@ class DeferrableMetrics(Mapping[str, Any]):
Implements Mapping[str, Any] so it is a drop-in replacement for
Dict[str, MetricValue] at both type and runtime level. Dict-style
access (__getitem__, __iter__, __len__) calls resolve() internally.

Future-backed instances emit a `DeferrableMetrics.deferred_failure`
INFO event to `torchrec_event_logging` if the Future raises. Success
is already captured by the enclosing `RecMetricModule.compute` decorator,
so no SUCCESS counterpart is emitted here. The done-callback fires once
per Future regardless of whether resolve/subscribe is called, so
unobserved futures still surface their failures.
"""

_warned: bool = False
Expand All @@ -96,6 +151,33 @@ def __init__(
self._data = dict(inner)
self._resolved = True

if self._future is not None:
self._future.add_done_callback(self._emit_failure_if_raised)

def _emit_failure_if_raised(self, f: Future[dict[str, Any]]) -> None:
# Runs on the Future's done-callback thread. Emit is synchronous to
# match the local @event_logger convention; telemetry must never
# raise into the caller.
try:
f.result()
except BaseException as e: # noqa: B036
try:
EventLoggingHandler.log_event(
component=TorchrecComponent.REC_METRICS.value,
event_name=_EVENT_NAME,
event_type=EventType.INFO,
metadata={"exception_type": type(e).__name__},
error_message=_truncate(str(e), _ERROR_MESSAGE_MAX_LEN),
stack_trace=_truncate(
"".join(
traceback.format_exception(type(e), e, e.__traceback__)
),
_STACK_TRACE_MAX_LEN,
),
)
except BaseException: # noqa: B036
pass

def _warn_sync_access(self) -> None:
"""Log a warning once per process when dict-style access
triggers resolve() on a Future-backed instance."""
Expand Down
65 changes: 65 additions & 0 deletions torchrec/metrics/tests/test_deferrable_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchrec.metrics.deferrable_metrics import (
DeferrableMetrics,
device_supports_async,
EventType,
transfer_tensors_to_cpu,
)

Expand Down Expand Up @@ -273,6 +274,70 @@ def test_future_backed_update_dict_merges(self) -> None:
self.assertEqual(received[0]["cpu_val"], "extra")


class FailureCaptureTest(unittest.TestCase):
"""Verifies INFO-event capture for deferred failures from Future-backed instances."""

def setUp(self) -> None:
DeferrableMetrics._warned = False
self._log = patch(
"torchrec.metrics.deferrable_metrics.EventLoggingHandler.log_event"
)
self.mock_log = self._log.start()
self.addCleanup(self._log.stop)

def test_pre_resolved_dict_emits_nothing(self) -> None:
DeferrableMetrics({"a": 1}).resolve()
self.assertEqual(self.mock_log.call_count, 0)

def test_future_success_emits_nothing(self) -> None:
# Success is captured by the enclosing RecMetricModule.compute
# decorator's SUCCESS event; we only emit on the failure path.
f: Future[dict] = Future()
DeferrableMetrics(f)
f.set_result({"a": 1})
self.assertEqual(self.mock_log.call_count, 0)

def test_future_failure_emits_info_event_with_payload(self) -> None:
f: Future[dict] = Future()
dm = DeferrableMetrics(f)
f.set_exception(RuntimeError("metric blew up"))
with self.assertRaisesRegex(RuntimeError, "metric blew up"):
dm.resolve()
self.assertEqual(self.mock_log.call_count, 1)
kwargs = self.mock_log.call_args.kwargs
self.assertEqual(kwargs["event_name"], "DeferrableMetrics.deferred_failure")
self.assertEqual(kwargs["event_type"], EventType.INFO)
self.assertEqual(kwargs["metadata"]["exception_type"], "RuntimeError")
self.assertEqual(kwargs["error_message"], "metric blew up")
self.assertTrue(kwargs["stack_trace"])

def test_future_failure_via_subscribe_path(self) -> None:
f: Future[dict] = Future()
dm = DeferrableMetrics(f)
errors: list[Exception] = []
dm.subscribe(lambda d: None, on_error=errors.append)
f.set_exception(RuntimeError("subscribe path"))
self.assertEqual(len(errors), 1)
self.assertEqual(self.mock_log.call_count, 1)

def test_unobserved_future_failure_still_emits(self) -> None:
# Future-backed DeferrableMetrics constructed and the reference
# dropped; Future raises. The done-callback runs unconditionally.
f: Future[dict] = Future()
DeferrableMetrics(f)
f.set_exception(RuntimeError("orphaned"))
self.assertEqual(self.mock_log.call_count, 1)

def test_telemetry_failure_does_not_raise(self) -> None:
# log_event itself raising must never propagate into the caller.
self.mock_log.side_effect = RuntimeError("scuba down")
f: Future[dict] = Future()
dm = DeferrableMetrics(f)
f.set_exception(ValueError("metric error"))
with self.assertRaisesRegex(ValueError, "metric error"):
dm.resolve()


class TransferTensorsToCpuTest(unittest.TestCase):

def test_cpu_tensors_passthrough(self) -> None:
Expand Down
Loading