Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 235 additions & 12 deletions rock/cli/command/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import argparse
import contextlib
import json
import sys
from dataclasses import asdict
from pathlib import Path

from rock.cli.command.command import Command
Expand All @@ -27,6 +30,46 @@ def _positive_int(value: str) -> int:
return ivalue


def _is_json_output(args: argparse.Namespace) -> bool:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数都封装到类里面

return getattr(args, "output", None) == "json"


def _print_json(data: dict | list) -> None:
print(json.dumps(data, ensure_ascii=False, indent=2))


def _dataset_to_dict(dataset) -> dict:
data = asdict(dataset)
data["task_count"] = len(dataset.task_ids)
return data


def _normalize_task_path(path: str, *, allow_empty: bool = False, directory: bool = False) -> str:
raw = (path or "").replace("\\", "/")
if raw.startswith("/"):
raise ValueError("relative task path must not be absolute")

parts = [p for p in raw.split("/") if p and p != "."]
if any(p == ".." for p in parts):
raise ValueError("relative task path must not contain '..'")
if not parts:
if allow_empty:
return ""
raise ValueError("relative task path is required")

normalized = "/".join(parts)
if directory or raw.endswith("/"):
normalized += "/"
return normalized


def _write_stdout_bytes(content: bytes) -> None:
try:
sys.stdout.write(content.decode("utf-8"))
except UnicodeDecodeError:
sys.stdout.buffer.write(content)


class DatasetsCommand(Command):
name = "datasets"

Expand All @@ -39,6 +82,8 @@ async def arun(self, args: argparse.Namespace) -> None:
await self._splits(args)
elif args.datasets_command == "upload":
await self._upload(args)
elif args.datasets_command in ("fs", "files"):
await self._fs(args)
else:
raise ValueError(f"Unknown datasets command: {args.datasets_command}")

Expand All @@ -62,6 +107,17 @@ async def _list(self, args: argparse.Namespace) -> None:
registry_info = self._build_oss_registry_info(args)
client = DatasetClient(registry_info)

if _is_json_output(args):
if getattr(args, "depth", None) == 1:
_print_json({"organizations": client.list_organizations()})
return
datasets = sorted(
client.list_datasets(getattr(args, "org", None)),
key=lambda d: (d.id, d.split),
)
_print_json({"datasets": [_dataset_to_dict(d) for d in datasets]})
return

if getattr(args, "org", None):
datasets = client.list_org_datasets(args.org)
pairs = [(args.org, d) for d in datasets]
Expand Down Expand Up @@ -107,24 +163,45 @@ def _render_orgs(orgs: list[str]) -> None:
async def _tasks(self, args: argparse.Namespace) -> None:
registry_info = self._build_oss_registry_info(args)
client = DatasetClient(registry_info)
spec = client.list_dataset_tasks(args.org, args.dataset, args.split)
spec = client.list_dataset_tasks(
args.org, args.dataset, args.split,
offset=args.offset, limit=args.limit, task_filter=getattr(args, "filter", None),
)

if spec is None or not spec.task_ids:
if _is_json_output(args):
_print_json({
"dataset": f"{args.org}/{args.dataset}",
"split": args.split,
"total": 0,
"offset": args.offset,
"limit": args.limit,
"task_ids": [],
})
return
print(f"No tasks found for dataset '{args.org}/{args.dataset}' split '{args.split}'.")
return

total = len(spec.task_ids)
start = args.offset
end = start + args.limit if args.limit is not None else None
shown_task_ids = spec.task_ids[start:end]
shown_task_ids = spec.task_ids

if _is_json_output(args):
_print_json({
"dataset": spec.id,
"split": spec.split,
"total": len(shown_task_ids),
"offset": args.offset,
"limit": args.limit,
"task_ids": shown_task_ids,
})
return

if not shown_task_ids:
print("No tasks found after applying offset/limit.")
return

print()
print("=" * 80)
print(f"Dataset: {spec.id} Split: {spec.split} Total: {total} Shown: {len(shown_task_ids)}")
print(f"Dataset: {spec.id} Split: {spec.split} Shown: {len(shown_task_ids)}")
print("=" * 80)
print("#Task name")
print("-" * 10)
Expand All @@ -150,8 +227,8 @@ async def _splits(self, args: argparse.Namespace) -> None:

async def _upload(self, args: argparse.Namespace) -> None:
local_dir = Path(args.dir)
if not local_dir.is_dir():
raise ValueError(f"--dir '{local_dir}' does not exist or is not a directory")
if not local_dir.exists():
raise ValueError(f"--dir '{local_dir}' does not exist")

registry_info = self._build_oss_registry_info(args)
source = LocalDatasetConfig(path=local_dir)
Expand All @@ -163,20 +240,131 @@ async def _upload(self, args: argparse.Namespace) -> None:
)

base = registry_info.oss_dataset_path or "datasets"
print(f"Uploading to oss://{registry_info.oss_bucket}/{base}/{args.org}/{args.dataset}/{args.split}/")
if not _is_json_output(args):
print(f"Uploading to oss://{registry_info.oss_bucket}/{base}/{args.org}/{args.dataset}/{args.split}/")

client = DatasetClient(registry_info)
result = client.upload_dataset(source, target, concurrency=args.concurrency)
if _is_json_output(args):
with contextlib.redirect_stdout(sys.stderr):
result = client.upload_dataset(source, target, concurrency=args.concurrency)
_print_json(asdict(result))
else:
result = client.upload_dataset(source, target, concurrency=args.concurrency)
print(f"\nDone: {result.uploaded} uploaded, {result.skipped} skipped, {result.failed} failed")

print(f"\nDone: {result.uploaded} uploaded, {result.skipped} skipped, {result.failed} failed")
if result.failed > 0:
sys.exit(1)

async def _fs(self, args: argparse.Namespace) -> None:
if args.fs_command == "ls":
await self._fs_ls(args)
elif args.fs_command == "get":
await self._fs_get(args)
elif args.fs_command == "download":
await self._fs_download(args)
else:
raise ValueError(f"Unknown datasets fs command: {args.fs_command}")

async def _fs_ls(self, args: argparse.Namespace) -> None:
path = _normalize_task_path(args.path, allow_empty=True, directory=bool(args.path)) if args.path else ""
registry_info = self._build_oss_registry_info(args)
client = DatasetClient(registry_info)
files = client.list_task_files(args.org, args.dataset, args.split, args.task, path.rstrip("/"))

if _is_json_output(args):
_print_json({
"dataset": f"{args.org}/{args.dataset}",
"split": args.split,
"task": args.task,
"path": path,
"files": [asdict(f) for f in files],
})
return

for file in files:
print(file.path)

async def _fs_get(self, args: argparse.Namespace) -> None:
path = _normalize_task_path(args.path) if args.path else None
registry_info = self._build_oss_registry_info(args)
client = DatasetClient(registry_info)
if path is None:
files = client.list_task_files(args.org, args.dataset, args.split, args.task, "")
if len(files) != 1:
raise ValueError("--path is required when task contains zero or multiple files")
path = files[0].path
content = client.get_task_file(args.org, args.dataset, args.split, args.task, path)
if content is None:
raise FileNotFoundError(f"Task file not found: {args.org}/{args.dataset}/{args.split}/{args.task}/{path}")

if _is_json_output(args):
_print_json({
"dataset": f"{args.org}/{args.dataset}",
"split": args.split,
"task": args.task,
"path": path,
"content": content.decode("utf-8"),
})
return

_write_stdout_bytes(content)

async def _fs_download(self, args: argparse.Namespace) -> None:
path = _normalize_task_path(args.path, directory=args.path.endswith("/"))
dest = Path(args.dest)
registry_info = self._build_oss_registry_info(args)
client = DatasetClient(registry_info)

content = client.get_task_file(args.org, args.dataset, args.split, args.task, path)
if content is not None:
target = dest / Path(path).name if dest.is_dir() else dest
target.parent.mkdir(parents=True, exist_ok=True)
target.write_bytes(content)
if _is_json_output(args):
_print_json({"downloaded": [str(target)]})
else:
print(str(target))
return

prefix = path if path.endswith("/") else f"{path}/"
files = client.list_task_files(args.org, args.dataset, args.split, args.task, prefix.rstrip("/"))
if not files:
raise FileNotFoundError(f"Task path not found: {args.org}/{args.dataset}/{args.split}/{args.task}/{path}")

downloaded: list[str] = []
for file in files:
file_content = client.get_task_file(args.org, args.dataset, args.split, args.task, file.path)
if file_content is None:
continue
relative = file.path[len(prefix) :] if file.path.startswith(prefix) else file.path
target = dest / relative
target.parent.mkdir(parents=True, exist_ok=True)
target.write_bytes(file_content)
downloaded.append(str(target))

if _is_json_output(args):
_print_json({"downloaded": downloaded})
else:
for path in downloaded:
print(path)

@staticmethod
async def add_parser_to(subparsers: argparse._SubParsersAction) -> None:
datasets_parser = subparsers.add_parser("datasets", description="Dataset operations on OSS")
datasets_parser.set_defaults(output=None)
datasets_subparsers = datasets_parser.add_subparsers(dest="datasets_command")

def add_output_arg(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"-o",
"--output",
"--ouput",
dest="output",
choices=["json"],
default=argparse.SUPPRESS,
help="Output format. Supported: json",
)

def add_oss_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--bucket", help="OSS bucket name (overrides config.ini)")
parser.add_argument("--endpoint", help="OSS endpoint URL (overrides config.ini)")
Expand All @@ -188,6 +376,8 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None:
)
parser.add_argument("--region", help="OSS region (overrides config.ini)")

add_output_arg(datasets_parser)

list_parser = datasets_subparsers.add_parser("list", help="List datasets in OSS registry")
list_group = list_parser.add_mutually_exclusive_group()
list_group.add_argument(
Expand All @@ -200,9 +390,11 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None:
list_group.add_argument("--org", help="List datasets under the given organization only")
add_oss_args(list_parser)
tasks_parser = datasets_subparsers.add_parser("tasks", help="List task IDs under one dataset split")
add_output_arg(tasks_parser)
tasks_parser.add_argument("--org", required=True, help="Organization name")
tasks_parser.add_argument("--dataset", required=True, help="Dataset name")
tasks_parser.add_argument("--split", default="test", help="Split name (default: test)")
tasks_parser.add_argument("--filter", default=None, help="Filter tasks by prefix (e.g. --filter 0xerr0r)")
tasks_parser.add_argument("--offset", type=_non_negative_int, default=0, help="Skip first N tasks")
tasks_parser.add_argument("--limit", type=_positive_int, default=None, help="Maximum number of tasks to show")
add_oss_args(tasks_parser)
Expand All @@ -213,10 +405,15 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None:
add_oss_args(splits_parser)

upload_parser = datasets_subparsers.add_parser("upload", help="Upload local task dirs to OSS")
add_output_arg(upload_parser)
upload_parser.add_argument("--org", required=True, help="Organization name")
upload_parser.add_argument("--dataset", required=True, help="Dataset name")
upload_parser.add_argument("--split", required=True, help="Split name (e.g. train, test, v1.0)")
upload_parser.add_argument("--dir", required=True, help="Local directory containing {task_id}/ subdirectories")
upload_parser.add_argument(
"--dir",
required=True,
help="Local dataset directory containing {task_id}/ subdirectories or direct task files, or one task file",
)
upload_parser.add_argument(
"--concurrency",
type=int,
Expand All @@ -229,3 +426,29 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None:
"--overwrite", action="store_true", help="Overwrite existing tasks in OSS (default: skip)"
)
add_oss_args(upload_parser)

fs_parser = datasets_subparsers.add_parser("fs", aliases=["files"], help="Inspect files under one task")
fs_subparsers = fs_parser.add_subparsers(dest="fs_command")

def add_task_fs_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--org", required=True, help="Organization name")
parser.add_argument("--dataset", required=True, help="Dataset name")
parser.add_argument("--split", default="test", help="Split name (default: test)")
parser.add_argument("--task", required=True, help="Task ID")
add_oss_args(parser)

ls_parser = fs_subparsers.add_parser("ls", help="List files under one task")
add_output_arg(ls_parser)
add_task_fs_args(ls_parser)
ls_parser.add_argument("--path", default="", help="Relative task directory path to list")

get_parser = fs_subparsers.add_parser("get", help="Print one task file to stdout")
add_output_arg(get_parser)
add_task_fs_args(get_parser)
get_parser.add_argument("--path", help="Relative task file path")

download_parser = fs_subparsers.add_parser("download", help="Download one task file or directory")
add_output_arg(download_parser)
add_task_fs_args(download_parser)
download_parser.add_argument("--path", required=True, help="Relative task file or directory path")
download_parser.add_argument("--dest", required=True, help="Local destination file or directory")
25 changes: 22 additions & 3 deletions rock/sdk/envhub/datasets/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from rock.sdk.bench.models.job.config import LocalDatasetConfig, OssRegistryInfo, RegistryDatasetConfig
from rock.sdk.envhub.datasets.models import DatasetSpec, UploadResult
from rock.sdk.envhub.datasets.models import DatasetSpec, TaskFile, UploadResult
from rock.sdk.envhub.datasets.registry.oss import OssDatasetRegistry


Expand All @@ -10,8 +10,19 @@ def __init__(self, registry: OssRegistryInfo) -> None:
def list_datasets(self, org: str | None = None) -> list[DatasetSpec]:
return self._registry.list_datasets(org)

def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test") -> DatasetSpec | None:
return self._registry.list_dataset_tasks(organization, dataset, split)
def list_dataset_tasks(
self,
organization: str,
dataset: str,
split: str = "test",
*,
offset: int = 0,
limit: int | None = None,
task_filter: str | None = None,
) -> DatasetSpec | None:
return self._registry.list_dataset_tasks(
organization, dataset, split, offset=offset, limit=limit, task_filter=task_filter
)

def list_organizations(self) -> list[str]:
return self._registry.list_organizations()
Expand All @@ -25,6 +36,14 @@ def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]:
def list_dataset_splits(self, organization: str, dataset: str) -> list[str]:
return self._registry.list_dataset_splits(organization, dataset)

def list_task_files(
self, organization: str, dataset: str, split: str, task_id: str, path: str = ""
) -> list[TaskFile]:
return self._registry.list_task_files(organization, dataset, split, task_id, path)

def get_task_file(self, organization: str, dataset: str, split: str, task_id: str, path: str) -> bytes | None:
return self._registry.get_task_file(organization, dataset, split, task_id, path)

def upload_dataset(
self,
source: LocalDatasetConfig,
Expand Down
Loading
Loading