diff --git a/docs/dev/envhub/README.md b/docs/dev/envhub/README.md index 6255efe1f6..7f5b4080e4 100644 --- a/docs/dev/envhub/README.md +++ b/docs/dev/envhub/README.md @@ -248,20 +248,52 @@ class DatasetClient: rock datasets list [OPTIONS] Options: - --org TEXT 只列出指定 organization 的 datasets + --depth {1,2} 1: 只列出 organizations;2: 列出 organizations 和 datasets(默认) + --org TEXT 只列出指定 organization 的 datasets(与 --depth 互斥) --bucket TEXT OSS bucket 名称(覆盖 config.ini) --endpoint TEXT OSS endpoint(覆盖 config.ini) --access-key-id TEXT OSS access key ID(覆盖 config.ini) --access-key-secret TEXT OSS access key secret(覆盖 config.ini) + --region TEXT OSS region(覆盖 config.ini) ``` 输出示例: ``` -Dataset Split Tasks -qwen/my-bench train 42 -qwen/my-bench test 10 -alibaba/code-eval train 100 +Organization Dataset +-------------------------- +alibaba code-eval +qwen my-bench + +2 datasets in 2 organizations. +``` + +#### rock datasets splits + +``` +rock datasets splits [OPTIONS] + +Required: + --org TEXT Organization 名称 + --dataset TEXT Dataset 名称 + +Options: + --bucket TEXT OSS bucket 名称(覆盖 config.ini) + --endpoint TEXT OSS endpoint(覆盖 config.ini) + --access-key-id TEXT OSS access key ID(覆盖 config.ini) + --access-key-secret TEXT OSS access key secret(覆盖 config.ini) + --region TEXT OSS region(覆盖 config.ini) +``` + +输出示例: + +``` +Split +----- +test +train + +2 splits. ``` #### rock datasets upload diff --git a/rock/cli/command/datasets.py b/rock/cli/command/datasets.py index b53636cb79..d35b8b87ea 100644 --- a/rock/cli/command/datasets.py +++ b/rock/cli/command/datasets.py @@ -35,6 +35,8 @@ async def arun(self, args: argparse.Namespace) -> None: await self._list(args) elif args.datasets_command == "tasks": await self._tasks(args) + elif args.datasets_command == "splits": + await self._splits(args) elif args.datasets_command == "upload": await self._upload(args) else: @@ -59,20 +61,48 @@ def _build_oss_registry_info(self, args: argparse.Namespace) -> OssRegistryInfo: async def _list(self, args: argparse.Namespace) -> None: registry_info = self._build_oss_registry_info(args) client = DatasetClient(registry_info) - datasets = client.list_datasets(org=getattr(args, "org", None)) - if not datasets: - print("No datasets found.") + if getattr(args, "org", None): + datasets = client.list_org_datasets(args.org) + pairs = [(args.org, d) for d in datasets] + self._render_org_dataset_pairs(pairs) + return + + depth = getattr(args, "depth", None) or 2 + if depth == 1: + orgs = client.list_organizations() + self._render_orgs(orgs) return - col_id = max(len("Dataset"), max(len(d.id) for d in datasets)) - col_split = max(len("Split"), max(len(d.split) for d in datasets)) + pairs = client.list_all_datasets() + self._render_org_dataset_pairs(pairs) - header = f"{'Dataset':<{col_id}} {'Split':<{col_split}} {'Tasks':>6}" + @staticmethod + def _render_org_dataset_pairs(pairs: list[tuple[str, str]]) -> None: + if not pairs: + print("No datasets found.") + return + col_org = max(len("Organization"), max(len(o) for o, _ in pairs)) + col_ds = max(len("Dataset"), max(len(d) for _, d in pairs)) + header = f"{'Organization':<{col_org}} {'Dataset':<{col_ds}}" print(header) print("-" * len(header)) - for ds in sorted(datasets, key=lambda d: (d.id, d.split)): - print(f"{ds.id:<{col_id}} {ds.split:<{col_split}} {len(ds.task_ids):>6}") + for o, d in pairs: + print(f"{o:<{col_org}} {d:<{col_ds}}") + n_orgs = len({o for o, _ in pairs}) + print(f"\n{len(pairs)} datasets in {n_orgs} organizations.") + + @staticmethod + def _render_orgs(orgs: list[str]) -> None: + if not orgs: + print("No organizations found.") + return + width = max(len("Organization"), max(len(o) for o in orgs)) + print(f"{'Organization':<{width}}") + print("-" * width) + for o in orgs: + print(o) + print(f"\n{len(orgs)} organizations.") async def _tasks(self, args: argparse.Namespace) -> None: registry_info = self._build_oss_registry_info(args) @@ -92,8 +122,6 @@ async def _tasks(self, args: argparse.Namespace) -> None: print("No tasks found after applying offset/limit.") return - limit_text = str(args.limit) if args.limit is not None else "all" - print() print("=" * 80) print(f"Dataset: {spec.id} Split: {spec.split} Total: {total} Shown: {len(shown_task_ids)}") @@ -103,6 +131,23 @@ async def _tasks(self, args: argparse.Namespace) -> None: for task_id in shown_task_ids: print(task_id) + async def _splits(self, args: argparse.Namespace) -> None: + registry_info = self._build_oss_registry_info(args) + client = DatasetClient(registry_info) + splits = client.list_dataset_splits(args.org, args.dataset) + + if not splits: + print(f"No splits found for dataset '{args.org}/{args.dataset}'.") + return + + width = max(len("Split"), max(len(s) for s in splits)) + print(f"{'Split':<{width}}") + print("-" * width) + for s in splits: + print(s) + word = "split" if len(splits) == 1 else "splits" + print(f"\n{len(splits)} {word}.") + async def _upload(self, args: argparse.Namespace) -> None: local_dir = Path(args.dir) if not local_dir.is_dir(): @@ -135,14 +180,24 @@ async def add_parser_to(subparsers: argparse._SubParsersAction) -> None: def add_oss_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--bucket", help="OSS bucket name (overrides config.ini)") parser.add_argument("--endpoint", help="OSS endpoint URL (overrides config.ini)") - parser.add_argument("--access-key-id", dest="access_key_id", - help="OSS access key ID (overrides config.ini)") - parser.add_argument("--access-key-secret", dest="access_key_secret", - help="OSS access key secret (overrides config.ini)") + parser.add_argument( + "--access-key-id", dest="access_key_id", help="OSS access key ID (overrides config.ini)" + ) + parser.add_argument( + "--access-key-secret", dest="access_key_secret", help="OSS access key secret (overrides config.ini)" + ) parser.add_argument("--region", help="OSS region (overrides config.ini)") list_parser = datasets_subparsers.add_parser("list", help="List datasets in OSS registry") - list_parser.add_argument("--org", help="Filter by organization") + list_group = list_parser.add_mutually_exclusive_group() + list_group.add_argument( + "--depth", + type=int, + choices=[1, 2], + default=None, + help="1: list orgs only. 2 (default): list orgs and datasets.", + ) + list_group.add_argument("--org", help="List datasets under the given organization only") add_oss_args(list_parser) tasks_parser = datasets_subparsers.add_parser("tasks", help="List task IDs under one dataset split") tasks_parser.add_argument("--org", required=True, help="Organization name") @@ -152,15 +207,25 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: tasks_parser.add_argument("--limit", type=_positive_int, default=None, help="Maximum number of tasks to show") add_oss_args(tasks_parser) + splits_parser = datasets_subparsers.add_parser("splits", help="List splits under one dataset") + splits_parser.add_argument("--org", required=True, help="Organization name") + splits_parser.add_argument("--dataset", required=True, help="Dataset name") + add_oss_args(splits_parser) + upload_parser = datasets_subparsers.add_parser("upload", help="Upload local task dirs to OSS") upload_parser.add_argument("--org", required=True, help="Organization name") upload_parser.add_argument("--dataset", required=True, help="Dataset name") upload_parser.add_argument("--split", required=True, help="Split name (e.g. train, test, v1.0)") - upload_parser.add_argument("--dir", required=True, - help="Local directory containing {task_id}/ subdirectories") - upload_parser.add_argument("--concurrency", type=int, default=4, - choices=range(1, 17), metavar="[1-16]", - help="Upload concurrency (default: 4)") - upload_parser.add_argument("--overwrite", action="store_true", - help="Overwrite existing tasks in OSS (default: skip)") + upload_parser.add_argument("--dir", required=True, help="Local directory containing {task_id}/ subdirectories") + upload_parser.add_argument( + "--concurrency", + type=int, + default=4, + choices=range(1, 17), + metavar="[1-16]", + help="Upload concurrency (default: 4)", + ) + upload_parser.add_argument( + "--overwrite", action="store_true", help="Overwrite existing tasks in OSS (default: skip)" + ) add_oss_args(upload_parser) diff --git a/rock/sdk/envhub/datasets/client.py b/rock/sdk/envhub/datasets/client.py index 2ea4d70af0..16189defec 100644 --- a/rock/sdk/envhub/datasets/client.py +++ b/rock/sdk/envhub/datasets/client.py @@ -4,7 +4,6 @@ class DatasetClient: - def __init__(self, registry: OssRegistryInfo) -> None: self._registry = OssDatasetRegistry(registry) @@ -14,6 +13,18 @@ def list_datasets(self, org: str | None = None) -> list[DatasetSpec]: def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test") -> DatasetSpec | None: return self._registry.list_dataset_tasks(organization, dataset, split) + def list_organizations(self) -> list[str]: + return self._registry.list_organizations() + + def list_org_datasets(self, organization: str) -> list[str]: + return self._registry.list_org_datasets(organization) + + def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: + return self._registry.list_all_datasets(concurrency) + + def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: + return self._registry.list_dataset_splits(organization, dataset) + def upload_dataset( self, source: LocalDatasetConfig, diff --git a/rock/sdk/envhub/datasets/registry/base.py b/rock/sdk/envhub/datasets/registry/base.py index 5c9192d2d6..0532671ffe 100644 --- a/rock/sdk/envhub/datasets/registry/base.py +++ b/rock/sdk/envhub/datasets/registry/base.py @@ -5,7 +5,6 @@ class BaseDatasetRegistry(ABC): - @abstractmethod def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: """List all datasets. Filtered to `organization` if provided.""" @@ -16,6 +15,26 @@ def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test """List task ids for one dataset split. Returns None if dataset/split has no tasks.""" ... + @abstractmethod + def list_organizations(self) -> list[str]: + """List organization names under the dataset registry. Single backend call.""" + ... + + @abstractmethod + def list_org_datasets(self, organization: str) -> list[str]: + """List dataset names under one organization. Single backend call.""" + ... + + @abstractmethod + def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: + """List split names under one dataset. Single backend call.""" + ... + + @abstractmethod + def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: + """List all (org, dataset) pairs. 1 + N_org backend calls with bounded concurrency.""" + ... + @abstractmethod def upload_dataset( self, diff --git a/rock/sdk/envhub/datasets/registry/oss.py b/rock/sdk/envhub/datasets/registry/oss.py index 4e90613571..af36481ce9 100644 --- a/rock/sdk/envhub/datasets/registry/oss.py +++ b/rock/sdk/envhub/datasets/registry/oss.py @@ -1,6 +1,6 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import oss2 @@ -14,7 +14,6 @@ class OssDatasetRegistry(BaseDatasetRegistry): - def __init__(self, registry: OssRegistryInfo) -> None: self._registry = registry @@ -58,7 +57,7 @@ def _extract_tasks_from_split(self, bucket: oss2.Bucket, split_prefix: str) -> l if key.endswith("/"): continue # Get the relative path from split_prefix - relative = key[len(split_prefix):] + relative = key[len(split_prefix) :] # Only direct files (no nested paths with "/") if "/" in relative: continue @@ -70,6 +69,37 @@ def _extract_tasks_from_split(self, bucket: oss2.Bucket, split_prefix: str) -> l all_tasks = sorted(set(dir_tasks + file_tasks)) return all_tasks + def list_organizations(self) -> list[str]: + bucket = self._build_bucket() + base = self._registry.oss_dataset_path or "datasets" + result = bucket.list_objects_v2(prefix=f"{base}/", delimiter="/", max_keys=1000) + return sorted(self._last_segment(p) for p in result.prefix_list) + + def list_org_datasets(self, organization: str) -> list[str]: + bucket = self._build_bucket() + base = self._registry.oss_dataset_path or "datasets" + result = bucket.list_objects_v2(prefix=f"{base}/{organization}/", delimiter="/", max_keys=1000) + return sorted(self._last_segment(p) for p in result.prefix_list) + + def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: + bucket = self._build_bucket() + base = self._registry.oss_dataset_path or "datasets" + result = bucket.list_objects_v2(prefix=f"{base}/{organization}/{dataset}/", delimiter="/", max_keys=1000) + return sorted(self._last_segment(p) for p in result.prefix_list) + + def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: + orgs = self.list_organizations() + if not orgs: + return [] + pairs: list[tuple[str, str]] = [] + with ThreadPoolExecutor(max_workers=concurrency) as ex: + future_to_org = {ex.submit(self.list_org_datasets, o): o for o in orgs} + for fut in as_completed(future_to_org): + org = future_to_org[fut] + for ds in fut.result(): + pairs.append((org, ds)) + return sorted(pairs) + def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" @@ -93,11 +123,13 @@ def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: split = self._last_segment(split_prefix) task_ids = self._extract_tasks_from_split(bucket, split_prefix) - datasets.append(DatasetSpec( - id=f"{org}/{name}", - split=split, - task_ids=task_ids, - )) + datasets.append( + DatasetSpec( + id=f"{org}/{name}", + split=split, + task_ids=task_ids, + ) + ) return datasets @@ -157,10 +189,7 @@ def upload_dataset( raw: dict[str, int | None | Exception] = {} with ThreadPoolExecutor(max_workers=concurrency) as executor: - futures = { - executor.submit(self._upload_task, bucket, org, name, split, d, overwrite): d - for d in task_dirs - } + futures = {executor.submit(self._upload_task, bucket, org, name, split, d, overwrite): d for d in task_dirs} for future, task_dir in futures.items(): try: raw[task_dir.name] = future.result() diff --git a/tests/unit/datasets/test_client.py b/tests/unit/datasets/test_client.py index baf8a8f753..f811763120 100644 --- a/tests/unit/datasets/test_client.py +++ b/tests/unit/datasets/test_client.py @@ -42,3 +42,42 @@ def test_dataset_client_list_tasks_delegates_to_registry_with_default_split(): mock_list_tasks.assert_called_once_with("qwen", "bench", "test") assert result == expected + + +def test_dataset_client_list_organizations_delegates(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_organizations", return_value=["a", "b"]) as m: + result = client.list_organizations() + m.assert_called_once_with() + assert result == ["a", "b"] + + +def test_dataset_client_list_org_datasets_delegates(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_org_datasets", return_value=["d1"]) as m: + result = client.list_org_datasets("qwen") + m.assert_called_once_with("qwen") + assert result == ["d1"] + + +def test_dataset_client_list_all_datasets_delegates_with_default_concurrency(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_all_datasets", return_value=[("a", "x")]) as m: + result = client.list_all_datasets() + m.assert_called_once_with(10) + assert result == [("a", "x")] + + +def test_dataset_client_list_all_datasets_passes_custom_concurrency(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_all_datasets", return_value=[]) as m: + client.list_all_datasets(concurrency=5) + m.assert_called_once_with(5) + + +def test_dataset_client_list_dataset_splits_delegates(): + client = DatasetClient(make_registry_info()) + with patch.object(client._registry, "list_dataset_splits", return_value=["test", "train"]) as m: + result = client.list_dataset_splits("qwen", "bench") + m.assert_called_once_with("qwen", "bench") + assert result == ["test", "train"] diff --git a/tests/unit/datasets/test_datasets_command.py b/tests/unit/datasets/test_datasets_command.py index 7c0b0d068e..7fc17561f9 100644 --- a/tests/unit/datasets/test_datasets_command.py +++ b/tests/unit/datasets/test_datasets_command.py @@ -21,6 +21,7 @@ def make_base_args(**kwargs): org=None, dataset=None, split=None, + depth=2, offset=0, limit=None, ) @@ -28,6 +29,7 @@ def make_base_args(**kwargs): setattr(args, k, v) return args + def make_registry_info(): return OssRegistryInfo(oss_bucket="b", oss_access_key_id="k", oss_access_key_secret="s") @@ -45,7 +47,9 @@ def test_command_name(): def test_build_oss_registry_info_from_cli_args(): cmd = DatasetsCommand() - args = make_base_args(bucket="cli-bucket", endpoint="https://oss.example.com", access_key_id="kid", access_key_secret="ksec") + args = make_base_args( + bucket="cli-bucket", endpoint="https://oss.example.com", access_key_id="kid", access_key_secret="ksec" + ) with patch("rock.cli.command.datasets.ConfigManager") as mock_mgr: ds_cfg = mock_mgr.return_value.get_config.return_value.dataset_config @@ -204,3 +208,172 @@ def test_tasks_prints_no_tasks_message_when_not_found(capsys): out = capsys.readouterr().out assert "No tasks found" in out + + +# --------------------------------------------------------------------------- +# list subcommand tests (depth + --org rewrite) +# --------------------------------------------------------------------------- + + +def test_list_default_depth_calls_list_all_datasets_and_renders_two_columns(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=None, org=None) + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_all_datasets.return_value = [ + ("alibaba", "pinch"), + ("qwen", "bench-1"), + ] + asyncio.run(cmd._list(args)) + + MockClient.return_value.list_all_datasets.assert_called_once_with() + out = capsys.readouterr().out + assert "Organization" in out + assert "Dataset" in out + assert "alibaba" in out and "pinch" in out + assert "qwen" in out and "bench-1" in out + assert "2 datasets in 2 organizations." in out + + +def test_list_depth_1_calls_list_organizations_and_renders_one_column(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=1, org=None) + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_organizations.return_value = ["alibaba", "qwen"] + asyncio.run(cmd._list(args)) + + MockClient.return_value.list_organizations.assert_called_once_with() + MockClient.return_value.list_all_datasets.assert_not_called() + out = capsys.readouterr().out + assert "Organization" in out + assert "alibaba" in out + assert "qwen" in out + assert "2 organizations." in out + + +def test_list_with_org_calls_list_org_datasets(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=2, org="alibaba") + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_org_datasets.return_value = ["pinch", "webdev"] + asyncio.run(cmd._list(args)) + + MockClient.return_value.list_org_datasets.assert_called_once_with("alibaba") + MockClient.return_value.list_all_datasets.assert_not_called() + MockClient.return_value.list_organizations.assert_not_called() + out = capsys.readouterr().out + assert "alibaba" in out and "pinch" in out and "webdev" in out + assert "2 datasets in 1 organizations." in out + + +def test_list_empty_prints_no_datasets_message(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=2, org=None) + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_all_datasets.return_value = [] + asyncio.run(cmd._list(args)) + + out = capsys.readouterr().out + assert "No datasets found." in out + + +def test_list_depth_1_empty_prints_no_organizations_message(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", depth=1, org=None) + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_organizations.return_value = [] + asyncio.run(cmd._list(args)) + + out = capsys.readouterr().out + assert "No organizations found." in out + + +def test_list_parser_depth_and_org_mutually_exclusive(): + parser = _build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["datasets", "list", "--depth", "2", "--org", "alibaba"]) + + +def test_list_parser_depth_default_is_deferred_to_runtime(): + parser = _build_parser() + parsed = parser.parse_args(["datasets", "list"]) + assert parsed.depth is None + assert parsed.org is None + + +def test_list_parser_rejects_invalid_depth(): + parser = _build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["datasets", "list", "--depth", "3"]) + + +# --------------------------------------------------------------------------- +# splits subcommand tests +# --------------------------------------------------------------------------- + + +def test_splits_lists_split_names(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="splits", org="alibaba", dataset="pinch") + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_dataset_splits.return_value = ["test", "train"] + asyncio.run(cmd._splits(args)) + + MockClient.return_value.list_dataset_splits.assert_called_once_with("alibaba", "pinch") + out = capsys.readouterr().out + assert "Split" in out + assert "test" in out + assert "train" in out + assert "2 splits." in out + + +def test_splits_empty_prints_no_splits_message(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="splits", org="alibaba", dataset="missing") + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_dataset_splits.return_value = [] + asyncio.run(cmd._splits(args)) + + out = capsys.readouterr().out + assert "No splits found for dataset 'alibaba/missing'." in out + + +def test_splits_singular_footer_for_one_split(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="splits", org="alibaba", dataset="pinch") + + with patch.object(DatasetsCommand, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient") as MockClient: + MockClient.return_value.list_dataset_splits.return_value = ["test"] + asyncio.run(cmd._splits(args)) + + out = capsys.readouterr().out + assert "1 split." in out + + +def test_splits_parser_requires_org_and_dataset(): + parser = _build_parser() + with pytest.raises(SystemExit): + parser.parse_args(["datasets", "splits"]) + with pytest.raises(SystemExit): + parser.parse_args(["datasets", "splits", "--org", "alibaba"]) + + +def test_splits_parser_accepts_org_and_dataset(): + parser = _build_parser() + parsed = parser.parse_args(["datasets", "splits", "--org", "alibaba", "--dataset", "pinch"]) + assert parsed.org == "alibaba" + assert parsed.dataset == "pinch" diff --git a/tests/unit/datasets/test_oss_registry.py b/tests/unit/datasets/test_oss_registry.py index 9237ac1d6f..fc3dd9be70 100644 --- a/tests/unit/datasets/test_oss_registry.py +++ b/tests/unit/datasets/test_oss_registry.py @@ -27,10 +27,12 @@ def test_list_datasets_returns_all(): make_list_result(prefixes=["datasets/qwen/"]), make_list_result(prefixes=["datasets/qwen/my-bench/"]), make_list_result(prefixes=["datasets/qwen/my-bench/train/"]), - make_list_result(prefixes=[ - "datasets/qwen/my-bench/train/task-001/", - "datasets/qwen/my-bench/train/task-002/", - ]), + make_list_result( + prefixes=[ + "datasets/qwen/my-bench/train/task-001/", + "datasets/qwen/my-bench/train/task-002/", + ] + ), ] with patch.object(registry, "_build_bucket", return_value=mock_bucket): @@ -58,6 +60,7 @@ def test_list_datasets_filter_by_org(): assert first_call_kwargs["prefix"] == "datasets/qwen/" assert len(datasets) == 1 + def test_list_datasets_counts_directory_and_file_tasks(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() @@ -104,6 +107,7 @@ def test_build_prefix_with_split(): registry = OssDatasetRegistry(make_registry_info()) assert registry._build_prefix("qwen", "my-bench", "train") == "datasets/qwen/my-bench/train" + # --------------------------------------------------------------------------- # list_dataset_tasks tests # --------------------------------------------------------------------------- @@ -112,10 +116,12 @@ def test_build_prefix_with_split(): def test_list_dataset_tasks_uses_default_test_split_and_sorts_task_ids(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() - mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ - "datasets/qwen/my-bench/test/task-002/", - "datasets/qwen/my-bench/test/task-001/", - ]) + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/my-bench/test/task-002/", + "datasets/qwen/my-bench/test/task-001/", + ] + ) with patch.object(registry, "_build_bucket", return_value=mock_bucket): spec = registry.list_dataset_tasks("qwen", "my-bench") @@ -132,9 +138,11 @@ def test_list_dataset_tasks_uses_default_test_split_and_sorts_task_ids(): def test_list_dataset_tasks_supports_custom_split(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() - mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[ - "datasets/qwen/my-bench/train/task-001/", - ]) + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/my-bench/train/task-001/", + ] + ) with patch.object(registry, "_build_bucket", return_value=mock_bucket): spec = registry.list_dataset_tasks("qwen", "my-bench", "train") @@ -146,6 +154,7 @@ def test_list_dataset_tasks_supports_custom_split(): first_call_kwargs = mock_bucket.list_objects_v2.call_args_list[0][1] assert first_call_kwargs["prefix"] == "datasets/qwen/my-bench/train/" + def test_list_dataset_tasks_includes_directory_and_file_tasks_with_suffix_stripped(): registry = OssDatasetRegistry(make_registry_info()) mock_bucket = MagicMock() @@ -281,3 +290,166 @@ def test_upload_dataset_oss_key_format(tmp_path): key = mock_bucket.put_object.call_args[0][0] assert key == "datasets/qwen/my-bench/train/task-001/task.toml" + + +# --------------------------------------------------------------------------- +# list_organizations tests +# --------------------------------------------------------------------------- + + +def test_list_organizations_returns_sorted_org_names(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/", + "datasets/alibaba/", + "datasets/AoneBenchDev/", + ] + ) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + orgs = registry.list_organizations() + + call_kwargs = mock_bucket.list_objects_v2.call_args[1] + assert call_kwargs["prefix"] == "datasets/" + assert call_kwargs["delimiter"] == "/" + assert call_kwargs["max_keys"] == 1000 + assert orgs == ["AoneBenchDev", "alibaba", "qwen"] + + +def test_list_organizations_returns_empty_when_no_orgs(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + orgs = registry.list_organizations() + + assert orgs == [] + + +def test_list_org_datasets_returns_sorted_dataset_names(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/bench-2/", + "datasets/qwen/bench-1/", + ] + ) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + datasets = registry.list_org_datasets("qwen") + + call_kwargs = mock_bucket.list_objects_v2.call_args[1] + assert call_kwargs["prefix"] == "datasets/qwen/" + assert call_kwargs["delimiter"] == "/" + assert call_kwargs["max_keys"] == 1000 + assert datasets == ["bench-1", "bench-2"] + + +def test_list_org_datasets_returns_empty_when_org_missing(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + assert registry.list_org_datasets("nonexistent") == [] + + +def test_list_dataset_splits_returns_sorted_split_names(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result( + prefixes=[ + "datasets/qwen/bench/train/", + "datasets/qwen/bench/test/", + ] + ) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + splits = registry.list_dataset_splits("qwen", "bench") + + call_kwargs = mock_bucket.list_objects_v2.call_args[1] + assert call_kwargs["prefix"] == "datasets/qwen/bench/" + assert call_kwargs["delimiter"] == "/" + assert splits == ["test", "train"] + + +def test_list_dataset_splits_returns_empty_when_dataset_missing(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + mock_bucket.list_objects_v2.return_value = make_list_result(prefixes=[]) + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + assert registry.list_dataset_splits("qwen", "nope") == [] + + +def test_list_all_datasets_returns_sorted_pairs(): + registry = OssDatasetRegistry(make_registry_info()) + + def fake_list_org_datasets(org): + return {"qwen": ["bench-2", "bench-1"], "alibaba": ["pinch"]}[org] + + with patch.object(registry, "list_organizations", return_value=["qwen", "alibaba"]): + with patch.object(registry, "list_org_datasets", side_effect=fake_list_org_datasets): + pairs = registry.list_all_datasets() + + assert pairs == [("alibaba", "pinch"), ("qwen", "bench-1"), ("qwen", "bench-2")] + + +def test_list_all_datasets_uses_bounded_concurrency(): + registry = OssDatasetRegistry(make_registry_info()) + + with patch.object(registry, "list_organizations", return_value=["o1", "o2"]): + with patch.object(registry, "list_org_datasets", return_value=["d"]): + with patch("rock.sdk.envhub.datasets.registry.oss.ThreadPoolExecutor") as mock_pool: + with patch("rock.sdk.envhub.datasets.registry.oss.as_completed", side_effect=lambda d: list(d)): + mock_executor = MagicMock() + mock_pool.return_value.__enter__.return_value = mock_executor + future = MagicMock() + future.result.return_value = ["d"] + mock_executor.submit.return_value = future + registry.list_all_datasets(concurrency=7) + + mock_pool.assert_called_once_with(max_workers=7) + + +def test_list_all_datasets_default_concurrency_is_10(): + registry = OssDatasetRegistry(make_registry_info()) + + with patch.object(registry, "list_organizations", return_value=["o1"]): + with patch.object(registry, "list_org_datasets", return_value=["d"]): + with patch("rock.sdk.envhub.datasets.registry.oss.ThreadPoolExecutor") as mock_pool: + with patch("rock.sdk.envhub.datasets.registry.oss.as_completed", side_effect=lambda d: list(d)): + mock_executor = MagicMock() + mock_pool.return_value.__enter__.return_value = mock_executor + future = MagicMock() + future.result.return_value = ["d"] + mock_executor.submit.return_value = future + registry.list_all_datasets() + + mock_pool.assert_called_once_with(max_workers=10) + + +def test_list_all_datasets_propagates_exception_from_worker(): + import pytest as _pytest + + registry = OssDatasetRegistry(make_registry_info()) + + def fake_list_org_datasets(org): + if org == "bad": + raise RuntimeError("oss boom") + return ["d"] + + with patch.object(registry, "list_organizations", return_value=["good", "bad"]): + with patch.object(registry, "list_org_datasets", side_effect=fake_list_org_datasets): + with _pytest.raises(RuntimeError, match="oss boom"): + registry.list_all_datasets() + + +def test_list_all_datasets_empty_when_no_orgs(): + registry = OssDatasetRegistry(make_registry_info()) + with patch.object(registry, "list_organizations", return_value=[]): + assert registry.list_all_datasets() == []