From 5a0f837c433749adb6de1ecacb04fbfc2d037d4c Mon Sep 17 00:00:00 2001 From: Jiachen Zhang Date: Fri, 22 May 2026 18:06:55 +0800 Subject: [PATCH 01/10] fix(admin): retry SandboxTable ops once on stale connection after DB restart (#987) (#997) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test(admin): add SandboxTable reconnect tests with real PG process restart Covers the scenario where the postgres process is killed and restarted inside a running container (pg_ctl stop/start), leaving the container port stable but invalidating existing connections. pool_pre_ping=False forces the decorator — not the pool — to handle recovery. * fix(admin): retry SandboxTable ops once on stale connection after DB restart Adds _retry_on_disconnect decorator applied to all six SandboxTable methods. Retries once when DBAPIError.connection_invalidated is True, which SQLAlchemy sets when asyncpg detects "connection is closed" — meaning the query never executed and is safe to retry. Addresses stale connections caused by DB process restart or NAT idle timeout dropping the TCP connection. * test(admin): simulate 3s PG outage to enforce back-off requirement A bare single-attempt retry fires immediately after the DB stops and finds it still down. Only a retry strategy with cumulative back-off exceeding the 3-second outage window can bridge the gap. This makes the test RED against the old no-sleep implementation and GREEN once sufficient exponential back-off is in place. * fix(admin): retry SandboxTable ops with exponential back-off across DB outages The retry decorator now spans both failure modes seen during a PG restart: 1. statement-execution path - an already-checked-out connection goes stale and asyncpg raises sqlalchemy.exc.InterfaceError / OperationalError (DBAPIError subclasses, wrapped by SQLAlchemy's _handle_dbapi_exception). 2. connect path - the pool tries to dial a fresh connection while PG is still down; asyncpg raises ConnectionError / ConnectionResetError / OSError directly. SQLAlchemy does NOT wrap connect-path failures into DBAPIError, so the previous "except DBAPIError" missed this path entirely - retries fired only on the first stale-connection error and then crashed on the second attempt's connect failure. Exception set: (OperationalError, InterfaceError, DisconnectionError, ConnectionError, OSError, asyncio.TimeoutError) Excluded on purpose: DatabaseError - it would swallow IntegrityError / ProgrammingError / DataError, all permanent failures that must fast-fail. ATTEMPTS=4 with exponential back-off (1s, 2s, 4s) gives a cumulative 7s window, sufficient to bridge typical PG process-restart outages. (cherry picked from commit f8b456dd5bd464ac54703a5e8336752c3c9d1da5) Signed-off-by: Jiachen Zhang --- rock/admin/core/sandbox_table.py | 58 ++++++ .../core/test_sandbox_table_reconnect.py | 170 ++++++++++++++++++ 2 files changed, 228 insertions(+) create mode 100644 tests/unit/admin/core/test_sandbox_table_reconnect.py diff --git a/rock/admin/core/sandbox_table.py b/rock/admin/core/sandbox_table.py index fb1d0c1949..3db3c939eb 100644 --- a/rock/admin/core/sandbox_table.py +++ b/rock/admin/core/sandbox_table.py @@ -2,9 +2,12 @@ from __future__ import annotations +import asyncio +import functools from typing import TYPE_CHECKING, Any from sqlalchemy import select +from sqlalchemy.exc import DisconnectionError, InterfaceError, OperationalError from sqlalchemy.ext.asyncio import AsyncSession from rock.admin.core.db_provider import DatabaseProvider @@ -21,6 +24,55 @@ logger = init_logger(__name__) +_DISCONNECT_RETRY_ATTEMPTS = 4 + +# Exceptions retried with exponential back-off across DB outages. +# - OperationalError / InterfaceError: SQLAlchemy-wrapped runtime connection +# problems on the statement-execution path (stale connection, server gone, +# socket-level failures observed mid-query). +# - DisconnectionError: explicit pool-level "connection is invalid" signal. +# - OSError / ConnectionError / asyncio.TimeoutError: asyncpg's connect path +# raises these directly; SQLAlchemy does NOT wrap them into DBAPIError +# because they fire before a statement is ever issued. Without catching +# them here, retries cannot bridge a multi-second PG restart window. +# Excluded on purpose: DatabaseError (would swallow IntegrityError, +# DataError, ProgrammingError — all permanent failures that must fast-fail). +_RETRY_EXCEPTIONS: tuple[type[BaseException], ...] = ( + OperationalError, + InterfaceError, + DisconnectionError, + ConnectionError, + OSError, + asyncio.TimeoutError, +) + + +def _retry_on_disconnect(func): + """Retry up to _DISCONNECT_RETRY_ATTEMPTS times across DB outages.""" + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + last_exc: BaseException | None = None + for attempt in range(1, _DISCONNECT_RETRY_ATTEMPTS + 1): + try: + return await func(*args, **kwargs) + except _RETRY_EXCEPTIONS as exc: + last_exc = exc + logger.warning( + "DB connection lost on %s (attempt %d/%d): %r", + func.__name__, + attempt, + _DISCONNECT_RETRY_ATTEMPTS, + exc, + ) + if attempt < _DISCONNECT_RETRY_ATTEMPTS: + await asyncio.sleep(1.0 * 2 ** (attempt - 1)) + assert last_exc is not None + raise last_exc + + return wrapper + + class SandboxTable: """Sandbox-specific database access layer backed by DatabaseProvider. @@ -46,6 +98,7 @@ def __init__(self, db_provider: DatabaseProvider, rock_config: RockConfig | None metric_prefix="meta_store.db", ) + @_retry_on_disconnect @monitor_metastore_operation async def create( self, @@ -76,6 +129,7 @@ async def create( session.add(record) await session.commit() + @_retry_on_disconnect @monitor_metastore_operation async def get(self, sandbox_id: str) -> dict | None: """Return a sandbox row as a plain dict, or ``None`` if not found.""" @@ -85,6 +139,7 @@ async def get(self, sandbox_id: str) -> dict | None: return None return record.to_dict() + @_retry_on_disconnect @monitor_metastore_operation async def update(self, sandbox_id: str, info: SandboxInfo) -> None: """Partial update of scalar columns; always overwrites ``status`` with *info*.""" @@ -100,6 +155,7 @@ async def update(self, sandbox_id: str, info: SandboxInfo) -> None: setattr(record, key, value) await session.commit() + @_retry_on_disconnect @monitor_metastore_operation async def delete(self, sandbox_id: str) -> None: """Hard-delete a sandbox record.""" @@ -109,6 +165,7 @@ async def delete(self, sandbox_id: str) -> None: await session.delete(record) await session.commit() + @_retry_on_disconnect @monitor_metastore_operation async def list_by(self, column: str, value: str | int | float | bool) -> list[dict]: """Equality query on a single column. Only columns in ``SandboxRecord.LIST_BY_ALLOWLIST`` are permitted.""" @@ -120,6 +177,7 @@ async def list_by(self, column: str, value: str | int | float | bool) -> list[di result = await session.execute(stmt) return [r.to_dict() for r in result.scalars().all()] + @_retry_on_disconnect @monitor_metastore_operation async def list_by_in(self, column: str, values: list[str | int | float | bool]) -> list[dict]: """IN query on a single column. Only columns in ``SandboxRecord.LIST_BY_ALLOWLIST`` are permitted.""" diff --git a/tests/unit/admin/core/test_sandbox_table_reconnect.py b/tests/unit/admin/core/test_sandbox_table_reconnect.py new file mode 100644 index 0000000000..b4aafe6ab8 --- /dev/null +++ b/tests/unit/admin/core/test_sandbox_table_reconnect.py @@ -0,0 +1,170 @@ +"""SandboxTable reconnect tests — PostgreSQL process restart inside the container. + +Setup +----- +PID 1 of the container is ``sh`` blocked on ``sleep infinity``. +postgres runs as a background child. ``pg_ctl stop / start`` restarts the +postgres process without touching the container, so the host port stays stable +and data is preserved (same PGDATA directory, new process). +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from rock.admin.core.sandbox_table import SandboxTable + +_PGUSER = "test" +_PGPASS = "test" +_PGDB = "testdb" +_PGDATA = "/var/lib/postgresql/data" + + +def _wait_pg_ready_sql(container, user: str, db: str, timeout: int = 30) -> None: + """Two-stage wait: pg_isready (socket up) then SELECT 1 (queries accepted). + + Mirrors the logic in tests/unit/conftest.py::pg_container to close the + startup race window between the socket accepting and WAL replay finishing. + """ + import time + + deadline = time.time() + timeout + while time.time() < deadline: + code, _ = container.exec_run(f"pg_isready -U {user}") + if code == 0: + code, _ = container.exec_run(f'psql -U {user} -d {db} -c "SELECT 1"') + if code == 0: + return + time.sleep(0.5) + raise TimeoutError(f"PostgreSQL did not become ready within {timeout}s") + + +@pytest.mark.need_docker +class TestSandboxTablePgProcessRestart: + """PostgreSQL process restart inside a running container. + + pool_pre_ping=False so the pool does NOT silently reconnect — the + @_retry_on_disconnect decorator must handle recovery. + """ + + @pytest.fixture + def restartable_pg(self): + """Container where postgres runs as a background child of PID 1 (sleep infinity).""" + import socket + import uuid + + import docker + + client = docker.from_env() + name = f"rock-test-pg-proc-{uuid.uuid4().hex[:8]}" + + hostname = socket.gethostname() + try: + current = client.containers.get(hostname) + networks = current.attrs["NetworkSettings"]["Networks"] + network_name = "bridge" if "bridge" in networks else next(iter(networks), None) + except Exception: + network_name = None + + env = {"POSTGRES_USER": _PGUSER, "POSTGRES_PASSWORD": _PGPASS, "POSTGRES_DB": _PGDB} + run_kwargs = { + "image": "postgres:16-alpine", + "name": name, + "detach": True, + "environment": env, + "entrypoint": ["sh", "-c"], + # Single-element list: Docker passes the whole string as argv[1] to sh -c. + # A plain string would be split on spaces, breaking the & operator. + "command": ["docker-entrypoint.sh postgres & sleep infinity"], + } + if network_name: + run_kwargs["network"] = network_name + else: + run_kwargs["ports"] = {"5432/tcp": None} + + container = client.containers.run(**run_kwargs) + try: + _wait_pg_ready_sql(container, _PGUSER, _PGDB) + container.reload() + if network_name: + host = container.attrs["NetworkSettings"]["Networks"][network_name]["IPAddress"] + port = 5432 + else: + host = "127.0.0.1" + port = int(container.ports["5432/tcp"][0]["HostPort"]) + + yield { + "container": container, + "url": f"postgresql://{_PGUSER}:{_PGPASS}@{host}:{port}/{_PGDB}", + } + finally: + try: + container.stop(timeout=5) + container.remove() + except Exception: + pass + + @pytest.fixture + async def table(self, restartable_pg): + """pool_size=1, pool_pre_ping=False — decorator must handle stale connections.""" + from sqlalchemy.ext.asyncio import create_async_engine + + from rock.admin.core.schema import Base + + url = restartable_pg["url"].replace("postgresql://", "postgresql+asyncpg://") + engine = create_async_engine( + url, + pool_size=1, + max_overflow=0, + pool_pre_ping=False, + connect_args={"statement_cache_size": 0}, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + provider = MagicMock() + provider.engine = engine + t = SandboxTable(provider) + yield t + await engine.dispose() + + _OUTAGE_SECONDS = 4 + + def _do_pg_restart_blocking(self, restartable_pg) -> None: + """Stop, hold the outage open for _OUTAGE_SECONDS, then start. + + Runs in an executor so the asyncio loop stays free to drive the + retry/back-off path inside SandboxTable while PG is down. + Cumulative back-off in the production wrapper (1+2+4+8 = 15s) must + exceed _OUTAGE_SECONDS for the test to pass; a no-sleep retry fires + immediately, finds the DB still down, and fails. + """ + import time + + container = restartable_pg["container"] + container.exec_run(f"su postgres -c 'pg_ctl stop -D {_PGDATA} -m fast'") + time.sleep(self._OUTAGE_SECONDS) + container.exec_run(f"su postgres -c 'pg_ctl start -D {_PGDATA} -l /tmp/pg.log'") + _wait_pg_ready_sql(container, _PGUSER, _PGDB) + + async def test_retry_recovers_after_pg_restart(self, table, restartable_pg): + import asyncio + + await table.create("pgr-1", {"user_id": "bob", "create_time": "2025-01-01T00:00:00Z"}) + await table.list_by_in("sandbox_id", ["pgr-1"]) # warm pool + + # Kick off the outage in the background so the query below races against it. + restart_task = asyncio.create_task(asyncio.to_thread(self._do_pg_restart_blocking, restartable_pg)) + + # Let `pg_ctl stop` actually land and the asyncpg reader observe the RST + # before issuing the query; the query must therefore traverse the outage + # window via the retry decorator's back-off. + await asyncio.sleep(0.5) + + result = await table.list_by_in("sandbox_id", ["pgr-1"]) + + await restart_task + assert len(result) == 1 + assert result[0]["sandbox_id"] == "pgr-1" From 2c5d2455cf4f0f5c64a041e598565de6b1cff0c6 Mon Sep 17 00:00:00 2001 From: jinbai340997 <15652831212@163.com> Date: Fri, 22 May 2026 22:30:10 +0800 Subject: [PATCH 02/10] deduplicate region scheduler.tasks via base config inheritance Add `_base: ` resolution in RockConfig.from_env() with deep merge support for dicts and identity-keyed lists. Multi-region YAML configs can now factor out a single base file; previously the `_base` key was silently dropped by the kwargs whitelist, leading to dataclass-default fallbacks (e.g. redis port=0) at runtime. fixes #1004 --- rock/config.py | 70 +++++++++ tests/unit/test_config.py | 294 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 364 insertions(+) diff --git a/rock/config.py b/rock/config.py index cec3899e7a..b15b5336d9 100644 --- a/rock/config.py +++ b/rock/config.py @@ -357,6 +357,17 @@ def from_env(cls, config_path: str | None = None): with open(config_file) as f: config = yaml.safe_load(f) + # Handle _base config inheritance + if "_base" in config: + base_path = Path(config.pop("_base")) + if not base_path.is_absolute(): + base_path = config_file.parent / base_path + if not base_path.exists(): + raise Exception(f"base config file {base_path} not found") + with open(base_path) as f: + base_config = yaml.safe_load(f) + config = cls._deep_merge(base_config, config) + # Convert nested dictionaries to dataclass objects kwargs = {} if "ray" in config: @@ -385,6 +396,65 @@ def from_env(cls, config_path: str | None = None): return cls(**kwargs) + # ============================================================================ + # Merging Rules: + # 1. Dictionary elements within the list are matched based on their `task_class`. + # 2. Matched elements: Regional configurations are deep-merged to override the base library (applied at the field level, not as a complete replacement). + # 3. Unmatched base tasks: Retained as-is. + # 4. Newly added regional tasks: Appended to the list. + # 5. To "disable" a specific base task within a region: Set `enabled: false`. + # ============================================================================ + + @staticmethod + def _deep_merge(base: dict, override: dict) -> dict: + """Deep merge override into base. Override values take precedence.""" + result = base.copy() + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = RockConfig._deep_merge(result[key], value) + elif key in result and isinstance(result[key], list) and isinstance(value, list): + result[key] = RockConfig._merge_lists(result[key], value) + else: + result[key] = value + return result + + @staticmethod + def _merge_lists(base_list: list, override_list: list) -> list: + """Merge two lists. For lists of dicts with 'task_class' key, merge by that key.""" + if not base_list or not override_list: + return override_list if override_list else base_list + + # Check if both lists contain dicts with a common identity key + merge_key = None + for candidate in ("task_class", "name", "id"): + if all(isinstance(item, dict) and candidate in item for item in base_list) and all( + isinstance(item, dict) and candidate in item for item in override_list + ): + merge_key = candidate + break + + if not merge_key: + # No merge key found, override replaces base entirely + return override_list + + # Merge by key: base items are kept/overridden, new override items appended + override_map = {item[merge_key]: item for item in override_list} + result = [] + seen = set() + for item in base_list: + key_val = item[merge_key] + if key_val in override_map: + # Deep merge the matching item + result.append(RockConfig._deep_merge(item, override_map[key_val])) + else: + result.append(item) + seen.add(key_val) + # Append new items from override that aren't in base + for item in override_list: + if item[merge_key] not in seen: + result.append(item) + return result + def __post_init__(self) -> None: logger.info(f"init RockConfig: {self}") diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index f6a993e1c8..bc69baf1a8 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,4 +1,5 @@ import tempfile +import textwrap from pathlib import Path import pytest @@ -237,3 +238,296 @@ async def test_resolve_includes_non_mapping_raises(): k8s = {"template_includes": ["bad.yml"]} with pytest.raises(ValueError, match="must be a mapping"): _resolve_k8s_template_includes(k8s, base) + + +# --------------------------------------------------------------------------- +# Unit tests for RockConfig._deep_merge +# --------------------------------------------------------------------------- + + +class TestDeepMerge: + """Tests for RockConfig._deep_merge static method.""" + + def test_disjoint_keys(self): + """Non-overlapping keys are all preserved.""" + base = {"a": 1, "b": 2} + override = {"c": 3} + result = RockConfig._deep_merge(base, override) + assert result == {"a": 1, "b": 2, "c": 3} + + def test_override_scalar(self): + """Override value replaces base for scalar keys.""" + base = {"a": 1, "b": 2} + override = {"b": 99} + result = RockConfig._deep_merge(base, override) + assert result == {"a": 1, "b": 99} + + def test_nested_dict_merge(self): + """Nested dicts are recursively merged.""" + base = {"x": {"a": 1, "b": 2}} + override = {"x": {"b": 20, "c": 30}} + result = RockConfig._deep_merge(base, override) + assert result == {"x": {"a": 1, "b": 20, "c": 30}} + + def test_deeply_nested_merge(self): + """Three-level nested dicts merge correctly.""" + base = {"l1": {"l2": {"a": 1, "b": 2}}} + override = {"l1": {"l2": {"b": 99, "c": 3}}} + result = RockConfig._deep_merge(base, override) + assert result == {"l1": {"l2": {"a": 1, "b": 99, "c": 3}}} + + def test_override_dict_with_scalar(self): + """Override a dict value with a scalar replaces entirely.""" + base = {"a": {"nested": 1}} + override = {"a": "flat"} + result = RockConfig._deep_merge(base, override) + assert result == {"a": "flat"} + + def test_override_scalar_with_dict(self): + """Override a scalar with a dict replaces entirely.""" + base = {"a": "flat"} + override = {"a": {"nested": 1}} + result = RockConfig._deep_merge(base, override) + assert result == {"a": {"nested": 1}} + + def test_empty_base(self): + base = {} + override = {"a": 1} + assert RockConfig._deep_merge(base, override) == {"a": 1} + + def test_empty_override(self): + base = {"a": 1} + override = {} + assert RockConfig._deep_merge(base, override) == {"a": 1} + + def test_both_empty(self): + assert RockConfig._deep_merge({}, {}) == {} + + def test_base_not_mutated(self): + """_deep_merge must not mutate the base dict.""" + base = {"a": {"b": 1}} + override = {"a": {"b": 2}} + RockConfig._deep_merge(base, override) + assert base == {"a": {"b": 1}} + + def test_list_values_delegate_to_merge_lists(self): + """When both values are lists, _merge_lists is invoked.""" + base = {"items": [1, 2]} + override = {"items": [3, 4]} + result = RockConfig._deep_merge(base, override) + # No merge key → override replaces base + assert result == {"items": [3, 4]} + + +# --------------------------------------------------------------------------- +# Unit tests for RockConfig._merge_lists +# --------------------------------------------------------------------------- + + +class TestMergeLists: + """Tests for RockConfig._merge_lists static method.""" + + # --- edge cases: empty inputs --- + + def test_empty_base_returns_override(self): + assert RockConfig._merge_lists([], [{"a": 1}]) == [{"a": 1}] + + def test_empty_override_returns_base(self): + assert RockConfig._merge_lists([{"a": 1}], []) == [{"a": 1}] + + def test_both_empty(self): + assert RockConfig._merge_lists([], []) == [] + + # --- no merge key: override replaces --- + + def test_no_merge_key_plain_values(self): + """Lists of non-dicts → override replaces base.""" + assert RockConfig._merge_lists([1, 2], [3, 4]) == [3, 4] + + def test_no_common_key_in_dicts(self): + """Dicts without a shared identity key → override replaces base.""" + base = [{"foo": 1}] + override = [{"bar": 2}] + assert RockConfig._merge_lists(base, override) == [{"bar": 2}] + + # --- merge by task_class --- + + def test_merge_by_task_class_override(self): + """Matched items are deep-merged by task_class.""" + base = [ + {"task_class": "cleanup", "interval": 60, "enabled": True}, + {"task_class": "report", "interval": 300}, + ] + override = [ + {"task_class": "cleanup", "interval": 120}, + ] + result = RockConfig._merge_lists(base, override) + assert len(result) == 2 + assert result[0] == {"task_class": "cleanup", "interval": 120, "enabled": True} + assert result[1] == {"task_class": "report", "interval": 300} + + def test_merge_by_task_class_append_new(self): + """New items in override are appended.""" + base = [{"task_class": "cleanup", "interval": 60}] + override = [ + {"task_class": "cleanup", "interval": 120}, + {"task_class": "audit", "interval": 600}, + ] + result = RockConfig._merge_lists(base, override) + assert len(result) == 2 + assert result[0]["task_class"] == "cleanup" + assert result[0]["interval"] == 120 + assert result[1] == {"task_class": "audit", "interval": 600} + + def test_merge_by_task_class_disable(self): + """Override can disable a base task via enabled: false.""" + base = [{"task_class": "cleanup", "interval": 60, "enabled": True}] + override = [{"task_class": "cleanup", "enabled": False}] + result = RockConfig._merge_lists(base, override) + assert result[0]["enabled"] is False + assert result[0]["interval"] == 60 # preserved from base + + # --- merge by name --- + + def test_merge_by_name(self): + base = [{"name": "svc-a", "port": 80}] + override = [{"name": "svc-a", "port": 8080}] + result = RockConfig._merge_lists(base, override) + assert result == [{"name": "svc-a", "port": 8080}] + + # --- merge by id --- + + def test_merge_by_id(self): + base = [{"id": "x1", "value": 10}] + override = [{"id": "x1", "value": 20}] + result = RockConfig._merge_lists(base, override) + assert result == [{"id": "x1", "value": 20}] + + # --- key priority: task_class > name > id --- + + def test_merge_key_priority_task_class_over_name(self): + """When both task_class and name exist, task_class is used.""" + base = [{"task_class": "A", "name": "na", "v": 1}] + override = [{"task_class": "A", "name": "nb", "v": 2}] + result = RockConfig._merge_lists(base, override) + assert result[0]["v"] == 2 + assert result[0]["name"] == "nb" # overridden + + # --- nested deep merge within list items --- + + def test_nested_dict_merge_within_list_item(self): + """Dict values inside matched list items are recursively merged.""" + base = [{"task_class": "t1", "params": {"a": 1, "b": 2}}] + override = [{"task_class": "t1", "params": {"b": 20, "c": 30}}] + result = RockConfig._merge_lists(base, override) + assert result[0]["params"] == {"a": 1, "b": 20, "c": 30} + + # --- preserving order --- + + def test_order_preserved(self): + """Base order is preserved; new override items appended at end.""" + base = [ + {"task_class": "B", "v": 1}, + {"task_class": "A", "v": 2}, + ] + override = [ + {"task_class": "C", "v": 3}, + {"task_class": "A", "v": 22}, + ] + result = RockConfig._merge_lists(base, override) + assert [item["task_class"] for item in result] == ["B", "A", "C"] + assert result[1]["v"] == 22 + + +# --------------------------------------------------------------------------- +# Integration test: from_env with _base inheritance +# --------------------------------------------------------------------------- + + +class TestFromEnvBaseInheritance: + """Test RockConfig.from_env _base config inheritance using temp files.""" + + def test_base_inheritance_deep_merges(self, tmp_path: Path): + """Child config inherits and overrides base via _deep_merge.""" + base_file = tmp_path / "base.yml" + base_file.write_text( + textwrap.dedent("""\ + ray: + namespace: "base-ns" + runtime_env: + working_dir: ./ + warmup: + images: + - "python:3.11" + """) + ) + + child_file = tmp_path / "child.yml" + child_file.write_text( + textwrap.dedent("""\ + _base: base.yml + ray: + namespace: "child-ns" + """) + ) + + config = RockConfig.from_env(config_path=str(child_file)) + # Overridden + assert config.ray.namespace == "child-ns" + # Inherited from base (runtime_env is deep-merged: child has no runtime_env so base is kept) + assert config.ray.runtime_env == {"working_dir": "./"} + assert config.warmup.images == ["python:3.11"] + + def test_base_inheritance_scheduler_tasks_merge(self, tmp_path: Path): + """Scheduler tasks list is merged by task_class key.""" + base_file = tmp_path / "base.yml" + base_file.write_text( + textwrap.dedent("""\ + scheduler: + tasks: + - task_class: "rock.admin.scheduler.tasks.cleanup.CleanupTask" + enabled: true + interval_seconds: 60 + - task_class: "rock.admin.scheduler.tasks.report.ReportTask" + enabled: true + interval_seconds: 300 + """) + ) + + child_file = tmp_path / "child.yml" + child_file.write_text( + textwrap.dedent("""\ + _base: base.yml + scheduler: + tasks: + - task_class: "rock.admin.scheduler.tasks.cleanup.CleanupTask" + interval_seconds: 120 + - task_class: "rock.admin.scheduler.tasks.audit.AuditTask" + enabled: true + interval_seconds: 600 + """) + ) + + config = RockConfig.from_env(config_path=str(child_file)) + tasks = config.scheduler.tasks + task_map = {t.task_class: t for t in tasks} + + # Overridden interval + assert task_map["rock.admin.scheduler.tasks.cleanup.CleanupTask"].interval_seconds == 120 + assert task_map["rock.admin.scheduler.tasks.cleanup.CleanupTask"].enabled is True # inherited + # Preserved from base + assert task_map["rock.admin.scheduler.tasks.report.ReportTask"].interval_seconds == 300 + # Newly appended + assert task_map["rock.admin.scheduler.tasks.audit.AuditTask"].interval_seconds == 600 + + def test_base_not_found_raises(self, tmp_path: Path): + child_file = tmp_path / "child.yml" + child_file.write_text( + textwrap.dedent("""\ + _base: nonexistent.yml + ray: + namespace: "ns" + """) + ) + with pytest.raises(Exception, match="base config file.*not found"): + RockConfig.from_env(config_path=str(child_file)) From 317cca58986f3f20527bea81e1fa107879ea4f42 Mon Sep 17 00:00:00 2001 From: jinbai <15652831212@163.com> Date: Fri, 22 May 2026 22:32:01 +0800 Subject: [PATCH 03/10] chore: bump version to 1.8.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 883fd83e70..79b238416c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta" authors = [{ name = "chatos@alibaba" }] requires-python = "<4.0,>=3.10" name = "rl-rock" -version = "1.8.0" +version = "1.8.1" description = "ROCK-Reinforcement Open Construction Kit" readme = "README.md" dependencies = [ From 1e31d7c40b837b494833dd701f33f3ce2549401b Mon Sep 17 00:00:00 2001 From: daifangwen Date: Thu, 23 Apr 2026 06:57:00 +0000 Subject: [PATCH 04/10] add tracking config into job config --- .../proposals/job-metrics-reporting-config.md | 287 ++++++++++++++++++ rock/sdk/envhub/__init__.py | 4 +- rock/sdk/envhub/config.py | 25 ++ 3 files changed, 314 insertions(+), 2 deletions(-) create mode 100644 docs/proposals/job-metrics-reporting-config.md diff --git a/docs/proposals/job-metrics-reporting-config.md b/docs/proposals/job-metrics-reporting-config.md new file mode 100644 index 0000000000..0c5387eda1 --- /dev/null +++ b/docs/proposals/job-metrics-reporting-config.md @@ -0,0 +1,287 @@ +## Job 级指标汇报配置方案 — 集成 Harbor ml_tracker + +### 背景 + +ROCK 的 Job 系统需要在 **Bench 评测** 和 **RL 训练** 场景中汇报运行指标。Harbor 框架已内置 `ml_tracker` 模块,可汇报以下关键指标: + +| 类别 | 指标 | +|------|------| +| **Reward** | `reward/*`(verifier 输出的各 reward key) | +| **Duration** | `total_duration_sec`、`agent_duration_sec` | +| **Token** | `input_tokens`、`output_tokens`、`cache_tokens`、`cost_usd` | +| **RL 训练** | `logprobs_mean`、`entropy`、`loss`、`kl_divergence`、`advantage`、`grad_norm`、`clip_fraction`、`value_loss`、`explained_variance` | +| **Running** | `pass_rate`、`avg_reward`、`error_rate` | +| **Summary** | `final_pass_rate`、`final_avg_reward`、`final_error_rate`、`total_trials`、`total_errors`、`total_duration_sec` | + +但当前 ml_tracker 的启用方式依赖**环境变量** `ROCK_API_KEY` 的存在性(硬编码判断),用户无法通过 Job 配置声明式地控制是否启用、传入超参数等。 + +**改动前**(Harbor `job.py`): + +```python +# 硬编码检查环境变量,无配置入口 +if os.environ.get("ROCK_API_KEY"): + self._tracker = MLTrackerFactory.create(...) +``` + +--- + +### 目标 + +在 `EnvironmentConfig` 上新增 **`tracking`** 字段,让用户在 YAML 的 `environment` 段中声明式地启用 Harbor 内置的 ml_tracker,汇报 Bench/RL 训练指标。 + +**设计原则**: +- **字段名不绑定具体 SDK**:用 `tracking`(而非 `ml_tracker`),避免配置字段与具体包名耦合 +- **复用 Harbor 已有能力**:不另起炉灶,底层仍调用 Harbor `ml_tracker` 模块 +- 所有字段可选,零配置向后兼容(默认不启用,保持现有行为) +- 不侵入 `HarborJobConfig.metrics: list[MetricConfig]`(那是评测结果的聚合策略,语义不同) + +--- + +### 方案(已实现) + +#### 模型定义 + +**ROCK 侧** — `rock/sdk/envhub/config.py`: + +`TrackingConfig` 定义在 `EnvironmentConfig` 同级,作为 `EnvironmentConfig` 的二级字段: + +```python +class TrackingConfig(BaseModel): + """Experiment tracking configuration. + + When present and enabled, activates Harbor's built-in ml_tracker to report + per-trial metrics (reward, duration, token usage, RL training signals) + and a final job-level summary. + """ + + enabled: bool = Field( + default=True, + description="Whether to enable experiment tracking for this job.", + ) + params: dict[str, Any] = Field( + default_factory=dict, + description=( + "User-defined hyperparameters merged into ml_tracker.init(config=...). " + "Combined with auto-collected job metadata (agents, datasets, etc.)." + ), + ) + +class EnvironmentConfig(SandboxConfig): + uploads: list[tuple[str, str]] = Field(default_factory=list) + env: dict[str, str] = Field(default_factory=dict) + oss_mirror: OssMirrorConfig | None = None + tracking: TrackingConfig | None = Field( + default=None, + description="Experiment tracking configuration. None = disabled (default).", + ) +``` + +**Harbor 侧** — `harbor/ml_tracker/config.py`(内部模块名保持 `ml_tracker` 不变): + +```python +class MLTrackerConfig(BaseModel): + enabled: bool = Field(default=True) + params: dict[str, Any] = Field( + default_factory=dict, + description="User-defined hyperparameters merged into ml_tracker.init(config=...).", + ) +``` + +> **命名决策**: +> - 用户配置字段名 = `tracking`(不绑定具体 SDK,未来可扩展到其他 tracker) +> - 子字段 = `params`(而非 `config`,避免 `tracking.config` 语义重复) +> - Harbor 内部模块目录仍叫 `ml_tracker/`(内部实现,不暴露给用户) + +#### 在配置层次中的位置 + +`tracking` 放在 `EnvironmentConfig` 下作为二级字段,而非 `JobConfig` 的一级字段。原因: + +- **与 `oss_mirror` 同层**:`tracking` 和 `oss_mirror` 都是环境级别的能力配置,放在 environment 下更内聚 +- **Harbor 的 `EnvironmentConfig` 天然包含这类配置**:Harbor YAML 中 environment 段是 tracking 信息的自然归属 +- **简化序列化**:`to_harbor_yaml()` 通过 `to_harbor_environment()` 序列化 environment 时自然携带 tracking + +```python +# JobConfig 不直接暴露 tracking,通过 environment 间接访问 +class JobConfig(BaseModel): + environment: EnvironmentConfig = Field(default_factory=EnvironmentConfig) + job_name: str | None = None + namespace: str | None = None + experiment_id: str | None = None + labels: dict[str, str] = Field(default_factory=dict) + timeout: int = 7200 +``` + +> **默认 `None`**:不写 `tracking` 时行为等价于改动前(不启用)。用户显式写 `environment.tracking: {}` 即可启用。 + +--- + +#### YAML 配置示例 + +**最简启用**(所有默认值,自动采集 agent/dataset 信息): + +```yaml +experiment_id: exp-rl-001 +job_name: qwen-72b-swe-bench +environment: + tracking: {} +``` + +**记录额外超参数**(RL 训练场景): + +```yaml +experiment_id: exp-rl-002 +job_name: rl-grpo-run-3 +environment: + tracking: + params: + model: qwen-72b-instruct + algorithm: GRPO + learning_rate: 1.0e-5 + batch_size: 64 + kl_coeff: 0.05 + num_rollouts: 4 +``` + +**显式禁用**(覆盖团队默认配置): + +```yaml +environment: + tracking: + enabled: false +``` + +**不写 `tracking`**(默认行为,等同于禁用): + +```yaml +experiment_id: exp-001 +job_name: my-job +environment: {} +# tracking 不出现 → None → 不启用 +``` + +--- + +### 字段说明 + +| 字段路径 | 类型 | 默认值 | 说明 | +|----------|------|--------|------| +| **`environment.tracking`** | `TrackingConfig \| None` | `None` | 开关。`None` = 不启用(向后兼容);写 `{}` = 启用。 | +| **`environment.tracking.enabled`** | `bool` | `True` | 细粒度开关。配合 `tracking: { enabled: false }` 可显式禁用。 | +| **`environment.tracking.params`** | `dict[str, Any]` | `{}` | 用户自定义超参数,与自动采集的 job metadata 合并后传给 `ml_tracker.init(config=...)`。 | + +两层开关的设计意图: +- `tracking` 不写 / `null` → 不启用(向后兼容,默认路径) +- `tracking: {}` → 启用(`enabled` 默认 `True`) +- `tracking: { enabled: false }` → 显式禁用(团队配置模板中可以预留 `tracking` 段落但暂时关闭) + +--- + +### 汇报的指标详情 + +启用 tracking 后,Harbor 框架会在以下时机自动汇报: + +**每个 Trial 结束时**(`TrialEvent.END` hook): + +``` +reward/* — verifier 输出的 reward 值(每个 key 单独上报) +total_duration_sec — Trial 总耗时 +agent_duration_sec — Agent 执行耗时 +input_tokens — 输入 token 数 +output_tokens — 输出 token 数 +cache_tokens — 缓存 token 数 +cost_usd — 推理花费(USD) +logprobs_mean — rollout log probabilities 均值(RL) +entropy — 策略熵 = -logprobs_mean(RL) +loss — 训练 loss(RL,来自 agent metadata) +kl_divergence — KL 散度(RL) +advantage — 优势值(RL) +grad_norm — 梯度范数(RL) +clip_fraction — PPO clip fraction(RL) +value_loss — 值函数 loss(RL) +explained_variance — 解释方差(RL) +pass_rate — 截至当前的通过率(running) +avg_reward — 截至当前的平均 reward(running) +error_rate — 截至当前的错误率(running) +``` + +**Job 结束时**(`report_job_summary`): + +``` +final_pass_rate — 最终通过率 +final_avg_reward — 最终平均 reward +final_error_rate — 最终错误率 +total_trials — 总 trial 数 +total_errors — 总错误数 +total_duration_sec — Job 总耗时 +``` + +--- + +### 与现有体系的关系 + +``` +JobConfig +├── environment: EnvironmentConfig +│ ├── uploads, env, ... ← 已有: 环境级配置 +│ ├── oss_mirror: OssMirrorConfig ← 已有: OSS 镜像配置 +│ └── tracking: TrackingConfig | None ← NEW: 实验追踪配置 +├── labels: dict[str, str] ← 已有: Job 级标签 +└── ... + +HarborJobConfig(JobConfig) +├── environment.tracking (inherited) ← NEW: 通过 environment 继承 +├── metrics: list[MetricConfig] ← 已有: 评测结果聚合方式(sum/mean/max) +└── ... + +BashJobConfig(JobConfig) +├── environment.tracking (inherited) ← NEW: 通过 environment 继承 +└── ... +``` + +**关键区分**: +- **`environment.tracking`**(新增)= "实验追踪:每个 Trial 的 **业务指标怎么记录**"(reward/token/RL signals → ml_tracker SDK) +- **`metrics`**(HarborJobConfig 已有)= "评测聚合:多个 Trial 的结果 **怎么聚合成最终分数**"(mean/sum/max) + +两者语义正交,互不冲突。`tracking` 与 `oss_mirror` 同层,都属于环境级别的能力配置。 + +--- + +### 改动文件清单 + +#### ROCK 侧 + +| 文件 | 改动 | +|------|------| +| `rock/sdk/envhub/config.py` | 新增 `TrackingConfig` 类 + `EnvironmentConfig.tracking` 字段 | + +#### Harbor 侧 + +| 文件 | 改动 | +|------|------| +| `harbor/ml_tracker/config.py` | `config` 字段重命名为 `params` | +| `harbor/ml_tracker/factory.py` | 新增 `tracker_config` 参数,合并用户 `params` 到自动采集的 config | +| `harbor/models/job/config.py` | 新增 `tracking: MLTrackerConfig \| None` 字段 | +| `harbor/job.py` | 从 `self.config.tracking` 读取配置替代 env var 硬编码;传递 `tracker_config` 给 factory | +| `tests/unit/ml_tracker/test_config.py` | `config` → `params` 适配 | +| `tests/unit/ml_tracker/test_factory.py` | 新增 `test_create_merges_user_params` 测试 | +| `tests/unit/ml_tracker/test_job_integration.py` | 重写为测试 `tracking` 字段的配置集成 | + +#### 配置传递链路 + +``` +用户 YAML + → rock HarborJobConfig.environment.tracking (解析 + 校验) + → to_harbor_yaml() → environment 段携带 tracking + → harbor JobConfig.tracking (反序列化) + → harbor Job.__init__ 读取 → MLTrackerFactory.create(tracker_config=...) +``` + +--- + +### 向后兼容性 + +- `EnvironmentConfig.tracking` 默认为 `None`,不写等价于改动前行为(不启用)。 +- `SandboxConfig` 基类不受影响(`tracking` 只加在 `EnvironmentConfig` 层)。 +- `BashJobConfig` / `HarborJobConfig`:通过 `environment` 间接访问,不涉及 `extra="forbid"` 问题。 +- `_HarborJobFields`:environment 中 `tracking` 为 `None` 时被序列化过滤,不出现在 Harbor YAML 中。 +- `ROCK_API_KEY` 环境变量:Harbor `job.py` 中同时检查 `tracking is not None and tracking.enabled` **和** `ROCK_API_KEY`,两个条件都满足才启用。这保证了即使配置启用了 tracking,没有 API key 也不会报错。 diff --git a/rock/sdk/envhub/__init__.py b/rock/sdk/envhub/__init__.py index 115ee11588..5a528d5c34 100644 --- a/rock/sdk/envhub/__init__.py +++ b/rock/sdk/envhub/__init__.py @@ -1,3 +1,3 @@ -from rock.sdk.envhub.config import EnvironmentConfig, OssMirrorConfig +from rock.sdk.envhub.config import EnvironmentConfig, OssMirrorConfig, TrackingConfig -__all__ = ["EnvironmentConfig", "OssMirrorConfig"] +__all__ = ["EnvironmentConfig", "OssMirrorConfig", "TrackingConfig"] diff --git a/rock/sdk/envhub/config.py b/rock/sdk/envhub/config.py index 6aad1688aa..5ed3a8cc10 100644 --- a/rock/sdk/envhub/config.py +++ b/rock/sdk/envhub/config.py @@ -6,6 +6,8 @@ from __future__ import annotations +from typing import Any + from pydantic import BaseModel, Field, model_validator from rock.sdk.sandbox.config import SandboxConfig @@ -70,6 +72,25 @@ def _record_replay_mutually_exclusive(self): "set one (recording mode) or the other (replay mode), not both." ) return self +class TrackingConfig(BaseModel): + """Experiment tracking configuration. + + When present and enabled, activates Harbor's built-in ml_tracker to report + per-trial metrics (reward, duration, token usage, RL training signals) + and a final job-level summary. + """ + + enabled: bool = Field( + default=True, + description="Whether to enable experiment tracking for this job.", + ) + params: dict[str, Any] = Field( + default_factory=dict, + description=( + "User-defined hyperparameters merged into ml_tracker.init(config=...). " + "Combined with auto-collected job metadata (agents, datasets, etc.)." + ), + ) class EnvironmentConfig(SandboxConfig): @@ -85,3 +106,7 @@ class EnvironmentConfig(SandboxConfig): proxy: ProxyConfig | None = None """In-sandbox model-service proxy for OpenAI request record/replay. None (default) means no proxy is started.""" + tracking: TrackingConfig | None = Field( + default=None, + description="Experiment tracking configuration. None = disabled (default).", + ) From 6ea807cce2fdfe8eaaa16171de2ccb11321dfce6 Mon Sep 17 00:00:00 2001 From: daifangwen Date: Thu, 23 Apr 2026 09:02:52 +0000 Subject: [PATCH 05/10] update docs --- .../proposals/job-metrics-reporting-config.md | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/docs/proposals/job-metrics-reporting-config.md b/docs/proposals/job-metrics-reporting-config.md index 0c5387eda1..01fd693037 100644 --- a/docs/proposals/job-metrics-reporting-config.md +++ b/docs/proposals/job-metrics-reporting-config.md @@ -76,21 +76,33 @@ class EnvironmentConfig(SandboxConfig): ) ``` -**Harbor 侧** — `harbor/ml_tracker/config.py`(内部模块名保持 `ml_tracker` 不变): +**Harbor 侧** — `harbor/tracker/config.py`(模块目录已从 `ml_tracker/` 重命名为 `tracker/`): ```python -class MLTrackerConfig(BaseModel): +class TrackingConfig(BaseModel): enabled: bool = Field(default=True) params: dict[str, Any] = Field( default_factory=dict, - description="User-defined hyperparameters merged into ml_tracker.init(config=...).", + description="User-defined hyperparameters merged into tracker init config.", + ) +``` + +Harbor 侧 `EnvironmentConfig`(`harbor/models/trial/config.py`)中同样作为二级字段: + +```python +class EnvironmentConfig(BaseModel): + ... + tracking: TrackingConfig | None = Field( + default=None, + description="Experiment tracking configuration. None = disabled (default).", ) ``` > **命名决策**: > - 用户配置字段名 = `tracking`(不绑定具体 SDK,未来可扩展到其他 tracker) > - 子字段 = `params`(而非 `config`,避免 `tracking.config` 语义重复) -> - Harbor 内部模块目录仍叫 `ml_tracker/`(内部实现,不暴露给用户) +> - ROCK 和 Harbor 两侧配置类统一命名为 `TrackingConfig` +> - Harbor 内部模块目录从 `ml_tracker/` 重命名为 `tracker/`,实现类 `MLTrackerImpl` 保持不变(因为底层仍使用 `ml_tracker` SDK) #### 在配置层次中的位置 @@ -258,13 +270,13 @@ BashJobConfig(JobConfig) | 文件 | 改动 | |------|------| -| `harbor/ml_tracker/config.py` | `config` 字段重命名为 `params` | -| `harbor/ml_tracker/factory.py` | 新增 `tracker_config` 参数,合并用户 `params` 到自动采集的 config | -| `harbor/models/job/config.py` | 新增 `tracking: MLTrackerConfig \| None` 字段 | -| `harbor/job.py` | 从 `self.config.tracking` 读取配置替代 env var 硬编码;传递 `tracker_config` 给 factory | -| `tests/unit/ml_tracker/test_config.py` | `config` → `params` 适配 | -| `tests/unit/ml_tracker/test_factory.py` | 新增 `test_create_merges_user_params` 测试 | -| `tests/unit/ml_tracker/test_job_integration.py` | 重写为测试 `tracking` 字段的配置集成 | +| `harbor/tracker/config.py` | 新模块,`TrackingConfig`(`enabled` + `params`) | +| `harbor/tracker/base.py` | `BaseMLTracker` → `BaseTracker` | +| `harbor/tracker/tracker.py` | `MLTrackerImpl` 改为继承 `BaseTracker`,逻辑不变 | +| `harbor/tracker/factory.py` | `MLTrackerFactory`,新增 `tracker_config` 参数,合并用户 `params` | +| `harbor/models/trial/config.py` | `EnvironmentConfig` 新增 `tracking: TrackingConfig \| None` 字段 | +| `harbor/job.py` | 从 `self.config.environment.tracking` 读取配置;`tracking.enabled` 控制启用(不再硬编码检查 env var) | +| ~~`harbor/ml_tracker/`~~ | 整个目录重命名为 `harbor/tracker/` | #### 配置传递链路 @@ -272,8 +284,9 @@ BashJobConfig(JobConfig) 用户 YAML → rock HarborJobConfig.environment.tracking (解析 + 校验) → to_harbor_yaml() → environment 段携带 tracking - → harbor JobConfig.tracking (反序列化) - → harbor Job.__init__ 读取 → MLTrackerFactory.create(tracker_config=...) + → harbor EnvironmentConfig.tracking (反序列化) + → harbor Job.__init__ 从 self.config.environment.tracking 读取 + → MLTrackerFactory.create(tracker_config=...) ``` --- @@ -284,4 +297,4 @@ BashJobConfig(JobConfig) - `SandboxConfig` 基类不受影响(`tracking` 只加在 `EnvironmentConfig` 层)。 - `BashJobConfig` / `HarborJobConfig`:通过 `environment` 间接访问,不涉及 `extra="forbid"` 问题。 - `_HarborJobFields`:environment 中 `tracking` 为 `None` 时被序列化过滤,不出现在 Harbor YAML 中。 -- `ROCK_API_KEY` 环境变量:Harbor `job.py` 中同时检查 `tracking is not None and tracking.enabled` **和** `ROCK_API_KEY`,两个条件都满足才启用。这保证了即使配置启用了 tracking,没有 API key 也不会报错。 +- `ROCK_API_KEY` 环境变量:不再用于控制启用逻辑(改由 `tracking.enabled` 控制)。`ROCK_API_KEY` 仅在 `MLTrackerImpl.__init__` 中作为 `ml_tracker.login(key=...)` 的凭证使用,未设置时传 `None`(由 SDK 自行处理鉴权 fallback)。 From cf6a5e6d8750f33fe5bb28d459478547b464cb61 Mon Sep 17 00:00:00 2001 From: daifangwen Date: Thu, 23 Apr 2026 09:43:14 +0000 Subject: [PATCH 06/10] add test cases --- tests/unit/sdk/job/test_config.py | 110 ++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/tests/unit/sdk/job/test_config.py b/tests/unit/sdk/job/test_config.py index 8c678ab337..e8b41e9081 100644 --- a/tests/unit/sdk/job/test_config.py +++ b/tests/unit/sdk/job/test_config.py @@ -20,6 +20,7 @@ VerifierConfig, ) from rock.sdk.envhub import EnvironmentConfig +from rock.sdk.envhub.config import TrackingConfig from rock.sdk.job.config import BashJobConfig, JobConfig # --------------------------------------------------------------------------- @@ -684,3 +685,112 @@ def test_defaults_to_datetime_string(self): def test_explicit_name_preserved(self): assert BashJobConfig(job_name="x").job_name == "x" + + +# --------------------------------------------------------------------------- +# TrackingConfig +# --------------------------------------------------------------------------- + + +class TestTrackingConfig: + def test_default_values(self): + config = TrackingConfig() + assert config.enabled is True + assert config.params == {} + + def test_disabled(self): + config = TrackingConfig(enabled=False) + assert config.enabled is False + assert config.params == {} + + def test_custom_params(self): + config = TrackingConfig(params={"learning_rate": 0.01, "epochs": 10, "model": "qwen-72b"}) + assert config.params["learning_rate"] == 0.01 + assert config.params["epochs"] == 10 + assert config.params["model"] == "qwen-72b" + + def test_from_dict(self): + data = {"enabled": True, "params": {"batch_size": 32}} + config = TrackingConfig.model_validate(data) + assert config.enabled is True + assert config.params["batch_size"] == 32 + + def test_from_dict_minimal(self): + config = TrackingConfig.model_validate({}) + assert config.enabled is True + assert config.params == {} + + def test_serialization_roundtrip(self): + config = TrackingConfig(enabled=True, params={"lr": 0.001, "algo": "GRPO"}) + json_str = config.model_dump_json() + restored = TrackingConfig.model_validate_json(json_str) + assert restored == config + + +# --------------------------------------------------------------------------- +# TrackingConfig on base EnvironmentConfig +# --------------------------------------------------------------------------- + + +class TestTrackingConfigOnBaseEnvironment: + def test_tracking_default_none(self): + env = EnvironmentConfig() + assert env.tracking is None + + def test_tracking_enabled(self): + env = EnvironmentConfig(tracking=TrackingConfig()) + assert env.tracking is not None + assert env.tracking.enabled is True + assert env.tracking.params == {} + + def test_tracking_disabled(self): + env = EnvironmentConfig(tracking=TrackingConfig(enabled=False)) + assert env.tracking is not None + assert env.tracking.enabled is False + + def test_tracking_with_params(self): + env = EnvironmentConfig(tracking=TrackingConfig(params={"model": "qwen-72b", "lr": 1e-5})) + assert env.tracking.params["model"] == "qwen-72b" + assert env.tracking.params["lr"] == 1e-5 + + def test_tracking_from_dict(self): + data = {"tracking": {"enabled": True, "params": {"batch_size": 64}}} + env = EnvironmentConfig.model_validate(data) + assert env.tracking is not None + assert env.tracking.params["batch_size"] == 64 + + def test_tracking_none_from_dict(self): + data = {"tracking": None} + env = EnvironmentConfig.model_validate(data) + assert env.tracking is None + + def test_tracking_empty_dict_from_yaml(self): + """Simulates YAML `tracking: {}` — should enable with defaults.""" + data = {"tracking": {}} + env = EnvironmentConfig.model_validate(data) + assert env.tracking is not None + assert env.tracking.enabled is True + assert env.tracking.params == {} + + def test_tracking_coexists_with_other_fields(self): + env = EnvironmentConfig( + image="python:3.11", + env={"MY_VAR": "hello"}, + tracking=TrackingConfig(params={"model": "test"}), + ) + assert env.image == "python:3.11" + assert env.env == {"MY_VAR": "hello"} + assert env.tracking.params["model"] == "test" + + def test_serialization_roundtrip_with_tracking(self): + env = EnvironmentConfig(tracking=TrackingConfig(params={"lr": 0.01})) + json_str = env.model_dump_json() + restored = EnvironmentConfig.model_validate_json(json_str) + assert restored.tracking is not None + assert restored.tracking.params["lr"] == 0.01 + + def test_serialization_roundtrip_without_tracking(self): + env = EnvironmentConfig() + json_str = env.model_dump_json() + restored = EnvironmentConfig.model_validate_json(json_str) + assert restored.tracking is None From dddfc50b04910328f29ce03f8d8f4615d0023338 Mon Sep 17 00:00:00 2001 From: daifangwen Date: Thu, 23 Apr 2026 10:07:11 +0000 Subject: [PATCH 07/10] add api_key --- rock/sdk/envhub/config.py | 4 +++ tests/unit/sdk/job/test_config.py | 52 +++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/rock/sdk/envhub/config.py b/rock/sdk/envhub/config.py index 5ed3a8cc10..ad80ff56cd 100644 --- a/rock/sdk/envhub/config.py +++ b/rock/sdk/envhub/config.py @@ -84,6 +84,10 @@ class TrackingConfig(BaseModel): default=True, description="Whether to enable experiment tracking for this job.", ) + api_key: str | None = Field( + default=None, + description="API key for the tracking platform. Falls back to ROCK_API_KEY env var if not set.", + ) params: dict[str, Any] = Field( default_factory=dict, description=( diff --git a/tests/unit/sdk/job/test_config.py b/tests/unit/sdk/job/test_config.py index e8b41e9081..b9896133fa 100644 --- a/tests/unit/sdk/job/test_config.py +++ b/tests/unit/sdk/job/test_config.py @@ -696,11 +696,13 @@ class TestTrackingConfig: def test_default_values(self): config = TrackingConfig() assert config.enabled is True + assert config.api_key is None assert config.params == {} def test_disabled(self): config = TrackingConfig(enabled=False) assert config.enabled is False + assert config.api_key is None assert config.params == {} def test_custom_params(self): @@ -709,22 +711,56 @@ def test_custom_params(self): assert config.params["epochs"] == 10 assert config.params["model"] == "qwen-72b" + def test_api_key(self): + config = TrackingConfig(api_key="sk-test-key-123") + assert config.api_key == "sk-test-key-123" + assert config.enabled is True + + def test_api_key_with_disabled(self): + config = TrackingConfig(enabled=False, api_key="sk-key") + assert config.enabled is False + assert config.api_key == "sk-key" + + def test_api_key_none_by_default(self): + config = TrackingConfig() + assert config.api_key is None + def test_from_dict(self): data = {"enabled": True, "params": {"batch_size": 32}} config = TrackingConfig.model_validate(data) assert config.enabled is True + assert config.api_key is None assert config.params["batch_size"] == 32 + def test_from_dict_with_api_key(self): + data = {"api_key": "sk-from-dict", "params": {"lr": 0.01}} + config = TrackingConfig.model_validate(data) + assert config.api_key == "sk-from-dict" + assert config.params["lr"] == 0.01 + def test_from_dict_minimal(self): config = TrackingConfig.model_validate({}) assert config.enabled is True + assert config.api_key is None assert config.params == {} def test_serialization_roundtrip(self): - config = TrackingConfig(enabled=True, params={"lr": 0.001, "algo": "GRPO"}) + config = TrackingConfig(enabled=True, api_key="sk-round", params={"lr": 0.001, "algo": "GRPO"}) json_str = config.model_dump_json() restored = TrackingConfig.model_validate_json(json_str) assert restored == config + assert restored.api_key == "sk-round" + + def test_exclude_none_omits_api_key_when_not_set(self): + config = TrackingConfig(params={"lr": 0.01}) + data = config.model_dump(mode="json", exclude_none=True) + assert "api_key" not in data + assert data["params"] == {"lr": 0.01} + + def test_exclude_none_includes_api_key_when_set(self): + config = TrackingConfig(api_key="sk-present") + data = config.model_dump(mode="json", exclude_none=True) + assert data["api_key"] == "sk-present" # --------------------------------------------------------------------------- @@ -753,12 +789,23 @@ def test_tracking_with_params(self): assert env.tracking.params["model"] == "qwen-72b" assert env.tracking.params["lr"] == 1e-5 + def test_tracking_with_api_key(self): + env = EnvironmentConfig(tracking=TrackingConfig(api_key="sk-env-key", params={"model": "test"})) + assert env.tracking.api_key == "sk-env-key" + assert env.tracking.params["model"] == "test" + def test_tracking_from_dict(self): data = {"tracking": {"enabled": True, "params": {"batch_size": 64}}} env = EnvironmentConfig.model_validate(data) assert env.tracking is not None assert env.tracking.params["batch_size"] == 64 + def test_tracking_from_dict_with_api_key(self): + data = {"tracking": {"api_key": "sk-yaml-key", "params": {"lr": 0.01}}} + env = EnvironmentConfig.model_validate(data) + assert env.tracking.api_key == "sk-yaml-key" + assert env.tracking.params["lr"] == 0.01 + def test_tracking_none_from_dict(self): data = {"tracking": None} env = EnvironmentConfig.model_validate(data) @@ -783,10 +830,11 @@ def test_tracking_coexists_with_other_fields(self): assert env.tracking.params["model"] == "test" def test_serialization_roundtrip_with_tracking(self): - env = EnvironmentConfig(tracking=TrackingConfig(params={"lr": 0.01})) + env = EnvironmentConfig(tracking=TrackingConfig(api_key="sk-rt", params={"lr": 0.01})) json_str = env.model_dump_json() restored = EnvironmentConfig.model_validate_json(json_str) assert restored.tracking is not None + assert restored.tracking.api_key == "sk-rt" assert restored.tracking.params["lr"] == 0.01 def test_serialization_roundtrip_without_tracking(self): From dbd635a58ecf0e0ed23e266019aae6b29a103580 Mon Sep 17 00:00:00 2001 From: daifangwen Date: Fri, 24 Apr 2026 03:36:30 +0000 Subject: [PATCH 08/10] add tracking to harbor environment config --- rock/sdk/bench/models/trial/config.py | 3 ++- tests/unit/sdk/job/test_config.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/rock/sdk/bench/models/trial/config.py b/rock/sdk/bench/models/trial/config.py index 05f26ce20d..d4a6982859 100644 --- a/rock/sdk/bench/models/trial/config.py +++ b/rock/sdk/bench/models/trial/config.py @@ -7,7 +7,7 @@ from rock.sdk.bench.models.environment_type import EnvironmentType from rock.sdk.envhub import EnvironmentConfig as _EnvConfig -from rock.sdk.envhub.config import OssMirrorConfig +from rock.sdk.envhub.config import OssMirrorConfig, TrackingConfig class AgentConfig(BaseModel): @@ -33,6 +33,7 @@ class EnvironmentConfig(BaseModel): suppress_override_warnings: bool = False mounts_json: list[dict[str, Any]] | None = None oss_mirror: OssMirrorConfig | None = None + tracking: TrackingConfig | None = None oss_deps: dict[str, str] = Field(default_factory=dict) env: dict[str, str] = Field(default_factory=dict) kwargs: dict[str, Any] = Field(default_factory=dict) diff --git a/tests/unit/sdk/job/test_config.py b/tests/unit/sdk/job/test_config.py index b9896133fa..631d576eef 100644 --- a/tests/unit/sdk/job/test_config.py +++ b/tests/unit/sdk/job/test_config.py @@ -293,12 +293,31 @@ def test_returns_valid_yaml_string(self): parsed = yaml.safe_load(yaml_str) assert isinstance(parsed, dict) + def test_tracking_config_preserved_in_harbor_yaml(self): + """tracking config on environment must survive to_harbor_yaml() serialization.""" + tracking = TrackingConfig(enabled=True, api_key="sk-test-123", params={"lr": 0.01}) + env = RockEnvironmentConfig(tracking=tracking) + cfg = HarborJobConfig(experiment_id="test-exp", environment=env) + yaml_str = cfg.to_harbor_yaml() + data = yaml.safe_load(yaml_str) + assert "environment" in data + assert "tracking" in data["environment"], "tracking must not be stripped by to_harbor_yaml()" + assert data["environment"]["tracking"]["enabled"] is True + assert data["environment"]["tracking"]["api_key"] == "sk-test-123" + assert data["environment"]["tracking"]["params"] == {"lr": 0.01} + + def test_tracking_config_none_omitted_in_harbor_yaml(self): + """When tracking is None (default), it should not appear in harbor YAML.""" + cfg = HarborJobConfig(experiment_id="test-exp") + yaml_str = cfg.to_harbor_yaml() + data = yaml.safe_load(yaml_str) + env_data = data.get("environment", {}) + assert "tracking" not in env_data + # --------------------------------------------------------------------------- # HarborJobConfig.from_yaml # --------------------------------------------------------------------------- - - class TestHarborJobConfigFromYaml: def test_round_trip(self, tmp_path): """Write a YAML config, read it back, verify fields.""" From 462b089d27f73d2e45537b4daacb745acbc73f9c Mon Sep 17 00:00:00 2001 From: jiaoliao Date: Mon, 1 Jun 2026 06:05:06 +0000 Subject: [PATCH 09/10] sdk support gpu --- rock/sdk/sandbox/client.py | 2 ++ rock/sdk/sandbox/config.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/rock/sdk/sandbox/client.py b/rock/sdk/sandbox/client.py index 3a57eacda5..90007daba0 100644 --- a/rock/sdk/sandbox/client.py +++ b/rock/sdk/sandbox/client.py @@ -171,6 +171,8 @@ async def start(self): "startup_timeout": self.config.startup_timeout, "memory": self.config.memory, "cpus": self.config.cpus, + "num_gpus": self.config.num_gpus, + "accelerator_type": self.config.accelerator_type, "registry_username": self.config.registry_username, "registry_password": self.config.registry_password, "use_kata_runtime": self.config.use_kata_runtime, diff --git a/rock/sdk/sandbox/config.py b/rock/sdk/sandbox/config.py index 4fcf59e030..6ca3fa1c01 100644 --- a/rock/sdk/sandbox/config.py +++ b/rock/sdk/sandbox/config.py @@ -36,6 +36,8 @@ class SandboxConfig(BaseConfig): memory: str = "8g" cpus: float = 2 limit_cpus: float | None = None + num_gpus: float | None = None + accelerator_type: str | None = None user_id: str | None = None experiment_id: str | None = None cluster: str = env_vars.ROCK_DEFAULT_CLUSTER From 008f43fcccec3e796e064ee86c7a6312c9cf8e6c Mon Sep 17 00:00:00 2001 From: jiaoliao Date: Mon, 1 Jun 2026 06:19:05 +0000 Subject: [PATCH 10/10] update version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 79b238416c..8024cc11fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta" authors = [{ name = "chatos@alibaba" }] requires-python = "<4.0,>=3.10" name = "rl-rock" -version = "1.8.1" +version = "1.8.3" description = "ROCK-Reinforcement Open Construction Kit" readme = "README.md" dependencies = [