Skip to content
42 changes: 37 additions & 5 deletions docs/dev/envhub/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,52 @@ class DatasetClient:
rock datasets list [OPTIONS]

Options:
--org TEXT 只列出指定 organization 的 datasets
--depth {1,2} 1: 只列出 organizations;2: 列出 organizations 和 datasets(默认)
--org TEXT 只列出指定 organization 的 datasets(与 --depth 互斥)
--bucket TEXT OSS bucket 名称(覆盖 config.ini)
--endpoint TEXT OSS endpoint(覆盖 config.ini)
--access-key-id TEXT OSS access key ID(覆盖 config.ini)
--access-key-secret TEXT OSS access key secret(覆盖 config.ini)
--region TEXT OSS region(覆盖 config.ini)
```

输出示例:

```
Dataset Split Tasks
qwen/my-bench train 42
qwen/my-bench test 10
alibaba/code-eval train 100
Organization Dataset
--------------------------
alibaba code-eval
qwen my-bench

2 datasets in 2 organizations.
```

#### rock datasets splits

```
rock datasets splits [OPTIONS]

Required:
--org TEXT Organization 名称
--dataset TEXT Dataset 名称

Options:
--bucket TEXT OSS bucket 名称(覆盖 config.ini)
--endpoint TEXT OSS endpoint(覆盖 config.ini)
--access-key-id TEXT OSS access key ID(覆盖 config.ini)
--access-key-secret TEXT OSS access key secret(覆盖 config.ini)
--region TEXT OSS region(覆盖 config.ini)
```

输出示例:

```
Split
-----
test
train

2 splits.
```

#### rock datasets upload
Expand Down
109 changes: 87 additions & 22 deletions rock/cli/command/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ async def arun(self, args: argparse.Namespace) -> None:
await self._list(args)
elif args.datasets_command == "tasks":
await self._tasks(args)
elif args.datasets_command == "splits":
await self._splits(args)
elif args.datasets_command == "upload":
await self._upload(args)
else:
Expand All @@ -59,20 +61,48 @@ def _build_oss_registry_info(self, args: argparse.Namespace) -> OssRegistryInfo:
async def _list(self, args: argparse.Namespace) -> None:
registry_info = self._build_oss_registry_info(args)
client = DatasetClient(registry_info)
datasets = client.list_datasets(org=getattr(args, "org", None))

if not datasets:
print("No datasets found.")
if getattr(args, "org", None):
datasets = client.list_org_datasets(args.org)
pairs = [(args.org, d) for d in datasets]
self._render_org_dataset_pairs(pairs)
return

depth = getattr(args, "depth", None) or 2
if depth == 1:
orgs = client.list_organizations()
self._render_orgs(orgs)
return

col_id = max(len("Dataset"), max(len(d.id) for d in datasets))
col_split = max(len("Split"), max(len(d.split) for d in datasets))
pairs = client.list_all_datasets()
self._render_org_dataset_pairs(pairs)

header = f"{'Dataset':<{col_id}} {'Split':<{col_split}} {'Tasks':>6}"
@staticmethod
def _render_org_dataset_pairs(pairs: list[tuple[str, str]]) -> None:
if not pairs:
print("No datasets found.")
return
col_org = max(len("Organization"), max(len(o) for o, _ in pairs))
col_ds = max(len("Dataset"), max(len(d) for _, d in pairs))
header = f"{'Organization':<{col_org}} {'Dataset':<{col_ds}}"
print(header)
print("-" * len(header))
for ds in sorted(datasets, key=lambda d: (d.id, d.split)):
print(f"{ds.id:<{col_id}} {ds.split:<{col_split}} {len(ds.task_ids):>6}")
for o, d in pairs:
print(f"{o:<{col_org}} {d:<{col_ds}}")
n_orgs = len({o for o, _ in pairs})
print(f"\n{len(pairs)} datasets in {n_orgs} organizations.")

@staticmethod
def _render_orgs(orgs: list[str]) -> None:
if not orgs:
print("No organizations found.")
return
width = max(len("Organization"), max(len(o) for o in orgs))
print(f"{'Organization':<{width}}")
print("-" * width)
for o in orgs:
print(o)
print(f"\n{len(orgs)} organizations.")

async def _tasks(self, args: argparse.Namespace) -> None:
registry_info = self._build_oss_registry_info(args)
Expand All @@ -92,8 +122,6 @@ async def _tasks(self, args: argparse.Namespace) -> None:
print("No tasks found after applying offset/limit.")
return

limit_text = str(args.limit) if args.limit is not None else "all"

print()
print("=" * 80)
print(f"Dataset: {spec.id} Split: {spec.split} Total: {total} Shown: {len(shown_task_ids)}")
Expand All @@ -103,6 +131,23 @@ async def _tasks(self, args: argparse.Namespace) -> None:
for task_id in shown_task_ids:
print(task_id)

async def _splits(self, args: argparse.Namespace) -> None:
registry_info = self._build_oss_registry_info(args)
client = DatasetClient(registry_info)
splits = client.list_dataset_splits(args.org, args.dataset)

if not splits:
print(f"No splits found for dataset '{args.org}/{args.dataset}'.")
return

width = max(len("Split"), max(len(s) for s in splits))
print(f"{'Split':<{width}}")
print("-" * width)
for s in splits:
print(s)
word = "split" if len(splits) == 1 else "splits"
print(f"\n{len(splits)} {word}.")

async def _upload(self, args: argparse.Namespace) -> None:
local_dir = Path(args.dir)
if not local_dir.is_dir():
Expand Down Expand Up @@ -135,14 +180,24 @@ async def add_parser_to(subparsers: argparse._SubParsersAction) -> None:
def add_oss_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--bucket", help="OSS bucket name (overrides config.ini)")
parser.add_argument("--endpoint", help="OSS endpoint URL (overrides config.ini)")
parser.add_argument("--access-key-id", dest="access_key_id",
help="OSS access key ID (overrides config.ini)")
parser.add_argument("--access-key-secret", dest="access_key_secret",
help="OSS access key secret (overrides config.ini)")
parser.add_argument(
"--access-key-id", dest="access_key_id", help="OSS access key ID (overrides config.ini)"
)
parser.add_argument(
"--access-key-secret", dest="access_key_secret", help="OSS access key secret (overrides config.ini)"
)
parser.add_argument("--region", help="OSS region (overrides config.ini)")

list_parser = datasets_subparsers.add_parser("list", help="List datasets in OSS registry")
list_parser.add_argument("--org", help="Filter by organization")
list_group = list_parser.add_mutually_exclusive_group()
list_group.add_argument(
"--depth",
type=int,
choices=[1, 2],
default=None,
help="1: list orgs only. 2 (default): list orgs and datasets.",
)
list_group.add_argument("--org", help="List datasets under the given organization only")
add_oss_args(list_parser)
tasks_parser = datasets_subparsers.add_parser("tasks", help="List task IDs under one dataset split")
tasks_parser.add_argument("--org", required=True, help="Organization name")
Expand All @@ -152,15 +207,25 @@ def add_oss_args(parser: argparse.ArgumentParser) -> None:
tasks_parser.add_argument("--limit", type=_positive_int, default=None, help="Maximum number of tasks to show")
add_oss_args(tasks_parser)

splits_parser = datasets_subparsers.add_parser("splits", help="List splits under one dataset")
splits_parser.add_argument("--org", required=True, help="Organization name")
splits_parser.add_argument("--dataset", required=True, help="Dataset name")
add_oss_args(splits_parser)

upload_parser = datasets_subparsers.add_parser("upload", help="Upload local task dirs to OSS")
upload_parser.add_argument("--org", required=True, help="Organization name")
upload_parser.add_argument("--dataset", required=True, help="Dataset name")
upload_parser.add_argument("--split", required=True, help="Split name (e.g. train, test, v1.0)")
upload_parser.add_argument("--dir", required=True,
help="Local directory containing {task_id}/ subdirectories")
upload_parser.add_argument("--concurrency", type=int, default=4,
choices=range(1, 17), metavar="[1-16]",
help="Upload concurrency (default: 4)")
upload_parser.add_argument("--overwrite", action="store_true",
help="Overwrite existing tasks in OSS (default: skip)")
upload_parser.add_argument("--dir", required=True, help="Local directory containing {task_id}/ subdirectories")
upload_parser.add_argument(
"--concurrency",
type=int,
default=4,
choices=range(1, 17),
metavar="[1-16]",
help="Upload concurrency (default: 4)",
)
upload_parser.add_argument(
"--overwrite", action="store_true", help="Overwrite existing tasks in OSS (default: skip)"
)
add_oss_args(upload_parser)
13 changes: 12 additions & 1 deletion rock/sdk/envhub/datasets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class DatasetClient:

def __init__(self, registry: OssRegistryInfo) -> None:
self._registry = OssDatasetRegistry(registry)

Expand All @@ -14,6 +13,18 @@ def list_datasets(self, org: str | None = None) -> list[DatasetSpec]:
def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test") -> DatasetSpec | None:
return self._registry.list_dataset_tasks(organization, dataset, split)

def list_organizations(self) -> list[str]:
return self._registry.list_organizations()

def list_org_datasets(self, organization: str) -> list[str]:
return self._registry.list_org_datasets(organization)

def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]:
return self._registry.list_all_datasets(concurrency)

def list_dataset_splits(self, organization: str, dataset: str) -> list[str]:
return self._registry.list_dataset_splits(organization, dataset)

def upload_dataset(
self,
source: LocalDatasetConfig,
Expand Down
21 changes: 20 additions & 1 deletion rock/sdk/envhub/datasets/registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


class BaseDatasetRegistry(ABC):

@abstractmethod
def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]:
"""List all datasets. Filtered to `organization` if provided."""
Expand All @@ -16,6 +15,26 @@ def list_dataset_tasks(self, organization: str, dataset: str, split: str = "test
"""List task ids for one dataset split. Returns None if dataset/split has no tasks."""
...

@abstractmethod
def list_organizations(self) -> list[str]:
"""List organization names under the dataset registry. Single backend call."""
...

@abstractmethod
def list_org_datasets(self, organization: str) -> list[str]:
"""List dataset names under one organization. Single backend call."""
...

@abstractmethod
def list_dataset_splits(self, organization: str, dataset: str) -> list[str]:
"""List split names under one dataset. Single backend call."""
...

@abstractmethod
def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]:
"""List all (org, dataset) pairs. 1 + N_org backend calls with bounded concurrency."""
...

@abstractmethod
def upload_dataset(
self,
Expand Down
53 changes: 41 additions & 12 deletions rock/sdk/envhub/datasets/registry/oss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import oss2
Expand All @@ -14,7 +14,6 @@


class OssDatasetRegistry(BaseDatasetRegistry):

def __init__(self, registry: OssRegistryInfo) -> None:
self._registry = registry

Expand Down Expand Up @@ -58,7 +57,7 @@ def _extract_tasks_from_split(self, bucket: oss2.Bucket, split_prefix: str) -> l
if key.endswith("/"):
continue
# Get the relative path from split_prefix
relative = key[len(split_prefix):]
relative = key[len(split_prefix) :]
# Only direct files (no nested paths with "/")
if "/" in relative:
continue
Expand All @@ -70,6 +69,37 @@ def _extract_tasks_from_split(self, bucket: oss2.Bucket, split_prefix: str) -> l
all_tasks = sorted(set(dir_tasks + file_tasks))
return all_tasks

def list_organizations(self) -> list[str]:
bucket = self._build_bucket()
base = self._registry.oss_dataset_path or "datasets"
result = bucket.list_objects_v2(prefix=f"{base}/", delimiter="/", max_keys=1000)
return sorted(self._last_segment(p) for p in result.prefix_list)

def list_org_datasets(self, organization: str) -> list[str]:
bucket = self._build_bucket()
base = self._registry.oss_dataset_path or "datasets"
result = bucket.list_objects_v2(prefix=f"{base}/{organization}/", delimiter="/", max_keys=1000)
return sorted(self._last_segment(p) for p in result.prefix_list)

def list_dataset_splits(self, organization: str, dataset: str) -> list[str]:
bucket = self._build_bucket()
base = self._registry.oss_dataset_path or "datasets"
result = bucket.list_objects_v2(prefix=f"{base}/{organization}/{dataset}/", delimiter="/", max_keys=1000)
return sorted(self._last_segment(p) for p in result.prefix_list)

def list_all_datasets(self, concurrency: int = 10) -> list[tuple[str, str]]:
orgs = self.list_organizations()
if not orgs:
return []
pairs: list[tuple[str, str]] = []
with ThreadPoolExecutor(max_workers=concurrency) as ex:
future_to_org = {ex.submit(self.list_org_datasets, o): o for o in orgs}
for fut in as_completed(future_to_org):
org = future_to_org[fut]
for ds in fut.result():
pairs.append((org, ds))
return sorted(pairs)

def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]:
bucket = self._build_bucket()
base = self._registry.oss_dataset_path or "datasets"
Expand All @@ -93,11 +123,13 @@ def list_datasets(self, organization: str | None = None) -> list[DatasetSpec]:
split = self._last_segment(split_prefix)

task_ids = self._extract_tasks_from_split(bucket, split_prefix)
datasets.append(DatasetSpec(
id=f"{org}/{name}",
split=split,
task_ids=task_ids,
))
datasets.append(
DatasetSpec(
id=f"{org}/{name}",
split=split,
task_ids=task_ids,
)
)

return datasets

Expand Down Expand Up @@ -157,10 +189,7 @@ def upload_dataset(

raw: dict[str, int | None | Exception] = {}
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = {
executor.submit(self._upload_task, bucket, org, name, split, d, overwrite): d
for d in task_dirs
}
futures = {executor.submit(self._upload_task, bucket, org, name, split, d, overwrite): d for d in task_dirs}
for future, task_dir in futures.items():
try:
raw[task_dir.name] = future.result()
Expand Down
Loading
Loading