From 35fffe0f177ddc66343d04488c2e70f2b836dce0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Thu, 4 Jun 2026 10:57:32 +0800 Subject: [PATCH 1/4] perf(datasets): cache OSS bucket, add server-side pagination and --filter for tasks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Cache oss2.Bucket instance to avoid creating 25+ connections per list_all_datasets call (5.0s → 2.7s) - Add _PaginationCache with continuation_token to resume sequential page access in O(1) - Push offset/limit down to _extract_tasks_from_split with early termination (19.4s → 1.5s for --limit 10) - Adapt max_keys to actual limit needed instead of always 1000 - Add --filter flag to tasks command, pushed down as OSS prefix for server-side filtering - Update BaseDatasetRegistry and DatasetClient interfaces with offset/limit/task_filter params Co-Authored-By: Claude Code AI-Model: claude-opus-4-6 AI-Contributed/Feature: 120/540 AI-Contributed/UT: 17/448 --- rock/cli/command/datasets.py | 247 ++++++++++- rock/sdk/envhub/datasets/client.py | 25 +- rock/sdk/envhub/datasets/registry/base.py | 25 +- rock/sdk/envhub/datasets/registry/oss.py | 310 ++++++++++--- tests/unit/datasets/test_client.py | 25 +- tests/unit/datasets/test_datasets_command.py | 435 ++++++++++++++++++- 6 files changed, 994 insertions(+), 73 deletions(-) diff --git a/rock/cli/command/datasets.py b/rock/cli/command/datasets.py index d35b8b87ea..7f6603cf37 100644 --- a/rock/cli/command/datasets.py +++ b/rock/cli/command/datasets.py @@ -1,7 +1,10 @@ from __future__ import annotations import argparse +import contextlib +import json import sys +from dataclasses import asdict from pathlib import Path from rock.cli.command.command import Command @@ -27,6 +30,46 @@ def _positive_int(value: str) -> int: return ivalue +def _is_json_output(args: argparse.Namespace) -> bool: + return getattr(args, "output", None) == "json" + + +def _print_json(data: dict | list) -> None: + print(json.dumps(data, ensure_ascii=False, indent=2)) + + +def _dataset_to_dict(dataset) -> dict: + data = asdict(dataset) + data["task_count"] = len(dataset.task_ids) + return data + + +def _normalize_task_path(path: str, *, allow_empty: bool = False, directory: bool = False) -> str: + raw = (path or "").replace("\\", "/") + if raw.startswith("/"): + raise ValueError("relative task path must not be absolute") + + parts = [p for p in raw.split("/") if p and p != "."] + if any(p == ".." for p in parts): + raise ValueError("relative task path must not contain '..'") + if not parts: + if allow_empty: + return "" + raise ValueError("relative task path is required") + + normalized = "/".join(parts) + if directory or raw.endswith("/"): + normalized += "/" + return normalized + + +def _write_stdout_bytes(content: bytes) -> None: + try: + sys.stdout.write(content.decode("utf-8")) + except UnicodeDecodeError: + sys.stdout.buffer.write(content) + + class DatasetsCommand(Command): name = "datasets" @@ -39,6 +82,8 @@ async def arun(self, args: argparse.Namespace) -> None: await self._splits(args) elif args.datasets_command == "upload": await self._upload(args) + elif args.datasets_command in ("fs", "files"): + await self._fs(args) else: raise ValueError(f"Unknown datasets command: {args.datasets_command}") @@ -62,6 +107,17 @@ async def _list(self, args: argparse.Namespace) -> None: registry_info = self._build_oss_registry_info(args) client = DatasetClient(registry_info) + if _is_json_output(args): + if getattr(args, "depth", None) == 1: + _print_json({"organizations": client.list_organizations()}) + return + datasets = sorted( + client.list_datasets(getattr(args, "org", None)), + key=lambda d: (d.id, d.split), + ) + _print_json({"datasets": [_dataset_to_dict(d) for d in datasets]}) + return + if getattr(args, "org", None): datasets = client.list_org_datasets(args.org) pairs = [(args.org, d) for d in datasets] @@ -107,16 +163,37 @@ def _render_orgs(orgs: list[str]) -> None: async def _tasks(self, args: argparse.Namespace) -> None: registry_info = self._build_oss_registry_info(args) client = DatasetClient(registry_info) - spec = client.list_dataset_tasks(args.org, args.dataset, args.split) + spec = client.list_dataset_tasks( + args.org, args.dataset, args.split, + offset=args.offset, limit=args.limit, task_filter=getattr(args, "filter", None), + ) if spec is None or not spec.task_ids: + if _is_json_output(args): + _print_json({ + "dataset": f"{args.org}/{args.dataset}", + "split": args.split, + "total": 0, + "offset": args.offset, + "limit": args.limit, + "task_ids": [], + }) + return print(f"No tasks found for dataset '{args.org}/{args.dataset}' split '{args.split}'.") return - total = len(spec.task_ids) - start = args.offset - end = start + args.limit if args.limit is not None else None - shown_task_ids = spec.task_ids[start:end] + shown_task_ids = spec.task_ids + + if _is_json_output(args): + _print_json({ + "dataset": spec.id, + "split": spec.split, + "total": len(shown_task_ids), + "offset": args.offset, + "limit": args.limit, + "task_ids": shown_task_ids, + }) + return if not shown_task_ids: print("No tasks found after applying offset/limit.") @@ -124,7 +201,7 @@ async def _tasks(self, args: argparse.Namespace) -> None: print() print("=" * 80) - print(f"Dataset: {spec.id} Split: {spec.split} Total: {total} Shown: {len(shown_task_ids)}") + print(f"Dataset: {spec.id} Split: {spec.split} Shown: {len(shown_task_ids)}") print("=" * 80) print("#Task name") print("-" * 10) @@ -150,8 +227,8 @@ async def _splits(self, args: argparse.Namespace) -> None: async def _upload(self, args: argparse.Namespace) -> None: local_dir = Path(args.dir) - if not local_dir.is_dir(): - raise ValueError(f"--dir '{local_dir}' does not exist or is not a directory") + if not local_dir.exists(): + raise ValueError(f"--dir '{local_dir}' does not exist") registry_info = self._build_oss_registry_info(args) source = LocalDatasetConfig(path=local_dir) @@ -163,20 +240,131 @@ async def _upload(self, args: argparse.Namespace) -> None: ) base = registry_info.oss_dataset_path or "datasets" - print(f"Uploading to oss://{registry_info.oss_bucket}/{base}/{args.org}/{args.dataset}/{args.split}/") + if not _is_json_output(args): + print(f"Uploading to oss://{registry_info.oss_bucket}/{base}/{args.org}/{args.dataset}/{args.split}/") client = DatasetClient(registry_info) - result = client.upload_dataset(source, target, concurrency=args.concurrency) + if _is_json_output(args): + with contextlib.redirect_stdout(sys.stderr): + result = client.upload_dataset(source, target, concurrency=args.concurrency) + _print_json(asdict(result)) + else: + result = client.upload_dataset(source, target, concurrency=args.concurrency) + print(f"\nDone: {result.uploaded} uploaded, {result.skipped} skipped, {result.failed} failed") - print(f"\nDone: {result.uploaded} uploaded, {result.skipped} skipped, {result.failed} failed") if result.failed > 0: sys.exit(1) + async def _fs(self, args: argparse.Namespace) -> None: + if args.fs_command == "ls": + await self._fs_ls(args) + elif args.fs_command == "get": + await self._fs_get(args) + elif args.fs_command == "download": + await self._fs_download(args) + else: + raise ValueError(f"Unknown datasets fs command: {args.fs_command}") + + async def _fs_ls(self, args: argparse.Namespace) -> None: + path = _normalize_task_path(args.path, allow_empty=True, directory=bool(args.path)) if args.path else "" + registry_info = self._build_oss_registry_info(args) + client = DatasetClient(registry_info) + files = client.list_task_files(args.org, args.dataset, args.split, args.task, path.rstrip("/")) + + if _is_json_output(args): + _print_json({ + "dataset": f"{args.org}/{args.dataset}", + "split": args.split, + "task": args.task, + "path": path, + "files": [asdict(f) for f in files], + }) + return + + for file in files: + print(file.path) + + async def _fs_get(self, args: argparse.Namespace) -> None: + path = _normalize_task_path(args.path) if args.path else None + registry_info = self._build_oss_registry_info(args) + client = DatasetClient(registry_info) + if path is None: + files = client.list_task_files(args.org, args.dataset, args.split, args.task, "") + if len(files) != 1: + raise ValueError("--path is required when task contains zero or multiple files") + path = files[0].path + content = client.get_task_file(args.org, args.dataset, args.split, args.task, path) + if content is None: + raise FileNotFoundError(f"Task file not found: {args.org}/{args.dataset}/{args.split}/{args.task}/{path}") + + if _is_json_output(args): + _print_json({ + "dataset": f"{args.org}/{args.dataset}", + "split": args.split, + "task": args.task, + "path": path, + "content": content.decode("utf-8"), + }) + return + + _write_stdout_bytes(content) + + async def _fs_download(self, args: argparse.Namespace) -> None: + path = _normalize_task_path(args.path, directory=args.path.endswith("/")) + dest = Path(args.dest) + registry_info = self._build_oss_registry_info(args) + client = DatasetClient(registry_info) + + content = client.get_task_file(args.org, args.dataset, args.split, args.task, path) + if content is not None: + target = dest / Path(path).name if dest.is_dir() else dest + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(content) + if _is_json_output(args): + _print_json({"downloaded": [str(target)]}) + else: + print(str(target)) + return + + prefix = path if path.endswith("/") else f"{path}/" + files = client.list_task_files(args.org, args.dataset, args.split, args.task, prefix.rstrip("/")) + if not files: + raise FileNotFoundError(f"Task path not found: {args.org}/{args.dataset}/{args.split}/{args.task}/{path}") + + downloaded: list[str] = [] + for file in files: + file_content = client.get_task_file(args.org, args.dataset, args.split, args.task, file.path) + if file_content is None: + continue + relative = file.path[len(prefix) :] if file.path.startswith(prefix) else file.path + target = dest / relative + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(file_content) + downloaded.append(str(target)) + + if _is_json_output(args): + _print_json({"downloaded": downloaded}) + else: + for path in downloaded: + print(path) + @staticmethod async def add_parser_to(subparsers: argparse._SubParsersAction) -> None: datasets_parser = subparsers.add_parser("datasets", description="Dataset operations on OSS") + datasets_parser.set_defaults(output=None) datasets_subparsers = datasets_parser.add_subparsers(dest="datasets_command") + def add_output_arg(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "-o", + "--output", + "--ouput", + dest="output", + choices=["json"], + default=argparse.SUPPRESS, + help="Output format. Supported: json", + ) + def add_oss_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--bucket", help="OSS bucket name (overrides config.ini)") parser.add_argument("--endpoint", help="OSS endpoint URL (overrides config.ini)") @@ -188,6 +376,8 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: ) parser.add_argument("--region", help="OSS region (overrides config.ini)") + add_output_arg(datasets_parser) + list_parser = datasets_subparsers.add_parser("list", help="List datasets in OSS registry") list_group = list_parser.add_mutually_exclusive_group() list_group.add_argument( @@ -200,9 +390,11 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: list_group.add_argument("--org", help="List datasets under the given organization only") add_oss_args(list_parser) tasks_parser = datasets_subparsers.add_parser("tasks", help="List task IDs under one dataset split") + add_output_arg(tasks_parser) tasks_parser.add_argument("--org", required=True, help="Organization name") tasks_parser.add_argument("--dataset", required=True, help="Dataset name") tasks_parser.add_argument("--split", default="test", help="Split name (default: test)") + tasks_parser.add_argument("--filter", default=None, help="Filter tasks by prefix (e.g. --filter 0xerr0r)") tasks_parser.add_argument("--offset", type=_non_negative_int, default=0, help="Skip first N tasks") tasks_parser.add_argument("--limit", type=_positive_int, default=None, help="Maximum number of tasks to show") add_oss_args(tasks_parser) @@ -213,10 +405,15 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: add_oss_args(splits_parser) upload_parser = datasets_subparsers.add_parser("upload", help="Upload local task dirs to OSS") + add_output_arg(upload_parser) upload_parser.add_argument("--org", required=True, help="Organization name") upload_parser.add_argument("--dataset", required=True, help="Dataset name") upload_parser.add_argument("--split", required=True, help="Split name (e.g. train, test, v1.0)") - upload_parser.add_argument("--dir", required=True, help="Local directory containing {task_id}/ subdirectories") + upload_parser.add_argument( + "--dir", + required=True, + help="Local dataset directory containing {task_id}/ subdirectories or direct task files, or one task file", + ) upload_parser.add_argument( "--concurrency", type=int, @@ -229,3 +426,29 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None: "--overwrite", action="store_true", help="Overwrite existing tasks in OSS (default: skip)" ) add_oss_args(upload_parser) + + fs_parser = datasets_subparsers.add_parser("fs", aliases=["files"], help="Inspect files under one task") + fs_subparsers = fs_parser.add_subparsers(dest="fs_command") + + def add_task_fs_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--org", required=True, help="Organization name") + parser.add_argument("--dataset", required=True, help="Dataset name") + parser.add_argument("--split", default="test", help="Split name (default: test)") + parser.add_argument("--task", required=True, help="Task ID") + add_oss_args(parser) + + ls_parser = fs_subparsers.add_parser("ls", help="List files under one task") + add_output_arg(ls_parser) + add_task_fs_args(ls_parser) + ls_parser.add_argument("--path", default="", help="Relative task directory path to list") + + get_parser = fs_subparsers.add_parser("get", help="Print one task file to stdout") + add_output_arg(get_parser) + add_task_fs_args(get_parser) + get_parser.add_argument("--path", help="Relative task file path") + + download_parser = fs_subparsers.add_parser("download", help="Download one task file or directory") + add_output_arg(download_parser) + add_task_fs_args(download_parser) + download_parser.add_argument("--path", required=True, help="Relative task file or directory path") + download_parser.add_argument("--dest", required=True, help="Local destination file or directory") diff --git a/rock/sdk/envhub/datasets/client.py b/rock/sdk/envhub/datasets/client.py index 16189defec..002a9f98bb 100644 --- a/rock/sdk/envhub/datasets/client.py +++ b/rock/sdk/envhub/datasets/client.py @@ -1,5 +1,5 @@ from rock.sdk.bench.models.job.config import LocalDatasetConfig, OssRegistryInfo, RegistryDatasetConfig -from rock.sdk.envhub.datasets.models import DatasetSpec, UploadResult +from rock.sdk.envhub.datasets.models import DatasetSpec, TaskFile, UploadResult from rock.sdk.envhub.datasets.registry.oss import OssDatasetRegistry @@ -10,8 +10,19 @@ def __init__(self, registry: OssRegistryInfo) -> None: def list_datasets(self, org: str | None = None) -> list[DatasetSpec]: return self._registry.list_datasets(org) - def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test") -> DatasetSpec | None: - return self._registry.list_dataset_tasks(organization, dataset, split) + def list_dataset_tasks( + self, + organization: str, + dataset: str, + split: str = "test", + *, + offset: int = 0, + limit: int | None = None, + task_filter: str | None = None, + ) -> DatasetSpec | None: + return self._registry.list_dataset_tasks( + organization, dataset, split, offset=offset, limit=limit, task_filter=task_filter + ) def list_organizations(self) -> list[str]: return self._registry.list_organizations() @@ -25,6 +36,14 @@ def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: return self._registry.list_dataset_splits(organization, dataset) + def list_task_files( + self, organization: str, dataset: str, split: str, task_id: str, path: str = "" + ) -> list[TaskFile]: + return self._registry.list_task_files(organization, dataset, split, task_id, path) + + def get_task_file(self, organization: str, dataset: str, split: str, task_id: str, path: str) -> bytes | None: + return self._registry.get_task_file(organization, dataset, split, task_id, path) + def upload_dataset( self, source: LocalDatasetConfig, diff --git a/rock/sdk/envhub/datasets/registry/base.py b/rock/sdk/envhub/datasets/registry/base.py index 0532671ffe..65820402de 100644 --- a/rock/sdk/envhub/datasets/registry/base.py +++ b/rock/sdk/envhub/datasets/registry/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from rock.sdk.bench.models.job.config import LocalDatasetConfig, RegistryDatasetConfig -from rock.sdk.envhub.datasets.models import DatasetSpec, UploadResult +from rock.sdk.envhub.datasets.models import DatasetSpec, TaskFile, UploadResult class BaseDatasetRegistry(ABC): @@ -11,7 +11,16 @@ def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: ... @abstractmethod - def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test") -> DatasetSpec | None: + def list_dataset_tasks( + self, + organization: str, + dataset: str, + split: str = "test", + *, + offset: int = 0, + limit: int | None = None, + task_filter: str | None = None, + ) -> DatasetSpec | None: """List task ids for one dataset split. Returns None if dataset/split has no tasks.""" ... @@ -30,6 +39,18 @@ def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: """List split names under one dataset. Single backend call.""" ... + @abstractmethod + def list_task_files( + self, organization: str, dataset: str, split: str, task_id: str, path: str = "" + ) -> list[TaskFile]: + """List files under a task path. Paths are relative to the task root.""" + ... + + @abstractmethod + def get_task_file(self, organization: str, dataset: str, split: str, task_id: str, path: str) -> bytes | None: + """Read one task file by relative path. Returns None when the object does not exist.""" + ... + @abstractmethod def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: """List all (org, dataset) pairs. 1 + N_org backend calls with bounded concurrency.""" diff --git a/rock/sdk/envhub/datasets/registry/oss.py b/rock/sdk/envhub/datasets/registry/oss.py index af36481ce9..9c676eef86 100644 --- a/rock/sdk/envhub/datasets/registry/oss.py +++ b/rock/sdk/envhub/datasets/registry/oss.py @@ -1,28 +1,41 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field from pathlib import Path import oss2 from rock.logger import init_logger from rock.sdk.bench.models.job.config import LocalDatasetConfig, OssRegistryInfo, RegistryDatasetConfig -from rock.sdk.envhub.datasets.models import DatasetSpec, UploadResult +from rock.sdk.envhub.datasets.models import DatasetSpec, TaskFile, UploadResult from rock.sdk.envhub.datasets.registry.base import BaseDatasetRegistry logger = init_logger(__name__) +@dataclass +class _PaginationCache: + split_prefix: str = "" + tasks: list[str] = field(default_factory=list) + continuation_token: str = "" + is_exhausted: bool = False + + class OssDatasetRegistry(BaseDatasetRegistry): def __init__(self, registry: OssRegistryInfo) -> None: self._registry = registry + self._bucket: oss2.Bucket | None = None + self._page_cache = _PaginationCache() def _build_bucket(self) -> oss2.Bucket: - auth = oss2.Auth( - self._registry.oss_access_key_id or "", - self._registry.oss_access_key_secret or "", - ) - return oss2.Bucket(auth, self._registry.oss_endpoint or "", self._registry.oss_bucket) + if self._bucket is None: + auth = oss2.Auth( + self._registry.oss_access_key_id or "", + self._registry.oss_access_key_secret or "", + ) + self._bucket = oss2.Bucket(auth, self._registry.oss_endpoint or "", self._registry.oss_bucket) + return self._bucket def _build_prefix(self, org: str, name: str, split: str | None = None) -> str: base = self._registry.oss_dataset_path or "datasets" @@ -31,11 +44,55 @@ def _build_prefix(self, org: str, name: str, split: str | None = None) -> str: parts.append(split) return "/".join(parts) + def _build_task_prefix(self, org: str, name: str, split: str, task_id: str) -> str: + return f"{self._build_prefix(org, name, split)}/{task_id}/" + @staticmethod def _last_segment(prefix: str) -> str: return prefix.rstrip("/").rsplit("/", 1)[-1] - def _extract_tasks_from_split(self, bucket: oss2.Bucket, split_prefix: str) -> list[str]: + @staticmethod + def _normalize_task_path(path: str, *, allow_empty: bool = False, directory: bool = False) -> str: + raw = (path or "").replace("\\", "/") + if raw.startswith("/"): + raise ValueError("relative task path must not be absolute") + + parts = [p for p in raw.split("/") if p and p != "."] + if any(p == ".." for p in parts): + raise ValueError("relative task path must not contain '..'") + if not parts: + if allow_empty: + return "" + raise ValueError("relative task path is required") + + normalized = "/".join(parts) + if directory or raw.endswith("/"): + normalized += "/" + return normalized + + @staticmethod + def _list_objects_v2_pages(bucket: oss2.Bucket, **kwargs): + token = "" + while True: + page_kwargs = dict(kwargs) + if token: + page_kwargs["continuation_token"] = token + result = bucket.list_objects_v2(**page_kwargs) + yield result + if not getattr(result, "is_truncated", False): + break + token = getattr(result, "next_continuation_token", "") or "" + if not token: + break + + def _extract_tasks_from_split( + self, + bucket: oss2.Bucket, + split_prefix: str, + *, + max_items: int | None = None, + task_filter: str | None = None, + ) -> list[str]: """Extract tasks from a split prefix, combining directory and file tasks. Directory tasks: from prefix_list (e.g., "datasets/org/name/split/task-001/") @@ -43,49 +100,100 @@ def _extract_tasks_from_split(self, bucket: oss2.Bucket, split_prefix: str) -> l File tasks are stripped of their suffix (e.g., "task-001.json" -> "task-001"). Placeholder objects (key ending with "/") and nested objects are ignored. - """ - result = bucket.list_objects_v2(prefix=split_prefix, delimiter="/", max_keys=1000) - # Directory tasks from prefix_list - dir_tasks = [self._last_segment(p) for p in result.prefix_list] + When *task_filter* is set, only tasks whose name starts with the filter + string are returned (pushed down to the OSS prefix for efficiency). - # File tasks from object_list: direct files under split, strip suffix - file_tasks = [] - for obj in result.object_list: - key = obj.key - # Ignore directory placeholder objects (key ending with "/") - if key.endswith("/"): - continue - # Get the relative path from split_prefix - relative = key[len(split_prefix) :] - # Only direct files (no nested paths with "/") - if "/" in relative: - continue - # Strip suffix (e.g., "task-001.json" -> "task-001") - name = relative.rsplit(".", 1)[0] if "." in relative else relative - file_tasks.append(name) - - # Merge and dedupe with stable sort - all_tasks = sorted(set(dir_tasks + file_tasks)) - return all_tasks + Uses an internal pagination cache: if the same query is repeated, results + are served from cache or resumed via continuation token. + """ + query_prefix = f"{split_prefix}{task_filter}" if task_filter else split_prefix + cache = self._page_cache + + # Cache hit: same query + if cache.split_prefix == query_prefix: + if cache.is_exhausted or (max_items is not None and len(cache.tasks) >= max_items): + return cache.tasks[:max_items] if max_items else list(cache.tasks) + tasks_set: set[str] = set(cache.tasks) + token = cache.continuation_token + else: + tasks_set = set() + token = "" + + while True: + mk = 1000 + if max_items is not None: + mk = min(1000, max(max_items - len(tasks_set), 100)) + + kwargs: dict = {"prefix": query_prefix, "delimiter": "/", "max_keys": mk} + if token: + kwargs["continuation_token"] = token + + result = bucket.list_objects_v2(**kwargs) + + for p in result.prefix_list: + s = self._last_segment(p) + if not s.startswith("."): + tasks_set.add(s) + + for obj in result.object_list: + key = obj.key + if key.endswith("/"): + continue + relative = key[len(split_prefix):] + if "/" in relative or relative.startswith("."): + continue + name = relative.rsplit(".", 1)[0] if "." in relative else relative + tasks_set.add(name) + + is_truncated = getattr(result, "is_truncated", False) + next_token = getattr(result, "next_continuation_token", "") or "" + + if not is_truncated or not next_token: + sorted_tasks = sorted(tasks_set) + cache.split_prefix = query_prefix + cache.tasks = sorted_tasks + cache.continuation_token = "" + cache.is_exhausted = True + return sorted_tasks[:max_items] if max_items else sorted_tasks + + if max_items is not None and len(tasks_set) >= max_items: + sorted_tasks = sorted(tasks_set) + cache.split_prefix = query_prefix + cache.tasks = sorted_tasks + cache.continuation_token = next_token + cache.is_exhausted = False + return sorted_tasks[:max_items] + + token = next_token def list_organizations(self) -> list[str]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" - result = bucket.list_objects_v2(prefix=f"{base}/", delimiter="/", max_keys=1000) - return sorted(self._last_segment(p) for p in result.prefix_list) + prefixes = [] + for result in self._list_objects_v2_pages(bucket, prefix=f"{base}/", delimiter="/", max_keys=1000): + prefixes.extend(result.prefix_list) + return sorted(s for p in prefixes if not (s := self._last_segment(p)).startswith(".")) def list_org_datasets(self, organization: str) -> list[str]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" - result = bucket.list_objects_v2(prefix=f"{base}/{organization}/", delimiter="/", max_keys=1000) - return sorted(self._last_segment(p) for p in result.prefix_list) + prefixes = [] + for result in self._list_objects_v2_pages( + bucket, prefix=f"{base}/{organization}/", delimiter="/", max_keys=1000 + ): + prefixes.extend(result.prefix_list) + return sorted(s for p in prefixes if not (s := self._last_segment(p)).startswith(".")) def list_dataset_splits(self, organization: str, dataset: str) -> list[str]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" - result = bucket.list_objects_v2(prefix=f"{base}/{organization}/{dataset}/", delimiter="/", max_keys=1000) - return sorted(self._last_segment(p) for p in result.prefix_list) + prefixes = [] + for result in self._list_objects_v2_pages( + bucket, prefix=f"{base}/{organization}/{dataset}/", delimiter="/", max_keys=1000 + ): + prefixes.extend(result.prefix_list) + return sorted(s for p in prefixes if not (s := self._last_segment(p)).startswith(".")) def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]: orgs = self.list_organizations() @@ -107,20 +215,31 @@ def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: if organization: org_prefixes = [f"{base}/{organization}/"] else: - result = bucket.list_objects_v2(prefix=f"{base}/", delimiter="/", max_keys=1000) - org_prefixes = result.prefix_list + org_prefixes = [] + for result in self._list_objects_v2_pages(bucket, prefix=f"{base}/", delimiter="/", max_keys=1000): + org_prefixes.extend(result.prefix_list) datasets: list[DatasetSpec] = [] for org_prefix in org_prefixes: org = self._last_segment(org_prefix) + if org.startswith("."): + continue - result = bucket.list_objects_v2(prefix=org_prefix, delimiter="/", max_keys=1000) - for name_prefix in result.prefix_list: + name_prefixes = [] + for result in self._list_objects_v2_pages(bucket, prefix=org_prefix, delimiter="/", max_keys=1000): + name_prefixes.extend(result.prefix_list) + for name_prefix in name_prefixes: name = self._last_segment(name_prefix) + if name.startswith("."): + continue - result2 = bucket.list_objects_v2(prefix=name_prefix, delimiter="/", max_keys=1000) - for split_prefix in result2.prefix_list: + split_prefixes = [] + for result2 in self._list_objects_v2_pages(bucket, prefix=name_prefix, delimiter="/", max_keys=1000): + split_prefixes.extend(result2.prefix_list) + for split_prefix in split_prefixes: split = self._last_segment(split_prefix) + if split.startswith("."): + continue task_ids = self._extract_tasks_from_split(bucket, split_prefix) datasets.append( @@ -133,10 +252,22 @@ def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]: return datasets - def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test") -> DatasetSpec | None: + def list_dataset_tasks( + self, + organization: str, + dataset: str, + split: str = "test", + *, + offset: int = 0, + limit: int | None = None, + task_filter: str | None = None, + ) -> DatasetSpec | None: bucket = self._build_bucket() split_prefix = f"{self._build_prefix(organization, dataset, split)}/" - task_ids = self._extract_tasks_from_split(bucket, split_prefix) + max_items = offset + limit if limit is not None else None + task_ids = self._extract_tasks_from_split( + bucket, split_prefix, max_items=max_items, task_filter=task_filter + ) if not task_ids: return None @@ -144,13 +275,66 @@ def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test return DatasetSpec( id=f"{organization}/{dataset}", split=split, - task_ids=task_ids, + task_ids=task_ids[offset:max_items], ) + def list_task_files( + self, organization: str, dataset: str, split: str, task_id: str, path: str = "" + ) -> list[TaskFile]: + bucket = self._build_bucket() + task_prefix = self._build_task_prefix(organization, dataset, split, task_id) + relative_prefix = self._normalize_task_path(path, allow_empty=True, directory=bool(path)) if path else "" + + files: list[TaskFile] = [] + for result in self._list_objects_v2_pages(bucket, prefix=f"{task_prefix}{relative_prefix}", max_keys=1000): + for obj in result.object_list: + key = obj.key + if key.endswith("/") or not key.startswith(task_prefix): + continue + relative = key[len(task_prefix) :] + if not relative: + continue + files.append(TaskFile(path=relative, size=getattr(obj, "size", None))) + if not files and not relative_prefix: + split_prefix = f"{self._build_prefix(organization, dataset, split)}/" + for result in self._list_objects_v2_pages(bucket, prefix=f"{split_prefix}{task_id}", max_keys=1000): + for obj in result.object_list: + key = obj.key + if key.endswith("/") or not key.startswith(split_prefix): + continue + relative = key[len(split_prefix) :] + if "/" in relative: + continue + name = relative.rsplit(".", 1)[0] if "." in relative else relative + if name == task_id: + files.append(TaskFile(path=relative, size=getattr(obj, "size", None))) + return sorted(files, key=lambda f: f.path) + + def get_task_file(self, organization: str, dataset: str, split: str, task_id: str, path: str) -> bytes | None: + bucket = self._build_bucket() + relative = self._normalize_task_path(path) + key = f"{self._build_task_prefix(organization, dataset, split, task_id)}{relative}" + try: + return bucket.get_object(key).read() + except (oss2.exceptions.NoSuchKey, oss2.exceptions.NotFound): + if "/" not in relative: + direct_name = relative.rsplit(".", 1)[0] if "." in relative else relative + if direct_name == task_id: + direct_key = f"{self._build_prefix(organization, dataset, split)}/{relative}" + try: + return bucket.get_object(direct_key).read() + except (oss2.exceptions.NoSuchKey, oss2.exceptions.NotFound): + return None + return None + def _task_exists(self, bucket: oss2.Bucket, task_prefix: str) -> bool: result = bucket.list_objects_v2(prefix=task_prefix, max_keys=1) return len(result.object_list) > 0 + def _object_exists(self, bucket: oss2.Bucket, key: str) -> bool: + result = bucket.list_objects_v2(prefix=key, max_keys=1) + return any(obj.key == key for obj in result.object_list) + def _upload_task( self, bucket: oss2.Bucket, @@ -173,6 +357,23 @@ def _upload_task( bucket.put_object(key, file.read_bytes()) return len(files) + def _upload_task_file( + self, + bucket: oss2.Bucket, + org: str, + name: str, + split: str, + task_file: Path, + overwrite: bool, + ) -> int | None: + key = f"{self._build_prefix(org, name, split)}/{task_file.name}" + + if not overwrite and self._object_exists(bucket, key): + return None + + bucket.put_object(key, task_file.read_bytes()) + return 1 + def upload_dataset( self, source: LocalDatasetConfig, @@ -185,16 +386,25 @@ def upload_dataset( local_dir = source.path bucket = self._build_bucket() - task_dirs = sorted([d for d in local_dir.iterdir() if d.is_dir()]) + if local_dir.is_file(): + upload_items = [local_dir] + else: + upload_items = sorted([p for p in local_dir.iterdir() if p.is_dir() or p.is_file()]) raw: dict[str, int | None | Exception] = {} with ThreadPoolExecutor(max_workers=concurrency) as executor: - futures = {executor.submit(self._upload_task, bucket, org, name, split, d, overwrite): d for d in task_dirs} - for future, task_dir in futures.items(): + futures = {} + for item in upload_items: + if item.is_dir(): + future = executor.submit(self._upload_task, bucket, org, name, split, item, overwrite) + else: + future = executor.submit(self._upload_task_file, bucket, org, name, split, item, overwrite) + futures[future] = item + for future, item in futures.items(): try: - raw[task_dir.name] = future.result() + raw[item.name] = future.result() except Exception as exc: - raw[task_dir.name] = exc + raw[item.name] = exc uploaded = skipped = failed = 0 for task_id in sorted(raw): diff --git a/tests/unit/datasets/test_client.py b/tests/unit/datasets/test_client.py index f811763120..23eb6dbcfe 100644 --- a/tests/unit/datasets/test_client.py +++ b/tests/unit/datasets/test_client.py @@ -2,7 +2,7 @@ from rock.sdk.bench.models.job.config import LocalDatasetConfig, OssRegistryInfo, RegistryDatasetConfig from rock.sdk.envhub.datasets.client import DatasetClient -from rock.sdk.envhub.datasets.models import DatasetSpec, UploadResult +from rock.sdk.envhub.datasets.models import DatasetSpec, TaskFile, UploadResult def make_registry_info(): @@ -40,7 +40,7 @@ def test_dataset_client_list_tasks_delegates_to_registry_with_default_split(): with patch.object(client._registry, "list_dataset_tasks", return_value=expected) as mock_list_tasks: result = client.list_dataset_tasks("qwen", "bench") - mock_list_tasks.assert_called_once_with("qwen", "bench", "test") + mock_list_tasks.assert_called_once_with("qwen", "bench", "test", offset=0, limit=None, task_filter=None) assert result == expected @@ -81,3 +81,24 @@ def test_dataset_client_list_dataset_splits_delegates(): result = client.list_dataset_splits("qwen", "bench") m.assert_called_once_with("qwen", "bench") assert result == ["test", "train"] + + +def test_dataset_client_list_task_files_delegates(): + client = DatasetClient(make_registry_info()) + expected = [TaskFile(path="tests/test_api.py", size=10)] + + with patch.object(client._registry, "list_task_files", return_value=expected) as m: + result = client.list_task_files("qwen", "bench", "test", "task-001", "tests") + + m.assert_called_once_with("qwen", "bench", "test", "task-001", "tests") + assert result == expected + + +def test_dataset_client_get_task_file_delegates(): + client = DatasetClient(make_registry_info()) + + with patch.object(client._registry, "get_task_file", return_value=b"content") as m: + result = client.get_task_file("qwen", "bench", "test", "task-001", "task.yaml") + + m.assert_called_once_with("qwen", "bench", "test", "task-001", "task.yaml") + assert result == b"content" diff --git a/tests/unit/datasets/test_datasets_command.py b/tests/unit/datasets/test_datasets_command.py index 7fc17561f9..d840ab446c 100644 --- a/tests/unit/datasets/test_datasets_command.py +++ b/tests/unit/datasets/test_datasets_command.py @@ -1,12 +1,13 @@ import argparse import asyncio +import json from unittest.mock import AsyncMock, MagicMock, patch import pytest from rock.cli.command.datasets import DatasetsCommand from rock.sdk.bench.models.job.config import OssRegistryInfo -from rock.sdk.envhub.datasets.models import DatasetSpec +from rock.sdk.envhub.datasets.models import DatasetSpec, TaskFile, UploadResult def make_base_args(**kwargs): @@ -24,6 +25,7 @@ def make_base_args(**kwargs): depth=2, offset=0, limit=None, + output=None, ) for k, v in kwargs.items(): setattr(args, k, v) @@ -110,6 +112,27 @@ def test_tasks_parser_defaults_split_offset_limit(): assert ns.split == "test" assert ns.offset == 0 assert ns.limit is None + assert ns.output is None + + +@pytest.mark.parametrize( + "flag", + ["-o", "--output", "--ouput"], +) +def test_datasets_subcommands_accept_json_output(flag): + parser = _build_parser() + + ns = parser.parse_args(["datasets", "tasks", "--org", "qwen", "--dataset", "my-bench", flag, "json"]) + + assert ns.output == "json" + + +def test_datasets_parser_accepts_json_output_before_subcommand(): + parser = _build_parser() + + ns = parser.parse_args(["datasets", "-o", "json", "list"]) + + assert ns.output == "json" @pytest.mark.parametrize( @@ -156,6 +179,259 @@ def test_arun_dispatches_tasks(): mock_tasks.assert_awaited_once_with(args) +def test_arun_dispatches_fs_aliases(): + for command in ("fs", "files"): + cmd = DatasetsCommand() + args = make_base_args(datasets_command=command, fs_command="ls", org="qwen", dataset="my-bench", task="task-001") + + with patch.object(DatasetsCommand, "_fs", new_callable=AsyncMock, create=True) as mock_fs: + asyncio.run(cmd.arun(args)) + + mock_fs.assert_awaited_once_with(args) + + +@pytest.mark.parametrize("command", ["fs", "files"]) +def test_fs_parser_accepts_ls_get_download(command): + parser = _build_parser() + + ls_ns = parser.parse_args( + ["datasets", command, "ls", "--org", "qwen", "--dataset", "my-bench", "--task", "task-001"] + ) + get_ns = parser.parse_args( + [ + "datasets", + command, + "get", + "--org", + "qwen", + "--dataset", + "my-bench", + "--task", + "task-001", + "--path", + "tests/input.json", + ] + ) + download_ns = parser.parse_args( + [ + "datasets", + command, + "download", + "--org", + "qwen", + "--dataset", + "my-bench", + "--task", + "task-001", + "--path", + "tests/", + "--dest", + "./tests", + ] + ) + + assert ls_ns.datasets_command == command + assert ls_ns.fs_command == "ls" + assert ls_ns.split == "test" + assert ls_ns.path == "" + assert get_ns.fs_command == "get" + assert get_ns.path == "tests/input.json" + assert download_ns.fs_command == "download" + assert download_ns.dest == "./tests" + + +def test_fs_ls_outputs_recursive_task_files(capsys): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="fs", + fs_command="ls", + org="qwen", + dataset="my-bench", + split="test", + task="task-001", + path="tests", + ) + mock_client = MagicMock() + mock_client.list_task_files.return_value = [ + TaskFile(path="tests/test_api.py", size=10), + TaskFile(path="tests/fixtures/input.json", size=20), + ] + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._fs(args)) + + mock_client.list_task_files.assert_called_once_with("qwen", "my-bench", "test", "task-001", "tests") + out = capsys.readouterr().out + assert "tests/test_api.py" in out + assert "tests/fixtures/input.json" in out + + +def test_fs_ls_outputs_json(capsys): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="fs", + fs_command="ls", + org="qwen", + dataset="my-bench", + split="test", + task="task-001", + path="", + output="json", + ) + mock_client = MagicMock() + mock_client.list_task_files.return_value = [TaskFile(path="task.yaml", size=42)] + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._fs(args)) + + assert json.loads(capsys.readouterr().out) == { + "dataset": "qwen/my-bench", + "split": "test", + "task": "task-001", + "path": "", + "files": [{"path": "task.yaml", "size": 42}], + } + + +def test_fs_get_writes_file_content(capsys): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="fs", + fs_command="get", + org="qwen", + dataset="my-bench", + split="test", + task="task-001", + path="task.yaml", + ) + mock_client = MagicMock() + mock_client.get_task_file.return_value = b"instruction: fix bug\n" + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._fs(args)) + + mock_client.get_task_file.assert_called_once_with("qwen", "my-bench", "test", "task-001", "task.yaml") + assert capsys.readouterr().out == "instruction: fix bug\n" + + +def test_fs_get_without_path_reads_single_task_file(capsys): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="fs", + fs_command="get", + org="qwen", + dataset="my-bench", + split="test", + task="task-001", + path=None, + ) + mock_client = MagicMock() + mock_client.list_task_files.return_value = [TaskFile(path="task-001.json", size=42)] + mock_client.get_task_file.return_value = b'{"instruction": "fix bug"}\n' + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._fs(args)) + + mock_client.list_task_files.assert_called_once_with("qwen", "my-bench", "test", "task-001", "") + mock_client.get_task_file.assert_called_once_with("qwen", "my-bench", "test", "task-001", "task-001.json") + assert capsys.readouterr().out == '{"instruction": "fix bug"}\n' + + +def test_fs_get_without_path_requires_path_when_task_has_multiple_files(): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="fs", + fs_command="get", + org="qwen", + dataset="my-bench", + split="test", + task="task-001", + path=None, + ) + mock_client = MagicMock() + mock_client.list_task_files.return_value = [ + TaskFile(path="task.yaml", size=42), + TaskFile(path="tests/test_api.py", size=10), + ] + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + with pytest.raises(ValueError, match="--path is required"): + asyncio.run(cmd._fs(args)) + + +def test_fs_download_file_writes_destination(tmp_path): + cmd = DatasetsCommand() + dest = tmp_path / "task.yaml" + args = make_base_args( + datasets_command="fs", + fs_command="download", + org="qwen", + dataset="my-bench", + split="test", + task="task-001", + path="task.yaml", + dest=str(dest), + ) + mock_client = MagicMock() + mock_client.get_task_file.return_value = b"instruction: fix bug\n" + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._fs(args)) + + assert dest.read_bytes() == b"instruction: fix bug\n" + mock_client.list_task_files.assert_not_called() + + +def test_fs_download_directory_strips_requested_prefix(tmp_path): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="fs", + fs_command="download", + org="qwen", + dataset="my-bench", + split="test", + task="task-001", + path="tests/", + dest=str(tmp_path / "downloaded-tests"), + ) + mock_client = MagicMock() + mock_client.get_task_file.return_value = None + mock_client.list_task_files.return_value = [ + TaskFile(path="tests/test_api.py", size=10), + TaskFile(path="tests/fixtures/input.json", size=20), + ] + mock_client.get_task_file.side_effect = [None, b"assert True\n", b"{}\n"] + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._fs(args)) + + assert (tmp_path / "downloaded-tests" / "test_api.py").read_text() == "assert True\n" + assert (tmp_path / "downloaded-tests" / "fixtures" / "input.json").read_text() == "{}\n" + + +def test_fs_rejects_unsafe_relative_path(): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="fs", + fs_command="get", + org="qwen", + dataset="my-bench", + split="test", + task="task-001", + path="../other-task/task.yaml", + ) + + with pytest.raises(ValueError, match="relative task path"): + asyncio.run(cmd._fs(args)) + + def test_tasks_outputs_paginated_results(capsys): cmd = DatasetsCommand() args = make_base_args( @@ -170,25 +446,176 @@ def test_tasks_outputs_paginated_results(capsys): mock_client.list_dataset_tasks.return_value = DatasetSpec( id="qwen/my-bench", split="test", - task_ids=["task-001", "task-002", "task-003"], + task_ids=["task-002", "task-003"], ) with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): asyncio.run(cmd._tasks(args)) - mock_client.list_dataset_tasks.assert_called_once_with("qwen", "my-bench", "test") + mock_client.list_dataset_tasks.assert_called_once_with( + "qwen", "my-bench", "test", offset=1, limit=2, task_filter=None + ) out = capsys.readouterr().out assert "Dataset: qwen/my-bench" in out assert "Split: test" in out assert "task-002" in out assert "task-003" in out assert "task-001" not in out - assert "Total: 3" in out assert "Shown: 2" in out assert "#Task name" in out +def test_list_outputs_json(capsys): + cmd = DatasetsCommand() + args = make_base_args(datasets_command="list", output="json") + mock_client = MagicMock() + mock_client.list_datasets.return_value = [ + DatasetSpec(id="qwen/bench-b", split="test", task_ids=["b-1"]), + DatasetSpec(id="qwen/bench-a", split="train", task_ids=["a-1", "a-2"]), + ] + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._list(args)) + + data = json.loads(capsys.readouterr().out) + assert data == { + "datasets": [ + {"id": "qwen/bench-a", "split": "train", "task_ids": ["a-1", "a-2"], "task_count": 2}, + {"id": "qwen/bench-b", "split": "test", "task_ids": ["b-1"], "task_count": 1}, + ] + } + + +def test_tasks_outputs_json_with_pagination(capsys): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="tasks", + org="qwen", + dataset="my-bench", + split="test", + offset=1, + limit=2, + output="json", + ) + mock_client = MagicMock() + mock_client.list_dataset_tasks.return_value = DatasetSpec( + id="qwen/my-bench", + split="test", + task_ids=["task-002", "task-003"], + ) + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._tasks(args)) + + data = json.loads(capsys.readouterr().out) + assert data == { + "dataset": "qwen/my-bench", + "split": "test", + "total": 2, + "offset": 1, + "limit": 2, + "task_ids": ["task-002", "task-003"], + } + + +def test_tasks_outputs_empty_json_when_not_found(capsys): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="tasks", + org="qwen", + dataset="my-bench", + split="test", + offset=0, + limit=None, + output="json", + ) + mock_client = MagicMock() + mock_client.list_dataset_tasks.return_value = None + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._tasks(args)) + + data = json.loads(capsys.readouterr().out) + assert data == { + "dataset": "qwen/my-bench", + "split": "test", + "total": 0, + "offset": 0, + "limit": None, + "task_ids": [], + } + + +def test_upload_outputs_json(capsys, tmp_path): + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="upload", + org="qwen", + dataset="my-bench", + split="test", + dir=str(tmp_path), + overwrite=False, + concurrency=4, + output="json", + ) + mock_client = MagicMock() + mock_client.upload_dataset.return_value = UploadResult( + id="qwen/my-bench", + split="test", + uploaded=2, + skipped=1, + failed=0, + ) + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._upload(args)) + + data = json.loads(capsys.readouterr().out) + assert data == { + "id": "qwen/my-bench", + "split": "test", + "uploaded": 2, + "skipped": 1, + "failed": 0, + } + + +def test_upload_accepts_single_file_source(capsys, tmp_path): + task_file = tmp_path / "task-001.json" + task_file.write_text("{}") + cmd = DatasetsCommand() + args = make_base_args( + datasets_command="upload", + org="qwen", + dataset="my-bench", + split="test", + dir=str(task_file), + overwrite=False, + concurrency=4, + ) + mock_client = MagicMock() + mock_client.upload_dataset.return_value = UploadResult( + id="qwen/my-bench", + split="test", + uploaded=1, + skipped=0, + failed=0, + ) + + with patch.object(cmd, "_build_oss_registry_info", return_value=make_registry_info()): + with patch("rock.cli.command.datasets.DatasetClient", return_value=mock_client): + asyncio.run(cmd._upload(args)) + + source = mock_client.upload_dataset.call_args.args[0] + assert source.path == task_file + assert "Uploading to oss://b/datasets/qwen/my-bench/test/" in capsys.readouterr().out + + def test_tasks_prints_no_tasks_message_when_not_found(capsys): cmd = DatasetsCommand() args = make_base_args( From 2271f9fbea85175a12a0bf46295d0d1ce84d6187 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Thu, 4 Jun 2026 16:22:29 +0800 Subject: [PATCH 2/4] fix(datasets): add missing TaskFile dataclass to models.py The client.py imports TaskFile but its definition was not included in the previous commit, causing ImportError in CI. Co-Authored-By: Claude Code AI-Model: claude-opus-4-6 AI-Contributed/Feature: 6/6 AI-Contributed/UT: 0/0 --- rock/sdk/envhub/datasets/models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rock/sdk/envhub/datasets/models.py b/rock/sdk/envhub/datasets/models.py index 264dd1e1f7..aabcb1d856 100644 --- a/rock/sdk/envhub/datasets/models.py +++ b/rock/sdk/envhub/datasets/models.py @@ -8,6 +8,12 @@ class DatasetSpec: task_ids: list[str] = field(default_factory=list) +@dataclass +class TaskFile: + path: str + size: int | None = None + + @dataclass class UploadResult: id: str # "{organization}/{dataset_name}" From bc62a224bfb1607886b2065d101e41f59625a6a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Thu, 4 Jun 2026 18:57:27 +0800 Subject: [PATCH 3/4] ci: retrigger CI From d81b94233af91b81e9e2301a21e7bb06b6faa46d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=82=E6=B4=9B?= Date: Tue, 9 Jun 2026 23:49:02 +0800 Subject: [PATCH 4/4] fix(datasets): bound OSS pagination to prevent infinite loop hang The new pagination loop in _extract_tasks_from_split used `while True` and read is_truncated/next_continuation_token via getattr, which on a MagicMock always returns a truthy value. Tests using `return_value` (and any real OSS response with a non-advancing token) spun forever, starving CPU/memory and crashing the self-hosted CI runner and local machines. Add a hard page budget (_MAX_PAGINATION_PAGES) and a non-advancing-token guard to both pagination paths; make the test helper set explicit pagination fields so mocks terminate correctly. refs #1063 AI-Model: claude-opus-4-8 AI-Contributed/Feature: 39/39 AI-Contributed/UT: 31/31 --- rock/sdk/envhub/datasets/registry/oss.py | 39 +++++++++++++++++++----- tests/unit/datasets/test_oss_registry.py | 31 ++++++++++++++++++- 2 files changed, 61 insertions(+), 9 deletions(-) diff --git a/rock/sdk/envhub/datasets/registry/oss.py b/rock/sdk/envhub/datasets/registry/oss.py index 9c676eef86..d0d34dd635 100644 --- a/rock/sdk/envhub/datasets/registry/oss.py +++ b/rock/sdk/envhub/datasets/registry/oss.py @@ -13,6 +13,11 @@ logger = init_logger(__name__) +# Hard upper bound on pagination pages. At 1000 keys/page this covers 10M keys, +# far beyond any real split, while guaranteeing the loop always terminates even +# if OSS (or a mock) keeps reporting truncation with a non-advancing token. +_MAX_PAGINATION_PAGES = 10_000 + @dataclass class _PaginationCache: @@ -73,7 +78,7 @@ def _normalize_task_path(path: str, *, allow_empty: bool = False, directory: boo @staticmethod def _list_objects_v2_pages(bucket: oss2.Bucket, **kwargs): token = "" - while True: + for _ in range(_MAX_PAGINATION_PAGES): page_kwargs = dict(kwargs) if token: page_kwargs["continuation_token"] = token @@ -81,9 +86,11 @@ def _list_objects_v2_pages(bucket: oss2.Bucket, **kwargs): yield result if not getattr(result, "is_truncated", False): break - token = getattr(result, "next_continuation_token", "") or "" - if not token: + next_token = getattr(result, "next_continuation_token", "") or "" + # Stop if the token is empty or fails to advance (would loop forever). + if not next_token or next_token == token: break + token = next_token def _extract_tasks_from_split( self, @@ -120,7 +127,7 @@ def _extract_tasks_from_split( tasks_set = set() token = "" - while True: + for _ in range(_MAX_PAGINATION_PAGES): mk = 1000 if max_items is not None: mk = min(1000, max(max_items - len(tasks_set), 100)) @@ -140,7 +147,7 @@ def _extract_tasks_from_split( key = obj.key if key.endswith("/"): continue - relative = key[len(split_prefix):] + relative = key[len(split_prefix) :] if "/" in relative or relative.startswith("."): continue name = relative.rsplit(".", 1)[0] if "." in relative else relative @@ -165,8 +172,26 @@ def _extract_tasks_from_split( cache.is_exhausted = False return sorted_tasks[:max_items] + # Guard against a non-advancing continuation token: if OSS keeps + # returning the same token we would otherwise spin forever. + if next_token == token: + break token = next_token + # Page budget exhausted (or token stopped advancing): return what we + # have rather than looping forever. + logger.warning( + "Pagination stopped after %d pages for prefix %r; returning partial results", + _MAX_PAGINATION_PAGES, + query_prefix, + ) + sorted_tasks = sorted(tasks_set) + cache.split_prefix = query_prefix + cache.tasks = sorted_tasks + cache.continuation_token = "" + cache.is_exhausted = True + return sorted_tasks[:max_items] if max_items else sorted_tasks + def list_organizations(self) -> list[str]: bucket = self._build_bucket() base = self._registry.oss_dataset_path or "datasets" @@ -265,9 +290,7 @@ def list_dataset_tasks( bucket = self._build_bucket() split_prefix = f"{self._build_prefix(organization, dataset, split)}/" max_items = offset + limit if limit is not None else None - task_ids = self._extract_tasks_from_split( - bucket, split_prefix, max_items=max_items, task_filter=task_filter - ) + task_ids = self._extract_tasks_from_split(bucket, split_prefix, max_items=max_items, task_filter=task_filter) if not task_ids: return None diff --git a/tests/unit/datasets/test_oss_registry.py b/tests/unit/datasets/test_oss_registry.py index fc3dd9be70..9f2e156e0f 100644 --- a/tests/unit/datasets/test_oss_registry.py +++ b/tests/unit/datasets/test_oss_registry.py @@ -13,10 +13,14 @@ def make_registry_info(): ) -def make_list_result(prefixes=None, objects=None): +def make_list_result(prefixes=None, objects=None, *, is_truncated=False, next_continuation_token=""): result = MagicMock() result.prefix_list = prefixes or [] result.object_list = objects or [] + # Explicitly set pagination fields: a bare MagicMock would return a truthy + # mock for these attributes, making the pagination loop never terminate. + result.is_truncated = is_truncated + result.next_continuation_token = next_continuation_token return result @@ -453,3 +457,28 @@ def test_list_all_datasets_empty_when_no_orgs(): registry = OssDatasetRegistry(make_registry_info()) with patch.object(registry, "list_organizations", return_value=[]): assert registry.list_all_datasets() == [] + + +# --------------------------------------------------------------------------- +# pagination safety: must never loop forever even if OSS keeps reporting +# truncation with a non-advancing continuation token (regression for the +# self-hosted-runner / local hang caused by an unbounded `while True`). +# --------------------------------------------------------------------------- + + +def test_extract_tasks_terminates_when_token_never_advances(): + registry = OssDatasetRegistry(make_registry_info()) + mock_bucket = MagicMock() + + page = make_list_result(prefixes=["datasets/qwen/my-bench/test/task-001/"]) + page.is_truncated = True + page.next_continuation_token = "stuck-token" # never changes -> would loop forever + mock_bucket.list_objects_v2.return_value = page + + with patch.object(registry, "_build_bucket", return_value=mock_bucket): + spec = registry.list_dataset_tasks("qwen", "my-bench", "test") + + # Must terminate quickly: a non-advancing token stops the loop on the + # second page instead of spinning forever. + assert spec is not None + assert mock_bucket.list_objects_v2.call_count <= 2