From cd62b767113ee83a5db1d25313019b94dd121dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Mon, 25 May 2026 17:45:48 +0800 Subject: [PATCH 1/9] feat(datasets): add list_organizations to registry --- rock/sdk/envhub/datasets/registry/base.py | 5 ++++ rock/sdk/envhub/datasets/registry/oss.py | 6 ++++ tests/unit/datasets/test_oss_registry.py | 35 +++++++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/rock/sdk/envhub/datasets/registry/base.py b/rock/sdk/envhub/datasets/registry/base.py index 5c9192d2d6..978690eeb4 100644 --- a/rock/sdk/envhub/datasets/registry/base.py +++ b/rock/sdk/envhub/datasets/registry/base.py @@ -16,6 +16,11 @@ def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test """List task ids for one dataset split. Returns None if dataset/split has no tasks.""" ... + @abstractmethod + def list_organizations(self) -> list[str]: + """List organization names under the dataset registry. Single backend call.""" + ... + @abstractmethod def upload_dataset( self, diff --git a/rock/sdk/envhub/datasets/registry/oss.py b/rock/sdk/envhub/datasets/registry/oss.py index 4e90613571..512ce84ca6 100644 --- a/rock/sdk/envhub/datasets/registry/oss.py +++ b/rock/sdk/envhub/datasets/registry/oss.py @@ -70,6 +70,12 @@ def _extract_tasks_from_split(self, bucket: oss2.Bucket, split_prefix: str) -> l all_tasks = sorted(set(dir_tasks + file_tasks)) return all_tasks + def list_organizations(self) -> list[str]: + bucket = self._build_bucket() + base = self._registry.oss_dataset_path or "datasets" + result = bucket.list_objects_v2(prefix=f"{base}/", delimiter="/", max_keys=1000) + return sorted(self._last_segment(p) for p in result.prefix_list) + def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" diff --git a/tests/unit/datasets/test_oss_registry.py b/tests/unit/datasets/test_oss_registry.py index 9237ac1d6f..4ba00b3324 100644 --- a/tests/unit/datasets/test_oss_registry.py +++ b/tests/unit/datasets/test_oss_registry.py @@ -281,3 +281,38 @@ def test_upload_dataset_oss_key_format(tmp_path): key = mock_bucket.put_object.call_args[0][0] assert key == "datasets/qwen/my-bench/train/task-001/task.toml" + + +# --------------------------------------------------------------------------- +# list_organizations tests +# --------------------------------------------------------------------------- + + +def test_list_organizations_returns_sorted_org_names(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ + "datasets/qwen/", + "datasets/alibaba/", + "datasets/AoneBenchDev/", + ]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + orgs = registry.list_organizations() + + call_kwargs = mock_bucket.list_objects_v2.call_args[1] + assert call_kwargs["prefix"] == "datasets/" + assert call_kwargs["delimiter"] == "/" + assert call_kwargs["max_keys"] == 1000 + assert orgs == ["AoneBenchDev", "alibaba", "qwen"] + + +def test_list_organizations_returns_empty_when_no_orgs(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + orgs = registry.list_organizations() + + assert orgs == [] From fb9032f67e8dd6a6f0b018fd0aa12294dbda510e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Mon, 25 May 2026 17:51:40 +0800 Subject: [PATCH 2/9] feat(datasets): add list_org_datasets to registry --- rock/sdk/envhub/datasets/registry/base.py | 5 +++++ rock/sdk/envhub/datasets/registry/oss.py | 8 +++++++ tests/unit/datasets/test_oss_registry.py | 27 +++++++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/rock/sdk/envhub/datasets/registry/base.py b/rock/sdk/envhub/datasets/registry/base.py index 978690eeb4..1df3a6b5bb 100644 --- a/rock/sdk/envhub/datasets/registry/base.py +++ b/rock/sdk/envhub/datasets/registry/base.py @@ -21,6 +21,11 @@ def list_organizations(self) -> list[str]: """List organization names under the dataset registry. Single backend call.""" ... + @abstractmethod + def list_org_datasets(self, organization: str) -> list[str]: + """List dataset names under one organization. Single backend call.""" + ... + @abstractmethod def upload_dataset( self, diff --git a/rock/sdk/envhub/datasets/registry/oss.py b/rock/sdk/envhub/datasets/registry/oss.py index 512ce84ca6..7118e0be20 100644 --- a/rock/sdk/envhub/datasets/registry/oss.py +++ b/rock/sdk/envhub/datasets/registry/oss.py @@ -76,6 +76,14 @@ def list_organizations(self) -> list[str]: result = bucket.list_objects_v2(prefix=f"{base}/", delimiter="/", max_keys=1000) return sorted(self._last_segment(p) for p in result.prefix_list) + def list_org_datasets(self, organization: str) -> list[str]: + bucket = self._build_bucket() + base = self._registry.oss_dataset_path or "datasets" + result = bucket.list_objects_v2( + prefix=f"{base}/{organization}/", delimiter="/", max_keys=1000 + ) + return sorted(self._last_segment(p) for p in result.prefix_list) + def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" diff --git a/tests/unit/datasets/test_oss_registry.py b/tests/unit/datasets/test_oss_registry.py index 4ba00b3324..800305a78b 100644 --- a/tests/unit/datasets/test_oss_registry.py +++ b/tests/unit/datasets/test_oss_registry.py @@ -316,3 +316,30 @@ def test_list_organizations_returns_empty_when_no_orgs(): orgs = registry.list_organizations() assert orgs == [] + + +def test_list_org_datasets_returns_sorted_dataset_names(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ + "datasets/qwen/bench-2/", + "datasets/qwen/bench-1/", + ]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + datasets = registry.list_org_datasets("qwen") + + call_kwargs = mock_bucket.list_objects_v2.call_args[1] + assert call_kwargs["prefix"] == "datasets/qwen/" + assert call_kwargs["delimiter"] == "/" + assert call_kwargs["max_keys"] == 1000 + assert datasets == ["bench-1", "bench-2"] + + +def test_list_org_datasets_returns_empty_when_org_missing(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + assert registry.list_org_datasets("nonexistent") == [] From f3bc24126712b076dcd85c461cd24b63ebd5956f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Mon, 25 May 2026 17:54:47 +0800 Subject: [PATCH 3/9] feat(datasets): add list_dataset_splits to registry --- rock/sdk/envhub/datasets/registry/base.py | 5 +++++ rock/sdk/envhub/datasets/registry/oss.py | 8 +++++++ tests/unit/datasets/test_oss_registry.py | 26 +++++++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/rock/sdk/envhub/datasets/registry/base.py b/rock/sdk/envhub/datasets/registry/base.py index 1df3a6b5bb..a7e5aa31c6 100644 --- a/rock/sdk/envhub/datasets/registry/base.py +++ b/rock/sdk/envhub/datasets/registry/base.py @@ -26,6 +26,11 @@ def list_org_datasets(self, organization: str) -> list[str]: """List dataset names under one organization. Single backend call.""" ... + @abstractmethod + def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: + """List split names under one dataset. Single backend call.""" + ... + @abstractmethod def upload_dataset( self, diff --git a/rock/sdk/envhub/datasets/registry/oss.py b/rock/sdk/envhub/datasets/registry/oss.py index 7118e0be20..e8ef9f1e7b 100644 --- a/rock/sdk/envhub/datasets/registry/oss.py +++ b/rock/sdk/envhub/datasets/registry/oss.py @@ -84,6 +84,14 @@ def list_org_datasets(self, organization: str) -> list[str]: ) return sorted(self._last_segment(p) for p in result.prefix_list) + def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: + bucket = self._build_bucket() + base = self._registry.oss_dataset_path or "datasets" + result = bucket.list_objects_v2( + prefix=f"{base}/{organization}/{dataset}/", delimiter="/", max_keys=1000 + ) + return sorted(self._last_segment(p) for p in result.prefix_list) + def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" diff --git a/tests/unit/datasets/test_oss_registry.py b/tests/unit/datasets/test_oss_registry.py index 800305a78b..6ea438db0a 100644 --- a/tests/unit/datasets/test_oss_registry.py +++ b/tests/unit/datasets/test_oss_registry.py @@ -343,3 +343,29 @@ def test_list_org_datasets_returns_empty_when_org_missing(): with patch.object(registry, "_build_bucket", return_value=mock_bucket): assert registry.list_org_datasets("nonexistent") == [] + + +def test_list_dataset_splits_returns_sorted_split_names(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ + "datasets/qwen/bench/train/", + "datasets/qwen/bench/test/", + ]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + splits = registry.list_dataset_splits("qwen", "bench") + + call_kwargs = mock_bucket.list_objects_v2.call_args[1] + assert call_kwargs["prefix"] == "datasets/qwen/bench/" + assert call_kwargs["delimiter"] == "/" + assert splits == ["test", "train"] + + +def test_list_dataset_splits_returns_empty_when_dataset_missing(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + assert registry.list_dataset_splits("qwen", "nope") == [] From 8b438a8c46de5f27939c4324212e8acdced35165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Mon, 25 May 2026 18:48:42 +0800 Subject: [PATCH 4/9] feat(datasets): add list_all_datasets with bounded concurrency --- rock/sdk/envhub/datasets/registry/base.py | 5 ++ rock/sdk/envhub/datasets/registry/oss.py | 15 ++++- tests/unit/datasets/test_oss_registry.py | 68 +++++++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/rock/sdk/envhub/datasets/registry/base.py b/rock/sdk/envhub/datasets/registry/base.py index a7e5aa31c6..ee1e6ce3bc 100644 --- a/rock/sdk/envhub/datasets/registry/base.py +++ b/rock/sdk/envhub/datasets/registry/base.py @@ -31,6 +31,11 @@ def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: """List split names under one dataset. Single backend call.""" ... + @abstractmethod + def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: + """List all (org, dataset) pairs. 1 + N_org backend calls with bounded concurrency.""" + ... + @abstractmethod def upload_dataset( self, diff --git a/rock/sdk/envhub/datasets/registry/oss.py b/rock/sdk/envhub/datasets/registry/oss.py index e8ef9f1e7b..bd77ceee29 100644 --- a/rock/sdk/envhub/datasets/registry/oss.py +++ b/rock/sdk/envhub/datasets/registry/oss.py @@ -1,6 +1,6 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import oss2 @@ -92,6 +92,19 @@ def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: ) return sorted(self._last_segment(p) for p in result.prefix_list) + def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: + orgs = self.list_organizations() + if not orgs: + return [] + pairs: list[tuple[str, str]] = [] + with ThreadPoolExecutor(max_workers=concurrency) as ex: + future_to_org = {ex.submit(self.list_org_datasets, o): o for o in orgs} + for fut in as_completed(future_to_org): + org = future_to_org[fut] + for ds in fut.result(): + pairs.append((org, ds)) + return sorted(pairs) + def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" diff --git a/tests/unit/datasets/test_oss_registry.py b/tests/unit/datasets/test_oss_registry.py index 6ea438db0a..8a3cc27b17 100644 --- a/tests/unit/datasets/test_oss_registry.py +++ b/tests/unit/datasets/test_oss_registry.py @@ -369,3 +369,71 @@ def test_list_dataset_splits_returns_empty_when_dataset_missing(): with patch.object(registry, "_build_bucket", return_value=mock_bucket): assert registry.list_dataset_splits("qwen", "nope") == [] + + +def test_list_all_datasets_returns_sorted_pairs(): + registry = OssDatasetRegistry(make_registry_info()) + + def fake_list_org_datasets(org): + return {"qwen": ["bench-2", "bench-1"], "alibaba": ["pinch"]}[org] + + with patch.object(registry, "list_organizations", return_value=["qwen", "alibaba"]): + with patch.object(registry, "list_org_datasets", side_effect=fake_list_org_datasets): + pairs = registry.list_all_datasets() + + assert pairs == [("alibaba", "pinch"), ("qwen", "bench-1"), ("qwen", "bench-2")] + + +def test_list_all_datasets_uses_bounded_concurrency(): + registry = OssDatasetRegistry(make_registry_info()) + + with patch.object(registry, "list_organizations", return_value=["o1", "o2"]): + with patch.object(registry, "list_org_datasets", return_value=["d"]): + with patch("rock.sdk.envhub.datasets.registry.oss.ThreadPoolExecutor") as mock_pool: + with patch("rock.sdk.envhub.datasets.registry.oss.as_completed", side_effect=lambda d: list(d)): + mock_executor = MagicMock() + mock_pool.return_value.__enter__.return_value = mock_executor + future = MagicMock() + future.result.return_value = ["d"] + mock_executor.submit.return_value = future + registry.list_all_datasets(concurrency=7) + + mock_pool.assert_called_once_with(max_workers=7) + + +def test_list_all_datasets_default_concurrency_is_10(): + registry = OssDatasetRegistry(make_registry_info()) + + with patch.object(registry, "list_organizations", return_value=["o1"]): + with patch.object(registry, "list_org_datasets", return_value=["d"]): + with patch("rock.sdk.envhub.datasets.registry.oss.ThreadPoolExecutor") as mock_pool: + with patch("rock.sdk.envhub.datasets.registry.oss.as_completed", side_effect=lambda d: list(d)): + mock_executor = MagicMock() + mock_pool.return_value.__enter__.return_value = mock_executor + future = MagicMock() + future.result.return_value = ["d"] + mock_executor.submit.return_value = future + registry.list_all_datasets() + + mock_pool.assert_called_once_with(max_workers=10) + + +def test_list_all_datasets_propagates_exception_from_worker(): + import pytest as _pytest + registry = OssDatasetRegistry(make_registry_info()) + + def fake_list_org_datasets(org): + if org == "bad": + raise RuntimeError("oss boom") + return ["d"] + + with patch.object(registry, "list_organizations", return_value=["good", "bad"]): + with patch.object(registry, "list_org_datasets", side_effect=fake_list_org_datasets): + with _pytest.raises(RuntimeError, match="oss boom"): + registry.list_all_datasets() + + +def test_list_all_datasets_empty_when_no_orgs(): + registry = OssDatasetRegistry(make_registry_info()) + with patch.object(registry, "list_organizations", return_value=[]): + assert registry.list_all_datasets() == [] From 2b714127834c9f2c7202cae8eabee1297380e6a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Mon, 25 May 2026 18:51:53 +0800 Subject: [PATCH 5/9] feat(datasets): add fast-path methods to DatasetClient --- rock/sdk/envhub/datasets/client.py | 12 +++++++++ tests/unit/datasets/test_client.py | 39 ++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/rock/sdk/envhub/datasets/client.py b/rock/sdk/envhub/datasets/client.py index 2ea4d70af0..9939e6623b 100644 --- a/rock/sdk/envhub/datasets/client.py +++ b/rock/sdk/envhub/datasets/client.py @@ -14,6 +14,18 @@ def list_datasets(self, org: str | None = None) -> list[DatasetSpec]: def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test") -> DatasetSpec | None: return self._registry.list_dataset_tasks(organization, dataset, split) + def list_organizations(self) -> list[str]: + return self._registry.list_organizations() + + def list_org_datasets(self, organization: str) -> list[str]: + return self._registry.list_org_datasets(organization) + + def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: + return self._registry.list_all_datasets(concurrency) + + def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: + return self._registry.list_dataset_splits(organization, dataset) + def upload_dataset( self, source: LocalDatasetConfig, diff --git a/tests/unit/datasets/test_client.py b/tests/unit/datasets/test_client.py index baf8a8f753..f811763120 100644 --- a/tests/unit/datasets/test_client.py +++ b/tests/unit/datasets/test_client.py @@ -42,3 +42,42 @@ def test_dataset_client_list_tasks_delegates_to_registry_with_default_split(): mock_list_tasks.assert_called_once_with("qwen", "bench", "test") assert result == expected + + +def test_dataset_client_list_organizations_delegates(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_organizations", return_value=["a", "b"]) as m: + result = client.list_organizations() + m.assert_called_once_with() + assert result == ["a", "b"] + + +def test_dataset_client_list_org_datasets_delegates(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_org_datasets", return_value=["d1"]) as m: + result = client.list_org_datasets("qwen") + m.assert_called_once_with("qwen") + assert result == ["d1"] + + +def test_dataset_client_list_all_datasets_delegates_with_default_concurrency(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_all_datasets", return_value=[("a", "x")]) as m: + result = client.list_all_datasets() + m.assert_called_once_with(10) + assert result == [("a", "x")] + + +def test_dataset_client_list_all_datasets_passes_custom_concurrency(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_all_datasets", return_value=[]) as m: + client.list_all_datasets(concurrency=5) + m.assert_called_once_with(5) + + +def test_dataset_client_list_dataset_splits_delegates(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_dataset_splits", return_value=["test", "train"]) as m: + result = client.list_dataset_splits("qwen", "bench") + m.assert_called_once_with("qwen", "bench") + assert result == ["test", "train"] From 92f72219d6e668faa7d83234ab3c7a013dabaa64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Mon, 25 May 2026 18:55:53 +0800 Subject: [PATCH 6/9] feat(datasets): rewrite list with --depth and fast paths AI-Model: claude-opus-4-7 AI-Contributed/Feature: 48/48 AI-Contributed/UT: 107/107 --- rock/cli/command/datasets.py | 48 +++++++-- tests/unit/datasets/test_datasets_command.py | 107 +++++++++++++++++++ 2 files changed, 146 insertions(+), 9 deletions(-) diff --git a/rock/cli/command/datasets.py b/rock/cli/command/datasets.py index b53636cb79..8264cda894 100644 --- a/rock/cli/command/datasets.py +++ b/rock/cli/command/datasets.py @@ -59,20 +59,47 @@ def _build_oss_registry_info(self, args: argparse.Namespace) -> OssRegistryInfo: async def _list(self, args: argparse.Namespace) -> None: registry_info = self._build_oss_registry_info(args) client = DatasetClient(registry_info) - datasets = client.list_datasets(org=getattr(args, "org", None)) - if not datasets: - print("No datasets found.") + if getattr(args, "org", None): + datasets = client.list_org_datasets(args.org) + pairs = [(args.org, d) for d in datasets] + self._render_org_dataset_pairs(pairs) + return + + if getattr(args, "depth", 2) == 1: + orgs = client.list_organizations() + self._render_orgs(orgs) return - col_id = max(len("Dataset"), max(len(d.id) for d in datasets)) - col_split = max(len("Split"), max(len(d.split) for d in datasets)) + pairs = client.list_all_datasets() + self._render_org_dataset_pairs(pairs) - header = f"{'Dataset':<{col_id}} {'Split':<{col_split}} {'Tasks':>6}" + @staticmethod + def _render_org_dataset_pairs(pairs: list[tuple[str, str]]) -> None: + if not pairs: + print("No datasets found.") + return + col_org = max(len("Organization"), max(len(o) for o, _ in pairs)) + col_ds = max(len("Dataset"), max(len(d) for _, d in pairs)) + header = f"{'Organization':<{col_org}} {'Dataset':<{col_ds}}" print(header) print("-" * len(header)) - for ds in sorted(datasets, key=lambda d: (d.id, d.split)): - print(f"{ds.id:<{col_id}} {ds.split:<{col_split}} {len(ds.task_ids):>6}") + for o, d in pairs: + print(f"{o:<{col_org}} {d:<{col_ds}}") + n_orgs = len({o for o, _ in pairs}) + print(f"\n{len(pairs)} datasets in {n_orgs} organizations.") + + @staticmethod + def _render_orgs(orgs: list[str]) -> None: + if not orgs: + print("No organizations found.") + return + width = max(len("Organization"), max(len(o) for o in orgs)) + print(f"{'Organization':<{width}}") + print("-" * width) + for o in orgs: + print(o) + print(f"\n{len(orgs)} organizations.") async def _tasks(self, args: argparse.Namespace) -> None: registry_info = self._build_oss_registry_info(args) @@ -142,7 +169,10 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--region", help="OSS region (overrides config.ini)") list_parser = datasets_subparsers.add_parser("list", help="List datasets in OSS registry") - list_parser.add_argument("--org", help="Filter by organization") + list_group = list_parser.add_mutually_exclusive_group() + list_group.add_argument("--depth", type=int, choices=[1, 2], default=2, + help="1: list orgs only. 2 (default): list orgs and datasets.") + list_group.add_argument("--org", help="List datasets under the given organization only") add_oss_args(list_parser) tasks_parser = datasets_subparsers.add_parser("tasks", help="List task IDs under one dataset split") tasks_parser.add_argument("--org", required=True, help="Organization name") diff --git a/tests/unit/datasets/test_datasets_command.py b/tests/unit/datasets/test_datasets_command.py index 7c0b0d068e..d25d3bddb7 100644 --- a/tests/unit/datasets/test_datasets_command.py +++ b/tests/unit/datasets/test_datasets_command.py @@ -21,6 +21,7 @@ def make_base_args(**kwargs): org=None, dataset=None, split=None, + depth=2, offset=0, limit=None, ) @@ -204,3 +205,109 @@ def test_tasks_prints_no_tasks_message_when_not_found(capsys): out = capsys.readouterr().out assert "No tasks found" in out + + +# --------------------------------------------------------------------------- +# list subcommand tests (depth + --org rewrite) +# --------------------------------------------------------------------------- + + +def test_list_default_calls_list_all_datasets_and_renders_two_columns(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=2, org=None) + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_all_datasets.return_value = [ + ("alibaba", "pinch"), + ("qwen", "bench-1"), + ] + asyncio.run(cmd._list(args)) + + MockClient.return_value.list_all_datasets.assert_called_once_with() + out = capsys.readouterr().out + assert "Organization" in out + assert "Dataset" in out + assert "alibaba" in out and "pinch" in out + assert "qwen" in out and "bench-1" in out + assert "2 datasets in 2 organizations." in out + + +def test_list_depth_1_calls_list_organizations_and_renders_one_column(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=1, org=None) + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_organizations.return_value = ["alibaba", "qwen"] + asyncio.run(cmd._list(args)) + + MockClient.return_value.list_organizations.assert_called_once_with() + MockClient.return_value.list_all_datasets.assert_not_called() + out = capsys.readouterr().out + assert "Organization" in out + assert "alibaba" in out + assert "qwen" in out + assert "2 organizations." in out + + +def test_list_with_org_calls_list_org_datasets(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=2, org="alibaba") + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_org_datasets.return_value = ["pinch", "webdev"] + asyncio.run(cmd._list(args)) + + MockClient.return_value.list_org_datasets.assert_called_once_with("alibaba") + MockClient.return_value.list_all_datasets.assert_not_called() + MockClient.return_value.list_organizations.assert_not_called() + out = capsys.readouterr().out + assert "alibaba" in out and "pinch" in out and "webdev" in out + assert "2 datasets in 1 organizations." in out + + +def test_list_empty_prints_no_datasets_message(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=2, org=None) + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_all_datasets.return_value = [] + asyncio.run(cmd._list(args)) + + out = capsys.readouterr().out + assert "No datasets found." in out + + +def test_list_depth_1_empty_prints_no_organizations_message(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=1, org=None) + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_organizations.return_value = [] + asyncio.run(cmd._list(args)) + + out = capsys.readouterr().out + assert "No organizations found." in out + + +def test_list_parser_depth_and_org_mutually_exclusive(): + parser = _build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["datasets", "list", "--depth", "2", "--org", "alibaba"]) + + +def test_list_parser_depth_default_is_2(): + parser = _build_parser() + parsed = parser.parse_args(["datasets", "list"]) + assert parsed.depth == 2 + assert parsed.org is None + + +def test_list_parser_rejects_invalid_depth(): + parser = _build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["datasets", "list", "--depth", "3"]) From 49fa4628d909920bbdbd6200e1e46060748e5e95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Mon, 25 May 2026 18:57:23 +0800 Subject: [PATCH 7/9] feat(datasets): add splits subcommand AI-Model: claude-opus-4-7 AI-Contributed/Feature: 24/24 AI-Contributed/UT: 63/63 --- rock/cli/command/datasets.py | 24 ++++++++ tests/unit/datasets/test_datasets_command.py | 63 ++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/rock/cli/command/datasets.py b/rock/cli/command/datasets.py index 8264cda894..e0949b024f 100644 --- a/rock/cli/command/datasets.py +++ b/rock/cli/command/datasets.py @@ -35,6 +35,8 @@ async def arun(self, args: argparse.Namespace) -> None: await self._list(args) elif args.datasets_command == "tasks": await self._tasks(args) + elif args.datasets_command == "splits": + await self._splits(args) elif args.datasets_command == "upload": await self._upload(args) else: @@ -130,6 +132,23 @@ async def _tasks(self, args: argparse.Namespace) -> None: for task_id in shown_task_ids: print(task_id) + async def _splits(self, args: argparse.Namespace) -> None: + registry_info = self._build_oss_registry_info(args) + client = DatasetClient(registry_info) + splits = client.list_dataset_splits(args.org, args.dataset) + + if not splits: + print(f"No splits found for dataset '{args.org}/{args.dataset}'.") + return + + width = max(len("Split"), max(len(s) for s in splits)) + print(f"{'Split':<{width}}") + print("-" * width) + for s in splits: + print(s) + word = "split" if len(splits) == 1 else "splits" + print(f"\n{len(splits)} {word}.") + async def _upload(self, args: argparse.Namespace) -> None: local_dir = Path(args.dir) if not local_dir.is_dir(): @@ -182,6 +201,11 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: tasks_parser.add_argument("--limit", type=_positive_int, default=None, help="Maximum number of tasks to show") add_oss_args(tasks_parser) + splits_parser = datasets_subparsers.add_parser("splits", help="List splits under one dataset") + splits_parser.add_argument("--org", required=True, help="Organization name") + splits_parser.add_argument("--dataset", required=True, help="Dataset name") + add_oss_args(splits_parser) + upload_parser = datasets_subparsers.add_parser("upload", help="Upload local task dirs to OSS") upload_parser.add_argument("--org", required=True, help="Organization name") upload_parser.add_argument("--dataset", required=True, help="Dataset name") diff --git a/tests/unit/datasets/test_datasets_command.py b/tests/unit/datasets/test_datasets_command.py index d25d3bddb7..6c47f41e82 100644 --- a/tests/unit/datasets/test_datasets_command.py +++ b/tests/unit/datasets/test_datasets_command.py @@ -311,3 +311,66 @@ def test_list_parser_rejects_invalid_depth(): parser = _build_parser() with pytest.raises(SystemExit): parser.parse_args(["datasets", "list", "--depth", "3"]) + + +# --------------------------------------------------------------------------- +# splits subcommand tests +# --------------------------------------------------------------------------- + + +def test_splits_lists_split_names(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="splits", org="alibaba", dataset="pinch") + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_dataset_splits.return_value = ["test", "train"] + asyncio.run(cmd._splits(args)) + + MockClient.return_value.list_dataset_splits.assert_called_once_with("alibaba", "pinch") + out = capsys.readouterr().out + assert "Split" in out + assert "test" in out + assert "train" in out + assert "2 splits." in out + + +def test_splits_empty_prints_no_splits_message(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="splits", org="alibaba", dataset="missing") + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_dataset_splits.return_value = [] + asyncio.run(cmd._splits(args)) + + out = capsys.readouterr().out + assert "No splits found for dataset 'alibaba/missing'." in out + + +def test_splits_singular_footer_for_one_split(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="splits", org="alibaba", dataset="pinch") + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_dataset_splits.return_value = ["test"] + asyncio.run(cmd._splits(args)) + + out = capsys.readouterr().out + assert "1 split." in out + + +def test_splits_parser_requires_org_and_dataset(): + parser = _build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["datasets", "splits"]) + with pytest.raises(SystemExit): + parser.parse_args(["datasets", "splits", "--org", "alibaba"]) + + +def test_splits_parser_accepts_org_and_dataset(): + parser = _build_parser() + parsed = parser.parse_args(["datasets", "splits", "--org", "alibaba", "--dataset", "pinch"]) + assert parsed.org == "alibaba" + assert parsed.dataset == "pinch" From a18d497be2c287f68edaf841ad0e898401099949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Mon, 25 May 2026 18:58:48 +0800 Subject: [PATCH 8/9] chore(datasets): apply lint and format AI-Model: claude-opus-4-7 AI-Contributed/Feature: 36/53 AI-Contributed/UT: 44/65 --- rock/cli/command/datasets.py | 40 +++++++----- rock/sdk/envhub/datasets/client.py | 1 - rock/sdk/envhub/datasets/registry/base.py | 1 - rock/sdk/envhub/datasets/registry/oss.py | 28 ++++----- tests/unit/datasets/test_datasets_command.py | 5 +- tests/unit/datasets/test_oss_registry.py | 64 ++++++++++++-------- 6 files changed, 80 insertions(+), 59 deletions(-) diff --git a/rock/cli/command/datasets.py b/rock/cli/command/datasets.py index e0949b024f..7e36f1d377 100644 --- a/rock/cli/command/datasets.py +++ b/rock/cli/command/datasets.py @@ -121,8 +121,6 @@ async def _tasks(self, args: argparse.Namespace) -> None: print("No tasks found after applying offset/limit.") return - limit_text = str(args.limit) if args.limit is not None else "all" - print() print("=" * 80) print(f"Dataset: {spec.id} Split: {spec.split} Total: {total} Shown: {len(shown_task_ids)}") @@ -181,16 +179,23 @@ async def add_parser_to(subparsers: argparse._SubParsersAction) -> None: def add_oss_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--bucket", help="OSS bucket name (overrides config.ini)") parser.add_argument("--endpoint", help="OSS endpoint URL (overrides config.ini)") - parser.add_argument("--access-key-id", dest="access_key_id", - help="OSS access key ID (overrides config.ini)") - parser.add_argument("--access-key-secret", dest="access_key_secret", - help="OSS access key secret (overrides config.ini)") + parser.add_argument( + "--access-key-id", dest="access_key_id", help="OSS access key ID (overrides config.ini)" + ) + parser.add_argument( + "--access-key-secret", dest="access_key_secret", help="OSS access key secret (overrides config.ini)" + ) parser.add_argument("--region", help="OSS region (overrides config.ini)") list_parser = datasets_subparsers.add_parser("list", help="List datasets in OSS registry") list_group = list_parser.add_mutually_exclusive_group() - list_group.add_argument("--depth", type=int, choices=[1, 2], default=2, - help="1: list orgs only. 2 (default): list orgs and datasets.") + list_group.add_argument( + "--depth", + type=int, + choices=[1, 2], + default=2, + help="1: list orgs only. 2 (default): list orgs and datasets.", + ) list_group.add_argument("--org", help="List datasets under the given organization only") add_oss_args(list_parser) tasks_parser = datasets_subparsers.add_parser("tasks", help="List task IDs under one dataset split") @@ -210,11 +215,16 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: upload_parser.add_argument("--org", required=True, help="Organization name") upload_parser.add_argument("--dataset", required=True, help="Dataset name") upload_parser.add_argument("--split", required=True, help="Split name (e.g. train, test, v1.0)") - upload_parser.add_argument("--dir", required=True, - help="Local directory containing {task_id}/ subdirectories") - upload_parser.add_argument("--concurrency", type=int, default=4, - choices=range(1, 17), metavar="[1-16]", - help="Upload concurrency (default: 4)") - upload_parser.add_argument("--overwrite", action="store_true", - help="Overwrite existing tasks in OSS (default: skip)") + upload_parser.add_argument("--dir", required=True, help="Local directory containing {task_id}/ subdirectories") + upload_parser.add_argument( + "--concurrency", + type=int, + default=4, + choices=range(1, 17), + metavar="[1-16]", + help="Upload concurrency (default: 4)", + ) + upload_parser.add_argument( + "--overwrite", action="store_true", help="Overwrite existing tasks in OSS (default: skip)" + ) add_oss_args(upload_parser) diff --git a/rock/sdk/envhub/datasets/client.py b/rock/sdk/envhub/datasets/client.py index 9939e6623b..16189defec 100644 --- a/rock/sdk/envhub/datasets/client.py +++ b/rock/sdk/envhub/datasets/client.py @@ -4,7 +4,6 @@ class DatasetClient: - def __init__(self, registry: OssRegistryInfo) -> None: self._registry = OssDatasetRegistry(registry) diff --git a/rock/sdk/envhub/datasets/registry/base.py b/rock/sdk/envhub/datasets/registry/base.py index ee1e6ce3bc..0532671ffe 100644 --- a/rock/sdk/envhub/datasets/registry/base.py +++ b/rock/sdk/envhub/datasets/registry/base.py @@ -5,7 +5,6 @@ class BaseDatasetRegistry(ABC): - @abstractmethod def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: """List all datasets. Filtered to `organization` if provided.""" diff --git a/rock/sdk/envhub/datasets/registry/oss.py b/rock/sdk/envhub/datasets/registry/oss.py index bd77ceee29..af36481ce9 100644 --- a/rock/sdk/envhub/datasets/registry/oss.py +++ b/rock/sdk/envhub/datasets/registry/oss.py @@ -14,7 +14,6 @@ class OssDatasetRegistry(BaseDatasetRegistry): - def __init__(self, registry: OssRegistryInfo) -> None: self._registry = registry @@ -58,7 +57,7 @@ def _extract_tasks_from_split(self, bucket: oss2.Bucket, split_prefix: str) -> l if key.endswith("/"): continue # Get the relative path from split_prefix - relative = key[len(split_prefix):] + relative = key[len(split_prefix) :] # Only direct files (no nested paths with "/") if "/" in relative: continue @@ -79,17 +78,13 @@ def list_organizations(self) -> list[str]: def list_org_datasets(self, organization: str) -> list[str]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" - result = bucket.list_objects_v2( - prefix=f"{base}/{organization}/", delimiter="/", max_keys=1000 - ) + result = bucket.list_objects_v2(prefix=f"{base}/{organization}/", delimiter="/", max_keys=1000) return sorted(self._last_segment(p) for p in result.prefix_list) def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" - result = bucket.list_objects_v2( - prefix=f"{base}/{organization}/{dataset}/", delimiter="/", max_keys=1000 - ) + result = bucket.list_objects_v2(prefix=f"{base}/{organization}/{dataset}/", delimiter="/", max_keys=1000) return sorted(self._last_segment(p) for p in result.prefix_list) def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: @@ -128,11 +123,13 @@ def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: split = self._last_segment(split_prefix) task_ids = self._extract_tasks_from_split(bucket, split_prefix) - datasets.append(DatasetSpec( - id=f"{org}/{name}", - split=split, - task_ids=task_ids, - )) + datasets.append( + DatasetSpec( + id=f"{org}/{name}", + split=split, + task_ids=task_ids, + ) + ) return datasets @@ -192,10 +189,7 @@ def upload_dataset( raw: dict[str, int | None | Exception] = {} with ThreadPoolExecutor(max_workers=concurrency) as executor: - futures = { - executor.submit(self._upload_task, bucket, org, name, split, d, overwrite): d - for d in task_dirs - } + futures = {executor.submit(self._upload_task, bucket, org, name, split, d, overwrite): d for d in task_dirs} for future, task_dir in futures.items(): try: raw[task_dir.name] = future.result() diff --git a/tests/unit/datasets/test_datasets_command.py b/tests/unit/datasets/test_datasets_command.py index 6c47f41e82..4546a47985 100644 --- a/tests/unit/datasets/test_datasets_command.py +++ b/tests/unit/datasets/test_datasets_command.py @@ -29,6 +29,7 @@ def make_base_args(**kwargs): setattr(args, k, v) return args + def make_registry_info(): return OssRegistryInfo(oss_bucket="b", oss_access_key_id="k", oss_access_key_secret="s") @@ -46,7 +47,9 @@ def test_command_name(): def test_build_oss_registry_info_from_cli_args(): cmd = DatasetsCommand() - args = make_base_args(bucket="cli-bucket", endpoint="https://oss.example.com", access_key_id="kid", access_key_secret="ksec") + args = make_base_args( + bucket="cli-bucket", endpoint="https://oss.example.com", access_key_id="kid", access_key_secret="ksec" + ) with patch("rock.cli.command.datasets.ConfigManager") as mock_mgr: ds_cfg = mock_mgr.return_value.get_config.return_value.dataset_config diff --git a/tests/unit/datasets/test_oss_registry.py b/tests/unit/datasets/test_oss_registry.py index 8a3cc27b17..fc3dd9be70 100644 --- a/tests/unit/datasets/test_oss_registry.py +++ b/tests/unit/datasets/test_oss_registry.py @@ -27,10 +27,12 @@ def test_list_datasets_returns_all(): make_list_result(prefixes=["datasets/qwen/"]), make_list_result(prefixes=["datasets/qwen/my-bench/"]), make_list_result(prefixes=["datasets/qwen/my-bench/train/"]), - make_list_result(prefixes=[ - "datasets/qwen/my-bench/train/task-001/", - "datasets/qwen/my-bench/train/task-002/", - ]), + make_list_result( + prefixes=[ + "datasets/qwen/my-bench/train/task-001/", + "datasets/qwen/my-bench/train/task-002/", + ] + ), ] with patch.object(registry, "_build_bucket", return_value=mock_bucket): @@ -58,6 +60,7 @@ def test_list_datasets_filter_by_org(): assert first_call_kwargs["prefix"] == "datasets/qwen/" assert len(datasets) == 1 + def test_list_datasets_counts_directory_and_file_tasks(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() @@ -104,6 +107,7 @@ def test_build_prefix_with_split(): registry = OssDatasetRegistry(make_registry_info()) assert registry._build_prefix("qwen", "my-bench", "train") == "datasets/qwen/my-bench/train" + # --------------------------------------------------------------------------- # list_dataset_tasks tests # --------------------------------------------------------------------------- @@ -112,10 +116,12 @@ def test_build_prefix_with_split(): def test_list_dataset_tasks_uses_default_test_split_and_sorts_task_ids(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() - mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ - "datasets/qwen/my-bench/test/task-002/", - "datasets/qwen/my-bench/test/task-001/", - ]) + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/my-bench/test/task-002/", + "datasets/qwen/my-bench/test/task-001/", + ] + ) with patch.object(registry, "_build_bucket", return_value=mock_bucket): spec = registry.list_dataset_tasks("qwen", "my-bench") @@ -132,9 +138,11 @@ def test_list_dataset_tasks_uses_default_test_split_and_sorts_task_ids(): def test_list_dataset_tasks_supports_custom_split(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() - mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ - "datasets/qwen/my-bench/train/task-001/", - ]) + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/my-bench/train/task-001/", + ] + ) with patch.object(registry, "_build_bucket", return_value=mock_bucket): spec = registry.list_dataset_tasks("qwen", "my-bench", "train") @@ -146,6 +154,7 @@ def test_list_dataset_tasks_supports_custom_split(): first_call_kwargs = mock_bucket.list_objects_v2.call_args_list[0][1] assert first_call_kwargs["prefix"] == "datasets/qwen/my-bench/train/" + def test_list_dataset_tasks_includes_directory_and_file_tasks_with_suffix_stripped(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() @@ -291,11 +300,13 @@ def test_upload_dataset_oss_key_format(tmp_path): def test_list_organizations_returns_sorted_org_names(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() - mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ - "datasets/qwen/", - "datasets/alibaba/", - "datasets/AoneBenchDev/", - ]) + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/", + "datasets/alibaba/", + "datasets/AoneBenchDev/", + ] + ) with patch.object(registry, "_build_bucket", return_value=mock_bucket): orgs = registry.list_organizations() @@ -321,10 +332,12 @@ def test_list_organizations_returns_empty_when_no_orgs(): def test_list_org_datasets_returns_sorted_dataset_names(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() - mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ - "datasets/qwen/bench-2/", - "datasets/qwen/bench-1/", - ]) + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/bench-2/", + "datasets/qwen/bench-1/", + ] + ) with patch.object(registry, "_build_bucket", return_value=mock_bucket): datasets = registry.list_org_datasets("qwen") @@ -348,10 +361,12 @@ def test_list_org_datasets_returns_empty_when_org_missing(): def test_list_dataset_splits_returns_sorted_split_names(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() - mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ - "datasets/qwen/bench/train/", - "datasets/qwen/bench/test/", - ]) + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/bench/train/", + "datasets/qwen/bench/test/", + ] + ) with patch.object(registry, "_build_bucket", return_value=mock_bucket): splits = registry.list_dataset_splits("qwen", "bench") @@ -420,6 +435,7 @@ def test_list_all_datasets_default_concurrency_is_10(): def test_list_all_datasets_propagates_exception_from_worker(): import pytest as _pytest + registry = OssDatasetRegistry(make_registry_info()) def fake_list_org_datasets(org): From b6d29486789ff38adc9c47638c296183cef38d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Wed, 27 May 2026 10:38:24 +0800 Subject: [PATCH 9/9] fix(datasets): preserve list parser exclusivity Defer the default list depth until runtime so argparse still treats an explicit --depth value as present in the mutually exclusive group on Python 3.10. Update envhub CLI docs for the rewritten list output and splits command. Co-Authored-By: Codex AI-Model: gpt-5 AI-Contributed/Feature: 47/47 AI-Contributed/UT: 8/8 --- docs/dev/envhub/README.md | 42 +++++++++++++++++--- rock/cli/command/datasets.py | 5 ++- tests/unit/datasets/test_datasets_command.py | 8 ++-- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/docs/dev/envhub/README.md b/docs/dev/envhub/README.md index 6255efe1f6..7f5b4080e4 100644 --- a/docs/dev/envhub/README.md +++ b/docs/dev/envhub/README.md @@ -248,20 +248,52 @@ class DatasetClient: rock datasets list [OPTIONS] Options: - --org TEXT 只列出指定 organization 的 datasets + --depth {1,2} 1: 只列出 organizations;2: 列出 organizations 和 datasets(默认) + --org TEXT 只列出指定 organization 的 datasets(与 --depth 互斥) --bucket TEXT OSS bucket 名称(覆盖 config.ini) --endpoint TEXT OSS endpoint(覆盖 config.ini) --access-key-id TEXT OSS access key ID(覆盖 config.ini) --access-key-secret TEXT OSS access key secret(覆盖 config.ini) + --region TEXT OSS region(覆盖 config.ini) ``` 输出示例: ``` -Dataset Split Tasks -qwen/my-bench train 42 -qwen/my-bench test 10 -alibaba/code-eval train 100 +Organization Dataset +-------------------------- +alibaba code-eval +qwen my-bench + +2 datasets in 2 organizations. +``` + +#### rock datasets splits + +``` +rock datasets splits [OPTIONS] + +Required: + --org TEXT Organization 名称 + --dataset TEXT Dataset 名称 + +Options: + --bucket TEXT OSS bucket 名称(覆盖 config.ini) + --endpoint TEXT OSS endpoint(覆盖 config.ini) + --access-key-id TEXT OSS access key ID(覆盖 config.ini) + --access-key-secret TEXT OSS access key secret(覆盖 config.ini) + --region TEXT OSS region(覆盖 config.ini) +``` + +输出示例: + +``` +Split +----- +test +train + +2 splits. ``` #### rock datasets upload diff --git a/rock/cli/command/datasets.py b/rock/cli/command/datasets.py index 7e36f1d377..d35b8b87ea 100644 --- a/rock/cli/command/datasets.py +++ b/rock/cli/command/datasets.py @@ -68,7 +68,8 @@ async def _list(self, args: argparse.Namespace) -> None: self._render_org_dataset_pairs(pairs) return - if getattr(args, "depth", 2) == 1: + depth = getattr(args, "depth", None) or 2 + if depth == 1: orgs = client.list_organizations() self._render_orgs(orgs) return @@ -193,7 +194,7 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: "--depth", type=int, choices=[1, 2], - default=2, + default=None, help="1: list orgs only. 2 (default): list orgs and datasets.", ) list_group.add_argument("--org", help="List datasets under the given organization only") diff --git a/tests/unit/datasets/test_datasets_command.py b/tests/unit/datasets/test_datasets_command.py index 4546a47985..7fc17561f9 100644 --- a/tests/unit/datasets/test_datasets_command.py +++ b/tests/unit/datasets/test_datasets_command.py @@ -215,9 +215,9 @@ def test_tasks_prints_no_tasks_message_when_not_found(capsys): # --------------------------------------------------------------------------- -def test_list_default_calls_list_all_datasets_and_renders_two_columns(capsys): +def test_list_default_depth_calls_list_all_datasets_and_renders_two_columns(capsys): cmd = DatasetsCommand() - args = make_base_args(datasets_command="list", depth=2, org=None) + args = make_base_args(datasets_command="list", depth=None, org=None) with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): with patch("rock.cli.command.datasets.DatasetClient") as MockClient: @@ -303,10 +303,10 @@ def test_list_parser_depth_and_org_mutually_exclusive(): parser.parse_args(["datasets", "list", "--depth", "2", "--org", "alibaba"]) -def test_list_parser_depth_default_is_2(): +def test_list_parser_depth_default_is_deferred_to_runtime(): parser = _build_parser() parsed = parser.parse_args(["datasets", "list"]) - assert parsed.depth == 2 + assert parsed.depth is None assert parsed.org is None