diff --git a/examples/job/README.md b/examples/job/README.md index ed591838fa..65e1e70750 100644 --- a/examples/job/README.md +++ b/examples/job/README.md @@ -10,6 +10,7 @@ For installing and running an agent inside a single sandbox, see [`../install-ag |--------|---------|-----------| | [`bash/`](./bash/) | `BashJobConfig` | Run an arbitrary shell script inside a sandbox (data processing, external evaluation tools) | | [`harbor/`](./harbor/) | `HarborJobConfig` | Run an AI agent benchmark task (SWE-bench, Terminal Bench, …) via the Harbor framework | +| [`observability/`](./observability/) | `BashJobConfig` | Show the exception observability layer (structured logs + optional OTLP metric) via a job that soft-fails on purpose | Both backends share a single `Job(config).run()` entrypoint — pick the config type based on your scenario. diff --git a/examples/job/observability/README.md b/examples/job/observability/README.md new file mode 100644 index 0000000000..aa58a4c070 --- /dev/null +++ b/examples/job/observability/README.md @@ -0,0 +1,76 @@ +# job/observability + +Demo for the **exception observability** layer in the Job SDK +(`rock.sdk.job.observability`): structured exception logs + an optional OTLP +metric, emitted automatically from every `JobExecutor` phase. + +You don't call the observability layer yourself — it's woven into the executor +via the `monitor_job_phase` decorator. Your only knobs are two env vars: + +| Env var | Default | Effect | +|---------|---------|--------| +| `ROCK_JOB_METRICS_OTLP_ENDPOINT` | unset | Unset → **log-only**. Set → also emit the OTLP counter `rock_job.exception.total`. | +| `ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS` | `true` | `false` → drop `job_name` / `sandbox_id` from the **metric** (they stay on the log). | + +On every job exception the SDK dual-writes: + +1. a structured `key=value` **ERROR log** — always, with the full label set, and +2. a counter increment on `rock_job.exception.total` — only when an OTLP endpoint is configured. + +Labels: `phase`, `severity`, `exception_type`, `trial_type`, `job_name`, +`experiment_id`, `namespace`, `sandbox_id`. + +## Two failure semantics + +| | What the SDK does | How you observe it | +|---|---|---| +| **soft fail** (script `exit != 0`, timeout, no output) | recorded as `severity=soft`, carried back in `TrialResult.exception_info` — `run()` does **not** raise | `JobResult.status == FAILED`, `TrialResult.exception_info` populated | +| **hard fail** (sandbox can't start, launch fails) | recorded as `severity=hard`, then **re-raised** (fail-fast across the batch) | `run()`/`wait()` raises | + +See [`docs/dev/job/exception-handling.md`](../../../docs/dev/job/exception-handling.md) for the full taxonomy. + +## Files + +| File | Purpose | +|------|---------| +| [`observability_demo.py`](./observability_demo.py) | Entry point — echoes the two knobs, loads `JobConfig.from_yaml()`, runs `Job(config).run()`, prints per-trial `exception_info` | +| [`observability_job_config.yaml.template`](./observability_job_config.yaml.template) | `BashJobConfig` whose script exits non-zero, producing one soft-fail event | + +## Quick run + +```bash +# 1. copy the template and fill in real values () +cp observability_job_config.yaml.template observability_job_config.yaml + +# 2. (optional) turn metrics ON — otherwise the demo is log-only +export ROCK_JOB_METRICS_OTLP_ENDPOINT="http://localhost:4318/v1/metrics" + +# 3. run +python observability_demo.py -c observability_job_config.yaml +``` + +The template's script does `exit 7`, so a clean run produces exactly one +soft-fail event. You'll see it twice in the output: + +``` +ERROR ... job exception phase=collect severity=soft exception_type=BashExitCode ... job_name=observability_demo ... +... +JobResult: status=failed completed=0 failed=1 exit_code=... + trial[0] ...: SOFT FAIL -> BashExitCode: ... +``` + +The ERROR log line is **always** emitted. The matching counter increment on +`rock_job.exception.total` is exported **only** when `ROCK_JOB_METRICS_OTLP_ENDPOINT` +is set. The demo calls `get_reporter().shutdown()` at the end to force a final +metric flush before this short-lived process exits (also registered via `atexit`, +so it's safe and idempotent). + +To see a **hard fail** (`run()` raises, `severity=hard`) instead, point +`base_url` at an unreachable admin so the sandbox can't start. + +## Wiring metrics to a collector + +`ROCK_JOB_METRICS_OTLP_ENDPOINT` is a standard OTLP/HTTP metrics endpoint. Point +it at any OpenTelemetry Collector (or a backend that accepts OTLP/HTTP). A quick +local collector that logs received metrics to stdout is enough to eyeball the +counter; then scrape/forward to Prometheus, etc., in your real deployment. diff --git a/examples/job/observability/observability_demo.py b/examples/job/observability/observability_demo.py new file mode 100644 index 0000000000..06366288c9 --- /dev/null +++ b/examples/job/observability/observability_demo.py @@ -0,0 +1,119 @@ +"""Observability demo for the ROCK Job SDK. + +Same shape as ``examples/job/harbor/harbor_demo.py`` — load a YAML config, run +``Job(config).run()``, print the per-trial results — but pointed at the +**exception observability** layer in ``rock.sdk.job.observability``. + +You do NOT call the observability layer yourself. It is woven automatically into +every ``JobExecutor`` phase (start / setup / launch / wait / collect). Your only +knobs are two environment variables: + + ROCK_JOB_METRICS_OTLP_ENDPOINT # unset -> log-only; set -> also emit OTLP counter + ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS # "false" -> drop job_name/sandbox_id from the metric + +On every job exception the SDK dual-writes: + 1. a structured ``key=value`` ERROR log (always, with full labels), and + 2. a counter ``rock_job.exception.total`` (only when an OTLP endpoint is set). + +Two failure semantics are surfaced (see docs/dev/job/exception-handling.md): + - soft fail (script exit != 0, timeout): carried back as data in + ``TrialResult.exception_info`` — run() does NOT raise. + - hard fail (sandbox can't start): re-raised out of run()/wait(). + +The bundled ``observability_job_config.yaml.template`` runs a script that exits +non-zero, so a clean run produces exactly one soft-fail event you can see in the +logs (and as a metric, if you set the OTLP endpoint). + +Usage: + cp observability_job_config.yaml.template observability_job_config.yaml + # fill in , then optionally turn metrics on: + export ROCK_JOB_METRICS_OTLP_ENDPOINT=http://localhost:4318/v1/metrics + python examples/job/observability/observability_demo.py -c observability_job_config.yaml +""" + +import argparse +import asyncio +import logging + +from rock import env_vars +from rock.sdk.job import Job, JobConfig, observability +from rock.sdk.job.result import JobResult + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) +# disable httpx +logging.getLogger("httpx").setLevel(logging.WARNING) + + +def print_observability_status() -> None: + """Echo the two knobs so it's obvious whether metrics are on.""" + endpoint = env_vars.ROCK_JOB_METRICS_OTLP_ENDPOINT + high_card = env_vars.ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS + logger.info("─" * 60) + logger.info("Job observability config:") + if endpoint: + logger.info(" metrics: ON -> OTLP counter to %s", endpoint) + else: + logger.info(" metrics: OFF -> log-only (set ROCK_JOB_METRICS_OTLP_ENDPOINT to enable)") + logger.info( + " high-card labels: %s (job_name/sandbox_id %s on the metric)", high_card, "kept" if high_card else "dropped" + ) + logger.info(" counter name: rock_job.exception.total") + logger.info("─" * 60) + + +def summarize(result: JobResult) -> None: + logger.info("─" * 60) + logger.info( + "JobResult: status=%s completed=%d failed=%d exit_code=%d", + result.status, + result.n_completed, + result.n_failed, + result.exit_code, + ) + for i, trial in enumerate(result.trial_results): + if trial.exception_info is not None: + # SOFT fail: surfaced as data, run() did not raise. + logger.info( + " trial[%d] %s: SOFT FAIL -> %s: %s", + i, + trial.task_name or "?", + trial.exception_info.exception_type, + trial.exception_info.exception_message, + ) + else: + logger.info(" trial[%d] %s: ok", i, trial.task_name or "?") + logger.info("─" * 60) + if result.n_failed: + logger.info("Each soft failure above was ALSO emitted as one ERROR log + one metric") + logger.info("increment, attributed to the 'collect' phase. Check the logs above.") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run a Bash job inside a ROCK sandbox and observe its exceptions") + parser.add_argument("-c", "--config", required=True, help="Path to BashJobConfig YAML file") + return parser.parse_args() + + +async def async_main(args: argparse.Namespace) -> None: + print_observability_status() + config = JobConfig.from_yaml(args.config) + try: + result = await Job(config).run() + except Exception: + # A HARD fail (e.g. sandbox can't start) propagates here and was already + # recorded by the observability layer (severity=hard) before re-raising. + logger.exception("Job hard-failed — observability recorded a severity=hard event for the failing phase") + raise + else: + summarize(result) + finally: + # Short-lived process: force a final metrics flush so the last buffered + # exceptions are exported before exit. atexit also calls this (idempotent), + # but doing it explicitly makes the flush deterministic in a demo. + observability.get_reporter().shutdown() + + +if __name__ == "__main__": + args = parse_args() + asyncio.run(async_main(args)) diff --git a/examples/job/observability/observability_job_config.yaml.template b/examples/job/observability/observability_job_config.yaml.template new file mode 100644 index 0000000000..314006ab07 --- /dev/null +++ b/examples/job/observability/observability_job_config.yaml.template @@ -0,0 +1,29 @@ +# BashJobConfig for the observability demo. +# +# The script below exits non-zero on purpose, so a clean run produces exactly +# one SOFT-FAIL event: one structured ERROR log (always) + one increment of the +# counter rock_job.exception.total (only if ROCK_JOB_METRICS_OTLP_ENDPOINT is set). +# +# Copy to observability_job_config.yaml and fill in the . + +# ── Job Identity ───────────────────────────────────── +experiment_id: "job_observability_demo" +labels: + demo: "observability" +timeout: 120 + +# ── Rock Environment ───────────────────────────────── +environment: + base_url: "" # e.g. http://localhost:8080 + xrl_authorization: "" + image: "" # e.g. python:3.11 + cluster: "" # optional + +# ── Bash Script ────────────────────────────────────── +# exit 7 -> BashExitCode soft fail, attributed to the 'collect' phase. +# Swap in `sleep 100000` (with a small timeout above) to see a ProcessTimeout soft fail instead. +script: | + echo "[trial] doing work..." + sleep 1 + echo "[trial] boom" + exit 7 diff --git a/rock/env_vars.py b/rock/env_vars.py index 11a9830613..2e14ad41e5 100644 --- a/rock/env_vars.py +++ b/rock/env_vars.py @@ -58,6 +58,10 @@ ROCK_MODEL_SERVICE_TRAJ_APPEND_MODE: bool | None = None ROCK_JOB_PROXY_REPLAY_FILE: str + # Client-side Job exception observability + ROCK_JOB_METRICS_OTLP_ENDPOINT: str | None = None + ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS: bool = True + # RuntimeEnv ROCK_RTENV_PYTHON_V31114_INSTALL_CMD: str ROCK_RTENV_PYTHON_V31212_INSTALL_CMD: str @@ -147,6 +151,12 @@ ), # Docker temp auth directory "ROCK_DOCKER_TEMP_AUTH_DIR": lambda: os.getenv("ROCK_DOCKER_TEMP_AUTH_DIR"), + # Client-side Job exception observability + "ROCK_JOB_METRICS_OTLP_ENDPOINT": lambda: os.getenv("ROCK_JOB_METRICS_OTLP_ENDPOINT"), + "ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS": lambda: os.getenv( + "ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS", "true" + ).lower() + == "true", } diff --git a/rock/sdk/job/executor.py b/rock/sdk/job/executor.py index 665a2426d8..40853af772 100644 --- a/rock/sdk/job/executor.py +++ b/rock/sdk/job/executor.py @@ -16,6 +16,7 @@ from rock.actions import CreateBashSessionRequest from rock.logger import init_logger +from rock.sdk.job.observability import monitor_job_phase from rock.sdk.job.operator import Operator from rock.sdk.job.result import TrialResult from rock.sdk.sandbox.client import Sandbox @@ -87,16 +88,27 @@ def _job_tmp_prefix(config: JobConfig) -> str: async def _do_submit(self, trial: AbstractTrial) -> TrialClient: """Start sandbox + execute script for a single trial.""" + sandbox = await self._phase_start(trial) + await self._phase_setup(trial, sandbox) + return await self._phase_launch(trial, sandbox) + + @monitor_job_phase("start") + async def _phase_start(self, trial: AbstractTrial) -> Sandbox: config = trial._config sandbox = Sandbox(config.environment) await sandbox.start() logger.info(f"Sandbox started: sandbox_id={sandbox.sandbox_id}, job_name={config.job_name}") + return sandbox + @monitor_job_phase("setup") + async def _phase_setup(self, trial: AbstractTrial, sandbox: Sandbox) -> None: # G4: let trial backfill config from sandbox state before setup await trial.on_sandbox_ready(sandbox) - await trial.setup(sandbox) + @monitor_job_phase("launch") + async def _phase_launch(self, trial: AbstractTrial, sandbox: Sandbox) -> TrialClient: + config = trial._config session = f"rock-job-{config.job_name or 'default'}" env = self._build_session_env(config) await sandbox.create_session(CreateBashSessionRequest(session=session, env_enable=True, env=env)) @@ -121,8 +133,11 @@ async def _do_submit(self, trial: AbstractTrial) -> TrialClient: async def _do_wait(self, client: TrialClient) -> TrialResult | list[TrialResult]: """Wait for a single trial to finish, call trial.collect().""" - from rock.sdk.job.result import ExceptionInfo + obs, success, message = await self._phase_wait(client) + return await self._phase_collect(client, obs, success, message) + @monitor_job_phase("wait") + async def _phase_wait(self, client: TrialClient): config = client.trial._config success, message = await client.sandbox.wait_for_process_completion( pid=client.pid, @@ -138,6 +153,13 @@ async def _do_wait(self, client: TrialClient) -> TrialResult | list[TrialResult] ignore_output=False, response_limited_bytes_in_nohup=None, ) + return obs, success, message + + @monitor_job_phase("collect") + async def _phase_collect(self, client: TrialClient, obs, success, message) -> TrialResult | list[TrialResult]: + from rock.sdk.job.result import ExceptionInfo + + config = client.trial._config exit_code = obs.exit_code if obs.exit_code is not None else 1 if obs.output: logger.info(f"Trial output (job={config.job_name}):\n{obs.output}") @@ -149,6 +171,9 @@ async def _do_wait(self, client: TrialClient) -> TrialResult | list[TrialResult] r.raw_output = obs.output or "" if r.exit_code == 0 and exit_code != 0: r.exit_code = exit_code + # Timeout is detected in _phase_wait but annotated here so the soft-fail + # is attributed to the "collect" phase (the decorator scans this method's + # final return value). if not success: fail_info = ExceptionInfo( exception_type="ProcessTimeout", diff --git a/rock/sdk/job/observability.py b/rock/sdk/job/observability.py new file mode 100644 index 0000000000..64d063684e --- /dev/null +++ b/rock/sdk/job/observability.py @@ -0,0 +1,254 @@ +"""Client-side exception observability for the SDK job module. + +Provides a lazily-initialized singleton ``JobMetricsReporter`` (opt-in OTLP +counter, log-only when no endpoint is configured) and the ``monitor_job_phase`` +decorator used by JobExecutor to emit one counter per exception with the +failing phase attributed. + +Metric: ``rock_job.exception.total`` (counter). +Labels: phase, severity, exception_type, trial_type, job_name, experiment_id, +namespace, sandbox_id. job_name/sandbox_id are dropped from the metric (not the +log) when ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS is false. +""" + +from __future__ import annotations + +import asyncio +import atexit +import functools +from collections.abc import Callable + +from rock import env_vars +from rock.logger import init_logger +from rock.sdk.job.result import TrialResult + +logger = init_logger(__name__) + +_COUNTER_NAME = "rock_job.exception.total" +_HIGH_CARDINALITY_KEYS = ("job_name", "sandbox_id") + + +def _fmt_labels(labels: dict[str, str]) -> str: + return " ".join(f"{k}={v}" for k, v in labels.items()) + + +class JobMetricsReporter: + """Dual-writes job exceptions to logs + (optionally) an OTLP counter.""" + + def __init__(self) -> None: + self._enabled = False + self._counter = None + self._reader = None + self._provider = None + endpoint = env_vars.ROCK_JOB_METRICS_OTLP_ENDPOINT + if endpoint: + self._init_otel(endpoint) + + def _init_otel(self, endpoint: str) -> None: + try: + from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader + + reader = PeriodicExportingMetricReader(OTLPMetricExporter(endpoint=endpoint)) + provider = MeterProvider(metric_readers=[reader]) + meter = provider.get_meter("rock.sdk.job") + self._counter = meter.create_counter( + name=_COUNTER_NAME, + description="Count of job execution exceptions (hard + soft)", + unit="1", + ) + self._reader = reader + self._provider = provider + self._enabled = True + atexit.register(self.shutdown) + logger.info("job metrics reporter enabled endpoint=%s", endpoint) + except Exception as e: # noqa: BLE001 — never let metrics setup break jobs + logger.warning("job metrics reporter init failed, log-only mode: %s", e) + self._enabled = False + self._counter = None + + def _filter_labels(self, labels: dict[str, str]) -> dict[str, str]: + if env_vars.ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS: + return labels + return {k: v for k, v in labels.items() if k not in _HIGH_CARDINALITY_KEYS} + + def record_exception(self, phase: str, exc_type: str, severity: str, labels: dict[str, str]) -> None: + # 1) structured log — always, always with full labels + logger.error( + "job exception phase=%s severity=%s exception_type=%s %s", + phase, + severity, + exc_type, + _fmt_labels(labels), + ) + # 2) metric — only when enabled; high-card filter happens here, not on the log + if self._enabled and self._counter is not None: + attrs = { + "phase": phase, + "severity": severity, + "exception_type": exc_type, + **self._filter_labels(labels), + } + self._counter.add(1, attrs) + + def shutdown(self) -> None: + """Flush buffered metrics before the process exits. + + PeriodicExportingMetricReader exports on an interval, so a short-lived + job can die with its last exceptions still buffered. Registered via + atexit so the provider force-flushes on a clean interpreter exit. + Idempotent and exception-isolated: teardown must never raise. + """ + provider = self._provider + if provider is None: + return + self._provider = None + self._enabled = False + try: + provider.shutdown() + except Exception: # noqa: BLE001 — never let metrics teardown break exit + logger.warning("job metrics reporter shutdown failed", exc_info=True) + + +_REPORTER: JobMetricsReporter | None = None + + +def get_reporter() -> JobMetricsReporter: + """Return the process-wide singleton reporter (lazy-init on first use).""" + global _REPORTER + if _REPORTER is None: + _REPORTER = JobMetricsReporter() + return _REPORTER + + +def _trial_type(trial) -> str: + name = type(trial).__name__ + base = name[:-5] if name.endswith("Trial") else name + return base.lower().lstrip("_") + + +def _extract_labels(args: tuple) -> dict[str, str]: + """Best-effort label extraction from a decorated method's positional args. + + Duck-typed (no imports of AbstractTrial/TrialClient/Sandbox) to avoid + circular imports: + - TrialClient : has both ``.trial`` and ``.sandbox`` + - AbstractTrial: has ``._config`` + - Sandbox : has ``.sandbox_id`` + """ + trial = None + sandbox = None + for a in args: + if hasattr(a, "trial") and hasattr(a, "sandbox"): + trial = a.trial + sandbox = a.sandbox + elif hasattr(a, "_config"): + trial = a + elif hasattr(a, "sandbox_id"): + sandbox = a + # Stop once we have a full (trial, sandbox) pair so later positional + # args (e.g. the nohup observation) can't clobber them. + if trial is not None and sandbox is not None: + break + + labels = { + "trial_type": "unknown", + "job_name": "unknown", + "experiment_id": "unknown", + "namespace": "unknown", + "sandbox_id": "unknown", + } + if trial is not None: + cfg = trial._config + labels["trial_type"] = _trial_type(trial) + labels["job_name"] = str(getattr(cfg, "job_name", None) or "unknown") + labels["experiment_id"] = str(getattr(cfg, "experiment_id", None) or "unknown") + labels["namespace"] = str(getattr(cfg, "namespace", None) or "unknown") + if sandbox is not None: + labels["sandbox_id"] = str(getattr(sandbox, "sandbox_id", None) or "unknown") + return labels + + +def _emit_soft(reporter, phase: str, result, args: tuple) -> None: + """Emit one soft event per returned TrialResult carrying exception_info. + + Labels are extracted lazily, only when a soft fail is actually present, so + the success path pays no extraction cost. + """ + items = result if isinstance(result, list) else [result] + labels = None + for r in items: + if isinstance(r, TrialResult) and r.exception_info is not None: + if labels is None: + labels = _extract_labels(args) + reporter.record_exception( + phase=phase, + exc_type=r.exception_info.exception_type or "unknown", + severity="soft", + labels=labels, + ) + + +def _safe_record_hard(reporter, phase: str, exc: BaseException, args: tuple) -> None: + """Record a hard failure. Observability must never break a job, so any + error in extraction/recording is swallowed and logged.""" + try: + reporter.record_exception(phase, type(exc).__name__, "hard", _extract_labels(args)) + except Exception: + logger.warning("job observability failed recording hard exception for phase=%s", phase, exc_info=True) + + +def _safe_emit_soft(reporter, phase: str, result, args: tuple) -> None: + """Emit soft failures. Swallows and logs any observability error so it can + never break a job.""" + try: + _emit_soft(reporter, phase, result, args) + except Exception: + logger.warning("job observability failed emitting soft exception for phase=%s", phase, exc_info=True) + + +def monitor_job_phase(phase: str) -> Callable: + """Decorate a JobExecutor phase method to emit log+metric on exceptions. + + - Hard fail (method raises): record severity="hard" then re-raise. + - Soft fail (method returns TrialResult(s) carrying exception_info): + record severity="soft" for each. + Only the leaf phase methods are decorated; outer _do_submit/_do_wait are + NOT decorated, so a re-raised hard fail is counted exactly once. + + All observability side-effects are exception-isolated: a bug in label + extraction or recording can never turn a successful phase into a failure, + and the original wrapped exception is always re-raised unchanged. + """ + + def deco(f): + if asyncio.iscoroutinefunction(f): + + @functools.wraps(f) + async def awrapper(self, *args, **kwargs): + reporter = get_reporter() + try: + result = await f(self, *args, **kwargs) + except Exception as e: + _safe_record_hard(reporter, phase, e, args) + raise + _safe_emit_soft(reporter, phase, result, args) + return result + + return awrapper + + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + reporter = get_reporter() + try: + result = f(self, *args, **kwargs) + except Exception as e: + _safe_record_hard(reporter, phase, e, args) + raise + _safe_emit_soft(reporter, phase, result, args) + return result + + return wrapper + + return deco diff --git a/tests/unit/sdk/job/test_executor.py b/tests/unit/sdk/job/test_executor.py index 18091aa84d..0363fd2146 100644 --- a/tests/unit/sdk/job/test_executor.py +++ b/tests/unit/sdk/job/test_executor.py @@ -8,6 +8,7 @@ import rock.sdk.job.trial.bash # register BashJobConfig -> BashTrial # noqa: F401 from rock.sdk.bench.constants import USER_DEFINED_LOGS +from rock.sdk.job import observability from rock.sdk.job.config import BashJobConfig from rock.sdk.job.executor import JobClient, JobExecutor, TrialClient from rock.sdk.job.operator import ScatterOperator @@ -269,3 +270,137 @@ async def test_do_submit_calls_on_sandbox_ready_after_start(self): trial.on_sandbox_ready.assert_awaited_once_with(mock_sandbox) # Must be called AFTER sandbox.start() assert mock_sandbox.start.call_count == 1 + + +class _RecordingReporter: + def __init__(self): + self.events = [] + + def record_exception(self, phase, exc_type, severity, labels): + self.events.append((phase, exc_type, severity, labels)) + + +class TestSubmitPhaseAttribution: + async def test_setup_failure_attributed_to_setup_phase(self, monkeypatch): + from rock.sdk.job.trial.bash import BashTrial + + fake = _RecordingReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + + mock_sandbox = _make_mock_sandbox() + with patch("rock.sdk.job.executor.Sandbox", return_value=mock_sandbox): + executor = JobExecutor() + trial = BashTrial(BashJobConfig(script="echo hi", job_name="t")) + trial.setup = AsyncMock(side_effect=ValueError("setup boom")) + with pytest.raises(ValueError, match="setup boom"): + await executor._do_submit(trial) + + hard = [e for e in fake.events if e[2] == "hard"] + assert len(hard) == 1 + assert hard[0][0] == "setup" + assert hard[0][1] == "ValueError" + + async def test_launch_failure_attributed_to_launch_phase(self, monkeypatch): + from rock.sdk.job.trial.bash import BashTrial + + fake = _RecordingReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + + mock_sandbox = _make_mock_sandbox() + error_obs = MagicMock() + error_obs.output = "boom" + mock_sandbox.start_nohup_process = AsyncMock(return_value=(None, error_obs)) + with patch("rock.sdk.job.executor.Sandbox", return_value=mock_sandbox): + executor = JobExecutor() + trial = BashTrial(BashJobConfig(script="echo hi", job_name="t")) + with pytest.raises(RuntimeError, match="Failed to start trial"): + await executor._do_submit(trial) + + hard = [e for e in fake.events if e[2] == "hard"] + assert len(hard) == 1 + assert hard[0][0] == "launch" + assert hard[0][1] == "RuntimeError" + + +class TestWaitPhaseAttribution: + async def test_soft_fail_emitted_in_collect_phase(self, monkeypatch): + # exit_code != 0 -> BashTrial.collect sets BashExitCode exception_info (soft) + fake = _RecordingReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + + mock_sandbox = _make_mock_sandbox() + nohup_obs = MagicMock() + nohup_obs.output = "bad" + nohup_obs.exit_code = 1 + mock_sandbox.handle_nohup_output = AsyncMock(return_value=nohup_obs) + + with patch("rock.sdk.job.executor.Sandbox", return_value=mock_sandbox): + executor = JobExecutor() + results = await executor.run(ScatterOperator(size=1), BashJobConfig(script="x", job_name="t")) + + assert results[0].exception_info is not None + soft = [e for e in fake.events if e[2] == "soft"] + assert len(soft) == 1 + assert soft[0][0] == "collect" + assert soft[0][1] == "BashExitCode" + + async def test_process_timeout_soft_fail_in_collect_phase(self, monkeypatch): + fake = _RecordingReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + + mock_sandbox = _make_mock_sandbox() + mock_sandbox.wait_for_process_completion = AsyncMock(return_value=(False, "timeout")) + + with patch("rock.sdk.job.executor.Sandbox", return_value=mock_sandbox): + executor = JobExecutor() + results = await executor.run(ScatterOperator(size=1), BashJobConfig(script="x", job_name="t")) + + assert results[0].exception_info is not None + soft = [e for e in fake.events if e[2] == "soft"] + assert len(soft) == 1 + assert soft[0][0] == "collect" + assert soft[0][1] == "ProcessTimeout" + + +class TestObservabilityEndToEnd: + async def test_success_path_emits_no_events(self, monkeypatch): + fake = _RecordingReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + mock_sandbox = _make_mock_sandbox() + with patch("rock.sdk.job.executor.Sandbox", return_value=mock_sandbox): + executor = JobExecutor() + await executor.run(ScatterOperator(size=1), BashJobConfig(script="echo hi", job_name="t")) + assert fake.events == [] + + async def test_labels_contain_sandbox_and_job_identity(self, monkeypatch): + fake = _RecordingReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + mock_sandbox = _make_mock_sandbox() + nohup_obs = MagicMock() + nohup_obs.output = "bad" + nohup_obs.exit_code = 1 + mock_sandbox.handle_nohup_output = AsyncMock(return_value=nohup_obs) + with patch("rock.sdk.job.executor.Sandbox", return_value=mock_sandbox): + executor = JobExecutor() + await executor.run(ScatterOperator(size=1), BashJobConfig(script="x", job_name="job-xyz")) + soft = [e for e in fake.events if e[2] == "soft"] + assert len(soft) == 1 + labels = soft[0][3] + assert labels["job_name"] == "job-xyz" + assert labels["sandbox_id"] == "sb-test" + assert labels["trial_type"] == "bash" + + async def test_hard_fail_counted_exactly_once(self, monkeypatch): + from rock.sdk.job.trial.bash import BashTrial + + fake = _RecordingReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + mock_sandbox = _make_mock_sandbox() + with patch("rock.sdk.job.executor.Sandbox", return_value=mock_sandbox): + executor = JobExecutor() + trial = BashTrial(BashJobConfig(script="x", job_name="t")) + trial.setup = AsyncMock(side_effect=ValueError("boom")) + with pytest.raises(ValueError): + await executor._do_submit(trial) + # exactly one hard event despite re-raise propagating through _do_submit + assert len([e for e in fake.events if e[2] == "hard"]) == 1 diff --git a/tests/unit/sdk/job/test_observability.py b/tests/unit/sdk/job/test_observability.py new file mode 100644 index 0000000000..e6c5dc9394 --- /dev/null +++ b/tests/unit/sdk/job/test_observability.py @@ -0,0 +1,234 @@ +"""Tests for rock.sdk.job.observability — reporter, decorator, env config.""" + +from __future__ import annotations + +import pytest + +from rock import env_vars +from rock.sdk.job import observability +from rock.sdk.job.observability import JobMetricsReporter, monitor_job_phase +from rock.sdk.job.result import ExceptionInfo, TrialResult + + +class TestEnvVars: + def test_otlp_endpoint_defaults_to_none(self, monkeypatch): + monkeypatch.delenv("ROCK_JOB_METRICS_OTLP_ENDPOINT", raising=False) + assert env_vars.ROCK_JOB_METRICS_OTLP_ENDPOINT is None + + def test_otlp_endpoint_reads_env(self, monkeypatch): + monkeypatch.setenv("ROCK_JOB_METRICS_OTLP_ENDPOINT", "http://otlp:4318/v1/metrics") + assert env_vars.ROCK_JOB_METRICS_OTLP_ENDPOINT == "http://otlp:4318/v1/metrics" + + def test_high_cardinality_defaults_true(self, monkeypatch): + monkeypatch.delenv("ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS", raising=False) + assert env_vars.ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS is True + + def test_high_cardinality_false_when_set_false(self, monkeypatch): + monkeypatch.setenv("ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS", "false") + assert env_vars.ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS is False + + +class _FakeCounter: + def __init__(self): + self.calls = [] + + def add(self, value, attributes): + self.calls.append((value, attributes)) + + +class TestReporter: + def test_disabled_when_no_endpoint(self, monkeypatch): + monkeypatch.delenv("ROCK_JOB_METRICS_OTLP_ENDPOINT", raising=False) + reporter = JobMetricsReporter() + assert reporter._enabled is False + + def test_record_exception_logs_when_disabled(self, monkeypatch, caplog): + monkeypatch.delenv("ROCK_JOB_METRICS_OTLP_ENDPOINT", raising=False) + reporter = JobMetricsReporter() + import logging + + # init_logger() sets propagate=False; caplog listens on the root logger, + # so enable propagation for this assertion to capture the record. + monkeypatch.setattr(observability.logger, "propagate", True) + with caplog.at_level(logging.ERROR, logger="rock.sdk.job.observability"): + reporter.record_exception("setup", "ValueError", "hard", {"job_name": "j1"}) + assert any("phase=setup" in r.message for r in caplog.records) + assert any("severity=hard" in r.message for r in caplog.records) + + def test_record_exception_emits_counter_when_enabled(self, monkeypatch): + reporter = JobMetricsReporter() + reporter._enabled = True + reporter._counter = _FakeCounter() + reporter.record_exception("collect", "BashExitCode", "soft", {"job_name": "j1", "sandbox_id": "sb1"}) + assert len(reporter._counter.calls) == 1 + value, attrs = reporter._counter.calls[0] + assert value == 1 + assert attrs["phase"] == "collect" + assert attrs["severity"] == "soft" + assert attrs["exception_type"] == "BashExitCode" + assert attrs["job_name"] == "j1" + assert attrs["sandbox_id"] == "sb1" + + def test_high_cardinality_labels_filtered_when_disabled(self, monkeypatch): + monkeypatch.setenv("ROCK_JOB_METRICS_HIGH_CARDINALITY_LABELS", "false") + reporter = JobMetricsReporter() + reporter._enabled = True + reporter._counter = _FakeCounter() + reporter.record_exception("collect", "X", "soft", {"job_name": "j1", "sandbox_id": "sb1", "namespace": "ns"}) + _, attrs = reporter._counter.calls[0] + assert "job_name" not in attrs + assert "sandbox_id" not in attrs + assert attrs["namespace"] == "ns" + + def test_get_reporter_is_singleton(self, monkeypatch): + monkeypatch.setattr(observability, "_REPORTER", None) + r1 = observability.get_reporter() + r2 = observability.get_reporter() + assert r1 is r2 + + +class _FakeProvider: + def __init__(self, raises=False): + self.shutdown_calls = 0 + self._raises = raises + + def shutdown(self): + self.shutdown_calls += 1 + if self._raises: + raise RuntimeError("provider shutdown blew up") + + +class TestShutdown: + def test_shutdown_flushes_provider_and_disables(self, monkeypatch): + monkeypatch.delenv("ROCK_JOB_METRICS_OTLP_ENDPOINT", raising=False) + reporter = JobMetricsReporter() + provider = _FakeProvider() + reporter._provider = provider + reporter._enabled = True + reporter.shutdown() + assert provider.shutdown_calls == 1 + assert reporter._enabled is False + + def test_shutdown_is_idempotent(self, monkeypatch): + monkeypatch.delenv("ROCK_JOB_METRICS_OTLP_ENDPOINT", raising=False) + reporter = JobMetricsReporter() + provider = _FakeProvider() + reporter._provider = provider + reporter.shutdown() + reporter.shutdown() + assert provider.shutdown_calls == 1 + + def test_shutdown_swallows_provider_errors(self, monkeypatch): + monkeypatch.delenv("ROCK_JOB_METRICS_OTLP_ENDPOINT", raising=False) + reporter = JobMetricsReporter() + reporter._provider = _FakeProvider(raises=True) + # must never raise — observability cannot break process teardown + reporter.shutdown() + assert reporter._enabled is False + + def test_shutdown_noop_without_provider(self, monkeypatch): + monkeypatch.delenv("ROCK_JOB_METRICS_OTLP_ENDPOINT", raising=False) + reporter = JobMetricsReporter() + assert reporter._provider is None + reporter.shutdown() # no provider -> clean no-op + + +class _FakeReporter: + def __init__(self): + self.events = [] # (phase, exc_type, severity, labels) + + def record_exception(self, phase, exc_type, severity, labels): + self.events.append((phase, exc_type, severity, labels)) + + +class _FakeTrial: + """Duck-typed AbstractTrial: only needs ._config.""" + + def __init__(self, config): + self._config = config + + +class _Cfg: + def __init__(self, job_name="j", experiment_id="e", namespace="n"): + self.job_name = job_name + self.experiment_id = experiment_id + self.namespace = namespace + + +class _Holder: + """Stand-in for JobExecutor — hosts decorated methods.""" + + @monitor_job_phase("setup") + async def boom_async(self, trial): + raise ValueError("nope") + + @monitor_job_phase("collect") + async def soft_async(self, trial): + return TrialResult(task_name="t", exception_info=ExceptionInfo(exception_type="BashExitCode")) + + @monitor_job_phase("collect") + async def ok_async(self, trial): + return TrialResult(task_name="t") + + @monitor_job_phase("launch") + def boom_sync(self, trial): + raise RuntimeError("sync-nope") + + +class TestDecorator: + async def test_hard_fail_emits_once_and_reraises(self, monkeypatch): + fake = _FakeReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + h = _Holder() + with pytest.raises(ValueError, match="nope"): + await h.boom_async(_FakeTrial(_Cfg(job_name="j1"))) + assert len(fake.events) == 1 + phase, exc_type, severity, labels = fake.events[0] + assert phase == "setup" + assert exc_type == "ValueError" + assert severity == "hard" + assert labels["job_name"] == "j1" + assert labels["trial_type"] == "fake" # _FakeTrial -> "fake" + + async def test_soft_fail_emitted_from_return_value(self, monkeypatch): + fake = _FakeReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + h = _Holder() + await h.soft_async(_FakeTrial(_Cfg())) + assert len(fake.events) == 1 + phase, exc_type, severity, _ = fake.events[0] + assert phase == "collect" + assert exc_type == "BashExitCode" + assert severity == "soft" + + async def test_clean_return_emits_nothing(self, monkeypatch): + fake = _FakeReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + h = _Holder() + await h.ok_async(_FakeTrial(_Cfg())) + assert fake.events == [] + + def test_sync_hard_fail_emits(self, monkeypatch): + fake = _FakeReporter() + monkeypatch.setattr(observability, "_REPORTER", fake) + h = _Holder() + with pytest.raises(RuntimeError, match="sync-nope"): + h.boom_sync(_FakeTrial(_Cfg())) + assert len(fake.events) == 1 + assert fake.events[0][0] == "launch" + assert fake.events[0][2] == "hard" + + async def test_broken_reporter_does_not_break_job(self, monkeypatch): + class _ExplodingReporter: + def record_exception(self, *a, **k): + raise RuntimeError("reporter blew up") + + monkeypatch.setattr(observability, "_REPORTER", _ExplodingReporter()) + h = _Holder() + # soft path: a broken reporter must NOT propagate out of the phase + result = await h.soft_async(_FakeTrial(_Cfg())) + assert result.exception_info is not None # job result still returned + + # hard path: the ORIGINAL exception must still propagate (not the reporter's) + with pytest.raises(ValueError, match="nope"): + await h.boom_async(_FakeTrial(_Cfg()))